001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.model.fit;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.math.model.EstimatableModel;
036import org.openimaj.math.model.fit.residuals.ResidualCalculator;
037import org.openimaj.math.util.distance.DistanceCheck;
038import org.openimaj.math.util.distance.ThresholdDistanceCheck;
039import org.openimaj.util.CollectionSampler;
040import org.openimaj.util.UniformSampler;
041import org.openimaj.util.pair.IndependentPair;
042
043/**
044 * The RANSAC Algorithm (RANdom SAmple Consensus)
045 * <p>
046 * For fitting noisy data consisting of inliers and outliers to a model.
047 * <p>
048 * Assume: M data items required to estimate parameter x N data items in total
049 * <p>
050 * 1.) select M data items at random <br/>
051 * 2.) estimate parameter x <br/>
052 * 3.) find how many of the N data items fit (i.e. have an error less than a
053 * threshold or pass some check) x within tolerence tol, call this K <br/>
054 * 4.) if K is large enough (bigger than numItems) accept x and exit with
055 * success <br/>
056 * 5.) repeat 1..4 nIter times <br/>
057 * 6.) fail - no good x fit of data
058 * <p>
059 * In this implementation, the conditions that control the iterations are
060 * configurable. In addition, the best matching model is always stored, even if
061 * the fitData() method returns false.
062 * 
063 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
064 * 
065 * @param <I>
066 *            type of independent data
067 * @param <D>
068 *            type of dependent data
069 * @param <M>
070 *            concrete type of model learned
071 */
072public class RANSAC<I, D, M extends EstimatableModel<I, D>> implements RobustModelFitting<I, D, M> {
073        /**
074         * Interface for classes that can control RANSAC iterations
075         */
076        public static interface StoppingCondition {
077                /**
078                 * Initialise the stopping condition if necessary. Return false if the
079                 * initialisation cannot be performed and RANSAC should fail
080                 * 
081                 * @param data
082                 *            The data being fitted
083                 * @param model
084                 *            The model to fit
085                 * @return true if initialisation is successful, false otherwise.
086                 */
087                public abstract boolean init(final List<?> data, EstimatableModel<?, ?> model);
088
089                /**
090                 * Should we stop iterating and return the model?
091                 * 
092                 * @param numInliers
093                 *            number of inliers in this iteration
094                 * @return true if the model is good and iterations should stop
095                 */
096                public abstract boolean shouldStopIterations(int numInliers);
097
098                /**
099                 * Should the model be considered to fit after the final iteration has
100                 * passed?
101                 * 
102                 * @param numInliers
103                 *            number of inliers in the final model
104                 * @return true if the model fits, false otherwise
105                 */
106                public abstract boolean finalFitCondition(int numInliers);
107        }
108
109        /**
110         * Stopping condition that tests the number of matches against a threshold.
111         * If the number exceeds the threshold, then the model is considered to fit.
112         */
113        public static class NumberInliersStoppingCondition implements StoppingCondition {
114                int limit;
115
116                /**
117                 * Construct the stopping condition with the given threshold on the
118                 * number of data points which must match for a model to be considered a
119                 * fit.
120                 * 
121                 * @param limit
122                 *            the threshold
123                 */
124                public NumberInliersStoppingCondition(int limit) {
125                        this.limit = limit;
126                }
127
128                @Override
129                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
130                        if (limit < model.numItemsToEstimate()) {
131                                limit = model.numItemsToEstimate();
132                        }
133
134                        if (data.size() < limit)
135                                return false;
136                        return true;
137                }
138
139                @Override
140                public boolean shouldStopIterations(int numInliers) {
141                        return numInliers >= limit; // stop if there are more inliers than
142                                                                                // our limit
143                }
144
145                @Override
146                public boolean finalFitCondition(int numInliers) {
147                        return numInliers >= limit;
148                }
149        }
150
151        /**
152         * Stopping condition that tests the number of matches against a percentage
153         * threshold of the whole data. If the number exceeds the threshold, then
154         * the model is considered to fit.
155         */
156        public static class PercentageInliersStoppingCondition extends NumberInliersStoppingCondition {
157                double percentageLimit;
158
159                /**
160                 * Construct the stopping condition with the given percentage threshold
161                 * on the number of data points which must match for a model to be
162                 * considered a fit.
163                 * 
164                 * @param percentageLimit
165                 *            the percentage threshold
166                 */
167                public PercentageInliersStoppingCondition(double percentageLimit) {
168                        super(0);
169                        this.percentageLimit = percentageLimit;
170                }
171
172                @Override
173                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
174                        this.limit = (int) Math.rint(percentageLimit * data.size());
175                        return super.init(data, model);
176                }
177        }
178
179        /**
180         * Stopping condition that tests the number of matches against a percentage
181         * threshold of the whole data. If the number exceeds the threshold, then
182         * the model is considered to fit.
183         */
184        public static class ProbabilisticMinInliersStoppingCondition implements StoppingCondition {
185                private static final double DEFAULT_INLIER_IS_BAD_PROBABILITY = 0.1;
186                private static final double DEFAULT_PERCENTAGE_INLIERS = 0.25;
187                private double inlierIsBadProbability;
188                private double desiredErrorProbability;
189                private double percentageInliers;
190
191                private int numItemsToEstimate;
192                private int iteration = 0;
193                private int limit;
194                private int maxInliers = 0;
195                private double currentProb;
196                private int numDataItems;
197
198                /**
199                 * Default constructor.
200                 * 
201                 * @param desiredErrorProbability
202                 *            The desired error rate
203                 * @param inlierIsBadProbability
204                 *            The probability an inlier is bad
205                 * @param percentageInliers
206                 *            The percentage of inliers in the data
207                 */
208                public ProbabilisticMinInliersStoppingCondition(double desiredErrorProbability, double inlierIsBadProbability,
209                                double percentageInliers)
210                {
211                        this.desiredErrorProbability = desiredErrorProbability;
212                        this.inlierIsBadProbability = inlierIsBadProbability;
213                        this.percentageInliers = percentageInliers;
214                }
215
216                /**
217                 * Constructor with defaults for bad inlier probability and percentage
218                 * inliers.
219                 * 
220                 * @param desiredErrorProbability
221                 *            The desired error rate
222                 */
223                public ProbabilisticMinInliersStoppingCondition(double desiredErrorProbability) {
224                        this(desiredErrorProbability, DEFAULT_INLIER_IS_BAD_PROBABILITY, DEFAULT_PERCENTAGE_INLIERS);
225                }
226
227                @Override
228                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
229                        numItemsToEstimate = model.numItemsToEstimate();
230                        numDataItems = data.size();
231                        this.limit = calculateMinInliers();
232                        this.iteration = 0;
233                        this.currentProb = 1.0;
234                        this.maxInliers = 0;
235
236                        return true;
237                }
238
239                @Override
240                public boolean finalFitCondition(int numInliers) {
241                        return numInliers >= limit;
242                }
243
244                private int calculateMinInliers() {
245                        double pi, sum;
246                        int i, j;
247
248                        for (j = numItemsToEstimate + 1; j <= numDataItems; j++)
249                        {
250                                sum = 0;
251                                for (i = j; i <= numDataItems; i++)
252                                {
253                                        pi = (i - numItemsToEstimate) * Math.log(inlierIsBadProbability)
254                                                        + (numDataItems - i + numItemsToEstimate) * Math.log(1.0 - inlierIsBadProbability) +
255                                                        log_factorial(numDataItems - numItemsToEstimate) - log_factorial(i - numItemsToEstimate)
256                                                        - log_factorial(numDataItems - i);
257                                        /*
258                                         * Last three terms above are equivalent to log( n-m choose
259                                         * i-m )
260                                         */
261                                        sum += Math.exp(pi);
262                                }
263                                if (sum < desiredErrorProbability)
264                                        break;
265                        }
266                        return j;
267                }
268
269                private double log_factorial(int n) {
270                        double f = 0;
271                        int i;
272
273                        for (i = 1; i <= n; i++)
274                                f += Math.log(i);
275
276                        return f;
277                }
278
279                @Override
280                public boolean shouldStopIterations(int numInliers) {
281
282                        if (numInliers > maxInliers) {
283                                maxInliers = numInliers;
284                                percentageInliers = (double) maxInliers / numDataItems;
285
286                                // System.err.format("Updated maxInliers: %d\n", maxInliers);
287                        }
288                        currentProb = Math.pow(1.0 - Math.pow(percentageInliers, numItemsToEstimate), ++iteration);
289                        return currentProb <= this.desiredErrorProbability;
290                }
291        }
292
293        /**
294         * Stopping condition that allows the RANSAC algorithm to run until all the
295         * iterations have been exhausted. The fitData method will return true if
296         * there are at least as many inliers as datapoints required to estimate the
297         * model, and the model will be the one from the iteration that had the most
298         * inliers.
299         */
300        public static class BestFitStoppingCondition implements StoppingCondition {
301                int required;
302
303                @Override
304                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
305                        required = model.numItemsToEstimate();
306                        return true;
307                }
308
309                @Override
310                public boolean shouldStopIterations(int numInliers) {
311                        return false; // just iterate until the end
312                }
313
314                @Override
315                public boolean finalFitCondition(int numInliers) {
316                        return numInliers > required; // accept the best result as a good
317                                                                                        // fit if there are enough inliers
318                }
319        }
320
321        protected M model;
322        protected ResidualCalculator<I, D, M> errorModel;
323        protected DistanceCheck dc;
324
325        protected int nIter;
326        protected boolean improveEstimate;
327        protected List<IndependentPair<I, D>> inliers;
328        protected List<IndependentPair<I, D>> outliers;
329        protected List<IndependentPair<I, D>> bestModelInliers;
330        protected List<IndependentPair<I, D>> bestModelOutliers;
331        protected StoppingCondition stoppingCondition;
332        protected List<? extends IndependentPair<I, D>> modelConstructionData;
333        protected CollectionSampler<IndependentPair<I, D>> sampler;
334
335        /**
336         * Create a RANSAC object with uniform random sampling for creating the
337         * subsets
338         * 
339         * @param model
340         *            Model object with which to fit data
341         * @param errorModel
342         *            object to compute the error of the model
343         * @param errorThreshold
344         *            the threshold below which error is deemed acceptable for a fit
345         * @param nIterations
346         *            Maximum number of allowed iterations (L)
347         * @param stoppingCondition
348         *            the stopping condition
349         * @param impEst
350         *            True if we want to perform a final fitting of the model with
351         *            all inliers, false otherwise
352         */
353        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
354                        double errorThreshold, int nIterations,
355                        StoppingCondition stoppingCondition, boolean impEst)
356        {
357                this(model, errorModel, new ThresholdDistanceCheck(errorThreshold), nIterations, stoppingCondition, impEst);
358        }
359
360        /**
361         * Create a RANSAC object with uniform random sampling for creating the
362         * subsets
363         * 
364         * @param model
365         *            Model object with which to fit data
366         * @param errorModel
367         *            object to compute the error of the model
368         * @param dc
369         *            the distance check that tests whether a point with given error
370         *            from the error model should be considered an inlier
371         * @param nIterations
372         *            Maximum number of allowed iterations (L)
373         * @param stoppingCondition
374         *            the stopping condition
375         * @param impEst
376         *            True if we want to perform a final fitting of the model with
377         *            all inliers, false otherwise
378         */
379        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
380                        DistanceCheck dc, int nIterations,
381                        StoppingCondition stoppingCondition, boolean impEst)
382        {
383                this(model, errorModel, dc, nIterations, stoppingCondition, impEst, new UniformSampler<IndependentPair<I, D>>());
384        }
385
386        /**
387         * Create a RANSAC object
388         * 
389         * @param model
390         *            Model object with which to fit data
391         * @param errorModel
392         *            object to compute the error of the model
393         * @param errorThreshold
394         *            the threshold below which error is deemed acceptable for a fit
395         * @param nIterations
396         *            Maximum number of allowed iterations (L)
397         * @param stoppingCondition
398         *            the stopping condition
399         * @param impEst
400         *            True if we want to perform a final fitting of the model with
401         *            all inliers, false otherwise
402         * @param sampler
403         *            the sampling algorithm for selecting random subsets
404         */
405        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
406                        double errorThreshold, int nIterations,
407                        StoppingCondition stoppingCondition, boolean impEst, CollectionSampler<IndependentPair<I, D>> sampler)
408        {
409                this(model, errorModel, new ThresholdDistanceCheck(errorThreshold), nIterations, stoppingCondition, impEst,
410                                sampler);
411        }
412
413        /**
414         * Create a RANSAC object
415         * 
416         * @param model
417         *            Model object with which to fit data
418         * @param errorModel
419         *            object to compute the error of the model
420         * @param dc
421         *            the distance check that tests whether a point with given error
422         *            from the error model should be considered an inlier
423         * @param nIterations
424         *            Maximum number of allowed iterations (L)
425         * @param stoppingCondition
426         *            the stopping condition
427         * @param impEst
428         *            True if we want to perform a final fitting of the model with
429         *            all inliers, false otherwise
430         * @param sampler
431         *            the sampling algorithm for selecting random subsets
432         */
433        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
434                        DistanceCheck dc, int nIterations,
435                        StoppingCondition stoppingCondition, boolean impEst, CollectionSampler<IndependentPair<I, D>> sampler)
436        {
437                this.stoppingCondition = stoppingCondition;
438                this.model = model;
439                this.errorModel = errorModel;
440                this.dc = dc;
441                nIter = nIterations;
442                improveEstimate = impEst;
443
444                inliers = new ArrayList<IndependentPair<I, D>>();
445                outliers = new ArrayList<IndependentPair<I, D>>();
446                this.sampler = sampler;
447        }
448
449        @Override
450        public boolean fitData(final List<? extends IndependentPair<I, D>> data)
451        {
452                int l;
453                final int M = model.numItemsToEstimate();
454
455                bestModelInliers = null;
456                bestModelOutliers = null;
457
458                if (data.size() < M || !stoppingCondition.init(data, model)) {
459                        return false; // there are not enough points to create a model, or
460                                                        // init failed
461                }
462
463                sampler.setCollection(data);
464
465                for (l = 0; l < nIter; l++) {
466                        // 1
467                        final List<? extends IndependentPair<I, D>> rnd = sampler.sample(M);
468                        this.setModelConstructionData(rnd);
469
470                        // 2
471                        if (!model.estimate(rnd))
472                                continue; // bad estimate
473
474                        errorModel.setModel(model);
475
476                        // 3
477                        int K = 0;
478                        inliers.clear();
479                        outliers.clear();
480                        for (final IndependentPair<I, D> dp : data) {
481                                if (dc.check(errorModel.computeResidual(dp)))
482                                {
483                                        K++;
484                                        inliers.add(dp);
485                                } else {
486                                        outliers.add(dp);
487                                }
488                        }
489
490                        if (bestModelInliers == null || inliers.size() >= bestModelInliers.size()) {
491                                // copy
492                                bestModelInliers = new ArrayList<IndependentPair<I, D>>(inliers);
493                                bestModelOutliers = new ArrayList<IndependentPair<I, D>>(outliers);
494                        }
495
496                        // 4
497                        if (stoppingCondition.shouldStopIterations(K)) {
498                                // generate "best" fit from all the iterations
499                                inliers = bestModelInliers;
500                                outliers = bestModelOutliers;
501
502                                if (improveEstimate) {
503                                        if (inliers.size() >= model.numItemsToEstimate())
504                                                if (!model.estimate(inliers))
505                                                        return false;
506                                }
507                                final boolean stopping = stoppingCondition.finalFitCondition(inliers.size());
508                                // System.err.format("done: %b\n",stopping);
509                                return stopping;
510                        }
511                        // 5
512                        // ...repeat...
513                }
514
515                // generate "best" fit from all the iterations
516                if (bestModelInliers == null) {
517                        bestModelInliers = new ArrayList<IndependentPair<I, D>>();
518                        bestModelOutliers = new ArrayList<IndependentPair<I, D>>();
519                }
520
521                inliers = bestModelInliers;
522                outliers = bestModelOutliers;
523
524                if (bestModelInliers.size() >= M)
525                        if (!model.estimate(bestModelInliers))
526                                return false;
527
528                // 6 - fail
529                return stoppingCondition.finalFitCondition(inliers.size());
530        }
531
532        @Override
533        public List<? extends IndependentPair<I, D>> getInliers()
534        {
535                return inliers;
536        }
537
538        @Override
539        public List<? extends IndependentPair<I, D>> getOutliers()
540        {
541                return outliers;
542        }
543
544        /**
545         * @return maximum number of allowed iterations
546         */
547        public int getMaxIterations() {
548                return nIter;
549        }
550
551        /**
552         * Set the maximum number of allowed iterations
553         * 
554         * @param nIter
555         *            maximum number of allowed iterations
556         */
557        public void setMaxIterations(int nIter) {
558                this.nIter = nIter;
559        }
560
561        @Override
562        public M getModel() {
563                return model;
564        }
565
566        /**
567         * Set the underlying model being fitted
568         * 
569         * @param model
570         *            the model
571         */
572        public void setModel(M model) {
573                this.model = model;
574        }
575
576        /**
577         * @return whether RANSAC should attempt to improve the model using all
578         *         inliers as data
579         */
580        public boolean isImproveEstimate() {
581                return improveEstimate;
582        }
583
584        /**
585         * Set whether RANSAC should attempt to improve the model using all inliers
586         * as data
587         * 
588         * @param improveEstimate
589         *            should RANSAC attempt to improve the model using all inliers
590         *            as data
591         */
592        public void setImproveEstimate(boolean improveEstimate) {
593                this.improveEstimate = improveEstimate;
594        }
595
596        /**
597         * Set the data used to construct the model
598         * 
599         * @param modelConstructionData
600         */
601        public void setModelConstructionData(List<? extends IndependentPair<I, D>> modelConstructionData) {
602                this.modelConstructionData = modelConstructionData;
603        }
604
605        /**
606         * @return The data used to construct the model.
607         */
608        public List<? extends IndependentPair<I, D>> getModelConstructionData() {
609                return modelConstructionData;
610        }
611
612        @Override
613        public int numItemsToEstimate() {
614                return model.numItemsToEstimate();
615        }
616}