001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.image.objectdetection.haar.training; 031 032import org.openimaj.util.function.Operation; 033import org.openimaj.util.pair.ObjectFloatPair; 034import org.openimaj.util.parallel.Parallel; 035import org.openimaj.util.parallel.Parallel.IntRange; 036 037public class StumpClassifier { 038 public static class WeightedLearner { 039 // Trains using Error = \sum_{i=1}^{N} D_i * [y_i != h(x_i)] 040 // and h(x) = classifier.sign * (2 * [xclassifier.dimension > 041 // classifier.threshold] - 1) 042 public ObjectFloatPair<StumpClassifier> learn(final HaarTrainingData trainingSet, final float[] _weights) { 043 final StumpClassifier classifier = new StumpClassifier(); 044 045 // Search for minimum training set error 046 final float[] minimumError = { Float.POSITIVE_INFINITY }; 047 048 final boolean[] classes = trainingSet.getClasses(); 049 final int nInstances = trainingSet.numInstances(); 050 051 // Determine total potential error 052 float totalErrorC = 0.0f; 053 for (int i = 0; i < nInstances; i++) 054 totalErrorC += _weights[i]; 055 final float totalError = totalErrorC; 056 057 // Initialise search error 058 float initialErrorC = 0.0f; 059 for (int i = 0; i < nInstances; i++) 060 initialErrorC += !classes[i] ? _weights[i] : 0.0; 061 final float initialError = initialErrorC; 062 063 // Loop over possible dimensions 064 // for (int d = 0; d < trainingSet.numFeatures(); d++) { 065 Parallel.forRange(0, trainingSet.numFeatures(), 1, new Operation<IntRange>() { 066 @Override 067 public void perform(IntRange rng) { 068 final StumpClassifier currClassifier = new StumpClassifier(); 069 currClassifier.dimension = -1; 070 currClassifier.threshold = Float.NaN; 071 currClassifier.sign = 0; 072 073 float currMinimumError = Float.POSITIVE_INFINITY; 074 075 for (int d = rng.start; d < rng.stop; d += rng.incr) { 076 // Pre-sort data-items in dimension for efficient 077 // threshold 078 // search 079 final float[] data = trainingSet.getResponses(d); 080 final int[] indices = trainingSet.getSortedIndices(d); 081 082 // Initialise search error 083 float currentError = initialError; 084 085 // Search through the sorted list to determine best 086 // threshold 087 for (int i = 0; i < nInstances - 1; i++) { 088 // Update current error 089 final int index = indices[i]; 090 if (classes[index]) 091 currentError += _weights[index]; 092 else 093 currentError -= _weights[index]; 094 095 // Check for repeated values 096 if (data[indices[i]] == data[indices[i + 1]]) 097 continue; 098 099 // Compute the test threshold - maximises the margin 100 // between 101 // potential thresholds 102 final float testThreshold = (data[indices[i]] + data[indices[i + 1]]) / 2.0f; 103 104 // Compare to current best 105 if (currentError < currMinimumError) // Good 106 // classifier 107 // with 108 // classifier.sign 109 // = 110 // +1 111 { 112 currMinimumError = currentError; 113 currClassifier.dimension = d; 114 currClassifier.threshold = testThreshold; 115 currClassifier.sign = +1; 116 } 117 if ((totalError - currentError) < currMinimumError) // Good 118 // classifier 119 // with 120 // classifier.sign 121 // = 122 // -1 123 { 124 currMinimumError = (totalError - currentError); 125 currClassifier.dimension = d; 126 currClassifier.threshold = testThreshold; 127 currClassifier.sign = -1; 128 } 129 } 130 } 131 132 synchronized (classifier) { 133 if (currMinimumError < minimumError[0]) { 134 minimumError[0] = currMinimumError; 135 classifier.dimension = currClassifier.dimension; 136 classifier.sign = currClassifier.sign; 137 classifier.threshold = currClassifier.threshold; 138 } 139 } 140 } 141 }); 142 143 return new ObjectFloatPair<StumpClassifier>(classifier, minimumError[0]); 144 } 145 } 146 147 int dimension; 148 float threshold; 149 int sign; 150 151 public boolean classify(float[] instanceFeature) { 152 return (instanceFeature[dimension] > threshold ? sign : -sign) == 1 ? true : false; 153 } 154 155 public boolean classify(float f) { 156 return (f > threshold ? sign : -sign) == 1 ? true : false; 157 } 158}