001/*
002        AUTOMATICALLY GENERATED BY jTemp FROM
003        /Users/jon/Work/openimaj/tags/openimaj-1.3.1/machine-learning/clustering/src/main/jtemp/org/openimaj/ml/clustering/kmeans/#T#KMeans.jtemp
004*/
005/**
006 * Copyright (c) 2011, The University of Southampton and the individual contributors.
007 * All rights reserved.
008 *
009 * Redistribution and use in source and binary forms, with or without modification,
010 * are permitted provided that the following conditions are met:
011 *
012 *   *  Redistributions of source code must retain the above copyright notice,
013 *      this list of conditions and the following disclaimer.
014 *
015 *   *  Redistributions in binary form must reproduce the above copyright notice,
016 *      this list of conditions and the following disclaimer in the documentation
017 *      and/or other materials provided with the distribution.
018 *
019 *   *  Neither the name of the University of Southampton nor the names of its
020 *      contributors may be used to endorse or promote products derived from this
021 *      software without specific prior written permission.
022 *
023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
033 */
034 
035package org.openimaj.ml.clustering.kmeans;
036
037import java.util.ArrayList;
038import java.util.Arrays;
039import java.util.List;
040import java.util.Random;
041import java.util.concurrent.Callable;
042import java.util.concurrent.ExecutorService;
043
044import org.openimaj.data.DataSource;
045import org.openimaj.data.FloatArrayBackedDataSource;
046import org.openimaj.ml.clustering.IndexClusters;
047import org.openimaj.ml.clustering.SpatialClusterer;
048import org.openimaj.ml.clustering.assignment.HardAssigner;
049import org.openimaj.ml.clustering.assignment.hard.KDTreeFloatEuclideanAssigner;
050import org.openimaj.ml.clustering.assignment.hard.ExactFloatAssigner;
051import org.openimaj.ml.clustering.FloatCentroidsResult;
052import org.openimaj.knn.FloatNearestNeighbours;
053import org.openimaj.knn.FloatNearestNeighboursExact;
054import org.openimaj.knn.FloatNearestNeighboursProvider;
055import org.openimaj.knn.NearestNeighboursFactory;
056import org.openimaj.knn.approximate.FloatNearestNeighboursKDTree;
057import org.openimaj.util.pair.IntFloatPair;
058
059/**
060 * Fast, parallel implementation of the K-Means algorithm with support for
061 * bigger-than-memory data. Various flavors of K-Means are supported through the
062 * selection of different subclasses of {@link FloatNearestNeighbours}; for
063 * example, approximate K-Means can be achieved using a
064 * {@link FloatNearestNeighboursKDTree} whilst exact K-Means can be achieved
065 * using an {@link FloatNearestNeighboursExact}. The specific choice of
066 * nearest-neighbour object is controlled through the
067 * {@link NearestNeighboursFactory} provided to the {@link KMeansConfiguration}
068 * used to construct instances of this class. The choice of
069 * {@link FloatNearestNeighbours} affects the speed of clustering; using
070 * approximate nearest-neighbours algorithms for the K-Means can produces
071 * comparable results to the exact KMeans algorithm in much shorter time.
072 * The choice and configuration of {@link FloatNearestNeighbours} can also
073 * control the type of distance function being used in the clustering.
074 * <p>
075 * The algorithm is implemented as follows: Clustering is initiated using a
076 * {@link FloatKMeansInit} and is iterative. In each round, batches of
077 * samples are assigned to centroids in parallel. The centroid assignment is
078 * performed using the pre-configured {@link FloatNearestNeighbours} instances
079 * created from the {@link KMeansConfiguration}. Once all samples are assigned
080 * new centroids are calculated and the next round started. Data point pushing
081 * is performed using the same techniques as center point assignment.
082 * <p>
083 * This implementation is able to deal with larger-than-memory datasets by
084 * streaming the samples from disk using an appropriate {@link DataSource}. The
085 * only requirement is that there is enough memory to hold all the centroids
086 * plus working memory for the batches of samples being assigned.
087 * 
088 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
089 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
090 */
091 public class FloatKMeans implements SpatialClusterer<FloatCentroidsResult, float[]> {
092        private static class CentroidAssignmentJob implements Callable<Boolean> {
093                private final DataSource<float[]> ds;
094                private final int startRow;
095                private final int stopRow;
096                private final FloatNearestNeighbours nno;
097                private final float [][] centroids_accum;
098                private final int [] counts;
099
100                public CentroidAssignmentJob(DataSource<float[]> ds, int startRow, int stopRow, FloatNearestNeighbours nno, float [][] centroids_accum, int [] counts) {
101                        this.ds = ds; 
102                        this.startRow = startRow;
103                        this.stopRow = stopRow;
104                        this.nno = nno;
105                        this.centroids_accum = centroids_accum;
106                        this.counts = counts;
107                }
108                
109                @Override
110                public Boolean call() {
111                        try {
112                                int D = nno.numDimensions();
113
114                                float [][] points = new float[stopRow-startRow][D]; 
115                                ds.getData(startRow, stopRow, points);
116
117                                int [] argmins = new int[points.length];
118                                float [] mins = new float[points.length];
119
120                                nno.searchNN(points, argmins, mins);
121
122                                synchronized(centroids_accum){
123                                        for (int i=0; i < points.length; ++i) {
124                                                int k = argmins[i];
125                                                for (int d=0; d < D; ++d) {
126                                                        centroids_accum[k][d] += points[i][d];
127                                                }
128                                                counts[k] += 1;
129                                        }
130                                }
131                        } catch(Exception e) {
132                                e.printStackTrace();
133                        }
134                        return true;
135                }
136        }
137        
138        private static class Result extends FloatCentroidsResult implements FloatNearestNeighboursProvider {
139                protected FloatNearestNeighbours nn;
140                
141                @Override
142                public HardAssigner<float[], float[], IntFloatPair> defaultHardAssigner() {
143                        if (nn instanceof FloatNearestNeighboursExact)
144                                return new ExactFloatAssigner(this);
145                
146                        return new KDTreeFloatEuclideanAssigner(this);
147                }
148                
149                @Override
150                public FloatNearestNeighbours getNearestNeighbours() {
151                        return nn;
152                }
153        }
154        
155        private FloatKMeansInit init = new FloatKMeansInit.RANDOM(); 
156        private KMeansConfiguration<FloatNearestNeighbours, float[]> conf;
157        private Random rng = new Random();
158        
159        /**
160         * Construct the clusterer with the the given configuration.
161         * 
162         * @param conf The configuration.
163         */
164        public FloatKMeans(KMeansConfiguration<FloatNearestNeighbours, float[]> conf) {
165                this.conf = conf;
166        }
167        
168        /**
169         * A completely default {@link FloatKMeans} used primarily as a convenience function for reading.
170         */
171        protected FloatKMeans() {
172                this(new KMeansConfiguration<FloatNearestNeighbours, float[]>());
173        }
174        
175        /**
176         * Get the current initialisation algorithm
177         *
178         * @return the init algorithm being used
179         */
180        public FloatKMeansInit getInit() {
181                return init;
182        }
183
184        /**
185         * Set the current initialisation algorithm
186         *
187         * @param init the init algorithm to be used
188         */
189        public void setInit(FloatKMeansInit init) {
190                this.init = init;
191        }
192        
193        /**
194         * Set the seed for the internal random number generator.
195         *
196         * @param seed the random seed for init random sample selection, no seed if seed < -1
197         */
198        public void seed(long seed) {
199                if(seed < 0)
200                        this.rng = new Random();
201                else
202                        this.rng = new Random(seed);
203        }
204                
205        @Override
206        public FloatCentroidsResult cluster(float[][] data) {
207                DataSource<float[]> ds = new FloatArrayBackedDataSource(data, rng);
208                
209                try {
210                        Result result = cluster(ds, conf.K);
211                        result.nn = conf.factory.create(result.centroids);
212                                                
213                        return result;
214                } catch (Exception e) {
215                        throw new RuntimeException(e);
216                }
217        }
218        
219        @Override
220        public int[][] performClustering(float[][] data) {
221                FloatCentroidsResult clusters = this.cluster(data);
222                return new IndexClusters(clusters.defaultHardAssigner().assign(data)).clusters();
223        }
224        
225        /**
226         * Initiate clustering with the given data and number of clusters.
227         * Internally this method constructs the array to hold the centroids 
228         * and calls {@link #cluster(DataSource, float [][])}.
229         *
230         * @param data data source to cluster with
231         * @param K number of clusters to find
232         * @return cluster centroids
233         */
234        protected Result cluster(DataSource<float[]> data, int K) throws Exception {
235                int D = data.numDimensions();
236                
237                Result result = new Result();
238                result.centroids = new float[K][D];
239        
240                init.initKMeans(data, result.centroids);
241        
242                cluster(data, result);
243
244                return result;
245        }
246        
247        /**
248         * Main clustering algorithm. A number of threads as specified are 
249         * started each containing an assignment job and a reference to
250         * the same set of FloatNearestNeighbours object (i.e. Exact or KDTree). 
251         * Each thread is added to a job pool and started in parallel. 
252         * A single accumulator is shared between all threads and locked on update.
253         *
254         * @param data the data to be clustered
255         * @param centroids the centroids to be found
256         */
257        protected void cluster(DataSource<float[]> data, Result result) throws Exception {
258                final float[][] centroids = result.centroids;
259                final int K = centroids.length;
260                final int D = centroids[0].length;
261                final int N = data.numRows();
262                float [][] centroids_accum = new float[K][D];
263                int [] new_counts = new int[K];
264
265                ExecutorService service = conf.threadpool;
266
267                for (int i=0; i<conf.niters; i++) {
268                        for (int j=0; j<K; j++) Arrays.fill(centroids_accum[j], 0);
269                        Arrays.fill(new_counts, 0);
270
271                        FloatNearestNeighbours nno = conf.factory.create(centroids);
272                        
273                        List<CentroidAssignmentJob> jobs = new ArrayList<CentroidAssignmentJob>();
274                        for (int bl = 0; bl < N; bl += conf.blockSize) {
275                                int br = Math.min(bl + conf.blockSize, N);
276                                jobs.add(new CentroidAssignmentJob(data, bl, br, nno, centroids_accum, new_counts));
277                        }
278
279                        service.invokeAll(jobs);
280
281                        for (int k=0; k < K; ++k) {
282                                if (new_counts[k] == 0) {
283                                        // If there's an empty cluster we replace it with a random point.
284                                        new_counts[k] = 1;
285
286                                        float [][] rnd = new float[][] {centroids[k]};
287                                        data.getRandomRows(rnd);
288                                } else {
289                                        for (int d=0; d < D; ++d) {
290                                                centroids[k][d] = (float)((float)roundFloat((double)centroids_accum[k][d] / (double)new_counts[k]));
291                                        }
292                                }
293                        } 
294                }
295        }
296        
297        protected float roundFloat(double value) { return (float) value; }
298        protected double roundDouble(double value) { return value; }
299        protected long roundLong(double value) { return (long)Math.round(value); }
300        protected int roundInt(double value) { return (int)Math.round(value); }
301        
302        @Override
303        public FloatCentroidsResult cluster(DataSource<float[]> ds) {
304                try {
305                        Result result = cluster(ds, conf.K);
306                        result.nn = conf.factory.create(result.centroids);
307                        
308                        return result;
309                } catch (Exception e) {
310                        throw new RuntimeException(e);
311                }
312        }
313
314    /**
315         * Get the configuration
316         * 
317         * @return the configuration
318         */
319    public KMeansConfiguration<FloatNearestNeighbours, float[]> getConfiguration() {
320        return conf;
321    }
322    
323    /**
324         * Set the configuration
325         * 
326         * @param conf
327         *            the configuration to set
328         */
329    public void setConfiguration(KMeansConfiguration<FloatNearestNeighbours, float[]> conf) {
330        this.conf = conf;
331    }
332        
333        /**
334         * Convenience method to quickly create an exact {@link FloatKMeans}. All
335         * parameters other than the number of clusters are set
336         * at their defaults, but can be manipulated through the configuration
337         * returned by {@link #getConfiguration()}.
338         * <p>
339         * Euclidean distance is used to measure the distance between points.
340         * 
341         * @param K
342         *            the number of clusters
343         * @return a {@link FloatKMeans} instance configured for exact k-means
344         */
345        public static FloatKMeans createExact(int K) {
346                final KMeansConfiguration<FloatNearestNeighbours, float[]> conf =
347                                new KMeansConfiguration<FloatNearestNeighbours, float[]>(K, new FloatNearestNeighboursExact.Factory());
348
349                return new FloatKMeans(conf);
350        }
351
352        /**
353         * Convenience method to quickly create an exact {@link FloatKMeans}. All
354         * parameters other than the number of clusters and number
355         * of iterations are set at their defaults, but can be manipulated through
356         * the configuration returned by {@link #getConfiguration()}.
357         * <p>
358         * Euclidean distance is used to measure the distance between points.
359         * 
360         * @param K
361         *            the number of clusters
362         * @param niters
363         *            maximum number of iterations
364         * @return a {@link FloatKMeans} instance configured for exact k-means
365         */
366        public static FloatKMeans createExact(int K, int niters) {
367                final KMeansConfiguration<FloatNearestNeighbours, float[]> conf =
368                                new KMeansConfiguration<FloatNearestNeighbours, float[]>(K, new FloatNearestNeighboursExact.Factory(), niters);
369
370                return new FloatKMeans(conf);
371        }
372        
373        /**
374         * Convenience method to quickly create an approximate {@link FloatKMeans}
375         * using an ensemble of KD-Trees to perform nearest-neighbour lookup. All
376         * parameters other than the number of clusters are set
377         * at their defaults, but can be manipulated through the configuration
378         * returned by {@link #getConfiguration()}. 
379         * <p>
380         * Euclidean distance is used to measure the distance between points.
381         * 
382         * @param K
383         *            the number of clusters
384         * @return a {@link FloatKMeans} instance configured for approximate k-means 
385         *              using an ensemble of KD-Trees
386         */
387        public static FloatKMeans createKDTreeEnsemble(int K) {
388                final KMeansConfiguration<FloatNearestNeighbours, float[]> conf =
389                                new KMeansConfiguration<FloatNearestNeighbours, float[]>(K, new FloatNearestNeighboursKDTree.Factory());
390
391                return new FloatKMeans(conf);
392        }
393        
394        @Override
395        public String toString() {
396                return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf.getNearestNeighbourFactory().getClass().getSimpleName());
397        }
398}