001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.incremental;
031
032import gnu.trove.list.array.TIntArrayList;
033import gnu.trove.set.TIntSet;
034import gnu.trove.set.hash.TIntHashSet;
035
036import java.util.ArrayList;
037import java.util.HashMap;
038import java.util.List;
039import java.util.Map;
040import java.util.Map.Entry;
041
042import org.apache.log4j.Logger;
043import org.openimaj.experiment.evaluation.cluster.analyser.FScoreClusterAnalyser;
044import org.openimaj.math.matrix.MatlibMatrixUtils;
045import org.openimaj.ml.clustering.IndexClusters;
046import org.openimaj.ml.clustering.SparseMatrixClusterer;
047import org.openimaj.util.pair.IntDoublePair;
048
049import ch.akuhn.matrix.SparseMatrix;
050
051/**
052 *
053 * An incremental clusterer which holds old {@link SparseMatrix} instances internally, 
054 * only forgetting rows once they have been clustered and are relatively stable.
055 * 
056 * The criteria for row removal is cluster stability.
057 * The defenition of cluster stability is maximum f1-score achieving a threshold between
058 * clusters in the previous round and the current round. Once one round of stability is achieved
059 * the cluster is stable and its elements are removed.
060 * 
061 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
062 */
063public class IncrementalSparseClusterer implements SparseMatrixClusterer<IndexClusters>{
064        
065        private SparseMatrixClusterer<? extends IndexClusters> clusterer;
066        private int window;
067        protected double threshold;
068        private int maxwindow = -1;
069        private final static Logger logger = Logger.getLogger(IncrementalSparseClusterer.class);
070        
071
072        /**
073         * @param clusterer the underlying clusterer
074         * @param window 
075         */
076        public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window) {
077                this.clusterer = clusterer;
078                this.window = window;
079                this.threshold = 1.;
080        }
081        
082        /**
083         * @param clusterer the underlying clusterer
084         * @param window 
085         * @param threshold 
086         */
087        public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, double threshold) {
088                this.clusterer = clusterer;
089                this.window = window;
090                this.threshold = threshold;
091        }
092        
093        /**
094         * @param clusterer the underlying clusterer
095         * @param window 
096         * @param maxwindow 
097         */
098        @SuppressWarnings("unused")
099        private IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, int maxwindow) {
100                this.clusterer = clusterer;
101                this.window = window;
102                if(maxwindow>0 ){
103                        if(maxwindow < window * 2)
104                                maxwindow = window * 2;
105                }
106                if(maxwindow <= 0){
107                        maxwindow = -1;
108                }
109                this.maxwindow  = maxwindow;
110                this.threshold = 1.;
111        }
112        
113        class WindowedSparseMatrix{
114                SparseMatrix window;
115                Map<Integer,Integer> indexCorrection;
116                
117                public WindowedSparseMatrix(SparseMatrix sm, int nextwindow, TIntSet inactive) {
118                        TIntArrayList active = new TIntArrayList(nextwindow);
119                        indexCorrection = new HashMap<Integer, Integer>();
120                        for (int i = 0; i < nextwindow; i++) {
121                                if(!inactive.contains(i)){
122                                        indexCorrection.put(active.size(), i);
123                                        active.add(i);
124                                }
125                        }
126                        window = MatlibMatrixUtils.subMatrix(sm, active, active);
127                }
128                
129                public void correctClusters(IndexClusters clstrs){
130                        int[][] clusters = clstrs.clusters();
131                        for (int i = 0; i < clusters.length; i++) {
132                                int[] cluster = clusters[i];
133                                for (int j = 0; j < cluster.length; j++) {
134                                        cluster[j] = indexCorrection.get(cluster[j]);
135                                }
136                        }
137                }
138        }
139
140        @Override
141        public IndexClusters cluster(SparseMatrix data) {
142                if(window >= data.rowCount()) window = data.rowCount();
143                SparseMatrix seen = MatlibMatrixUtils.subMatrix(data, 0, window, 0, window);
144                int seenrows = window;
145                TIntSet inactiveRows = new TIntHashSet(window);
146                logger.debug("First clustering!: " + seen.rowCount() + "x" + seen.columnCount());
147                IndexClusters oldClusters = clusterer.cluster(seen);
148                logger.debug("First clusters:\n" + oldClusters);
149                List<int[]> completedClusters = new ArrayList<int[]>();
150                while(seenrows < data.rowCount()){
151                        int nextwindow = seenrows + window;
152                        if(nextwindow >= data.rowCount()) nextwindow = data.rowCount();
153                        if(this.maxwindow > 0 && nextwindow - inactiveRows.size() > this.maxwindow){
154                                logger.debug(String.format("Window size (%d) without inactive (%d) = (%d), greater than maximum (%d)",nextwindow, inactiveRows.size(), nextwindow - inactiveRows.size(), this.maxwindow));
155                                deactiveOldItemsAsNoise(nextwindow,inactiveRows,completedClusters);
156                        }
157                        WindowedSparseMatrix wsp = new WindowedSparseMatrix(data, nextwindow, inactiveRows);
158                        logger.debug("Clustering: " + wsp.window.rowCount() + "x" + wsp.window.columnCount());
159                        IndexClusters newClusters = clusterer.cluster(wsp.window);
160                        wsp.correctClusters(newClusters);
161                        logger.debug("New clusters:\n" + newClusters);
162                        // if stability == 1 for any cluster, it was the same last window, we should not include those items next round
163                        detectInactive(oldClusters, newClusters, inactiveRows, completedClusters);
164                        
165                        oldClusters = newClusters;
166                        seenrows += window;
167                        logger.debug("Seen rows: " + seenrows);
168                        logger.debug("Inactive rows: " + inactiveRows.size());
169                }
170                for (int i = 0; i < oldClusters.clusters().length; i++) {
171                        int[] cluster = oldClusters.clusters()[i];
172                        if(cluster.length!=0) 
173                                completedClusters.add(cluster);
174                }
175                
176                return new IndexClusters(completedClusters);
177        }
178
179        private void deactiveOldItemsAsNoise(int nextwindow, TIntSet inactiveRows, List<int[]> completedClusters) {
180                int toDeactivate = 0;
181                while(nextwindow - inactiveRows.size() > this.maxwindow){
182                        if(!inactiveRows.contains(toDeactivate)){
183                                logger.debug("Forcing the deactivation of: " + toDeactivate);
184                                inactiveRows.add(toDeactivate);
185                                completedClusters.add(new int[]{toDeactivate});
186                        }
187                        toDeactivate++;
188                }
189        }
190
191        /**
192         * Given the old and new clusters, make a decision as to which rows are now inactive,
193         * and therefore which clusters are now completed
194         * @param oldClusters
195         * @param newClusters
196         * @param inactiveRows
197         * @param completedClusters
198         */
199        protected void detectInactive(IndexClusters oldClusters, IndexClusters newClusters, TIntSet inactiveRows, List<int[]> completedClusters) {
200                Map<Integer, IntDoublePair> stability = calculateStability(oldClusters,newClusters,inactiveRows);
201                for (Entry<Integer, IntDoublePair> e : stability.entrySet()) {
202                        if(e.getValue().second >= threshold){
203                                int[] completedCluster = oldClusters.clusters()[e.getKey()];
204                                inactiveRows.addAll(completedCluster);
205                                completedClusters.add(completedCluster);
206                                if(threshold == 1){
207                                        newClusters.clusters()[e.getValue().first] = new int[0];
208                                }
209                        }
210                }
211        }
212
213        protected Map<Integer, IntDoublePair> calculateStability(IndexClusters c1, IndexClusters c2, TIntSet inactiveRows) {
214                
215                Map<Integer, IntDoublePair> stability = new HashMap<Integer, IntDoublePair>();
216                int[][] clusters1 = c1.clusters();
217                int[][] clusters2 = c2.clusters();
218                for (int i = 0; i < clusters1.length; i++) {
219                        if(clusters1[i].length == 0) continue;
220                        double maxnmi = 0;
221                        int maxj = -1;
222                        TIntArrayList cluster = new TIntArrayList(clusters1[i].length);
223                        for (int j = 0; j < clusters1[i].length; j++) {
224                                if(inactiveRows.contains(clusters1[i][j]))
225                                        continue;
226                                cluster.add(clusters1[i][j]);
227                        }
228                        int[][] correct = new int[][]{cluster.toArray()};
229                        for (int j = 0; j < clusters2.length; j++) {
230                                int[][] estimated = new int[][]{clusters2[j]};
231//                              NMIAnalysis nmi = new NMIClusterAnalyser().analyse(correct, estimated);
232                                double score = 0;
233                                if(correct[0].length == 1 && estimated[0].length == 1){
234                                        // BOTH 1, either they are the same or not!
235                                        score = correct[0][0] == estimated[0][0] ? 1 : 0;
236                                }
237                                else{                                   
238                                        score = new FScoreClusterAnalyser().analyse(correct, estimated).score();
239                                }
240                                if(!Double.isNaN(score))
241                                {
242                                        if(score > maxnmi){
243                                                maxnmi = score;
244                                                maxj = j;
245                                        }
246                                }
247                        }
248                        stability.put(i, IntDoublePair.pair(maxj, maxnmi));
249                }
250                logger.debug(String.format("The stability is:\n%s",stability));
251                return stability;
252        }
253
254        @Override
255        public int[][] performClustering(SparseMatrix data) {
256                return this.cluster(data).clusters();
257        }
258
259        
260}