001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.gmm; 031 032import gnu.trove.list.array.TDoubleArrayList; 033 034import java.util.Arrays; 035import java.util.EnumSet; 036 037import org.apache.commons.math.util.MathUtils; 038import org.openimaj.math.matrix.MatrixUtils; 039import org.openimaj.math.statistics.MeanAndCovariance; 040import org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian; 041import org.openimaj.math.statistics.distribution.DiagonalMultivariateGaussian; 042import org.openimaj.math.statistics.distribution.FullMultivariateGaussian; 043import org.openimaj.math.statistics.distribution.MixtureOfGaussians; 044import org.openimaj.math.statistics.distribution.MultivariateGaussian; 045import org.openimaj.math.statistics.distribution.SphericalMultivariateGaussian; 046import org.openimaj.ml.clustering.DoubleCentroidsResult; 047import org.openimaj.ml.clustering.kmeans.DoubleKMeans; 048import org.openimaj.util.array.ArrayUtils; 049import org.openimaj.util.pair.IndependentPair; 050 051import Jama.Matrix; 052 053/** 054 * Gaussian mixture model learning using the EM algorithm. Supports a range of 055 * different shapes Gaussian through different covariance matrix forms. An 056 * initialisation step is used to learn the initial means using K-Means, 057 * although this can be disabled in the constructor. 058 * <p> 059 * Implementation was originally inspired by the SciPy's "gmm.py". 060 * 061 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 062 */ 063public class GaussianMixtureModelEM { 064 /** 065 * Different forms of covariance matrix supported by the 066 * {@link GaussianMixtureModelEM}. 067 * 068 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 069 */ 070 public static enum CovarianceType { 071 /** 072 * Spherical Gaussians: variance is the same along all axes and zero 073 * across-axes. 074 */ 075 Spherical { 076 @Override 077 protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) 078 { 079 double mean = 0; 080 081 for (int i = 0; i < cv.getRowDimension(); i++) 082 for (int j = 0; j < cv.getColumnDimension(); j++) 083 mean += cv.get(i, j); 084 mean /= (cv.getColumnDimension() * cv.getRowDimension()); 085 086 for (final MultivariateGaussian mg : gaussians) { 087 ((SphericalMultivariateGaussian) mg).variance = mean; 088 } 089 } 090 091 @Override 092 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 093 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 094 for (int i = 0; i < ngauss; i++) { 095 arr[i] = new SphericalMultivariateGaussian(ndims); 096 } 097 098 return arr; 099 } 100 101 @Override 102 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 103 Matrix weightedXsum, 104 double[] norm) 105 { 106 final Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X)); 107 108 for (int i = 0; i < gmm.gaussians.length; i++) { 109 final Matrix weightedXsumi = new Matrix(new double[][] { weightedXsum.getArray()[i] }); 110 final Matrix avgX2uwi = new Matrix(new double[][] { avgX2uw.getArray()[i] }); 111 112 final Matrix avgX2 = avgX2uwi.times(norm[i]); 113 final Matrix mu = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean; 114 final Matrix avgMeans2 = MatrixUtils.pow(mu, 2); 115 final Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]); 116 final Matrix covar = MatrixUtils.plus(avgX2.minus(avgXmeans.times(2)).plus(avgMeans2), 117 learner.minCovar); 118 119 ((SphericalMultivariateGaussian) gmm.gaussians[i]).variance = 120 MatrixUtils.sum(covar) / X.getColumnDimension(); 121 } 122 } 123 }, 124 /** 125 * Gaussians with diagonal covariance matrices. 126 */ 127 Diagonal { 128 @Override 129 protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) 130 { 131 for (final MultivariateGaussian mg : gaussians) { 132 ((DiagonalMultivariateGaussian) mg).variance = MatrixUtils.diagVector(cv); 133 } 134 } 135 136 @Override 137 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 138 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 139 for (int i = 0; i < ngauss; i++) { 140 arr[i] = new DiagonalMultivariateGaussian(ndims); 141 } 142 143 return arr; 144 } 145 146 @Override 147 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 148 Matrix weightedXsum, 149 double[] norm) 150 { 151 final Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X)); 152 153 for (int i = 0; i < gmm.gaussians.length; i++) { 154 final Matrix weightedXsumi = new Matrix(new double[][] { weightedXsum.getArray()[i] }); 155 final Matrix avgX2uwi = new Matrix(new double[][] { avgX2uw.getArray()[i] }); 156 157 final Matrix avgX2 = avgX2uwi.times(norm[i]); 158 final Matrix mu = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean; 159 final Matrix avgMeans2 = MatrixUtils.pow(mu, 2); 160 final Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]); 161 162 final Matrix covar = MatrixUtils.plus(avgX2.minus(avgXmeans.times(2)).plus(avgMeans2), 163 learner.minCovar); 164 165 ((DiagonalMultivariateGaussian) gmm.gaussians[i]).variance = covar.getArray()[0]; 166 } 167 } 168 }, 169 /** 170 * Gaussians with full covariance 171 */ 172 Full { 173 @Override 174 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 175 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 176 for (int i = 0; i < ngauss; i++) { 177 arr[i] = new FullMultivariateGaussian(ndims); 178 } 179 180 return arr; 181 } 182 183 @Override 184 protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) { 185 for (final MultivariateGaussian mg : gaussians) { 186 ((FullMultivariateGaussian) mg).covar = cv.copy(); 187 } 188 } 189 190 @Override 191 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 192 Matrix weightedXsum, 193 double[] norm) 194 { 195 // Eq. 12 from K. Murphy, 196 // "Fitting a Conditional Linear Gaussian Distribution" 197 final int nfeatures = X.getColumnDimension(); 198 for (int c = 0; c < learner.nComponents; c++) { 199 final Matrix post = responsibilities.getMatrix(0, X.getRowDimension() - 1, c, c).transpose(); 200 201 final double factor = 1.0 / (ArrayUtils.sumValues(post.getArray()) + 10 * MathUtils.EPSILON); 202 203 final Matrix pXt = X.transpose(); 204 for (int i = 0; i < pXt.getRowDimension(); i++) 205 for (int j = 0; j < pXt.getColumnDimension(); j++) 206 pXt.set(i, j, pXt.get(i, j) * post.get(0, j)); 207 208 final Matrix argcv = pXt.times(X).times(factor); 209 final Matrix mu = ((FullMultivariateGaussian) gmm.gaussians[c]).mean; 210 211 ((FullMultivariateGaussian) gmm.gaussians[c]).covar = argcv.minusEquals(mu.transpose().times(mu)) 212 .plusEquals(Matrix.identity(nfeatures, nfeatures).times(learner.minCovar)); 213 } 214 } 215 }, 216 /** 217 * Gaussians with a tied covariance matrix; the same covariance matrix 218 * is shared by all the gaussians. 219 */ 220 Tied { 221 // @Override 222 // protected double[][] logProbability(double[][] x, 223 // MultivariateGaussian[] gaussians) 224 // { 225 // final int ndim = x[0].length; 226 // final int nmix = gaussians.length; 227 // final int nsamples = x.length; 228 // final Matrix X = new Matrix(x); 229 // 230 // final double[][] logProb = new double[nsamples][nmix]; 231 // final Matrix cv = ((FullMultivariateGaussian) 232 // gaussians[0]).covar; 233 // 234 // final CholeskyDecomposition chol = cv.chol(); 235 // Matrix cvChol; 236 // if (chol.isSPD()) { 237 // cvChol = chol.getL(); 238 // } else { 239 // // covar probably doesn't have enough samples, so 240 // // recondition it 241 // final Matrix m = cv.plus(Matrix.identity(ndim, ndim).timesEquals( 242 // MixtureOfGaussians.MIN_COVAR_RECONDITION)); 243 // cvChol = m.chol().getL(); 244 // } 245 // 246 // double cvLogDet = 0; 247 // final double[][] cvCholD = cvChol.getArray(); 248 // for (int j = 0; j < ndim; j++) { 249 // cvLogDet += Math.log(cvCholD[j][j]); 250 // } 251 // cvLogDet *= 2; 252 // 253 // for (int i = 0; i < nmix; i++) { 254 // final Matrix mu = ((FullMultivariateGaussian) gaussians[i]).mean; 255 // final Matrix cvSol = cvChol.solve(MatrixUtils.minusRow(X, 256 // mu.getArray()[0]).transpose()) 257 // .transpose(); 258 // for (int k = 0; k < nsamples; k++) { 259 // double sum = 0; 260 // for (int j = 0; j < ndim; j++) { 261 // sum += cvSol.get(k, j) * cvSol.get(k, j); 262 // } 263 // 264 // logProb[k][i] = -0.5 * (sum + cvLogDet + ndim * Math.log(2 * 265 // Math.PI)); 266 // } 267 // } 268 // 269 // return logProb; 270 // } 271 272 @Override 273 protected void setCovariances(MultivariateGaussian[] gaussians, 274 Matrix cv) 275 { 276 for (final MultivariateGaussian mg : gaussians) { 277 ((FullMultivariateGaussian) mg).covar = cv; 278 } 279 } 280 281 @Override 282 protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) { 283 final MultivariateGaussian[] arr = new MultivariateGaussian[ngauss]; 284 final Matrix covar = new Matrix(ndims, ndims); 285 286 for (int i = 0; i < ngauss; i++) { 287 arr[i] = new FullMultivariateGaussian(new Matrix(1, ndims), covar); 288 } 289 290 return arr; 291 } 292 293 @Override 294 protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, 295 Matrix weightedXsum, double[] norm) 296 { 297 // Eq. 15 from K. Murphy, "Fitting a Conditional Linear Gaussian 298 final int nfeatures = X.getColumnDimension(); 299 300 final Matrix avgX2 = X.transpose().times(X); 301 final double[][] mudata = new double[gmm.gaussians.length][]; 302 for (int i = 0; i < mudata.length; i++) 303 mudata[i] = ((FullMultivariateGaussian) gmm.gaussians[i]).mean.getArray()[0]; 304 final Matrix mu = new Matrix(mudata); 305 306 final Matrix avgMeans2 = mu.transpose().times(weightedXsum); 307 final Matrix covar = avgX2.minus(avgMeans2) 308 .plus(Matrix.identity(nfeatures, nfeatures).times(learner.minCovar)) 309 .times(1.0 / X.getRowDimension()); 310 311 for (int i = 0; i < learner.nComponents; i++) 312 ((FullMultivariateGaussian) gmm.gaussians[i]).covar = covar; 313 } 314 }; 315 316 protected abstract MultivariateGaussian[] createGaussians(int ngauss, int ndims); 317 318 protected abstract void setCovariances(MultivariateGaussian[] gaussians, Matrix cv); 319 320 /** 321 * Mode specific maximisation-step. Implementors should use the state to 322 * update the covariance of each of the 323 * {@link GaussianMixtureModelEM#gaussians}. 324 * 325 * @param gmm 326 * the mixture model being learned 327 * @param X 328 * the data 329 * @param responsibilities 330 * matrix with the same number of rows as X where each col is 331 * the amount that the data point belongs to each gaussian 332 * @param weightedXsum 333 * responsibilities.T * X 334 * @param inverseWeights 335 * 1/weights 336 */ 337 protected abstract void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, 338 Matrix responsibilities, Matrix weightedXsum, double[] inverseWeights); 339 } 340 341 /** 342 * Options for controlling what gets updated during the initialisation 343 * and/or iterations. 344 * 345 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 346 */ 347 public static enum UpdateOptions { 348 /** 349 * Update the means 350 */ 351 Means, 352 /** 353 * Update the weights 354 */ 355 Weights, 356 /** 357 * Update the covariances 358 */ 359 Covariances 360 } 361 362 private static class EMGMM extends MixtureOfGaussians { 363 EMGMM(int nComponents) { 364 super(null, null); 365 366 this.weights = new double[nComponents]; 367 Arrays.fill(this.weights, 1.0 / nComponents); 368 } 369 } 370 371 private static final double DEFAULT_THRESH = 1e-2; 372 private static final double DEFAULT_MIN_COVAR = 1e-3; 373 private static final int DEFAULT_NITERS = 100; 374 private static final int DEFAULT_NINIT = 1; 375 376 CovarianceType ctype; 377 int nComponents; 378 private double thresh; 379 private double minCovar; 380 private int nIters; 381 private int nInit; 382 383 private boolean converged = false; 384 private EnumSet<UpdateOptions> initOpts; 385 private EnumSet<UpdateOptions> iterOpts; 386 387 /** 388 * Construct with the given arguments. 389 * 390 * @param nComponents 391 * the number of gaussian components 392 * @param ctype 393 * the form of the covariance matrices 394 * @param thresh 395 * the threshold at which to stop iterating 396 * @param minCovar 397 * the minimum value allowed in the diagonal of the estimated 398 * covariance matrices to prevent overfitting 399 * @param nIters 400 * the maximum number of iterations 401 * @param nInit 402 * the number of runs of the algorithm to perform; the best 403 * result will be kept. 404 * @param iterOpts 405 * options controlling what is updated during iteration 406 * @param initOpts 407 * options controlling what is updated during initialisation. 408 * Enabling the {@link UpdateOptions#Means} option will cause 409 * K-Means to be used to generate initial starting points for the 410 * means. 411 */ 412 public GaussianMixtureModelEM(int nComponents, CovarianceType ctype, double thresh, double minCovar, 413 int nIters, int nInit, EnumSet<UpdateOptions> iterOpts, EnumSet<UpdateOptions> initOpts) 414 { 415 this.ctype = ctype; 416 this.nComponents = nComponents; 417 this.thresh = thresh; 418 this.minCovar = minCovar; 419 this.nIters = nIters; 420 this.nInit = nInit; 421 this.iterOpts = iterOpts; 422 this.initOpts = initOpts; 423 424 if (nInit < 1) { 425 throw new IllegalArgumentException("GMM estimation requires at least one run"); 426 } 427 this.converged = false; 428 } 429 430 /** 431 * Construct with the given arguments. 432 * 433 * @param nComponents 434 * the number of gaussian components 435 * @param ctype 436 * the form of the covariance matrices 437 */ 438 public GaussianMixtureModelEM(int nComponents, CovarianceType ctype) 439 { 440 this(nComponents, ctype, DEFAULT_THRESH, DEFAULT_MIN_COVAR, DEFAULT_NITERS, DEFAULT_NINIT, EnumSet 441 .allOf(UpdateOptions.class), EnumSet.allOf(UpdateOptions.class)); 442 } 443 444 /** 445 * Get's the convergence state of the algorithm. Will return false if 446 * {@link #estimate(double[][])} has not been called, or if the last call to 447 * {@link #estimate(double[][])} failed to reach convergence before running 448 * out of iterations. 449 * 450 * @return true if the last call to {@link #estimate(double[][])} reached 451 * convergence; false otherwise 452 */ 453 public boolean hasConverged() { 454 return converged; 455 } 456 457 /** 458 * Estimate a new {@link MixtureOfGaussians} from the given data. Use 459 * {@link #hasConverged()} to check whether the EM algorithm reached 460 * convergence in the estimation of the returned model. 461 * 462 * @param X 463 * the data matrix. 464 * @return the generated GMM. 465 */ 466 public MixtureOfGaussians estimate(Matrix X) { 467 return estimate(X.getArray()); 468 } 469 470 /** 471 * Estimate a new {@link MixtureOfGaussians} from the given data. Use 472 * {@link #hasConverged()} to check whether the EM algorithm reached 473 * convergence in the estimation of the returned model. 474 * 475 * @param X 476 * the data array. 477 * @return the generated GMM. 478 */ 479 public MixtureOfGaussians estimate(double[][] X) { 480 final EMGMM gmm = new EMGMM(nComponents); 481 482 if (X.length < nComponents) 483 throw new IllegalArgumentException(String.format( 484 "GMM estimation with %d components, but got only %d samples", nComponents, X.length)); 485 486 double max_log_prob = Double.NEGATIVE_INFINITY; 487 488 for (int j = 0; j < nInit; j++) { 489 gmm.gaussians = ctype.createGaussians(nComponents, X[0].length); 490 491 if (initOpts.contains(UpdateOptions.Means)) { 492 // initialise using k-means 493 final DoubleKMeans km = DoubleKMeans.createExact(nComponents); 494 final DoubleCentroidsResult means = km.cluster(X); 495 496 for (int i = 0; i < nComponents; i++) { 497 ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean.getArray()[0] = means.centroids[i]; 498 } 499 } 500 501 if (initOpts.contains(UpdateOptions.Weights)) { 502 gmm.weights = new double[nComponents]; 503 Arrays.fill(gmm.weights, 1.0 / nComponents); 504 } 505 506 if (initOpts.contains(UpdateOptions.Covariances)) { 507 // cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1]) 508 final Matrix cv = MeanAndCovariance.computeCovariance(X); 509 510 ctype.setCovariances(gmm.gaussians, cv); 511 } 512 513 // EM algorithm 514 final TDoubleArrayList log_likelihood = new TDoubleArrayList(); 515 516 // reset converged to false 517 converged = false; 518 double[] bestWeights = null; 519 MultivariateGaussian[] bestMixture = null; 520 for (int i = 0; i < nIters; i++) { 521 // Expectation step 522 final IndependentPair<double[], double[][]> score = gmm.scoreSamples(X); 523 final double[] curr_log_likelihood = score.firstObject(); 524 final double[][] responsibilities = score.secondObject(); 525 log_likelihood.add(ArrayUtils.sumValues(curr_log_likelihood)); 526 527 // Check for convergence. 528 if (i > 0 && Math.abs(log_likelihood.get(i) - log_likelihood.get(i - 1)) < thresh) { 529 converged = true; 530 break; 531 } 532 533 // Perform the maximisation step 534 mstep(gmm, X, responsibilities); 535 536 // if the results are better, keep it 537 if (nIters > 0) { 538 if (log_likelihood.getQuick(i) > max_log_prob) { 539 max_log_prob = log_likelihood.getQuick(i); 540 bestWeights = gmm.weights; 541 bestMixture = gmm.gaussians; 542 } 543 } 544 545 // check the existence of an init param that was not subject to 546 // likelihood computation issue. 547 if (Double.isInfinite(max_log_prob) && nIters > 0) { 548 throw new RuntimeException( 549 "EM algorithm was never able to compute a valid likelihood given initial " + 550 "parameters. Try different init parameters (or increasing n_init) or " + 551 "check for degenerate data."); 552 } 553 554 if (nIters > 0) { 555 gmm.gaussians = bestMixture; 556 gmm.weights = bestWeights; 557 } 558 } 559 } 560 561 return gmm; 562 } 563 564 private void mstep(EMGMM gmm, double[][] X, double[][] responsibilities) { 565 final double[] weights = ArrayUtils.colSum(responsibilities); 566 final Matrix resMat = new Matrix(responsibilities); 567 final Matrix Xmat = new Matrix(X); 568 569 final Matrix weighted_X_sum = resMat.transpose().times(Xmat); 570 final double[] inverse_weights = new double[weights.length]; 571 for (int i = 0; i < inverse_weights.length; i++) 572 inverse_weights[i] = 1.0 / (weights[i] + 10 * MathUtils.EPSILON); 573 574 if (iterOpts.contains(UpdateOptions.Weights)) { 575 final double sum = ArrayUtils.sumValues(weights); 576 for (int i = 0; i < weights.length; i++) { 577 gmm.weights[i] = (weights[i] / (sum + 10 * MathUtils.EPSILON) + MathUtils.EPSILON); 578 } 579 } 580 581 if (iterOpts.contains(UpdateOptions.Means)) { 582 // self.means_ = weighted_X_sum * inverse_weights 583 final double[][] wx = weighted_X_sum.getArray(); 584 585 for (int i = 0; i < nComponents; i++) { 586 final double[][] m = ((AbstractMultivariateGaussian) gmm.gaussians[i]).mean.getArray(); 587 588 for (int j = 0; j < m[0].length; j++) { 589 m[0][j] = wx[i][j] * inverse_weights[i]; 590 } 591 } 592 } 593 594 if (iterOpts.contains(UpdateOptions.Covariances)) { 595 ctype.mstep(gmm, this, Xmat, resMat, weighted_X_sum, inverse_weights); 596 } 597 } 598 599 @Override 600 public GaussianMixtureModelEM clone() { 601 try { 602 return (GaussianMixtureModelEM) super.clone(); 603 } catch (final CloneNotSupportedException e) { 604 throw new RuntimeException(e); 605 } 606 } 607}