001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.demos;
031
032import java.io.File;
033import java.io.FileNotFoundException;
034import java.io.IOException;
035import java.util.ArrayList;
036import java.util.List;
037import java.util.Random;
038import java.util.Scanner;
039
040import org.openimaj.feature.FloatFV;
041import org.openimaj.io.IOUtils;
042import org.openimaj.ml.linear.projection.LargeMarginDimensionalityReduction;
043
044import Jama.Matrix;
045
046import com.jmatio.io.MatFileReader;
047import com.jmatio.types.MLSingle;
048
049public class FVFWExperiment {
050        // private static final String FOLDER =
051        // "lfw-centre-affine-pdsift-pca64-augm-fv512/";
052        // private static final String FOLDER = "lfw-centre-affine-matlab-fisher/";
053        private static final String FOLDER = "matlab-fvs/";
054
055        static class FacePair {
056                boolean same;
057                File firstFV;
058                File secondFV;
059
060                public FacePair(File first, File second, boolean same) {
061                        this.firstFV = first;
062                        this.secondFV = second;
063                        this.same = same;
064                }
065
066                FloatFV loadFirst() throws IOException {
067                        return IOUtils.read(firstFV, FloatFV.class);
068                }
069
070                FloatFV loadSecond() throws IOException {
071                        return IOUtils.read(secondFV, FloatFV.class);
072                }
073        }
074
075        static class Subset {
076                List<FacePair> testPairs = new ArrayList<FacePair>();
077                List<FacePair> trainingPairs = new ArrayList<FacePair>();
078        }
079
080        static List<Subset> loadSubsets() throws IOException {
081                final List<Subset> subsets = new ArrayList<Subset>();
082
083                for (int i = 0; i < 10; i++)
084                        subsets.add(new Subset());
085
086                loadPairs(new File("/Users/jon/Data/lfw/pairs.txt"), subsets);
087                loadPeople(new File("/Users/jon/Data/lfw/people.txt"), subsets);
088
089                return subsets;
090        }
091
092        private static void loadPairs(File file, List<Subset> subsets) throws FileNotFoundException {
093                final Scanner sc = new Scanner(file);
094
095                final int nsets = sc.nextInt();
096                final int nhpairs = sc.nextInt();
097
098                if (nsets != 10 || nhpairs != 300)
099                        throw new RuntimeException();
100
101                for (int s = 0; s < 10; s++) {
102                        for (int i = 0; i < 300; i++) {
103                                final String name = sc.next();
104                                final int firstIdx = sc.nextInt();
105                                final int secondIdx = sc.nextInt();
106
107                                final File first = new File(file.getParentFile(), FOLDER + name
108                                                + "/" + name + String.format("_%04d.bin", firstIdx));
109                                final File second = new File(file.getParentFile(), FOLDER + name
110                                                + "/" + name + String.format("_%04d.bin", secondIdx));
111
112                                subsets.get(s).testPairs.add(new FacePair(first, second, true));
113                        }
114
115                        for (int i = 0; i < 300; i++) {
116                                final String firstName = sc.next();
117                                final int firstIdx = sc.nextInt();
118                                final String secondName = sc.next();
119                                final int secondIdx = sc.nextInt();
120
121                                final File first = new File(file.getParentFile(), FOLDER
122                                                + firstName
123                                                + "/" + firstName + String.format("_%04d.bin", firstIdx));
124                                final File second = new File(file.getParentFile(), FOLDER
125                                                + secondName
126                                                + "/" + secondName + String.format("_%04d.bin", secondIdx));
127
128                                subsets.get(s).testPairs.add(new FacePair(first, second, false));
129                        }
130                }
131
132                sc.close();
133        }
134
135        private static void loadPeople(File file, List<Subset> subsets) throws FileNotFoundException {
136                final Scanner sc = new Scanner(file);
137
138                final int nsets = sc.nextInt();
139
140                if (nsets != 10)
141                        throw new RuntimeException();
142
143                for (int s = 0; s < 10; s++) {
144                        final int nnames = sc.nextInt();
145                        final List<File> files = new ArrayList<File>(nnames);
146                        for (int i = 0; i < nnames; i++) {
147                                final String name = sc.next();
148                                final int numPeople = sc.nextInt();
149                                for (int j = 1; j <= numPeople; j++) {
150                                        final File f = new File(file.getParentFile(), FOLDER + name
151                                                        + "/" + name + String.format("_%04d.bin", j));
152
153                                        files.add(f);
154                                }
155                        }
156
157                        for (int i = 0; i < files.size(); i++) {
158                                final File first = files.get(i);
159                                for (int j = i + 1; j < files.size(); j++) {
160                                        final File second = files.get(j);
161
162                                        final boolean same = first.getName().substring(0, first.getName().lastIndexOf("_"))
163                                                        .equals(second.getName().substring(0, second.getName().lastIndexOf("_")));
164
165                                        subsets.get(s).trainingPairs.add(new FacePair(first, second, same));
166                                        subsets.get(s).trainingPairs.add(new FacePair(second, first, same));
167                                }
168                        }
169                }
170
171                sc.close();
172        }
173
174        static Subset createExperimentalFold(List<Subset> subsets, int foldIdx) {
175                final Subset subset = new Subset();
176                // testing data is from the indexed fold
177                subset.testPairs = subsets.get(foldIdx).testPairs;
178
179                // training data is from the other folds
180                final List<FacePair> training = new ArrayList<FacePair>();
181                for (int i = 0; i < foldIdx; i++)
182                        training.addAll(subsets.get(i).trainingPairs);
183                for (int i = foldIdx + 1; i < subsets.size(); i++)
184                        training.addAll(subsets.get(i).trainingPairs);
185
186                subset.trainingPairs = reorder(training);
187
188                return subset;
189        }
190
191        private static List<FacePair> reorder(List<FacePair> training) {
192                final List<FacePair> trainingTrue = new ArrayList<FacePair>();
193                final List<FacePair> trainingFalse = new ArrayList<FacePair>();
194
195                for (final FacePair fp : training) {
196                        if (fp.same)
197                                trainingTrue.add(fp);
198                        else
199                                trainingFalse.add(fp);
200                }
201
202                resample(trainingTrue, 4000000);
203                resample(trainingFalse, 4000000);
204
205                final List<FacePair> trainingResorted = new ArrayList<FacePair>();
206                for (int i = 0; i < trainingTrue.size(); i++) {
207                        trainingResorted.add(trainingTrue.get(i));
208                        trainingResorted.add(trainingFalse.get(i));
209                }
210
211                return trainingResorted;
212        }
213
214        private static void resample(List<FacePair> pairs, int sz) {
215                final List<FacePair> oldPairs = new ArrayList<FVFWExperiment.FacePair>(sz);
216                oldPairs.addAll(pairs);
217                pairs.clear();
218
219                final Random r = new Random();
220
221                for (int i = 0; i < sz; i++) {
222                        pairs.add(oldPairs.get(r.nextInt(oldPairs.size())));
223                }
224        }
225
226        public static void main(String[] args) throws IOException {
227                final List<Subset> subsets = loadSubsets();
228                final Subset fold = createExperimentalFold(subsets, 1);
229
230                // // final LargeMarginDimensionalityReduction lmdr = new
231                // // LargeMarginDimensionalityReduction(128);
232                // final LargeMarginDimensionalityReduction lmdr = loadMatlabPCAW();
233                //
234                // final double[][] fInit = new double[1000][];
235                // final double[][] sInit = new double[1000][];
236                // final boolean[] same = new boolean[1000];
237                // for (int i = 0; i < 1000; i++) {
238                // final FacePair p =
239                // fold.trainingPairs.get(i);
240                // fInit[i] = p.loadFirst().asDoubleVector();
241                // sInit[i] = p.loadSecond().asDoubleVector();
242                // same[i] = p.same;
243                //
244                // for (int j = 0; j < fInit[i].length; j++) {
245                // if (Double.isInfinite(fInit[i][j]) || Double.isNaN(fInit[i][j]))
246                // throw new RuntimeException("" + fold.trainingPairs.get(i).firstFV);
247                // if (Double.isInfinite(sInit[i][j]) || Double.isNaN(sInit[i][j]))
248                // throw new RuntimeException("" + fold.trainingPairs.get(i).secondFV);
249                // }
250                // }
251                //
252                // System.out.println("LMDR Init");
253                // lmdr.recomputeBias(fInit, sInit, same);
254                // // lmdr.initialise(fInit, sInit, same);
255                // IOUtils.writeToFile(lmdr, new
256                // File("/Users/jon/Data/lfw/lmdr-matlabfvs-pcaw-init.bin"));
257                // // final LargeMarginDimensionalityReduction lmdr = IOUtils
258                // // .readFromFile(new File("/Users/jon/Data/lfw/lmdr-init.bin"));
259                //
260                // for (int i = 0; i < 1e6; i++) {
261                // if (i % 100 == 0)
262                // System.out.println("Iter " + i);
263                // final FacePair p = fold.trainingPairs.get(i);
264                // lmdr.step(p.loadFirst().asDoubleVector(),
265                // p.loadSecond().asDoubleVector(), p.same);
266                // }
267                // IOUtils.writeToFile(lmdr, new
268                // File("/Users/jon/Data/lfw/lmdr-matlabfvs-pcaw.bin"));
269
270                final LargeMarginDimensionalityReduction lmdr =
271                                IOUtils.readFromFile(new
272                                                File("/Users/jon/Data/lfw/lmdr-matlabfvs-pcaw.bin"));
273                // final LargeMarginDimensionalityReduction lmdr = loadMatlabLMDR();
274                // final LargeMarginDimensionalityReduction lmdr = loadMatlabPCAW();
275
276                final double[][] first = new double[fold.testPairs.size()][];
277                final double[][] second = new double[fold.testPairs.size()][];
278                final boolean[] same = new boolean[fold.testPairs.size()];
279                for (int j = 0; j < same.length; j++) {
280                        final FacePair p = fold.testPairs.get(j);
281                        first[j] = p.loadFirst().asDoubleVector();
282                        second[j] = p.loadSecond().asDoubleVector();
283                        same[j] = p.same;
284                }
285                // System.out.println("Current bias: " + lmdr.getBias());
286                // lmdr.recomputeBias(first, second, same);
287                // System.out.println("Best bias: " + lmdr.getBias());
288
289                double correct = 0;
290                double count = 0;
291                for (int j = 0; j < same.length; j++) {
292                        final boolean pred = lmdr.classify(first[j],
293                                        second[j]);
294
295                        if (pred == same[j])
296                                correct++;
297                        count++;
298                }
299                System.out.println(lmdr.getBias() + " " + (correct / count));
300        }
301
302        // private static double[] reorder(double[] in) {
303        // final double[] out = new double[in.length];
304        // final int D = 64;
305        // final int K = 512;
306        // for (int k = 0; k < K; k++) {
307        // for (int j = 0; j < D; j++) {
308        // out[k * D + j] = in[k * 2 * D + j];
309        // out[k * D + j + D * K] = in[k * 2 * D + j + D];
310        // }
311        // }
312        // return out;
313        // }
314
315        private static LargeMarginDimensionalityReduction loadMatlabLMDR() throws IOException {
316                final LargeMarginDimensionalityReduction lmdr = new LargeMarginDimensionalityReduction(128);
317
318                final MatFileReader reader = new MatFileReader(new File("/Users/jon/lmdr.mat"));
319                final MLSingle W = (MLSingle) reader.getContent().get("W");
320                final MLSingle b = (MLSingle) reader.getContent().get("b");
321
322                lmdr.setBias(b.get(0, 0));
323
324                final Matrix proj = new Matrix(W.getM(), W.getN());
325                for (int j = 0; j < W.getN(); j++) {
326                        for (int i = 0; i < W.getM(); i++) {
327                                proj.set(i, j, W.get(i, j));
328                        }
329                }
330
331                lmdr.setTransform(proj);
332
333                return lmdr;
334        }
335
336        private static LargeMarginDimensionalityReduction loadMatlabPCAW() throws IOException {
337                final LargeMarginDimensionalityReduction lmdr = new LargeMarginDimensionalityReduction(128);
338
339                final MatFileReader reader = new MatFileReader(new File("/Users/jon/pcaw.mat"));
340                final MLSingle W = (MLSingle) reader.getContent().get("proj");
341
342                lmdr.setBias(169.6264190673828);
343
344                final Matrix proj = new Matrix(W.getM(), W.getN());
345                for (int j = 0; j < W.getN(); j++) {
346                        for (int i = 0; i < W.getM(); i++) {
347                                proj.set(i, j, W.get(i, j));
348                        }
349                }
350
351                lmdr.setTransform(proj);
352
353                return lmdr;
354        }
355}