001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.incremental; 031 032import gnu.trove.list.array.TIntArrayList; 033import gnu.trove.set.TIntSet; 034import gnu.trove.set.hash.TIntHashSet; 035 036import java.util.ArrayList; 037import java.util.HashMap; 038import java.util.List; 039import java.util.Map; 040import java.util.Map.Entry; 041 042import org.apache.log4j.Logger; 043import org.openimaj.experiment.evaluation.cluster.analyser.FScoreClusterAnalyser; 044import org.openimaj.math.matrix.MatlibMatrixUtils; 045import org.openimaj.ml.clustering.IndexClusters; 046import org.openimaj.ml.clustering.SparseMatrixClusterer; 047import org.openimaj.util.pair.IntDoublePair; 048 049import ch.akuhn.matrix.SparseMatrix; 050 051/** 052 * 053 * An incremental clusterer which holds old {@link SparseMatrix} instances internally, 054 * only forgetting rows once they have been clustered and are relatively stable. 055 * 056 * The criteria for row removal is cluster stability. 057 * The defenition of cluster stability is maximum f1-score achieving a threshold between 058 * clusters in the previous round and the current round. Once one round of stability is achieved 059 * the cluster is stable and its elements are removed. 060 * 061 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 062 */ 063public class IncrementalSparseClusterer implements SparseMatrixClusterer<IndexClusters>{ 064 065 private SparseMatrixClusterer<? extends IndexClusters> clusterer; 066 private int window; 067 protected double threshold; 068 private int maxwindow = -1; 069 private final static Logger logger = Logger.getLogger(IncrementalSparseClusterer.class); 070 071 072 /** 073 * @param clusterer the underlying clusterer 074 * @param window 075 */ 076 public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window) { 077 this.clusterer = clusterer; 078 this.window = window; 079 this.threshold = 1.; 080 } 081 082 /** 083 * @param clusterer the underlying clusterer 084 * @param window 085 * @param threshold 086 */ 087 public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, double threshold) { 088 this.clusterer = clusterer; 089 this.window = window; 090 this.threshold = threshold; 091 } 092 093 /** 094 * @param clusterer the underlying clusterer 095 * @param window 096 * @param maxwindow 097 */ 098 @SuppressWarnings("unused") 099 private IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, int maxwindow) { 100 this.clusterer = clusterer; 101 this.window = window; 102 if(maxwindow>0 ){ 103 if(maxwindow < window * 2) 104 maxwindow = window * 2; 105 } 106 if(maxwindow <= 0){ 107 maxwindow = -1; 108 } 109 this.maxwindow = maxwindow; 110 this.threshold = 1.; 111 } 112 113 class WindowedSparseMatrix{ 114 SparseMatrix window; 115 Map<Integer,Integer> indexCorrection; 116 117 public WindowedSparseMatrix(SparseMatrix sm, int nextwindow, TIntSet inactive) { 118 TIntArrayList active = new TIntArrayList(nextwindow); 119 indexCorrection = new HashMap<Integer, Integer>(); 120 for (int i = 0; i < nextwindow; i++) { 121 if(!inactive.contains(i)){ 122 indexCorrection.put(active.size(), i); 123 active.add(i); 124 } 125 } 126 window = MatlibMatrixUtils.subMatrix(sm, active, active); 127 } 128 129 public void correctClusters(IndexClusters clstrs){ 130 int[][] clusters = clstrs.clusters(); 131 for (int i = 0; i < clusters.length; i++) { 132 int[] cluster = clusters[i]; 133 for (int j = 0; j < cluster.length; j++) { 134 cluster[j] = indexCorrection.get(cluster[j]); 135 } 136 } 137 } 138 } 139 140 @Override 141 public IndexClusters cluster(SparseMatrix data) { 142 if(window >= data.rowCount()) window = data.rowCount(); 143 SparseMatrix seen = MatlibMatrixUtils.subMatrix(data, 0, window, 0, window); 144 int seenrows = window; 145 TIntSet inactiveRows = new TIntHashSet(window); 146 logger.debug("First clustering!: " + seen.rowCount() + "x" + seen.columnCount()); 147 IndexClusters oldClusters = clusterer.cluster(seen); 148 logger.debug("First clusters:\n" + oldClusters); 149 List<int[]> completedClusters = new ArrayList<int[]>(); 150 while(seenrows < data.rowCount()){ 151 int nextwindow = seenrows + window; 152 if(nextwindow >= data.rowCount()) nextwindow = data.rowCount(); 153 if(this.maxwindow > 0 && nextwindow - inactiveRows.size() > this.maxwindow){ 154 logger.debug(String.format("Window size (%d) without inactive (%d) = (%d), greater than maximum (%d)",nextwindow, inactiveRows.size(), nextwindow - inactiveRows.size(), this.maxwindow)); 155 deactiveOldItemsAsNoise(nextwindow,inactiveRows,completedClusters); 156 } 157 WindowedSparseMatrix wsp = new WindowedSparseMatrix(data, nextwindow, inactiveRows); 158 logger.debug("Clustering: " + wsp.window.rowCount() + "x" + wsp.window.columnCount()); 159 IndexClusters newClusters = clusterer.cluster(wsp.window); 160 wsp.correctClusters(newClusters); 161 logger.debug("New clusters:\n" + newClusters); 162 // if stability == 1 for any cluster, it was the same last window, we should not include those items next round 163 detectInactive(oldClusters, newClusters, inactiveRows, completedClusters); 164 165 oldClusters = newClusters; 166 seenrows += window; 167 logger.debug("Seen rows: " + seenrows); 168 logger.debug("Inactive rows: " + inactiveRows.size()); 169 } 170 for (int i = 0; i < oldClusters.clusters().length; i++) { 171 int[] cluster = oldClusters.clusters()[i]; 172 if(cluster.length!=0) 173 completedClusters.add(cluster); 174 } 175 176 return new IndexClusters(completedClusters); 177 } 178 179 private void deactiveOldItemsAsNoise(int nextwindow, TIntSet inactiveRows, List<int[]> completedClusters) { 180 int toDeactivate = 0; 181 while(nextwindow - inactiveRows.size() > this.maxwindow){ 182 if(!inactiveRows.contains(toDeactivate)){ 183 logger.debug("Forcing the deactivation of: " + toDeactivate); 184 inactiveRows.add(toDeactivate); 185 completedClusters.add(new int[]{toDeactivate}); 186 } 187 toDeactivate++; 188 } 189 } 190 191 /** 192 * Given the old and new clusters, make a decision as to which rows are now inactive, 193 * and therefore which clusters are now completed 194 * @param oldClusters 195 * @param newClusters 196 * @param inactiveRows 197 * @param completedClusters 198 */ 199 protected void detectInactive(IndexClusters oldClusters, IndexClusters newClusters, TIntSet inactiveRows, List<int[]> completedClusters) { 200 Map<Integer, IntDoublePair> stability = calculateStability(oldClusters,newClusters,inactiveRows); 201 for (Entry<Integer, IntDoublePair> e : stability.entrySet()) { 202 if(e.getValue().second >= threshold){ 203 int[] completedCluster = oldClusters.clusters()[e.getKey()]; 204 inactiveRows.addAll(completedCluster); 205 completedClusters.add(completedCluster); 206 if(threshold == 1){ 207 newClusters.clusters()[e.getValue().first] = new int[0]; 208 } 209 } 210 } 211 } 212 213 protected Map<Integer, IntDoublePair> calculateStability(IndexClusters c1, IndexClusters c2, TIntSet inactiveRows) { 214 215 Map<Integer, IntDoublePair> stability = new HashMap<Integer, IntDoublePair>(); 216 int[][] clusters1 = c1.clusters(); 217 int[][] clusters2 = c2.clusters(); 218 for (int i = 0; i < clusters1.length; i++) { 219 if(clusters1[i].length == 0) continue; 220 double maxnmi = 0; 221 int maxj = -1; 222 TIntArrayList cluster = new TIntArrayList(clusters1[i].length); 223 for (int j = 0; j < clusters1[i].length; j++) { 224 if(inactiveRows.contains(clusters1[i][j])) 225 continue; 226 cluster.add(clusters1[i][j]); 227 } 228 int[][] correct = new int[][]{cluster.toArray()}; 229 for (int j = 0; j < clusters2.length; j++) { 230 int[][] estimated = new int[][]{clusters2[j]}; 231// NMIAnalysis nmi = new NMIClusterAnalyser().analyse(correct, estimated); 232 double score = 0; 233 if(correct[0].length == 1 && estimated[0].length == 1){ 234 // BOTH 1, either they are the same or not! 235 score = correct[0][0] == estimated[0][0] ? 1 : 0; 236 } 237 else{ 238 score = new FScoreClusterAnalyser().analyse(correct, estimated).score(); 239 } 240 if(!Double.isNaN(score)) 241 { 242 if(score > maxnmi){ 243 maxnmi = score; 244 maxj = j; 245 } 246 } 247 } 248 stability.put(i, IntDoublePair.pair(maxj, maxnmi)); 249 } 250 logger.debug(String.format("The stability is:\n%s",stability)); 251 return stability; 252 } 253 254 @Override 255 public int[][] performClustering(SparseMatrix data) { 256 return this.cluster(data).clusters(); 257 } 258 259 260}