001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.knn;
031
032import java.io.DataInput;
033import java.io.DataOutput;
034import java.io.IOException;
035import java.io.PrintWriter;
036import java.lang.reflect.Array;
037import java.util.ArrayList;
038import java.util.Arrays;
039import java.util.Collection;
040import java.util.List;
041import java.util.PriorityQueue;
042import java.util.Scanner;
043import java.util.Stack;
044
045import org.openimaj.math.geometry.point.Coordinate;
046
047class KDNode<T extends Coordinate> {
048        int _discriminate;
049        T _point;
050        KDNode<T> _left, _right;
051
052        KDNode(T point, int discriminate) {
053                _point = point;
054                _left  = _right = null;
055                _discriminate = discriminate;
056        }
057}
058
059/**
060 * Implementation of a simple KDTree with range search.
061 * The KDTree allows fast search for points in relatively low-dimension
062 * spaces.
063 * 
064 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
065 * 
066 * @param <T> the type of Coordinate. 
067 */
068public class CoordinateKDTree<T extends Coordinate> implements CoordinateIndex<T> {
069        KDNode<T> _root;
070
071        /**
072         * Create an empty KDTree object  
073         */
074        public CoordinateKDTree() { _root = null; }
075
076        /**
077         * Create a KDTree object and populate it with the
078         * given data.  
079         * @param coords the data to populate the index with. 
080         */
081        public CoordinateKDTree(Collection<T> coords) {
082                _root = null;
083                insertAll(coords);
084        }
085        
086        /**
087         * Insert all the points from the given collection
088         * into the index.
089         * @param coords The points to add.
090         */
091        public void insertAll(Collection<T> coords) {
092                for (T c : coords)
093                        insert(c);
094        }
095        
096        /**
097         * Inserts a point into the tree, preserving the
098         * spatial ordering.
099         * @param point Point to insert.
100         */
101        @Override
102        public void insert(T point) {
103
104                if(_root == null)
105                        _root = new KDNode<T>(point, 0);
106                else {
107                        int discriminate, dimensions;
108                        KDNode<T> curNode, tmpNode;
109                        double ordinate1, ordinate2;
110
111                        curNode = _root;
112
113                        do {
114                                tmpNode = curNode;
115                                discriminate = tmpNode._discriminate;
116
117                                ordinate1 = point.getOrdinate(discriminate).doubleValue();
118                                ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue();
119
120                                if(ordinate1 > ordinate2)
121                                        curNode = tmpNode._right;
122                                else
123                                        curNode = tmpNode._left;
124                        } while(curNode != null);
125
126                        dimensions = point.getDimensions();
127
128                        if(++discriminate >= dimensions)
129                                discriminate = 0;
130
131                        if(ordinate1 > ordinate2)
132                                tmpNode._right = new KDNode<T>(point, discriminate);
133                        else
134                                tmpNode._left = new KDNode<T>(point, discriminate);
135                }
136        }
137
138        /** 
139         * Determines if a point is contained within a given
140         * k-dimensional bounding box.
141         */
142        static final boolean isContained(
143                        Coordinate point, Coordinate lower, Coordinate upper)
144        {
145                int dimensions;
146                double ordinate1, ordinate2, ordinate3;
147
148                dimensions = point.getDimensions();
149
150                for(int i = 0; i < dimensions; ++i) {
151                        ordinate1 = point.getOrdinate(i).doubleValue();
152                        ordinate2 = lower.getOrdinate(i).doubleValue();
153                        ordinate3 = upper.getOrdinate(i).doubleValue();
154
155                        if(ordinate1 < ordinate2 || ordinate1 > ordinate3)
156                                return false;
157                }
158
159                return true;
160        }
161
162        /**
163         * Searches the tree for all points contained within a 
164         * given k-dimensional bounding box and stores them in a 
165         * Collection.
166         * @param results 
167         * @param lowerExtreme 
168         * @param upperExtreme 
169         */
170        @Override
171        public void rangeSearch(Collection<T> results, Coordinate lowerExtreme, Coordinate upperExtreme)
172        {
173                KDNode<T> tmpNode;
174                Stack<KDNode<T>> stack = new Stack<KDNode<T>>();
175                int discriminate;
176                double ordinate1, ordinate2;
177
178                if(_root == null)
179                        return;
180
181                stack.push(_root);
182
183                while(!stack.empty()) {
184                        tmpNode =stack.pop();
185                        discriminate = tmpNode._discriminate;
186
187                        ordinate1 = tmpNode._point.getOrdinate(discriminate).doubleValue();
188                        ordinate2 = lowerExtreme.getOrdinate(discriminate).doubleValue();
189
190                        if(ordinate1 > ordinate2 && tmpNode._left != null)
191                                stack.push(tmpNode._left);
192
193                        ordinate2 = upperExtreme.getOrdinate(discriminate).doubleValue();
194
195                        if(ordinate1 < ordinate2 && tmpNode._right != null)
196                                stack.push(tmpNode._right);
197
198                        if(isContained(tmpNode._point, lowerExtreme, upperExtreme))
199                                results.add(tmpNode._point);
200                }
201        }
202
203        protected static final float distance(Coordinate a, Coordinate b) {
204                float s=0;
205
206                for (int i=0; i<a.getDimensions(); i++) { 
207                        final float fa = ((Number)a.getOrdinate(i)).floatValue();
208                        final float fb = ((Number)b.getOrdinate(i)).floatValue();
209                        s += (fa-fb)*(fa-fb);
210                }
211                return s;
212        }
213
214        class NNState implements Comparable<NNState> {
215                T best;
216                float bestDist;
217
218                @Override public int compareTo(NNState o) {
219                        if (bestDist < o.bestDist) return 1;
220                        if (bestDist > o.bestDist) return -1;
221                        return 0;
222                }
223                @Override
224                public String toString() { return bestDist + ""; }
225        }
226
227        /**
228         * Find the nearest neighbour. Only one neighbour will be returned - if multiple neighbours
229         * share the same location, or are equidistant, then this might not be the one you expect.
230         * @param query query coordinate
231         * @return nearest neighbour
232         */
233        @Override
234        public T nearestNeighbour(Coordinate query) {
235                Stack<KDNode<T>> stack = walkdown(query);
236                NNState state = new NNState();
237                state.best = stack.peek()._point;
238                state.bestDist = distance(query, state.best);
239
240                if (state.bestDist == 0) return state.best;
241
242                while (!stack.isEmpty()) {
243                        KDNode<T> current = stack.pop();
244
245                        checkSubtree(current, query, state); 
246                }
247
248                return state.best;
249        }
250
251        @Override
252        public void kNearestNeighbour(Collection<T> result, Coordinate query, int k) {
253                Stack<KDNode<T>> stack = walkdown(query);
254                PriorityQueue<NNState> state = new PriorityQueue<NNState>(k);
255
256                NNState initialState = new NNState();
257                initialState.best = stack.peek()._point;
258                initialState.bestDist = distance(query, initialState.best);
259                state.add(initialState);
260
261                while (!stack.isEmpty()) {
262                        KDNode<T> current = stack.pop();
263
264                        checkSubtreeK(current, query, state, k); 
265                }
266
267                @SuppressWarnings("unchecked")
268                NNState[] stateList = state.toArray((NNState[])Array.newInstance(NNState.class, state.size()));
269                Arrays.sort(stateList);
270                for (int i=stateList.length-1; i>=0; i--) result.add(stateList[i].best);
271        }
272
273
274        /*
275         * Check a subtree for a closer match
276         */
277        private void checkSubtree(KDNode<T> node, Coordinate query, NNState state) {
278                if(node == null) return;
279
280                float dist = distance(query, node._point);
281                if (dist < state.bestDist) {
282                        state.best = node._point;
283                        state.bestDist = dist;
284                }
285
286                if (state.bestDist == 0) return;
287
288                float d = ((Number)node._point.getOrdinate(node._discriminate)).floatValue() - 
289                ((Number)query.getOrdinate(node._discriminate)).floatValue();
290                if (d*d > state.bestDist) {
291                        //check subtree
292                        double ordinate1 = query.getOrdinate(node._discriminate).doubleValue();
293                        double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue();
294
295                        if(ordinate1 > ordinate2)
296                                checkSubtree(node._right, query, state);
297                        else
298                                checkSubtree(node._left, query, state);
299                } else {
300                        checkSubtree(node._left, query, state);
301                        checkSubtree(node._right, query, state);
302                }
303        }
304
305        private void checkSubtreeK(KDNode<T> node, Coordinate query, PriorityQueue<NNState> state, int k) {
306                if(node == null) return;
307
308                float dist = distance(query, node._point);
309
310                boolean cont = false;
311                for (NNState s : state) 
312                        if (s.best.equals(node._point)) {
313                                cont=true;
314                                break;
315                        }
316
317                if (!cont) {
318                        if (state.size() < k) {
319                                //collect this node
320                                NNState s = new NNState();
321                                s.best = node._point;
322                                s.bestDist = dist;
323                                state.add(s);
324                        } else if (dist < state.peek().bestDist) {
325                                //replace last node
326                                NNState s = state.poll();
327                                s.best = node._point;
328                                s.bestDist = dist;
329                                state.add(s);
330                        }
331                }
332
333                float d = ((Number)node._point.getOrdinate(node._discriminate)).floatValue() - 
334                ((Number)query.getOrdinate(node._discriminate)).floatValue();
335                if (d*d > state.peek().bestDist) {
336                        //check subtree
337                        double ordinate1 = query.getOrdinate(node._discriminate).doubleValue();
338                        double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue();
339
340                        if(ordinate1 > ordinate2)
341                                checkSubtreeK(node._right, query, state, k);
342                        else
343                                checkSubtreeK(node._left, query, state, k);
344                } else {
345                        checkSubtreeK(node._left, query, state, k);
346                        checkSubtreeK(node._right, query, state, k);
347                }
348        }
349
350
351        /*
352         * walk down the tree until we hit a leaf, and return the path taken
353         */
354        private Stack<KDNode<T>> walkdown(Coordinate point) {
355                if(_root == null)
356                        return null;
357                else {
358                        Stack<KDNode<T>> stack = new Stack<KDNode<T>>();
359                        int discriminate, dimensions;
360                        KDNode<T> curNode, tmpNode;
361                        double ordinate1, ordinate2;
362
363                        curNode = _root;
364
365                        do {
366                                tmpNode = curNode;
367                                stack.push(tmpNode);
368                                if (tmpNode._point == point) return stack;
369                                discriminate = tmpNode._discriminate;
370
371                                ordinate1 = point.getOrdinate(discriminate).doubleValue();
372                                ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue();
373
374                                if(ordinate1 > ordinate2)
375                                        curNode = tmpNode._right;
376                                else
377                                        curNode = tmpNode._left;
378                        } while(curNode != null);
379
380                        dimensions = point.getDimensions();
381
382                        if(++discriminate >= dimensions)
383                                discriminate = 0;
384
385                        return stack;
386                }
387        }
388        
389        class Coord implements Coordinate {
390                float [] coords;
391                public Coord(int i) { coords = new float[i]; }
392                public Coord(Coordinate c) { 
393                        this(c.getDimensions());
394                        for (int i=0; i<coords.length; i++) coords[i] = ((Number)c.getOrdinate(i)).floatValue();
395                }
396                @Override public int getDimensions() { return coords.length; }
397                @Override public Float getOrdinate(int dimension) { return coords[dimension]; }
398                
399                @Override
400                public void readASCII(Scanner in) throws IOException {
401                        throw new RuntimeException("not implemented");
402                }
403                @Override
404                public String asciiHeader() {
405                        throw new RuntimeException("not implemented");
406                }
407                @Override
408                public void readBinary(DataInput in) throws IOException {
409                        throw new RuntimeException("not implemented");
410                }
411                @Override
412                public byte[] binaryHeader() {
413                        throw new RuntimeException("not implemented");
414                }
415                @Override
416                public void writeASCII(PrintWriter out) throws IOException {
417                        throw new RuntimeException("not implemented");                  
418                }
419                @Override
420                public void writeBinary(DataOutput out) throws IOException {
421                        throw new RuntimeException("not implemented");
422                }
423        }
424        
425        /**
426         * Faster implementation of K-nearest-neighbours.
427         * 
428         * @param result Collection to hold the found coordinates.
429         * @param query The query coordinate.
430         * @param k The number of neighbours to find.
431         */     
432        public void fastKNN(Collection<T> result, Coordinate query, int k) {
433                List<T> tmp = new ArrayList<T>();
434                Coord lowerExtreme = new Coord(query);
435                Coord upperExtreme = new Coord(query);
436                
437                while (tmp.size()<k) {
438                        tmp.clear();
439                        for (int i=0; i<lowerExtreme.getDimensions(); i++) lowerExtreme.coords[i]-=k;
440                        for (int i=0; i<upperExtreme.getDimensions(); i++) upperExtreme.coords[i]+=k;
441                        rangeSearch(tmp, lowerExtreme, upperExtreme);
442                }
443                
444                CoordinateBruteForce<T> bf = new CoordinateBruteForce<T>(tmp);
445                bf.kNearestNeighbour(result, query, k);
446        }
447}
448
449