001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.image.objectdetection.haar.training;
031
032import java.io.File;
033import java.io.FileOutputStream;
034import java.io.IOException;
035import java.io.ObjectOutputStream;
036import java.util.ArrayList;
037import java.util.List;
038
039import org.openimaj.image.FImage;
040import org.openimaj.image.ImageUtilities;
041import org.openimaj.image.analysis.algorithm.SummedSqTiltAreaTable;
042import org.openimaj.image.objectdetection.haar.HaarFeature;
043import org.openimaj.image.objectdetection.haar.HaarFeatureClassifier;
044import org.openimaj.image.objectdetection.haar.Stage;
045import org.openimaj.image.objectdetection.haar.StageTreeClassifier;
046import org.openimaj.image.objectdetection.haar.ValueClassifier;
047import org.openimaj.io.IOUtils;
048import org.openimaj.util.pair.ObjectFloatPair;
049
050public class Testing {
051        List<HaarFeature> features;
052        List<SummedSqTiltAreaTable> positive = new ArrayList<SummedSqTiltAreaTable>();
053        List<SummedSqTiltAreaTable> negative = new ArrayList<SummedSqTiltAreaTable>();
054
055        void createFeatures(int width, int height) {
056                features = HaarFeatureType.generateFeatures(width, height, HaarFeatureType.CORE);
057
058                final float invArea = 1f / ((width - 2) * (height - 2));
059                for (final HaarFeature f : features) {
060                        f.setScale(1, invArea);
061                }
062        }
063
064        // void loadPositive(boolean tilted) throws IOException {
065        // final String base = "/Users/jsh2/Data/att_faces/s%d/%d.pgm";
066        //
067        // for (int j = 1; j <= 40; j++) {
068        // for (int i = 1; i <= 10; i++) {
069        // final File file = new File(String.format(base, j, i));
070        //
071        // FImage img = ImageUtilities.readF(file);
072        // img = img.extractCenter(50, 50);
073        // img = ResizeProcessor.resample(img, 19, 19);
074        // positive.add(new SummedSqTiltAreaTable(img, tilted));
075        // }
076        // }
077        // }
078        //
079        // void loadNegative(boolean tilted) throws IOException {
080        // final File dir = new File(
081        // "/Volumes/Raid/face_databases/haartraining/tutorial-haartraining.googlecode.com/svn/trunk/data/negatives/");
082        //
083        // for (final File f : dir.listFiles()) {
084        // if (f.getName().endsWith(".jpg")) {
085        // FImage img = ImageUtilities.readF(f);
086        //
087        // final int minwh = Math.min(img.width, img.height);
088        //
089        // img = img.extractCenter(minwh, minwh);
090        // img = ResizeProcessor.resample(img, 19, 19);
091        //
092        // negative.add(new SummedSqTiltAreaTable(img, tilted));
093        // }
094        // }
095        // }
096
097        void loadImage(File image, List<SummedSqTiltAreaTable> sats, boolean
098                        tilted) throws IOException
099        {
100                final FImage img = ImageUtilities.readF(image);
101
102                sats.add(new SummedSqTiltAreaTable(img, false));
103        }
104
105        void loadPositive(boolean tilted) throws IOException {
106                for (final File file : new File("/Volumes/Raid/face_databases/cbcl-faces/train/face").listFiles()) {
107                        if (file.getName().endsWith(".pgm")) {
108                                loadImage(file, positive, tilted);
109                        }
110                }
111        }
112
113        void loadNegative(boolean tilted) throws IOException {
114                for (final File file : new File("/Volumes/Raid/face_databases/cbcl-faces/train/non-face").listFiles()) {
115                        if (file.getName().endsWith(".pgm")) {
116                                loadImage(file, negative, tilted);
117                        }
118                }
119        }
120
121        void perform() throws IOException {
122                System.out.println("Creating feature set");
123                createFeatures(19, 19);
124
125                System.out.println("Loading positive images and computing SATs");
126                loadPositive(false);
127
128                System.out.println("Loading negative images and computing SATs");
129                loadNegative(false);
130
131                System.out.println("+ve: " + positive.size());
132                System.out.println("-ve: " + negative.size());
133                System.out.println("features: " + features.size());
134
135                System.out.println("Computing cached feature sets");
136                final CachedTrainingData data = new CachedTrainingData(positive, negative, features);
137
138                System.out.println("Starting Training");
139                final AdaBoost boost = new AdaBoost();
140                final List<ObjectFloatPair<StumpClassifier>> ensemble = boost.learn(data, 100);
141
142                System.out.println("Training complete. Ensemble has " + ensemble.size() + " classifiers.");
143
144                for (float threshold = 3; threshold >= -3; threshold -= 0.1f) {
145                        System.out.println("Threshold = " + threshold);
146                        boost.printClassificationQuality(data, ensemble, threshold);
147                }
148
149                final HaarFeatureClassifier[] trees = new HaarFeatureClassifier[ensemble.size()];
150
151                for (int i = 0; i < trees.length; i++) {
152                        final ObjectFloatPair<StumpClassifier> wc = ensemble.get(i);
153                        final StumpClassifier c = wc.first;
154                        final float alpha = wc.second;
155                        final float threshold = c.threshold;
156                        final float leftValue = c.sign > 0 ? -alpha : alpha; // right way
157                                                                                                                                        // around???
158                        final HaarFeature feature = features.get(c.dimension);
159
160                        final ValueClassifier left = new ValueClassifier(leftValue);
161                        final ValueClassifier right = new ValueClassifier(-leftValue);
162
163                        trees[i] = new HaarFeatureClassifier(feature, threshold, left, right);
164                }
165
166                final Stage root = new Stage(0, trees, null, null);
167                final StageTreeClassifier classifier = new StageTreeClassifier(19, 19, "test cascade", false, root);
168
169                final ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("test-classifier.bin")));
170                IOUtils.write(classifier, oos);
171                oos.close();
172        }
173
174        public static void main(String[] args) throws IOException {
175                new Testing().perform();
176        }
177}