001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.image.objectdetection.haar.training; 031 032import java.util.ArrayList; 033import java.util.List; 034 035import org.openimaj.util.pair.ObjectFloatPair; 036 037public class AdaBoost { 038 StumpClassifier.WeightedLearner factory = new StumpClassifier.WeightedLearner(); 039 040 public List<ObjectFloatPair<StumpClassifier>> learn(HaarTrainingData trainingSet, int _numberOfRounds) { 041 // Initialise weights 042 final float[] _weights = new float[trainingSet.numInstances()]; 043 for (int i = 0; i < trainingSet.numInstances(); i++) 044 _weights[i] = 1.0f / trainingSet.numInstances(); 045 046 final boolean[] actualClasses = trainingSet.getClasses(); 047 048 final List<ObjectFloatPair<StumpClassifier>> _h = new ArrayList<ObjectFloatPair<StumpClassifier>>(); 049 050 // Perform the learning 051 for (int t = 0; t < _numberOfRounds; t++) { 052 System.out.println("Iteration: " + t); 053 054 // Create the weak learner and train it 055 final ObjectFloatPair<StumpClassifier> h = factory.learn(trainingSet, _weights); 056 057 // Compute the classifications and training error 058 final boolean[] hClassification = new boolean[trainingSet.numInstances()]; 059 final float[] responses = trainingSet.getResponses(h.first.dimension); 060 double epsilon = 0.0; 061 for (int i = 0; i < trainingSet.numInstances(); i++) { 062 hClassification[i] = h.first.classify(responses[i]); 063 epsilon += hClassification[i] != actualClasses[i] ? _weights[i] : 0.0; 064 } 065 066 System.out.println("epsilon = " + epsilon); 067 068 // Check stopping condition 069 if (epsilon >= 0.5) 070 break; 071 072 // Calculate alpha 073 final float alpha = (float) (0.5 * Math.log((1 - epsilon) / epsilon)); 074 075 System.out.println("alpha = " + alpha); 076 077 // Update the weights 078 float weightsSum = 0.0f; 079 for (int i = 0; i < trainingSet.numInstances(); i++) { 080 _weights[i] *= Math.exp(-alpha * (actualClasses[i] ? 1 : -1) * (hClassification[i] ? 1 : -1)); 081 weightsSum += _weights[i]; 082 } 083 // Normalise 084 for (int i = 0; i < trainingSet.numInstances(); i++) 085 _weights[i] /= weightsSum; 086 087 // Store the weak learner and alpha value 088 _h.add(new ObjectFloatPair<StumpClassifier>(h.first, alpha)); 089 090 // if (t % 5 == 0) 091 // printClassificationQuality(trainingSet, _h); 092 // DisplayUtilities.display(DrawingTest.drawRects(trainingSet.getFeature(h.first.dimension).rects)); 093 System.out.println("feature = " + h.first.dimension); 094 095 // Break if perfectly classifying data 096 if (epsilon == 0.0) 097 break; 098 } 099 100 return _h; 101 } 102 103 public void printClassificationQuality(HaarTrainingData data, List<ObjectFloatPair<StumpClassifier>> ensemble, 104 float threshold) 105 { 106 int tp = 0; 107 int fn = 0; 108 int tn = 0; 109 int fp = 0; 110 111 final int ninstances = data.numInstances(); 112 final boolean[] classes = data.getClasses(); 113 for (int i = 0; i < ninstances; i++) { 114 final float[] feature = data.getInstanceFeature(i); 115 116 final boolean predicted = AdaBoost.Classify(feature, ensemble, threshold); 117 final boolean actual = classes[i]; 118 119 if (actual) { 120 if (predicted) 121 tp++; // TP 122 else 123 fn++; // FN 124 } else { 125 if (predicted) 126 fp++; // FP 127 else 128 tn++; // TN 129 } 130 } 131 132 System.out.format("TP: %d\tFN: %d\tFP: %d\tTN: %d\n", tp, fn, fp, tn); 133 134 final float fpr = (float) fp / (float) (fp + tn); 135 final float tpr = (float) tp / (float) (tp + fn); 136 137 System.out.format("FPR: %2.2f\tTPR: %2.2f\n", fpr, tpr); 138 } 139 140 public static boolean Classify(float[] data, List<ObjectFloatPair<StumpClassifier>> _h) { 141 double classification = 0.0; 142 143 // Call the weak learner classify methods and combine results 144 for (int t = 0; t < _h.size(); t++) 145 classification += _h.get(t).second * (_h.get(t).first.classify(data) ? 1 : -1); 146 147 // Return the thresholded classification 148 return classification > 0.0 ? true : false; 149 } 150 151 public static boolean Classify(float[] data, List<ObjectFloatPair<StumpClassifier>> _h, float threshold) { 152 double classification = 0.0; 153 154 // Call the weak learner classify methods and combine results 155 for (int t = 0; t < _h.size(); t++) 156 classification += _h.get(t).second * (_h.get(t).first.classify(data) ? 1 : -1); 157 158 // Return the thresholded classification 159 return classification > threshold ? true : false; 160 } 161}