001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.ml.clustering.kdtree; 031 032import org.apache.commons.math.FunctionEvaluationException; 033import org.apache.commons.math.analysis.MultivariateRealFunction; 034import org.apache.commons.math.optimization.GoalType; 035import org.apache.commons.math.optimization.RealPointValuePair; 036import org.apache.commons.math.optimization.SimpleRealPointChecker; 037import org.apache.commons.math.optimization.direct.NelderMead; 038import org.apache.commons.math.stat.descriptive.moment.Mean; 039import org.openimaj.math.matrix.DiagonalMatrix; 040import org.openimaj.math.matrix.MatlibMatrixUtils; 041import org.openimaj.util.array.ArrayUtils; 042import org.openimaj.util.array.IntArrayView; 043import org.openimaj.util.pair.ObjectDoublePair; 044 045import scala.actors.threadpool.Arrays; 046import ch.akuhn.matrix.DenseMatrix; 047import ch.akuhn.matrix.SparseMatrix; 048import ch.akuhn.matrix.Vector; 049 050/** 051 * Given a vector, tell me the split 052 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 053 * 054 */ 055public interface SplitDetectionMode{ 056 /** 057 * minimise for y: (y' * (D - W) * y) / ( y' * D * y ); 058 * s.t. y = (1 + x) - b * (1 - x); 059 * s.t. b = k / (1 - k); 060 * s.t. k = sum(d(x > 0)) / sum(d); 061 * and 062 * s.t. x is an indicator (-1 for less than t, 1 for greater than or equal to t) 063 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 064 */ 065 public class OPTIMISED implements SplitDetectionMode { 066 067 private DiagonalMatrix D; 068 private SparseMatrix W; 069 private MEAN mean; 070 071 /** 072 * @param D 073 * @param W 074 */ 075 public OPTIMISED(DiagonalMatrix D, SparseMatrix W) { 076 this.D = D; 077 this.W = W; 078 this.mean = new MEAN(); 079 } 080 private ObjectDoublePair<double[]> indicator(double[] vec, double d) { 081 double[] ind = new double[vec.length]; 082 double sumx = 0; 083 for (int i = 0; i < ind.length; i++) { 084 if(vec[i] > d){ 085 ind[i] = 1; 086 sumx ++; 087 } 088 else{ 089 ind[i] = -1; 090 } 091 } 092 return ObjectDoublePair.pair(ind, sumx); 093 } 094 @Override 095 public double detect(final double[] vec) { 096 double[] t = {this.mean.detect(vec)}; 097 MultivariateRealFunction func = new MultivariateRealFunction() { 098 @Override 099 public double value(double[] x) throws FunctionEvaluationException { 100 ObjectDoublePair<double[]> ind = indicator(vec,x[0]); 101 double sumd = MatlibMatrixUtils.sum(D); 102 double k = ind.second / sumd; 103 double b = k / (1-k); 104 double[][] y = new double[1][vec.length]; 105 for (int i = 0; i < vec.length; i++) { 106 y[0][i] = ind.first[i] + 1 - b * (1 - ind.first[i]); 107 } 108 SparseMatrix dmw = MatlibMatrixUtils.minusInplace(D, W); 109 Vector yv = Vector.wrap(y[0]); 110 double nom = new DenseMatrix(y).mult(dmw.transposeMultiply(yv)).get(0); // y' * ( (D-W) * y) 111 double denom = new DenseMatrix(y).mult(D.transposeMultiply(yv)).get(0); 112 return nom/denom; 113 } 114 115 }; 116 117// 118 RealPointValuePair ret; 119 try { 120 NelderMead nelderMead = new NelderMead(); 121 nelderMead.setConvergenceChecker(new SimpleRealPointChecker(0.0001, -1)); 122 ret = nelderMead.optimize(func, GoalType.MINIMIZE, t); 123 return ret.getPoint()[0]; 124 } catch (Exception e) { 125 e.printStackTrace(); 126 System.err.println("Reverting to mean"); 127 } 128 return t[0]; 129 } 130 131 132 } 133 134 /** 135 * Splits clusters becuase they don't have exactly the same value! 136 */ 137 public static class MEDIAN implements SplitDetectionMode{ 138 @Override 139 public double detect(double[] col) { 140 double mid = ArrayUtils.quickSelect(col, col.length/2); 141 if(ArrayUtils.minValue(col) == mid) 142 mid += Double.MIN_NORMAL; 143 if(ArrayUtils.maxValue(col) == mid) 144 mid -= Double.MIN_NORMAL; 145 return 0; 146 } 147 148 149 } 150 151 /** 152 * Use the mean to split 153 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 154 * 155 */ 156 public static class MEAN implements SplitDetectionMode{ 157 158 @Override 159 public double detect(double[] vec) { 160 return new Mean().evaluate(vec); 161 } 162 163 164 165 } 166 167 /** 168 * Find the median, attempt to find a value which keeps clusters together 169 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 170 * 171 */ 172 public static class VARIABLE_MEDIAN implements SplitDetectionMode{ 173 174 private double tolchange; 175 /** 176 * Sets the change tolerance to 0.1 (i.e. if the next value is different by more than value * 0.1, we switch) 177 */ 178 public VARIABLE_MEDIAN() { 179 this.tolchange = 0.0001; 180 } 181 182 /** 183 * @param tol if the next value is different by more than value * tol we found a border 184 */ 185 public VARIABLE_MEDIAN(double tol) { 186 this.tolchange = tol; 187 } 188 189 @Override 190 public double detect(double[] vec) { 191 Arrays.sort(vec); 192 // Find the median index 193 int medInd = vec.length/2; 194 double medVal = vec[medInd]; 195 if(vec.length % 2 == 0){ 196 medVal += vec[medInd+1]; 197 medVal /= 2.; 198 } 199 200 201 boolean maxWithinTol = withinTol(medVal,vec[vec.length-1]); 202 boolean minWithinTol = withinTol(medVal,vec[0]); 203 if(maxWithinTol && minWithinTol) 204 { 205 // degenerate case, the min and max are not beyond the tolerance, return the median 206 return medVal; 207 } 208 // The split works like: 209 // < val go left 210 // >= val go right 211 if(maxWithinTol){ 212 // search left 213 for (int i = medInd; i > 0; i--) { 214 if(!withinTol(vec[i],vec[i-1])){ 215 return vec[i]; 216 } 217 } 218 } 219 else{ 220 // search right 221 for (int i = medInd; i < vec.length-1; i++) { 222 if(!withinTol(vec[i],vec[i+1])){ 223 return vec[i+1]; 224 } 225 } 226 } 227 228 229 230 return 0; 231 } 232 233 private boolean withinTol(double a, double d) { 234 return Math.abs(a - d) / Math.abs(a) < this.tolchange; 235 } 236 237 238 239 }; 240 /** 241 * @param vec 242 * @return find the split point 243 */ 244 public abstract double detect(double[] vec); 245}