001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.knn; 031 032import java.io.DataInput; 033import java.io.DataOutput; 034import java.io.IOException; 035import java.io.PrintWriter; 036import java.lang.reflect.Array; 037import java.util.ArrayList; 038import java.util.Arrays; 039import java.util.Collection; 040import java.util.List; 041import java.util.PriorityQueue; 042import java.util.Scanner; 043import java.util.Stack; 044 045import org.openimaj.math.geometry.point.Coordinate; 046 047class KDNode<T extends Coordinate> { 048 int _discriminate; 049 T _point; 050 KDNode<T> _left, _right; 051 052 KDNode(T point, int discriminate) { 053 _point = point; 054 _left = _right = null; 055 _discriminate = discriminate; 056 } 057} 058 059/** 060 * Implementation of a simple KDTree with range search. 061 * The KDTree allows fast search for points in relatively low-dimension 062 * spaces. 063 * 064 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 065 * 066 * @param <T> the type of Coordinate. 067 */ 068public class CoordinateKDTree<T extends Coordinate> implements CoordinateIndex<T> { 069 KDNode<T> _root; 070 071 /** 072 * Create an empty KDTree object 073 */ 074 public CoordinateKDTree() { _root = null; } 075 076 /** 077 * Create a KDTree object and populate it with the 078 * given data. 079 * @param coords the data to populate the index with. 080 */ 081 public CoordinateKDTree(Collection<T> coords) { 082 _root = null; 083 insertAll(coords); 084 } 085 086 /** 087 * Insert all the points from the given collection 088 * into the index. 089 * @param coords The points to add. 090 */ 091 public void insertAll(Collection<T> coords) { 092 for (T c : coords) 093 insert(c); 094 } 095 096 /** 097 * Inserts a point into the tree, preserving the 098 * spatial ordering. 099 * @param point Point to insert. 100 */ 101 @Override 102 public void insert(T point) { 103 104 if(_root == null) 105 _root = new KDNode<T>(point, 0); 106 else { 107 int discriminate, dimensions; 108 KDNode<T> curNode, tmpNode; 109 double ordinate1, ordinate2; 110 111 curNode = _root; 112 113 do { 114 tmpNode = curNode; 115 discriminate = tmpNode._discriminate; 116 117 ordinate1 = point.getOrdinate(discriminate).doubleValue(); 118 ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue(); 119 120 if(ordinate1 > ordinate2) 121 curNode = tmpNode._right; 122 else 123 curNode = tmpNode._left; 124 } while(curNode != null); 125 126 dimensions = point.getDimensions(); 127 128 if(++discriminate >= dimensions) 129 discriminate = 0; 130 131 if(ordinate1 > ordinate2) 132 tmpNode._right = new KDNode<T>(point, discriminate); 133 else 134 tmpNode._left = new KDNode<T>(point, discriminate); 135 } 136 } 137 138 /** 139 * Determines if a point is contained within a given 140 * k-dimensional bounding box. 141 */ 142 static final boolean isContained( 143 Coordinate point, Coordinate lower, Coordinate upper) 144 { 145 int dimensions; 146 double ordinate1, ordinate2, ordinate3; 147 148 dimensions = point.getDimensions(); 149 150 for(int i = 0; i < dimensions; ++i) { 151 ordinate1 = point.getOrdinate(i).doubleValue(); 152 ordinate2 = lower.getOrdinate(i).doubleValue(); 153 ordinate3 = upper.getOrdinate(i).doubleValue(); 154 155 if(ordinate1 < ordinate2 || ordinate1 > ordinate3) 156 return false; 157 } 158 159 return true; 160 } 161 162 /** 163 * Searches the tree for all points contained within a 164 * given k-dimensional bounding box and stores them in a 165 * Collection. 166 * @param results 167 * @param lowerExtreme 168 * @param upperExtreme 169 */ 170 @Override 171 public void rangeSearch(Collection<T> results, Coordinate lowerExtreme, Coordinate upperExtreme) 172 { 173 KDNode<T> tmpNode; 174 Stack<KDNode<T>> stack = new Stack<KDNode<T>>(); 175 int discriminate; 176 double ordinate1, ordinate2; 177 178 if(_root == null) 179 return; 180 181 stack.push(_root); 182 183 while(!stack.empty()) { 184 tmpNode =stack.pop(); 185 discriminate = tmpNode._discriminate; 186 187 ordinate1 = tmpNode._point.getOrdinate(discriminate).doubleValue(); 188 ordinate2 = lowerExtreme.getOrdinate(discriminate).doubleValue(); 189 190 if(ordinate1 > ordinate2 && tmpNode._left != null) 191 stack.push(tmpNode._left); 192 193 ordinate2 = upperExtreme.getOrdinate(discriminate).doubleValue(); 194 195 if(ordinate1 < ordinate2 && tmpNode._right != null) 196 stack.push(tmpNode._right); 197 198 if(isContained(tmpNode._point, lowerExtreme, upperExtreme)) 199 results.add(tmpNode._point); 200 } 201 } 202 203 protected static final float distance(Coordinate a, Coordinate b) { 204 float s=0; 205 206 for (int i=0; i<a.getDimensions(); i++) { 207 final float fa = ((Number)a.getOrdinate(i)).floatValue(); 208 final float fb = ((Number)b.getOrdinate(i)).floatValue(); 209 s += (fa-fb)*(fa-fb); 210 } 211 return s; 212 } 213 214 class NNState implements Comparable<NNState> { 215 T best; 216 float bestDist; 217 218 @Override public int compareTo(NNState o) { 219 if (bestDist < o.bestDist) return 1; 220 if (bestDist > o.bestDist) return -1; 221 return 0; 222 } 223 @Override 224 public String toString() { return bestDist + ""; } 225 } 226 227 /** 228 * Find the nearest neighbour. Only one neighbour will be returned - if multiple neighbours 229 * share the same location, or are equidistant, then this might not be the one you expect. 230 * @param query query coordinate 231 * @return nearest neighbour 232 */ 233 @Override 234 public T nearestNeighbour(Coordinate query) { 235 Stack<KDNode<T>> stack = walkdown(query); 236 NNState state = new NNState(); 237 state.best = stack.peek()._point; 238 state.bestDist = distance(query, state.best); 239 240 if (state.bestDist == 0) return state.best; 241 242 while (!stack.isEmpty()) { 243 KDNode<T> current = stack.pop(); 244 245 checkSubtree(current, query, state); 246 } 247 248 return state.best; 249 } 250 251 @Override 252 public void kNearestNeighbour(Collection<T> result, Coordinate query, int k) { 253 Stack<KDNode<T>> stack = walkdown(query); 254 PriorityQueue<NNState> state = new PriorityQueue<NNState>(k); 255 256 NNState initialState = new NNState(); 257 initialState.best = stack.peek()._point; 258 initialState.bestDist = distance(query, initialState.best); 259 state.add(initialState); 260 261 while (!stack.isEmpty()) { 262 KDNode<T> current = stack.pop(); 263 264 checkSubtreeK(current, query, state, k); 265 } 266 267 @SuppressWarnings("unchecked") 268 NNState[] stateList = state.toArray((NNState[])Array.newInstance(NNState.class, state.size())); 269 Arrays.sort(stateList); 270 for (int i=stateList.length-1; i>=0; i--) result.add(stateList[i].best); 271 } 272 273 274 /* 275 * Check a subtree for a closer match 276 */ 277 private void checkSubtree(KDNode<T> node, Coordinate query, NNState state) { 278 if(node == null) return; 279 280 float dist = distance(query, node._point); 281 if (dist < state.bestDist) { 282 state.best = node._point; 283 state.bestDist = dist; 284 } 285 286 if (state.bestDist == 0) return; 287 288 float d = ((Number)node._point.getOrdinate(node._discriminate)).floatValue() - 289 ((Number)query.getOrdinate(node._discriminate)).floatValue(); 290 if (d*d > state.bestDist) { 291 //check subtree 292 double ordinate1 = query.getOrdinate(node._discriminate).doubleValue(); 293 double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue(); 294 295 if(ordinate1 > ordinate2) 296 checkSubtree(node._right, query, state); 297 else 298 checkSubtree(node._left, query, state); 299 } else { 300 checkSubtree(node._left, query, state); 301 checkSubtree(node._right, query, state); 302 } 303 } 304 305 private void checkSubtreeK(KDNode<T> node, Coordinate query, PriorityQueue<NNState> state, int k) { 306 if(node == null) return; 307 308 float dist = distance(query, node._point); 309 310 boolean cont = false; 311 for (NNState s : state) 312 if (s.best.equals(node._point)) { 313 cont=true; 314 break; 315 } 316 317 if (!cont) { 318 if (state.size() < k) { 319 //collect this node 320 NNState s = new NNState(); 321 s.best = node._point; 322 s.bestDist = dist; 323 state.add(s); 324 } else if (dist < state.peek().bestDist) { 325 //replace last node 326 NNState s = state.poll(); 327 s.best = node._point; 328 s.bestDist = dist; 329 state.add(s); 330 } 331 } 332 333 float d = ((Number)node._point.getOrdinate(node._discriminate)).floatValue() - 334 ((Number)query.getOrdinate(node._discriminate)).floatValue(); 335 if (d*d > state.peek().bestDist) { 336 //check subtree 337 double ordinate1 = query.getOrdinate(node._discriminate).doubleValue(); 338 double ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue(); 339 340 if(ordinate1 > ordinate2) 341 checkSubtreeK(node._right, query, state, k); 342 else 343 checkSubtreeK(node._left, query, state, k); 344 } else { 345 checkSubtreeK(node._left, query, state, k); 346 checkSubtreeK(node._right, query, state, k); 347 } 348 } 349 350 351 /* 352 * walk down the tree until we hit a leaf, and return the path taken 353 */ 354 private Stack<KDNode<T>> walkdown(Coordinate point) { 355 if(_root == null) 356 return null; 357 else { 358 Stack<KDNode<T>> stack = new Stack<KDNode<T>>(); 359 int discriminate, dimensions; 360 KDNode<T> curNode, tmpNode; 361 double ordinate1, ordinate2; 362 363 curNode = _root; 364 365 do { 366 tmpNode = curNode; 367 stack.push(tmpNode); 368 if (tmpNode._point == point) return stack; 369 discriminate = tmpNode._discriminate; 370 371 ordinate1 = point.getOrdinate(discriminate).doubleValue(); 372 ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue(); 373 374 if(ordinate1 > ordinate2) 375 curNode = tmpNode._right; 376 else 377 curNode = tmpNode._left; 378 } while(curNode != null); 379 380 dimensions = point.getDimensions(); 381 382 if(++discriminate >= dimensions) 383 discriminate = 0; 384 385 return stack; 386 } 387 } 388 389 class Coord implements Coordinate { 390 float [] coords; 391 public Coord(int i) { coords = new float[i]; } 392 public Coord(Coordinate c) { 393 this(c.getDimensions()); 394 for (int i=0; i<coords.length; i++) coords[i] = ((Number)c.getOrdinate(i)).floatValue(); 395 } 396 @Override public int getDimensions() { return coords.length; } 397 @Override public Float getOrdinate(int dimension) { return coords[dimension]; } 398 399 @Override 400 public void readASCII(Scanner in) throws IOException { 401 throw new RuntimeException("not implemented"); 402 } 403 @Override 404 public String asciiHeader() { 405 throw new RuntimeException("not implemented"); 406 } 407 @Override 408 public void readBinary(DataInput in) throws IOException { 409 throw new RuntimeException("not implemented"); 410 } 411 @Override 412 public byte[] binaryHeader() { 413 throw new RuntimeException("not implemented"); 414 } 415 @Override 416 public void writeASCII(PrintWriter out) throws IOException { 417 throw new RuntimeException("not implemented"); 418 } 419 @Override 420 public void writeBinary(DataOutput out) throws IOException { 421 throw new RuntimeException("not implemented"); 422 } 423 } 424 425 /** 426 * Faster implementation of K-nearest-neighbours. 427 * 428 * @param result Collection to hold the found coordinates. 429 * @param query The query coordinate. 430 * @param k The number of neighbours to find. 431 */ 432 public void fastKNN(Collection<T> result, Coordinate query, int k) { 433 List<T> tmp = new ArrayList<T>(); 434 Coord lowerExtreme = new Coord(query); 435 Coord upperExtreme = new Coord(query); 436 437 while (tmp.size()<k) { 438 tmp.clear(); 439 for (int i=0; i<lowerExtreme.getDimensions(); i++) lowerExtreme.coords[i]-=k; 440 for (int i=0; i<upperExtreme.getDimensions(); i++) upperExtreme.coords[i]+=k; 441 rangeSearch(tmp, lowerExtreme, upperExtreme); 442 } 443 444 CoordinateBruteForce<T> bf = new CoordinateBruteForce<T>(tmp); 445 bf.kNearestNeighbour(result, query, k); 446 } 447} 448 449