001/* 002 AUTOMATICALLY GENERATED BY jTemp FROM 003 /Users/jon/Work/openimaj/tags/openimaj-1.3.1/machine-learning/clustering/src/main/jtemp/org/openimaj/ml/clustering/kmeans/#T#KMeans.jtemp 004*/ 005/** 006 * Copyright (c) 2011, The University of Southampton and the individual contributors. 007 * All rights reserved. 008 * 009 * Redistribution and use in source and binary forms, with or without modification, 010 * are permitted provided that the following conditions are met: 011 * 012 * * Redistributions of source code must retain the above copyright notice, 013 * this list of conditions and the following disclaimer. 014 * 015 * * Redistributions in binary form must reproduce the above copyright notice, 016 * this list of conditions and the following disclaimer in the documentation 017 * and/or other materials provided with the distribution. 018 * 019 * * Neither the name of the University of Southampton nor the names of its 020 * contributors may be used to endorse or promote products derived from this 021 * software without specific prior written permission. 022 * 023 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 024 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 025 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 026 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 027 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 028 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 029 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 030 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 031 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 032 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 033 */ 034 035package org.openimaj.ml.clustering.kmeans; 036 037import java.util.ArrayList; 038import java.util.Arrays; 039import java.util.List; 040import java.util.Random; 041import java.util.concurrent.Callable; 042import java.util.concurrent.ExecutorService; 043 044import org.openimaj.data.DataSource; 045import org.openimaj.data.FloatArrayBackedDataSource; 046import org.openimaj.ml.clustering.IndexClusters; 047import org.openimaj.ml.clustering.SpatialClusterer; 048import org.openimaj.ml.clustering.assignment.HardAssigner; 049import org.openimaj.ml.clustering.assignment.hard.KDTreeFloatEuclideanAssigner; 050import org.openimaj.ml.clustering.assignment.hard.ExactFloatAssigner; 051import org.openimaj.ml.clustering.FloatCentroidsResult; 052import org.openimaj.knn.FloatNearestNeighbours; 053import org.openimaj.knn.FloatNearestNeighboursExact; 054import org.openimaj.knn.FloatNearestNeighboursProvider; 055import org.openimaj.knn.NearestNeighboursFactory; 056import org.openimaj.knn.approximate.FloatNearestNeighboursKDTree; 057import org.openimaj.util.pair.IntFloatPair; 058 059/** 060 * Fast, parallel implementation of the K-Means algorithm with support for 061 * bigger-than-memory data. Various flavors of K-Means are supported through the 062 * selection of different subclasses of {@link FloatNearestNeighbours}; for 063 * example, approximate K-Means can be achieved using a 064 * {@link FloatNearestNeighboursKDTree} whilst exact K-Means can be achieved 065 * using an {@link FloatNearestNeighboursExact}. The specific choice of 066 * nearest-neighbour object is controlled through the 067 * {@link NearestNeighboursFactory} provided to the {@link KMeansConfiguration} 068 * used to construct instances of this class. The choice of 069 * {@link FloatNearestNeighbours} affects the speed of clustering; using 070 * approximate nearest-neighbours algorithms for the K-Means can produces 071 * comparable results to the exact KMeans algorithm in much shorter time. 072 * The choice and configuration of {@link FloatNearestNeighbours} can also 073 * control the type of distance function being used in the clustering. 074 * <p> 075 * The algorithm is implemented as follows: Clustering is initiated using a 076 * {@link FloatKMeansInit} and is iterative. In each round, batches of 077 * samples are assigned to centroids in parallel. The centroid assignment is 078 * performed using the pre-configured {@link FloatNearestNeighbours} instances 079 * created from the {@link KMeansConfiguration}. Once all samples are assigned 080 * new centroids are calculated and the next round started. Data point pushing 081 * is performed using the same techniques as center point assignment. 082 * <p> 083 * This implementation is able to deal with larger-than-memory datasets by 084 * streaming the samples from disk using an appropriate {@link DataSource}. The 085 * only requirement is that there is enough memory to hold all the centroids 086 * plus working memory for the batches of samples being assigned. 087 * 088 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 089 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 090 */ 091 public class FloatKMeans implements SpatialClusterer<FloatCentroidsResult, float[]> { 092 private static class CentroidAssignmentJob implements Callable<Boolean> { 093 private final DataSource<float[]> ds; 094 private final int startRow; 095 private final int stopRow; 096 private final FloatNearestNeighbours nno; 097 private final float [][] centroids_accum; 098 private final int [] counts; 099 100 public CentroidAssignmentJob(DataSource<float[]> ds, int startRow, int stopRow, FloatNearestNeighbours nno, float [][] centroids_accum, int [] counts) { 101 this.ds = ds; 102 this.startRow = startRow; 103 this.stopRow = stopRow; 104 this.nno = nno; 105 this.centroids_accum = centroids_accum; 106 this.counts = counts; 107 } 108 109 @Override 110 public Boolean call() { 111 try { 112 int D = nno.numDimensions(); 113 114 float [][] points = new float[stopRow-startRow][D]; 115 ds.getData(startRow, stopRow, points); 116 117 int [] argmins = new int[points.length]; 118 float [] mins = new float[points.length]; 119 120 nno.searchNN(points, argmins, mins); 121 122 synchronized(centroids_accum){ 123 for (int i=0; i < points.length; ++i) { 124 int k = argmins[i]; 125 for (int d=0; d < D; ++d) { 126 centroids_accum[k][d] += points[i][d]; 127 } 128 counts[k] += 1; 129 } 130 } 131 } catch(Exception e) { 132 e.printStackTrace(); 133 } 134 return true; 135 } 136 } 137 138 private static class Result extends FloatCentroidsResult implements FloatNearestNeighboursProvider { 139 protected FloatNearestNeighbours nn; 140 141 @Override 142 public HardAssigner<float[], float[], IntFloatPair> defaultHardAssigner() { 143 if (nn instanceof FloatNearestNeighboursExact) 144 return new ExactFloatAssigner(this); 145 146 return new KDTreeFloatEuclideanAssigner(this); 147 } 148 149 @Override 150 public FloatNearestNeighbours getNearestNeighbours() { 151 return nn; 152 } 153 } 154 155 private FloatKMeansInit init = new FloatKMeansInit.RANDOM(); 156 private KMeansConfiguration<FloatNearestNeighbours, float[]> conf; 157 private Random rng = new Random(); 158 159 /** 160 * Construct the clusterer with the the given configuration. 161 * 162 * @param conf The configuration. 163 */ 164 public FloatKMeans(KMeansConfiguration<FloatNearestNeighbours, float[]> conf) { 165 this.conf = conf; 166 } 167 168 /** 169 * A completely default {@link FloatKMeans} used primarily as a convenience function for reading. 170 */ 171 protected FloatKMeans() { 172 this(new KMeansConfiguration<FloatNearestNeighbours, float[]>()); 173 } 174 175 /** 176 * Get the current initialisation algorithm 177 * 178 * @return the init algorithm being used 179 */ 180 public FloatKMeansInit getInit() { 181 return init; 182 } 183 184 /** 185 * Set the current initialisation algorithm 186 * 187 * @param init the init algorithm to be used 188 */ 189 public void setInit(FloatKMeansInit init) { 190 this.init = init; 191 } 192 193 /** 194 * Set the seed for the internal random number generator. 195 * 196 * @param seed the random seed for init random sample selection, no seed if seed < -1 197 */ 198 public void seed(long seed) { 199 if(seed < 0) 200 this.rng = new Random(); 201 else 202 this.rng = new Random(seed); 203 } 204 205 @Override 206 public FloatCentroidsResult cluster(float[][] data) { 207 DataSource<float[]> ds = new FloatArrayBackedDataSource(data, rng); 208 209 try { 210 Result result = cluster(ds, conf.K); 211 result.nn = conf.factory.create(result.centroids); 212 213 return result; 214 } catch (Exception e) { 215 throw new RuntimeException(e); 216 } 217 } 218 219 @Override 220 public int[][] performClustering(float[][] data) { 221 FloatCentroidsResult clusters = this.cluster(data); 222 return new IndexClusters(clusters.defaultHardAssigner().assign(data)).clusters(); 223 } 224 225 /** 226 * Initiate clustering with the given data and number of clusters. 227 * Internally this method constructs the array to hold the centroids 228 * and calls {@link #cluster(DataSource, float [][])}. 229 * 230 * @param data data source to cluster with 231 * @param K number of clusters to find 232 * @return cluster centroids 233 */ 234 protected Result cluster(DataSource<float[]> data, int K) throws Exception { 235 int D = data.numDimensions(); 236 237 Result result = new Result(); 238 result.centroids = new float[K][D]; 239 240 init.initKMeans(data, result.centroids); 241 242 cluster(data, result); 243 244 return result; 245 } 246 247 /** 248 * Main clustering algorithm. A number of threads as specified are 249 * started each containing an assignment job and a reference to 250 * the same set of FloatNearestNeighbours object (i.e. Exact or KDTree). 251 * Each thread is added to a job pool and started in parallel. 252 * A single accumulator is shared between all threads and locked on update. 253 * 254 * @param data the data to be clustered 255 * @param centroids the centroids to be found 256 */ 257 protected void cluster(DataSource<float[]> data, Result result) throws Exception { 258 final float[][] centroids = result.centroids; 259 final int K = centroids.length; 260 final int D = centroids[0].length; 261 final int N = data.numRows(); 262 float [][] centroids_accum = new float[K][D]; 263 int [] new_counts = new int[K]; 264 265 ExecutorService service = conf.threadpool; 266 267 for (int i=0; i<conf.niters; i++) { 268 for (int j=0; j<K; j++) Arrays.fill(centroids_accum[j], 0); 269 Arrays.fill(new_counts, 0); 270 271 FloatNearestNeighbours nno = conf.factory.create(centroids); 272 273 List<CentroidAssignmentJob> jobs = new ArrayList<CentroidAssignmentJob>(); 274 for (int bl = 0; bl < N; bl += conf.blockSize) { 275 int br = Math.min(bl + conf.blockSize, N); 276 jobs.add(new CentroidAssignmentJob(data, bl, br, nno, centroids_accum, new_counts)); 277 } 278 279 service.invokeAll(jobs); 280 281 for (int k=0; k < K; ++k) { 282 if (new_counts[k] == 0) { 283 // If there's an empty cluster we replace it with a random point. 284 new_counts[k] = 1; 285 286 float [][] rnd = new float[][] {centroids[k]}; 287 data.getRandomRows(rnd); 288 } else { 289 for (int d=0; d < D; ++d) { 290 centroids[k][d] = (float)((float)roundFloat((double)centroids_accum[k][d] / (double)new_counts[k])); 291 } 292 } 293 } 294 } 295 } 296 297 protected float roundFloat(double value) { return (float) value; } 298 protected double roundDouble(double value) { return value; } 299 protected long roundLong(double value) { return (long)Math.round(value); } 300 protected int roundInt(double value) { return (int)Math.round(value); } 301 302 @Override 303 public FloatCentroidsResult cluster(DataSource<float[]> ds) { 304 try { 305 Result result = cluster(ds, conf.K); 306 result.nn = conf.factory.create(result.centroids); 307 308 return result; 309 } catch (Exception e) { 310 throw new RuntimeException(e); 311 } 312 } 313 314 /** 315 * Get the configuration 316 * 317 * @return the configuration 318 */ 319 public KMeansConfiguration<FloatNearestNeighbours, float[]> getConfiguration() { 320 return conf; 321 } 322 323 /** 324 * Set the configuration 325 * 326 * @param conf 327 * the configuration to set 328 */ 329 public void setConfiguration(KMeansConfiguration<FloatNearestNeighbours, float[]> conf) { 330 this.conf = conf; 331 } 332 333 /** 334 * Convenience method to quickly create an exact {@link FloatKMeans}. All 335 * parameters other than the number of clusters are set 336 * at their defaults, but can be manipulated through the configuration 337 * returned by {@link #getConfiguration()}. 338 * <p> 339 * Euclidean distance is used to measure the distance between points. 340 * 341 * @param K 342 * the number of clusters 343 * @return a {@link FloatKMeans} instance configured for exact k-means 344 */ 345 public static FloatKMeans createExact(int K) { 346 final KMeansConfiguration<FloatNearestNeighbours, float[]> conf = 347 new KMeansConfiguration<FloatNearestNeighbours, float[]>(K, new FloatNearestNeighboursExact.Factory()); 348 349 return new FloatKMeans(conf); 350 } 351 352 /** 353 * Convenience method to quickly create an exact {@link FloatKMeans}. All 354 * parameters other than the number of clusters and number 355 * of iterations are set at their defaults, but can be manipulated through 356 * the configuration returned by {@link #getConfiguration()}. 357 * <p> 358 * Euclidean distance is used to measure the distance between points. 359 * 360 * @param K 361 * the number of clusters 362 * @param niters 363 * maximum number of iterations 364 * @return a {@link FloatKMeans} instance configured for exact k-means 365 */ 366 public static FloatKMeans createExact(int K, int niters) { 367 final KMeansConfiguration<FloatNearestNeighbours, float[]> conf = 368 new KMeansConfiguration<FloatNearestNeighbours, float[]>(K, new FloatNearestNeighboursExact.Factory(), niters); 369 370 return new FloatKMeans(conf); 371 } 372 373 /** 374 * Convenience method to quickly create an approximate {@link FloatKMeans} 375 * using an ensemble of KD-Trees to perform nearest-neighbour lookup. All 376 * parameters other than the number of clusters are set 377 * at their defaults, but can be manipulated through the configuration 378 * returned by {@link #getConfiguration()}. 379 * <p> 380 * Euclidean distance is used to measure the distance between points. 381 * 382 * @param K 383 * the number of clusters 384 * @return a {@link FloatKMeans} instance configured for approximate k-means 385 * using an ensemble of KD-Trees 386 */ 387 public static FloatKMeans createKDTreeEnsemble(int K) { 388 final KMeansConfiguration<FloatNearestNeighbours, float[]> conf = 389 new KMeansConfiguration<FloatNearestNeighbours, float[]>(K, new FloatNearestNeighboursKDTree.Factory()); 390 391 return new FloatKMeans(conf); 392 } 393 394 @Override 395 public String toString() { 396 return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf.getNearestNeighbourFactory().getClass().getSimpleName()); 397 } 398}