001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.gmm.retrieval;
031
032import gov.sandia.cognition.collection.CollectionUtil;
033
034import java.io.File;
035import java.io.IOException;
036import java.io.InputStream;
037import java.net.URI;
038import java.net.URL;
039import java.util.ArrayList;
040import java.util.Collections;
041import java.util.Comparator;
042import java.util.HashMap;
043import java.util.List;
044import java.util.Map;
045import java.util.concurrent.Executors;
046import java.util.concurrent.ThreadPoolExecutor;
047
048import org.apache.commons.vfs2.FileObject;
049import org.apache.commons.vfs2.FileSystemException;
050import org.kohsuke.args4j.CmdLineException;
051import org.kohsuke.args4j.CmdLineParser;
052import org.kohsuke.args4j.Option;
053import org.openimaj.data.identity.Identifiable;
054import org.openimaj.feature.CachingFeatureExtractor;
055import org.openimaj.feature.DiskCachingFeatureExtractor;
056import org.openimaj.feature.FeatureExtractor;
057import org.openimaj.feature.FeatureVector;
058import org.openimaj.feature.local.LocalFeature;
059import org.openimaj.feature.local.list.LocalFeatureList;
060import org.openimaj.image.FImage;
061import org.openimaj.image.ImageUtilities;
062import org.openimaj.image.processing.resize.ResizeProcessor;
063import org.openimaj.io.InputStreamObjectReader;
064import org.openimaj.io.ObjectReader;
065import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
066import org.openimaj.math.statistics.distribution.metrics.SampledMultivariateDistanceComparator;
067import org.openimaj.ml.gmm.GaussianMixtureModelEM.CovarianceType;
068import org.openimaj.util.function.Function;
069import org.openimaj.util.function.Operation;
070import org.openimaj.util.pair.DoubleIntPair;
071import org.openimaj.util.pair.IndependentPair;
072import org.openimaj.util.pair.IntDoublePair;
073import org.openimaj.util.parallel.GlobalExecutorPool;
074import org.openimaj.util.parallel.Parallel;
075import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory;
076
077/**
078 * 
079 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
080 */
081public class UKBenchGMMExperiment {
082        private final class FImageFileObjectReader implements
083                        ObjectReader<FImage, FileObject> {
084                @Override
085                public FImage read(FileObject source) throws IOException {
086                        return ImageUtilities.FIMAGE_READER.read(source.getContent()
087                                        .getInputStream());
088                }
089
090                @Override
091                public boolean canRead(FileObject source, String name) {
092                        InputStream inputStream = null;
093                        try {
094                                inputStream = source.getContent().getInputStream();
095                                return ImageUtilities.FIMAGE_READER.canRead(inputStream, name);
096                        } catch (FileSystemException e) {
097                        } finally {
098                                if (inputStream != null) {
099                                        try {
100                                                inputStream.close();
101                                        } catch (IOException e) {
102                                                throw new RuntimeException(e);
103                                        }
104                                }
105                        }
106                        return false;
107                }
108        }
109
110        private final class URLFileObjectReader implements
111                        ObjectReader<URL, FileObject> {
112                @Override
113                public URL read(FileObject source) throws IOException {
114                        return source.getURL();
115                }
116
117                @Override
118                public boolean canRead(FileObject source, String name) {
119                        try {
120                                return (source.getURL() != null);
121                        } catch (FileSystemException e) {
122                                return false;
123                        }
124                }
125        }
126
127        private static final class IRecordWrapper<A, B> implements
128                        Function<UKBenchGMMExperiment.IRecord<A>, B> {
129                Function<A, B> inner;
130
131                public IRecordWrapper(Function<A, B> extract) {
132                        this.inner = extract;
133                }
134
135                @Override
136                public B apply(IRecord<A> in) {
137                        return inner.apply(in.image);
138                }
139
140                public static <A, B> Function<IRecord<A>, B> wrap(Function<A, B> extract) {
141                        return new IRecordWrapper<A, B>(extract);
142                }
143        }
144
145        private static class IRecord<IMAGE> implements Identifiable {
146
147                private String id;
148                private IMAGE image;
149
150                public IRecord(String id, IMAGE image) {
151                        this.id = id;
152                        this.image = image;
153                }
154
155                @Override
156                public String getID() {
157                        return this.id;
158                }
159
160                public static <A> IRecord<A> wrap(String id, A payload) {
161                        return new IRecord<A>(id, payload);
162                }
163
164        }
165
166        private static final class IRecordReader<IMAGE> implements
167                        ObjectReader<IRecord<IMAGE>, FileObject> {
168                ObjectReader<IMAGE, FileObject> reader;
169
170                public IRecordReader(ObjectReader<IMAGE, FileObject> reader) {
171                        this.reader = reader;
172                }
173
174                @Override
175                public IRecord<IMAGE> read(FileObject source) throws IOException {
176                        String name = source.getName().getBaseName();
177                        IMAGE image = reader.read(source);
178                        return new IRecord<IMAGE>(name, image);
179                }
180
181                @Override
182                public boolean canRead(FileObject source, String name) {
183                        return reader.canRead(source, name);
184                }
185        }
186
187        private String ukbenchRoot = "/Users/ss/Experiments/ukbench";
188        private ResizeProcessor resize;
189        private UKBenchGroupDataset<IRecord<URL>> dataset;
190        private FeatureExtractor<MixtureOfGaussians,IRecord<URL>> gmmExtract;
191        final SampledMultivariateDistanceComparator comp = new SampledMultivariateDistanceComparator();
192
193        public UKBenchGMMExperiment() {
194                setup();
195        }
196
197        public UKBenchGMMExperiment(String root) {
198                this.ukbenchRoot = root;
199                setup();
200        }
201
202        private void setup() {
203                this.dataset = new UKBenchGroupDataset<IRecord<URL>>(
204                                ukbenchRoot + "/full",
205                                // new IRecordReader<FImage>(new FImageFileObjectReader())
206                                new IRecordReader<URL>(new URLFileObjectReader()));
207
208                resize = new ResizeProcessor(640, 480);
209
210                Function<URL, MixtureOfGaussians> combined = new Function<URL, MixtureOfGaussians>() {
211
212                        @Override
213                        public MixtureOfGaussians apply(URL in) {
214                                
215                                final DSiftFeatureExtractor feature = new DSiftFeatureExtractor();
216                                final GMMFromFeatures gmmFunc = new GMMFromFeatures(3,CovarianceType.Diagonal);
217                                System.out.println("... resize");
218                                FImage process = null;
219                                try {
220                                        process = ImageUtilities.readF(in).process(resize);
221                                } catch (IOException e) {
222                                        throw new RuntimeException(e);
223                                }
224                                System.out.println("... dsift");
225                                LocalFeatureList<? extends LocalFeature<?, ? extends FeatureVector>> apply = feature
226                                                .apply(process);
227                                System.out.println("... gmm");
228                                return gmmFunc.apply(apply);
229                        }
230
231                };
232                this.gmmExtract = new CachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
233                                new DiskCachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>(
234                                                new File(ukbenchRoot + "/gmm/dsift"),
235                                                FeatureExtractionFunction.wrap(IRecordWrapper.wrap(combined)))
236                );
237        }
238
239        static class UKBenchGMMExperimentOptions {
240                @Option(name = "--input", aliases = "-i", required = true, usage = "Input location", metaVar = "STRING")
241                String input = null;
242
243                @Option(name = "--pre-extract-all", aliases = "-a", required = false, usage = "Preextract all", metaVar = "BOOLEAN")
244                boolean preextract = false;
245                
246                @Option(name = "--object", aliases = "-obj", required = false, usage = "Object", metaVar = "Integer")
247                int object = -1;
248                
249                @Option(name = "--image", aliases = "-img", required = false, usage = "Image", metaVar = "Integer")
250                int image = -1;
251        }
252
253        static class ObjectRecord extends IndependentPair<Integer, IRecord<URL>> {
254
255                public ObjectRecord(Integer obj1, IRecord<URL> obj2) {
256                        super(obj1, obj2);
257                }
258
259        }
260
261        /**
262         * @param args
263         * @throws IOException
264         * @throws CmdLineException
265         */
266        public static void main(String[] args) throws IOException, CmdLineException {
267                UKBenchGMMExperimentOptions opts = new UKBenchGMMExperimentOptions();
268                final CmdLineParser parser = new CmdLineParser(opts);
269                parser.parseArgument(args);
270                final UKBenchGMMExperiment exp = new UKBenchGMMExperiment(opts.input);
271                if (opts.preextract){
272                        System.out.println("Preloading all ukbench features...");
273                        exp.extractGroupGaussians();                    
274                }
275                
276                if(opts.object == -1 || opts.image == -1){                      
277                        exp.applyToEachGroup(new Operation<UKBenchListDataset<IRecord<URL>>>() {
278                                
279                                @Override
280                                public void perform(UKBenchListDataset<IRecord<URL>> group) {
281                                        int object = group.getObject();
282                                        for (int i = 0; i < group.size(); i++) {
283                                                double score = exp.score(object, i);
284                                                System.out.printf("Object %d, image %d, score: %2.2f\n",object,i,score);
285                                        }
286                                }
287                        });
288                } else {
289                        double score = exp.score(opts.object, opts.image);
290                        System.out.printf("Object %d, image %d, score: %2.2f\n",opts.object,opts.image,score);
291                }
292        }
293
294        protected MixtureOfGaussians extract(IRecord<URL> item) {
295                return this.gmmExtract.extractFeature(item);
296        }
297
298        private void applyToEachGroup(Operation<UKBenchListDataset<IRecord<URL>>> operation) {
299                for (int i = 0; i < this.dataset.size(); i++) {
300                        operation.perform(this.dataset.get(i));
301                }
302
303        }
304
305        private void applyToEachImage(Operation<ObjectRecord> operation) {
306                for (int i = 0; i < this.dataset.size(); i++) {
307                        UKBenchListDataset<IRecord<URL>> ukBenchListDataset = this.dataset.get(i);
308                        for (IRecord<URL> iRecord : ukBenchListDataset) {
309                                operation.perform(new ObjectRecord(i, iRecord));
310                        }
311                }
312        }
313        
314        public double score(int object, int image) {
315                System.out.printf("Scoring Object %d, Image %d\n",object,image);
316                IRecord<URL> item = this.dataset.get(object).get(image);
317                final MixtureOfGaussians thisGMM = extract(item);
318                final List<IntDoublePair> scored = new ArrayList<IntDoublePair>();
319                applyToEachImage(new Operation<UKBenchGMMExperiment.ObjectRecord>() {
320
321                        @Override
322                        public void perform(ObjectRecord object) {
323                                MixtureOfGaussians otherGMM = extract(object.getSecondObject());
324                                
325                                double distance = comp.compare(thisGMM, otherGMM);
326                                scored.add(IntDoublePair.pair(object.firstObject(), distance));
327                                if(scored.size() % 200 == 0){
328                                        System.out.printf("Loaded: %2.1f%%\n", 100 * (float)scored.size() / (dataset.size()*4));
329                                }
330                        }
331                });
332                
333                Collections.sort(scored, new Comparator<IntDoublePair>(){
334
335                        @Override
336                        public int compare(IntDoublePair o1, IntDoublePair o2) {
337                                return -Double.compare(o1.second, o2.second);
338                        }
339                        
340                });
341                double good = 0;
342                for (int i = 0; i < 4; i++) {
343                        if(scored.get(i).first == object) good+=1; 
344                }
345                return good/4f;
346        }
347
348        /**
349         * @return the mixture of gaussians for each group
350         */
351        public Map<Integer, List<MixtureOfGaussians>> extractGroupGaussians() {
352                final Map<Integer, List<MixtureOfGaussians>> groups = new HashMap<Integer, List<MixtureOfGaussians>>();
353                ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors
354                                .newFixedThreadPool(1,
355                                                new DaemonThreadFactory());
356                final double TOTAL = this.dataset.size() * 4;
357                Parallel.forIndex(0, this.dataset.size(), 1, new Operation<Integer>() {
358
359                        @Override
360                        public void perform(Integer i) {
361                                groups.put(i, extractGroupGaussians(i));
362                                if(groups.size() % 200 == 0){
363                                        System.out.printf("Loaded: %2.1f%%\n", 100 * groups.size() * 4 / TOTAL);
364                                }
365                        }
366                }, pool);
367
368                return groups;
369        }
370
371        public List<MixtureOfGaussians> extractGroupGaussians(int i) {
372                return this.extractGroupGaussians(this.dataset.get(i));
373        }
374
375        public List<MixtureOfGaussians> extractGroupGaussians( UKBenchListDataset<IRecord<URL>> ukbenchObject) {
376                List<MixtureOfGaussians> gaussians = new ArrayList<MixtureOfGaussians>();
377                int i = 0;
378                for (IRecord<URL> imageURL : ukbenchObject) {
379                        MixtureOfGaussians gmm = gmmExtract.extractFeature(imageURL);
380                        gaussians.add(gmm);
381                }
382                return gaussians;
383        }
384
385}