001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.gmm.retrieval; 031 032import gov.sandia.cognition.collection.CollectionUtil; 033 034import java.io.File; 035import java.io.IOException; 036import java.io.InputStream; 037import java.net.URI; 038import java.net.URL; 039import java.util.ArrayList; 040import java.util.Collections; 041import java.util.Comparator; 042import java.util.HashMap; 043import java.util.List; 044import java.util.Map; 045import java.util.concurrent.Executors; 046import java.util.concurrent.ThreadPoolExecutor; 047 048import org.apache.commons.vfs2.FileObject; 049import org.apache.commons.vfs2.FileSystemException; 050import org.kohsuke.args4j.CmdLineException; 051import org.kohsuke.args4j.CmdLineParser; 052import org.kohsuke.args4j.Option; 053import org.openimaj.data.identity.Identifiable; 054import org.openimaj.feature.CachingFeatureExtractor; 055import org.openimaj.feature.DiskCachingFeatureExtractor; 056import org.openimaj.feature.FeatureExtractor; 057import org.openimaj.feature.FeatureVector; 058import org.openimaj.feature.local.LocalFeature; 059import org.openimaj.feature.local.list.LocalFeatureList; 060import org.openimaj.image.FImage; 061import org.openimaj.image.ImageUtilities; 062import org.openimaj.image.processing.resize.ResizeProcessor; 063import org.openimaj.io.InputStreamObjectReader; 064import org.openimaj.io.ObjectReader; 065import org.openimaj.math.statistics.distribution.MixtureOfGaussians; 066import org.openimaj.math.statistics.distribution.metrics.SampledMultivariateDistanceComparator; 067import org.openimaj.ml.gmm.GaussianMixtureModelEM.CovarianceType; 068import org.openimaj.util.function.Function; 069import org.openimaj.util.function.Operation; 070import org.openimaj.util.pair.DoubleIntPair; 071import org.openimaj.util.pair.IndependentPair; 072import org.openimaj.util.pair.IntDoublePair; 073import org.openimaj.util.parallel.GlobalExecutorPool; 074import org.openimaj.util.parallel.Parallel; 075import org.openimaj.util.parallel.GlobalExecutorPool.DaemonThreadFactory; 076 077/** 078 * 079 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 080 */ 081public class UKBenchGMMExperiment { 082 private final class FImageFileObjectReader implements 083 ObjectReader<FImage, FileObject> { 084 @Override 085 public FImage read(FileObject source) throws IOException { 086 return ImageUtilities.FIMAGE_READER.read(source.getContent() 087 .getInputStream()); 088 } 089 090 @Override 091 public boolean canRead(FileObject source, String name) { 092 InputStream inputStream = null; 093 try { 094 inputStream = source.getContent().getInputStream(); 095 return ImageUtilities.FIMAGE_READER.canRead(inputStream, name); 096 } catch (FileSystemException e) { 097 } finally { 098 if (inputStream != null) { 099 try { 100 inputStream.close(); 101 } catch (IOException e) { 102 throw new RuntimeException(e); 103 } 104 } 105 } 106 return false; 107 } 108 } 109 110 private final class URLFileObjectReader implements 111 ObjectReader<URL, FileObject> { 112 @Override 113 public URL read(FileObject source) throws IOException { 114 return source.getURL(); 115 } 116 117 @Override 118 public boolean canRead(FileObject source, String name) { 119 try { 120 return (source.getURL() != null); 121 } catch (FileSystemException e) { 122 return false; 123 } 124 } 125 } 126 127 private static final class IRecordWrapper<A, B> implements 128 Function<UKBenchGMMExperiment.IRecord<A>, B> { 129 Function<A, B> inner; 130 131 public IRecordWrapper(Function<A, B> extract) { 132 this.inner = extract; 133 } 134 135 @Override 136 public B apply(IRecord<A> in) { 137 return inner.apply(in.image); 138 } 139 140 public static <A, B> Function<IRecord<A>, B> wrap(Function<A, B> extract) { 141 return new IRecordWrapper<A, B>(extract); 142 } 143 } 144 145 private static class IRecord<IMAGE> implements Identifiable { 146 147 private String id; 148 private IMAGE image; 149 150 public IRecord(String id, IMAGE image) { 151 this.id = id; 152 this.image = image; 153 } 154 155 @Override 156 public String getID() { 157 return this.id; 158 } 159 160 public static <A> IRecord<A> wrap(String id, A payload) { 161 return new IRecord<A>(id, payload); 162 } 163 164 } 165 166 private static final class IRecordReader<IMAGE> implements 167 ObjectReader<IRecord<IMAGE>, FileObject> { 168 ObjectReader<IMAGE, FileObject> reader; 169 170 public IRecordReader(ObjectReader<IMAGE, FileObject> reader) { 171 this.reader = reader; 172 } 173 174 @Override 175 public IRecord<IMAGE> read(FileObject source) throws IOException { 176 String name = source.getName().getBaseName(); 177 IMAGE image = reader.read(source); 178 return new IRecord<IMAGE>(name, image); 179 } 180 181 @Override 182 public boolean canRead(FileObject source, String name) { 183 return reader.canRead(source, name); 184 } 185 } 186 187 private String ukbenchRoot = "/Users/ss/Experiments/ukbench"; 188 private ResizeProcessor resize; 189 private UKBenchGroupDataset<IRecord<URL>> dataset; 190 private FeatureExtractor<MixtureOfGaussians,IRecord<URL>> gmmExtract; 191 final SampledMultivariateDistanceComparator comp = new SampledMultivariateDistanceComparator(); 192 193 public UKBenchGMMExperiment() { 194 setup(); 195 } 196 197 public UKBenchGMMExperiment(String root) { 198 this.ukbenchRoot = root; 199 setup(); 200 } 201 202 private void setup() { 203 this.dataset = new UKBenchGroupDataset<IRecord<URL>>( 204 ukbenchRoot + "/full", 205 // new IRecordReader<FImage>(new FImageFileObjectReader()) 206 new IRecordReader<URL>(new URLFileObjectReader())); 207 208 resize = new ResizeProcessor(640, 480); 209 210 Function<URL, MixtureOfGaussians> combined = new Function<URL, MixtureOfGaussians>() { 211 212 @Override 213 public MixtureOfGaussians apply(URL in) { 214 215 final DSiftFeatureExtractor feature = new DSiftFeatureExtractor(); 216 final GMMFromFeatures gmmFunc = new GMMFromFeatures(3,CovarianceType.Diagonal); 217 System.out.println("... resize"); 218 FImage process = null; 219 try { 220 process = ImageUtilities.readF(in).process(resize); 221 } catch (IOException e) { 222 throw new RuntimeException(e); 223 } 224 System.out.println("... dsift"); 225 LocalFeatureList<? extends LocalFeature<?, ? extends FeatureVector>> apply = feature 226 .apply(process); 227 System.out.println("... gmm"); 228 return gmmFunc.apply(apply); 229 } 230 231 }; 232 this.gmmExtract = new CachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>( 233 new DiskCachingFeatureExtractor<MixtureOfGaussians, IRecord<URL>>( 234 new File(ukbenchRoot + "/gmm/dsift"), 235 FeatureExtractionFunction.wrap(IRecordWrapper.wrap(combined))) 236 ); 237 } 238 239 static class UKBenchGMMExperimentOptions { 240 @Option(name = "--input", aliases = "-i", required = true, usage = "Input location", metaVar = "STRING") 241 String input = null; 242 243 @Option(name = "--pre-extract-all", aliases = "-a", required = false, usage = "Preextract all", metaVar = "BOOLEAN") 244 boolean preextract = false; 245 246 @Option(name = "--object", aliases = "-obj", required = false, usage = "Object", metaVar = "Integer") 247 int object = -1; 248 249 @Option(name = "--image", aliases = "-img", required = false, usage = "Image", metaVar = "Integer") 250 int image = -1; 251 } 252 253 static class ObjectRecord extends IndependentPair<Integer, IRecord<URL>> { 254 255 public ObjectRecord(Integer obj1, IRecord<URL> obj2) { 256 super(obj1, obj2); 257 } 258 259 } 260 261 /** 262 * @param args 263 * @throws IOException 264 * @throws CmdLineException 265 */ 266 public static void main(String[] args) throws IOException, CmdLineException { 267 UKBenchGMMExperimentOptions opts = new UKBenchGMMExperimentOptions(); 268 final CmdLineParser parser = new CmdLineParser(opts); 269 parser.parseArgument(args); 270 final UKBenchGMMExperiment exp = new UKBenchGMMExperiment(opts.input); 271 if (opts.preextract){ 272 System.out.println("Preloading all ukbench features..."); 273 exp.extractGroupGaussians(); 274 } 275 276 if(opts.object == -1 || opts.image == -1){ 277 exp.applyToEachGroup(new Operation<UKBenchListDataset<IRecord<URL>>>() { 278 279 @Override 280 public void perform(UKBenchListDataset<IRecord<URL>> group) { 281 int object = group.getObject(); 282 for (int i = 0; i < group.size(); i++) { 283 double score = exp.score(object, i); 284 System.out.printf("Object %d, image %d, score: %2.2f\n",object,i,score); 285 } 286 } 287 }); 288 } else { 289 double score = exp.score(opts.object, opts.image); 290 System.out.printf("Object %d, image %d, score: %2.2f\n",opts.object,opts.image,score); 291 } 292 } 293 294 protected MixtureOfGaussians extract(IRecord<URL> item) { 295 return this.gmmExtract.extractFeature(item); 296 } 297 298 private void applyToEachGroup(Operation<UKBenchListDataset<IRecord<URL>>> operation) { 299 for (int i = 0; i < this.dataset.size(); i++) { 300 operation.perform(this.dataset.get(i)); 301 } 302 303 } 304 305 private void applyToEachImage(Operation<ObjectRecord> operation) { 306 for (int i = 0; i < this.dataset.size(); i++) { 307 UKBenchListDataset<IRecord<URL>> ukBenchListDataset = this.dataset.get(i); 308 for (IRecord<URL> iRecord : ukBenchListDataset) { 309 operation.perform(new ObjectRecord(i, iRecord)); 310 } 311 } 312 } 313 314 public double score(int object, int image) { 315 System.out.printf("Scoring Object %d, Image %d\n",object,image); 316 IRecord<URL> item = this.dataset.get(object).get(image); 317 final MixtureOfGaussians thisGMM = extract(item); 318 final List<IntDoublePair> scored = new ArrayList<IntDoublePair>(); 319 applyToEachImage(new Operation<UKBenchGMMExperiment.ObjectRecord>() { 320 321 @Override 322 public void perform(ObjectRecord object) { 323 MixtureOfGaussians otherGMM = extract(object.getSecondObject()); 324 325 double distance = comp.compare(thisGMM, otherGMM); 326 scored.add(IntDoublePair.pair(object.firstObject(), distance)); 327 if(scored.size() % 200 == 0){ 328 System.out.printf("Loaded: %2.1f%%\n", 100 * (float)scored.size() / (dataset.size()*4)); 329 } 330 } 331 }); 332 333 Collections.sort(scored, new Comparator<IntDoublePair>(){ 334 335 @Override 336 public int compare(IntDoublePair o1, IntDoublePair o2) { 337 return -Double.compare(o1.second, o2.second); 338 } 339 340 }); 341 double good = 0; 342 for (int i = 0; i < 4; i++) { 343 if(scored.get(i).first == object) good+=1; 344 } 345 return good/4f; 346 } 347 348 /** 349 * @return the mixture of gaussians for each group 350 */ 351 public Map<Integer, List<MixtureOfGaussians>> extractGroupGaussians() { 352 final Map<Integer, List<MixtureOfGaussians>> groups = new HashMap<Integer, List<MixtureOfGaussians>>(); 353 ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors 354 .newFixedThreadPool(1, 355 new DaemonThreadFactory()); 356 final double TOTAL = this.dataset.size() * 4; 357 Parallel.forIndex(0, this.dataset.size(), 1, new Operation<Integer>() { 358 359 @Override 360 public void perform(Integer i) { 361 groups.put(i, extractGroupGaussians(i)); 362 if(groups.size() % 200 == 0){ 363 System.out.printf("Loaded: %2.1f%%\n", 100 * groups.size() * 4 / TOTAL); 364 } 365 } 366 }, pool); 367 368 return groups; 369 } 370 371 public List<MixtureOfGaussians> extractGroupGaussians(int i) { 372 return this.extractGroupGaussians(this.dataset.get(i)); 373 } 374 375 public List<MixtureOfGaussians> extractGroupGaussians( UKBenchListDataset<IRecord<URL>> ukbenchObject) { 376 List<MixtureOfGaussians> gaussians = new ArrayList<MixtureOfGaussians>(); 377 int i = 0; 378 for (IRecord<URL> imageURL : ukbenchObject) { 379 MixtureOfGaussians gmm = gmmExtract.extractFeature(imageURL); 380 gaussians.add(gmm); 381 } 382 return gaussians; 383 } 384 385}