001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.haar.training;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.util.pair.ObjectFloatPair;
036
037public class AdaBoost {
038        StumpClassifier.WeightedLearner factory = new StumpClassifier.WeightedLearner();
039
040        public List<ObjectFloatPair<StumpClassifier>> learn(HaarTrainingData trainingSet, int _numberOfRounds) {
041                // Initialise weights
042                final float[] _weights = new float[trainingSet.numInstances()];
043                for (int i = 0; i < trainingSet.numInstances(); i++)
044                        _weights[i] = 1.0f / trainingSet.numInstances();
045
046                final boolean[] actualClasses = trainingSet.getClasses();
047
048                final List<ObjectFloatPair<StumpClassifier>> _h = new ArrayList<ObjectFloatPair<StumpClassifier>>();
049
050                // Perform the learning
051                for (int t = 0; t < _numberOfRounds; t++) {
052                        System.out.println("Iteration: " + t);
053
054                        // Create the weak learner and train it
055                        final ObjectFloatPair<StumpClassifier> h = factory.learn(trainingSet, _weights);
056
057                        // Compute the classifications and training error
058                        final boolean[] hClassification = new boolean[trainingSet.numInstances()];
059                        final float[] responses = trainingSet.getResponses(h.first.dimension);
060                        double epsilon = 0.0;
061                        for (int i = 0; i < trainingSet.numInstances(); i++) {
062                                hClassification[i] = h.first.classify(responses[i]);
063                                epsilon += hClassification[i] != actualClasses[i] ? _weights[i] : 0.0;
064                        }
065
066                        System.out.println("epsilon = " + epsilon);
067
068                        // Check stopping condition
069                        if (epsilon >= 0.5)
070                                break;
071
072                        // Calculate alpha
073                        final float alpha = (float) (0.5 * Math.log((1 - epsilon) / epsilon));
074
075                        System.out.println("alpha = " + alpha);
076
077                        // Update the weights
078                        float weightsSum = 0.0f;
079                        for (int i = 0; i < trainingSet.numInstances(); i++) {
080                                _weights[i] *= Math.exp(-alpha * (actualClasses[i] ? 1 : -1) * (hClassification[i] ? 1 : -1));
081                                weightsSum += _weights[i];
082                        }
083                        // Normalise
084                        for (int i = 0; i < trainingSet.numInstances(); i++)
085                                _weights[i] /= weightsSum;
086
087                        // Store the weak learner and alpha value
088                        _h.add(new ObjectFloatPair<StumpClassifier>(h.first, alpha));
089
090                        // if (t % 5 == 0)
091                        // printClassificationQuality(trainingSet, _h);
092                        // DisplayUtilities.display(DrawingTest.drawRects(trainingSet.getFeature(h.first.dimension).rects));
093                        System.out.println("feature = " + h.first.dimension);
094
095                        // Break if perfectly classifying data
096                        if (epsilon == 0.0)
097                                break;
098                }
099
100                return _h;
101        }
102
103        public void printClassificationQuality(HaarTrainingData data, List<ObjectFloatPair<StumpClassifier>> ensemble,
104                        float threshold)
105        {
106                int tp = 0;
107                int fn = 0;
108                int tn = 0;
109                int fp = 0;
110
111                final int ninstances = data.numInstances();
112                final boolean[] classes = data.getClasses();
113                for (int i = 0; i < ninstances; i++) {
114                        final float[] feature = data.getInstanceFeature(i);
115
116                        final boolean predicted = AdaBoost.Classify(feature, ensemble, threshold);
117                        final boolean actual = classes[i];
118
119                        if (actual) {
120                                if (predicted)
121                                        tp++; // TP
122                                else
123                                        fn++; // FN
124                        } else {
125                                if (predicted)
126                                        fp++; // FP
127                                else
128                                        tn++; // TN
129                        }
130                }
131
132                System.out.format("TP: %d\tFN: %d\tFP: %d\tTN: %d\n", tp, fn, fp, tn);
133
134                final float fpr = (float) fp / (float) (fp + tn);
135                final float tpr = (float) tp / (float) (tp + fn);
136
137                System.out.format("FPR: %2.2f\tTPR: %2.2f\n", fpr, tpr);
138        }
139
140        public static boolean Classify(float[] data, List<ObjectFloatPair<StumpClassifier>> _h) {
141                double classification = 0.0;
142
143                // Call the weak learner classify methods and combine results
144                for (int t = 0; t < _h.size(); t++)
145                        classification += _h.get(t).second * (_h.get(t).first.classify(data) ? 1 : -1);
146
147                // Return the thresholded classification
148                return classification > 0.0 ? true : false;
149        }
150
151        public static boolean Classify(float[] data, List<ObjectFloatPair<StumpClassifier>> _h, float threshold) {
152                double classification = 0.0;
153
154                // Call the weak learner classify methods and combine results
155                for (int t = 0; t < _h.size(); t++)
156                        classification += _h.get(t).second * (_h.get(t).first.classify(data) ? 1 : -1);
157
158                // Return the thresholded classification
159                return classification > threshold ? true : false;
160        }
161}