001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.knn;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.util.comparator.DistanceComparator;
036import org.openimaj.util.pair.IntFloatPair;
037import org.openimaj.util.queue.BoundedPriorityQueue;
038
039/**
040 * Exact (brute-force) k-nearest-neighbour implementation for objects with a
041 * compatible {@link DistanceComparator}.
042 * 
043 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
044 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
045 * 
046 * @param <T>
047 *            Type of object being compared.
048 */
049public class ObjectNearestNeighboursExact<T> extends ObjectNearestNeighbours<T>
050                implements
051                IncrementalNearestNeighbours<T, float[], IntFloatPair>
052{
053        protected final List<T> pnts;
054
055        /**
056         * Construct the {@link ObjectNearestNeighboursExact} over the provided
057         * dataset with the given distance function.
058         * <p>
059         * Note: If the distance function provides similarities rather than
060         * distances they are automatically inverted.
061         * 
062         * @param pnts
063         *            the dataset
064         * @param distance
065         *            the distance function
066         */
067        public ObjectNearestNeighboursExact(final List<T> pnts, final DistanceComparator<? super T> distance) {
068                super(distance);
069                this.pnts = pnts;
070        }
071
072        /**
073         * Construct any empty {@link ObjectNearestNeighboursExact} with the given
074         * distance function.
075         * <p>
076         * Note: If the distance function provides similarities rather than
077         * distances they are automatically inverted.
078         * 
079         * @param distance
080         *            the distance function
081         */
082        public ObjectNearestNeighboursExact(final DistanceComparator<T> distance) {
083                super(distance);
084                this.pnts = new ArrayList<T>();
085        }
086
087        @Override
088        public void searchNN(final T[] qus, int[] indices, float[] distances) {
089                final int N = qus.length;
090
091                final BoundedPriorityQueue<IntFloatPair> queue =
092                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
093
094                // prepare working data
095                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
096                list.add(new IntFloatPair());
097                list.add(new IntFloatPair());
098
099                for (int n = 0; n < N; ++n) {
100                        final List<IntFloatPair> result = search(qus[n], queue, list);
101
102                        final IntFloatPair p = result.get(0);
103                        indices[n] = p.first;
104                        distances[n] = p.second;
105                }
106        }
107
108        @Override
109        public void searchKNN(final T[] qus, int K, int[][] indices, float[][] distances) {
110                // Fix for when the user asks for too many points.
111                K = Math.min(K, pnts.size());
112
113                final int N = qus.length;
114
115                final BoundedPriorityQueue<IntFloatPair> queue =
116                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
117
118                // prepare working data
119                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
120                for (int i = 0; i < K + 1; i++) {
121                        list.add(new IntFloatPair());
122                }
123
124                // search on each query
125                for (int n = 0; n < N; ++n) {
126                        final List<IntFloatPair> result = search(qus[n], queue, list);
127
128                        for (int k = 0; k < K; ++k) {
129                                final IntFloatPair p = result.get(k);
130                                indices[n][k] = p.first;
131                                distances[n][k] = p.second;
132                        }
133                }
134        }
135
136        @Override
137        public void searchNN(final List<T> qus, int[] indices, float[] distances) {
138                final int N = qus.size();
139
140                final BoundedPriorityQueue<IntFloatPair> queue =
141                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
142
143                // prepare working data
144                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
145                list.add(new IntFloatPair());
146                list.add(new IntFloatPair());
147
148                for (int n = 0; n < N; ++n) {
149                        final List<IntFloatPair> result = search(qus.get(n), queue, list);
150
151                        final IntFloatPair p = result.get(0);
152                        indices[n] = p.first;
153                        distances[n] = p.second;
154                }
155        }
156
157        @Override
158        public void searchKNN(final List<T> qus, int K, int[][] indices, float[][] distances) {
159                // Fix for when the user asks for too many points.
160                K = Math.min(K, pnts.size());
161
162                final int N = qus.size();
163
164                final BoundedPriorityQueue<IntFloatPair> queue =
165                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
166
167                // prepare working data
168                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
169                for (int i = 0; i < K + 1; i++) {
170                        list.add(new IntFloatPair());
171                }
172
173                // search on each query
174                for (int n = 0; n < N; ++n) {
175                        final List<IntFloatPair> result = search(qus.get(n), queue, list);
176
177                        for (int k = 0; k < K; ++k) {
178                                final IntFloatPair p = result.get(k);
179                                indices[n][k] = p.first;
180                                distances[n][k] = p.second;
181                        }
182                }
183        }
184
185        @Override
186        public List<IntFloatPair> searchKNN(T query, int K) {
187                // Fix for when the user asks for too many points.
188                K = Math.min(K, pnts.size());
189
190                final BoundedPriorityQueue<IntFloatPair> queue =
191                                new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
192
193                // prepare working data
194                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1);
195                for (int i = 0; i < K + 1; i++) {
196                        list.add(new IntFloatPair());
197                }
198
199                // search
200                return search(query, queue, list);
201        }
202
203        @Override
204        public IntFloatPair searchNN(final T query) {
205                final BoundedPriorityQueue<IntFloatPair> queue =
206                                new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
207
208                // prepare working data
209                final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2);
210                list.add(new IntFloatPair());
211                list.add(new IntFloatPair());
212
213                return search(query, queue, list).get(0);
214        }
215
216        private List<IntFloatPair> search(T query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results)
217        {
218                IntFloatPair wp = null;
219
220                // reset all values in the queue to MAX, -1
221                for (final IntFloatPair p : results) {
222                        p.second = Float.MAX_VALUE;
223                        p.first = -1;
224                        wp = queue.offerItem(p);
225                }
226
227                // perform the search
228                final int size = this.pnts.size();
229                for (int i = 0; i < size; i++) {
230                        wp.second = ObjectNearestNeighbours.distanceFunc(distance, query, pnts.get(i));
231                        wp.first = i;
232                        wp = queue.offerItem(wp);
233                }
234
235                return queue.toOrderedListDestructive();
236        }
237
238        @Override
239        public int size() {
240                return this.pnts.size();
241        }
242
243        @Override
244        public int[] addAll(final List<T> d) {
245                final int[] indexes = new int[d.size()];
246
247                for (int i = 0; i < indexes.length; i++) {
248                        indexes[i] = this.add(d.get(i));
249                }
250
251                return indexes;
252        }
253
254        @Override
255        public int add(final T o) {
256                final int ret = this.pnts.size();
257                this.pnts.add(o);
258                return ret;
259        }
260}