001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.knn; 031 032import java.util.ArrayList; 033import java.util.List; 034 035import org.openimaj.util.comparator.DistanceComparator; 036import org.openimaj.util.pair.IntFloatPair; 037import org.openimaj.util.queue.BoundedPriorityQueue; 038 039/** 040 * Exact (brute-force) k-nearest-neighbour implementation for objects with a 041 * compatible {@link DistanceComparator}. 042 * 043 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 044 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 045 * 046 * @param <T> 047 * Type of object being compared. 048 */ 049public class ObjectNearestNeighboursExact<T> extends ObjectNearestNeighbours<T> 050 implements 051 IncrementalNearestNeighbours<T, float[], IntFloatPair> 052{ 053 protected final List<T> pnts; 054 055 /** 056 * Construct the {@link ObjectNearestNeighboursExact} over the provided 057 * dataset with the given distance function. 058 * <p> 059 * Note: If the distance function provides similarities rather than 060 * distances they are automatically inverted. 061 * 062 * @param pnts 063 * the dataset 064 * @param distance 065 * the distance function 066 */ 067 public ObjectNearestNeighboursExact(final List<T> pnts, final DistanceComparator<? super T> distance) { 068 super(distance); 069 this.pnts = pnts; 070 } 071 072 /** 073 * Construct any empty {@link ObjectNearestNeighboursExact} with the given 074 * distance function. 075 * <p> 076 * Note: If the distance function provides similarities rather than 077 * distances they are automatically inverted. 078 * 079 * @param distance 080 * the distance function 081 */ 082 public ObjectNearestNeighboursExact(final DistanceComparator<T> distance) { 083 super(distance); 084 this.pnts = new ArrayList<T>(); 085 } 086 087 @Override 088 public void searchNN(final T[] qus, int[] indices, float[] distances) { 089 final int N = qus.length; 090 091 final BoundedPriorityQueue<IntFloatPair> queue = 092 new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 093 094 // prepare working data 095 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2); 096 list.add(new IntFloatPair()); 097 list.add(new IntFloatPair()); 098 099 for (int n = 0; n < N; ++n) { 100 final List<IntFloatPair> result = search(qus[n], queue, list); 101 102 final IntFloatPair p = result.get(0); 103 indices[n] = p.first; 104 distances[n] = p.second; 105 } 106 } 107 108 @Override 109 public void searchKNN(final T[] qus, int K, int[][] indices, float[][] distances) { 110 // Fix for when the user asks for too many points. 111 K = Math.min(K, pnts.size()); 112 113 final int N = qus.length; 114 115 final BoundedPriorityQueue<IntFloatPair> queue = 116 new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 117 118 // prepare working data 119 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1); 120 for (int i = 0; i < K + 1; i++) { 121 list.add(new IntFloatPair()); 122 } 123 124 // search on each query 125 for (int n = 0; n < N; ++n) { 126 final List<IntFloatPair> result = search(qus[n], queue, list); 127 128 for (int k = 0; k < K; ++k) { 129 final IntFloatPair p = result.get(k); 130 indices[n][k] = p.first; 131 distances[n][k] = p.second; 132 } 133 } 134 } 135 136 @Override 137 public void searchNN(final List<T> qus, int[] indices, float[] distances) { 138 final int N = qus.size(); 139 140 final BoundedPriorityQueue<IntFloatPair> queue = 141 new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 142 143 // prepare working data 144 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2); 145 list.add(new IntFloatPair()); 146 list.add(new IntFloatPair()); 147 148 for (int n = 0; n < N; ++n) { 149 final List<IntFloatPair> result = search(qus.get(n), queue, list); 150 151 final IntFloatPair p = result.get(0); 152 indices[n] = p.first; 153 distances[n] = p.second; 154 } 155 } 156 157 @Override 158 public void searchKNN(final List<T> qus, int K, int[][] indices, float[][] distances) { 159 // Fix for when the user asks for too many points. 160 K = Math.min(K, pnts.size()); 161 162 final int N = qus.size(); 163 164 final BoundedPriorityQueue<IntFloatPair> queue = 165 new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 166 167 // prepare working data 168 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1); 169 for (int i = 0; i < K + 1; i++) { 170 list.add(new IntFloatPair()); 171 } 172 173 // search on each query 174 for (int n = 0; n < N; ++n) { 175 final List<IntFloatPair> result = search(qus.get(n), queue, list); 176 177 for (int k = 0; k < K; ++k) { 178 final IntFloatPair p = result.get(k); 179 indices[n][k] = p.first; 180 distances[n][k] = p.second; 181 } 182 } 183 } 184 185 @Override 186 public List<IntFloatPair> searchKNN(T query, int K) { 187 // Fix for when the user asks for too many points. 188 K = Math.min(K, pnts.size()); 189 190 final BoundedPriorityQueue<IntFloatPair> queue = 191 new BoundedPriorityQueue<IntFloatPair>(K, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 192 193 // prepare working data 194 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(K + 1); 195 for (int i = 0; i < K + 1; i++) { 196 list.add(new IntFloatPair()); 197 } 198 199 // search 200 return search(query, queue, list); 201 } 202 203 @Override 204 public IntFloatPair searchNN(final T query) { 205 final BoundedPriorityQueue<IntFloatPair> queue = 206 new BoundedPriorityQueue<IntFloatPair>(1, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR); 207 208 // prepare working data 209 final List<IntFloatPair> list = new ArrayList<IntFloatPair>(2); 210 list.add(new IntFloatPair()); 211 list.add(new IntFloatPair()); 212 213 return search(query, queue, list).get(0); 214 } 215 216 private List<IntFloatPair> search(T query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results) 217 { 218 IntFloatPair wp = null; 219 220 // reset all values in the queue to MAX, -1 221 for (final IntFloatPair p : results) { 222 p.second = Float.MAX_VALUE; 223 p.first = -1; 224 wp = queue.offerItem(p); 225 } 226 227 // perform the search 228 final int size = this.pnts.size(); 229 for (int i = 0; i < size; i++) { 230 wp.second = ObjectNearestNeighbours.distanceFunc(distance, query, pnts.get(i)); 231 wp.first = i; 232 wp = queue.offerItem(wp); 233 } 234 235 return queue.toOrderedListDestructive(); 236 } 237 238 @Override 239 public int size() { 240 return this.pnts.size(); 241 } 242 243 @Override 244 public int[] addAll(final List<T> d) { 245 final int[] indexes = new int[d.size()]; 246 247 for (int i = 0; i < indexes.length; i++) { 248 indexes[i] = this.add(d.get(i)); 249 } 250 251 return indexes; 252 } 253 254 @Override 255 public int add(final T o) { 256 final int ret = this.pnts.size(); 257 this.pnts.add(o); 258 return ret; 259 } 260}