001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.haar.training;
031
032import org.openimaj.util.function.Operation;
033import org.openimaj.util.pair.ObjectFloatPair;
034import org.openimaj.util.parallel.Parallel;
035import org.openimaj.util.parallel.Parallel.IntRange;
036
037public class StumpClassifier {
038        public static class WeightedLearner {
039                // Trains using Error = \sum_{i=1}^{N} D_i * [y_i != h(x_i)]
040                // and h(x) = classifier.sign * (2 * [xclassifier.dimension >
041                // classifier.threshold] - 1)
042                public ObjectFloatPair<StumpClassifier> learn(final HaarTrainingData trainingSet, final float[] _weights) {
043                        final StumpClassifier classifier = new StumpClassifier();
044
045                        // Search for minimum training set error
046                        final float[] minimumError = { Float.POSITIVE_INFINITY };
047
048                        final boolean[] classes = trainingSet.getClasses();
049                        final int nInstances = trainingSet.numInstances();
050
051                        // Determine total potential error
052                        float totalErrorC = 0.0f;
053                        for (int i = 0; i < nInstances; i++)
054                                totalErrorC += _weights[i];
055                        final float totalError = totalErrorC;
056
057                        // Initialise search error
058                        float initialErrorC = 0.0f;
059                        for (int i = 0; i < nInstances; i++)
060                                initialErrorC += !classes[i] ? _weights[i] : 0.0;
061                        final float initialError = initialErrorC;
062
063                        // Loop over possible dimensions
064                        // for (int d = 0; d < trainingSet.numFeatures(); d++) {
065                        Parallel.forRange(0, trainingSet.numFeatures(), 1, new Operation<IntRange>() {
066                                @Override
067                                public void perform(IntRange rng) {
068                                        final StumpClassifier currClassifier = new StumpClassifier();
069                                        currClassifier.dimension = -1;
070                                        currClassifier.threshold = Float.NaN;
071                                        currClassifier.sign = 0;
072
073                                        float currMinimumError = Float.POSITIVE_INFINITY;
074
075                                        for (int d = rng.start; d < rng.stop; d += rng.incr) {
076                                                // Pre-sort data-items in dimension for efficient
077                                                // threshold
078                                                // search
079                                                final float[] data = trainingSet.getResponses(d);
080                                                final int[] indices = trainingSet.getSortedIndices(d);
081
082                                                // Initialise search error
083                                                float currentError = initialError;
084
085                                                // Search through the sorted list to determine best
086                                                // threshold
087                                                for (int i = 0; i < nInstances - 1; i++) {
088                                                        // Update current error
089                                                        final int index = indices[i];
090                                                        if (classes[index])
091                                                                currentError += _weights[index];
092                                                        else
093                                                                currentError -= _weights[index];
094
095                                                        // Check for repeated values
096                                                        if (data[indices[i]] == data[indices[i + 1]])
097                                                                continue;
098
099                                                        // Compute the test threshold - maximises the margin
100                                                        // between
101                                                        // potential thresholds
102                                                        final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f;
103
104                                                        // Compare to current best
105                                                        if (currentError < currMinimumError) // Good
106                                                        // classifier
107                                                        // with
108                                                        // classifier.sign
109                                                        // =
110                                                        // +1
111                                                        {
112                                                                currMinimumError = currentError;
113                                                                currClassifier.dimension = d;
114                                                                currClassifier.threshold = testThreshold;
115                                                                currClassifier.sign = +1;
116                                                        }
117                                                        if ((totalError - currentError) < currMinimumError) // Good
118                                                        // classifier
119                                                        // with
120                                                        // classifier.sign
121                                                        // =
122                                                        // -1
123                                                        {
124                                                                currMinimumError = (totalError - currentError);
125                                                                currClassifier.dimension = d;
126                                                                currClassifier.threshold = testThreshold;
127                                                                currClassifier.sign = -1;
128                                                        }
129                                                }
130                                        }
131
132                                        synchronized (classifier) {
133                                                if (currMinimumError < minimumError[0]) {
134                                                        minimumError[0] = currMinimumError;
135                                                        classifier.dimension = currClassifier.dimension;
136                                                        classifier.sign = currClassifier.sign;
137                                                        classifier.threshold = currClassifier.threshold;
138                                                }
139                                        }
140                                }
141                        });
142
143                        return new ObjectFloatPair<StumpClassifier>(classifier, minimumError[0]);
144                }
145        }
146
147        int dimension;
148        float threshold;
149        int sign;
150
151        public boolean classify(float[] instanceFeature) {
152                return (instanceFeature[dimension] > threshold ? sign : -sign) == 1 ? true : false;
153        }
154
155        public boolean classify(float f) {
156                return (f > threshold ? sign : -sign) == 1 ? true : false;
157        }
158}