001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.ml.clustering.kdtree;
031
032import org.apache.commons.math.FunctionEvaluationException;
033import org.apache.commons.math.analysis.MultivariateRealFunction;
034import org.apache.commons.math.optimization.GoalType;
035import org.apache.commons.math.optimization.RealPointValuePair;
036import org.apache.commons.math.optimization.SimpleRealPointChecker;
037import org.apache.commons.math.optimization.direct.NelderMead;
038import org.apache.commons.math.stat.descriptive.moment.Mean;
039import org.openimaj.math.matrix.DiagonalMatrix;
040import org.openimaj.math.matrix.MatlibMatrixUtils;
041import org.openimaj.util.array.ArrayUtils;
042import org.openimaj.util.array.IntArrayView;
043import org.openimaj.util.pair.ObjectDoublePair;
044
045import scala.actors.threadpool.Arrays;
046import ch.akuhn.matrix.DenseMatrix;
047import ch.akuhn.matrix.SparseMatrix;
048import ch.akuhn.matrix.Vector;
049
050/**
051 * Given a vector, tell me the split
052 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
053 *
054 */
055public interface SplitDetectionMode{
056        /**
057         * minimise for y: (y' * (D - W) * y) / ( y' * D * y );
058         * s.t. y = (1 + x) - b * (1 - x);
059         * s.t. b = k / (1 - k);
060         * s.t. k = sum(d(x > 0)) / sum(d);
061         * and
062         * s.t. x is an indicator (-1 for less than t, 1 for greater than or equal to t)
063         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
064         */
065        public class OPTIMISED implements SplitDetectionMode {
066                
067                private DiagonalMatrix D;
068                private SparseMatrix W;
069                private MEAN mean;
070
071                /**
072                 * @param D
073                 * @param W
074                 */
075                public OPTIMISED(DiagonalMatrix D, SparseMatrix W) {
076                        this.D = D;
077                        this.W = W;
078                        this.mean = new MEAN();
079                }
080                private ObjectDoublePair<double[]> indicator(double[] vec, double d) {
081                        double[] ind = new double[vec.length];
082                        double sumx = 0;
083                        for (int i = 0; i < ind.length; i++) {
084                                if(vec[i] > d){
085                                        ind[i] = 1;
086                                        sumx ++;
087                                }
088                                else{
089                                        ind[i] = -1;
090                                }
091                        }
092                        return ObjectDoublePair.pair(ind, sumx);
093                }
094                @Override
095                public double detect(final double[] vec) {
096                        double[] t = {this.mean.detect(vec)};
097                        MultivariateRealFunction func = new MultivariateRealFunction() {
098                                @Override
099                                public double value(double[] x) throws FunctionEvaluationException {
100                                        ObjectDoublePair<double[]> ind = indicator(vec,x[0]);
101                                        double sumd = MatlibMatrixUtils.sum(D);
102                                        double k = ind.second / sumd;
103                                        double b = k / (1-k);
104                                        double[][] y = new double[1][vec.length];
105                                        for (int i = 0; i < vec.length; i++) {
106                                                y[0][i] = ind.first[i] + 1 - b * (1 - ind.first[i]);
107                                        }
108                                        SparseMatrix dmw = MatlibMatrixUtils.minusInplace(D, W);
109                                        Vector yv = Vector.wrap(y[0]);
110                                        double nom = new DenseMatrix(y).mult(dmw.transposeMultiply(yv)).get(0); // y' * ( (D-W) * y)
111                                        double denom = new DenseMatrix(y).mult(D.transposeMultiply(yv)).get(0);
112                                        return nom/denom;
113                                }
114
115                        }; 
116                        
117//                      
118                        RealPointValuePair ret;
119                        try {
120                                NelderMead nelderMead = new NelderMead();
121                                nelderMead.setConvergenceChecker(new SimpleRealPointChecker(0.0001, -1));
122                                ret = nelderMead.optimize(func, GoalType.MINIMIZE, t);
123                                return ret.getPoint()[0];
124                        } catch (Exception e) {
125                                e.printStackTrace();
126                                System.err.println("Reverting to mean");
127                        }
128                        return t[0];
129                }
130                
131
132        }
133
134        /**
135         * Splits clusters becuase they don't have exactly the same value!
136         */
137        public static class MEDIAN implements SplitDetectionMode{
138                @Override
139                public double detect(double[] col) {
140                        double mid = ArrayUtils.quickSelect(col, col.length/2);
141                        if(ArrayUtils.minValue(col) == mid) 
142                                mid += Double.MIN_NORMAL;
143                        if(ArrayUtils.maxValue(col) == mid) 
144                                mid -= Double.MIN_NORMAL;
145                        return 0;
146                }
147
148                
149        }
150        
151        /**
152         * Use the mean to split
153         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
154         *
155         */
156        public static class MEAN implements SplitDetectionMode{
157
158                @Override
159                public double detect(double[] vec) {
160                        return new Mean().evaluate(vec);
161                }
162
163                
164                
165        }
166        
167        /**
168         * Find the median, attempt to find a value which keeps clusters together
169         * @author Sina Samangooei (ss@ecs.soton.ac.uk)
170         *
171         */
172        public static class VARIABLE_MEDIAN implements SplitDetectionMode{
173
174                private double tolchange;
175                /**
176                 * Sets the change tolerance to 0.1 (i.e. if the next value is different by more than value * 0.1, we switch)
177                 */
178                public VARIABLE_MEDIAN() {
179                        this.tolchange = 0.0001;
180                }
181                
182                /**
183                 * @param tol if the next value is different by more than value * tol we found a border
184                 */
185                public VARIABLE_MEDIAN(double tol) {
186                        this.tolchange = tol;
187                }
188                
189                @Override
190                public double detect(double[] vec) {
191                        Arrays.sort(vec);
192                        // Find the median index
193                        int medInd = vec.length/2;
194                        double medVal = vec[medInd];
195                        if(vec.length % 2 == 0){
196                                medVal += vec[medInd+1];
197                                medVal /= 2.;
198                        }
199                        
200                        
201                        boolean maxWithinTol = withinTol(medVal,vec[vec.length-1]);
202                        boolean minWithinTol = withinTol(medVal,vec[0]);
203                        if(maxWithinTol && minWithinTol) 
204                        {
205                                // degenerate case, the min and max are not beyond the tolerance, return the median
206                                return medVal;
207                        }
208                        // The split works like:
209                        // < val go left
210                        // >= val go right
211                        if(maxWithinTol){
212                                // search left
213                                for (int i = medInd; i > 0; i--) {
214                                        if(!withinTol(vec[i],vec[i-1])){
215                                                return vec[i];
216                                        }
217                                }
218                        }
219                        else{
220                                // search right
221                                for (int i = medInd; i < vec.length-1; i++) {
222                                        if(!withinTol(vec[i],vec[i+1])){
223                                                return vec[i+1];
224                                        }
225                                }
226                        }
227                        
228                        
229                        
230                        return 0;
231                }
232
233                private boolean withinTol(double a, double d) {
234                        return Math.abs(a - d) / Math.abs(a) < this.tolchange;
235                }
236
237                
238                
239        };
240        /**
241         * @param vec
242         * @return find the split point
243         */
244        public abstract double detect(double[] vec);
245}