1 /*********************************************************************** 2 * Software License Agreement (BSD License) 3 * 4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. 5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. 6 * 7 * THE BSD LICENSE 8 * 9 * Redistribution and use in source and binary forms, with or without 10 * modification, are permitted provided that the following conditions 11 * are met: 12 * 13 * 1. Redistributions of source code must retain the above copyright 14 * notice, this list of conditions and the following disclaimer. 15 * 2. Redistributions in binary form must reproduce the above copyright 16 * notice, this list of conditions and the following disclaimer in the 17 * documentation and/or other materials provided with the distribution. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 *************************************************************************/ 30 31 /*********************************************************************** 32 * Author: Vincent Rabaud 33 *************************************************************************/ 34 35 #ifndef FLANN_LSH_INDEX_H_ 36 #define FLANN_LSH_INDEX_H_ 37 38 #include <algorithm> 39 #include <cassert> 40 #include <cstring> 41 #include <map> 42 #include <vector> 43 44 #include "FLANN/general.h" 45 #include "FLANN/algorithms/nn_index.h" 46 #include "FLANN/util/matrix.h" 47 #include "FLANN/util/result_set.h" 48 #include "FLANN/util/heap.h" 49 #include "FLANN/util/lsh_table.h" 50 #include "FLANN/util/allocator.h" 51 #include "FLANN/util/random.h" 52 #include "FLANN/util/saving.h" 53 54 namespace flann 55 { 56 57 struct LshIndexParams : public IndexParams 58 { 59 LshIndexParams(unsigned int table_number = 12, unsigned int key_size = 20, unsigned int multi_probe_level = 2) 60 { 61 (* this)["algorithm"] = FLANN_INDEX_LSH; 62 // The number of hash tables to use 63 (*this)["table_number"] = table_number; 64 // The length of the key in the hash tables 65 (*this)["key_size"] = key_size; 66 // Number of levels to use in multi-probe (0 for standard LSH) 67 (*this)["multi_probe_level"] = multi_probe_level; 68 } 69 }; 70 71 /** 72 * Locality-sensitive hashing index 73 * 74 * Contains the tables and other information for indexing a set of points 75 * for nearest-neighbor matching. 76 */ 77 template<typename Distance> 78 class LshIndex : public NNIndex<Distance> 79 { 80 public: 81 typedef typename Distance::ElementType ElementType; 82 typedef typename Distance::ResultType DistanceType; 83 84 typedef NNIndex<Distance> BaseClass; 85 86 /** Constructor 87 * @param params parameters passed to the LSH algorithm 88 * @param d the distance used 89 */ 90 LshIndex(const IndexParams& params = LshIndexParams(), Distance d = Distance()) : BaseClass(params,d)91 BaseClass(params, d) 92 { 93 table_number_ = get_param<unsigned int>(index_params_,"table_number",12); 94 key_size_ = get_param<unsigned int>(index_params_,"key_size",20); 95 multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2); 96 97 fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_); 98 } 99 100 101 /** Constructor 102 * @param input_data dataset with the input features 103 * @param params parameters passed to the LSH algorithm 104 * @param d the distance used 105 */ 106 LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(), Distance d = Distance()) : BaseClass(params,d)107 BaseClass(params, d) 108 { 109 table_number_ = get_param<unsigned int>(index_params_,"table_number",12); 110 key_size_ = get_param<unsigned int>(index_params_,"key_size",20); 111 multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2); 112 113 fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_); 114 115 setDataset(input_data); 116 } 117 LshIndex(const LshIndex & other)118 LshIndex(const LshIndex& other) : BaseClass(other), 119 tables_(other.tables_), 120 table_number_(other.table_number_), 121 key_size_(other.key_size_), 122 multi_probe_level_(other.multi_probe_level_), 123 xor_masks_(other.xor_masks_) 124 { 125 } 126 127 LshIndex& operator=(LshIndex other) 128 { 129 this->swap(other); 130 return *this; 131 } 132 ~LshIndex()133 virtual ~LshIndex() 134 { 135 freeIndex(); 136 } 137 138 clone()139 BaseClass* clone() const 140 { 141 return new LshIndex(*this); 142 } 143 144 using BaseClass::buildIndex; 145 146 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2) 147 { 148 assert(points.cols==veclen_); 149 size_t old_size = size_; 150 151 extendDataset(points); 152 153 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) { 154 buildIndex(); 155 } 156 else { 157 for (unsigned int i = 0; i < table_number_; ++i) { 158 lsh::LshTable<ElementType>& table = tables_[i]; 159 for (size_t i=old_size;i<size_;++i) { 160 table.add(i, points_[i]); 161 } 162 } 163 } 164 } 165 166 getType()167 flann_algorithm_t getType() const 168 { 169 return FLANN_INDEX_LSH; 170 } 171 172 173 template<typename Archive> serialize(Archive & ar)174 void serialize(Archive& ar) 175 { 176 ar.setObject(this); 177 178 ar & *static_cast<NNIndex<Distance>*>(this); 179 180 ar & table_number_; 181 ar & key_size_; 182 ar & multi_probe_level_; 183 184 ar & xor_masks_; 185 ar & tables_; 186 187 if (Archive::is_loading::value) { 188 index_params_["algorithm"] = getType(); 189 index_params_["table_number"] = table_number_; 190 index_params_["key_size"] = key_size_; 191 index_params_["multi_probe_level"] = multi_probe_level_; 192 } 193 } 194 saveIndex(FILE * stream)195 void saveIndex(FILE* stream) 196 { 197 serialization::SaveArchive sa(stream); 198 sa & *this; 199 } 200 loadIndex(FILE * stream)201 void loadIndex(FILE* stream) 202 { 203 serialization::LoadArchive la(stream); 204 la & *this; 205 } 206 207 /** 208 * Computes the index memory usage 209 * Returns: memory used by the index 210 */ usedMemory()211 int usedMemory() const 212 { 213 return size_ * sizeof(int); 214 } 215 216 /** 217 * \brief Perform k-nearest neighbor search 218 * \param[in] queries The query points for which to find the nearest neighbors 219 * \param[out] indices The indices of the nearest neighbors found 220 * \param[out] dists Distances to the nearest neighbors found 221 * \param[in] knn Number of nearest neighbors to return 222 * \param[in] params Search parameters 223 */ knnSearch(const Matrix<ElementType> & queries,Matrix<size_t> & indices,Matrix<DistanceType> & dists,size_t knn,const SearchParams & params)224 int knnSearch(const Matrix<ElementType>& queries, 225 Matrix<size_t>& indices, 226 Matrix<DistanceType>& dists, 227 size_t knn, 228 const SearchParams& params) const 229 { 230 assert(queries.cols == veclen_); 231 assert(indices.rows >= queries.rows); 232 assert(dists.rows >= queries.rows); 233 assert(indices.cols >= knn); 234 assert(dists.cols >= knn); 235 236 int count = 0; 237 if (params.use_heap==FLANN_True) { 238 #pragma omp parallel num_threads(params.cores) 239 { 240 KNNUniqueResultSet<DistanceType> resultSet(knn); 241 #pragma omp for schedule(dynamic) reduction(+:count) 242 for (int i = 0; i < (int)queries.rows; i++) { 243 resultSet.clear(); 244 findNeighbors(resultSet, queries[i], params); 245 size_t n = std::min(resultSet.size(), knn); 246 resultSet.copy(indices[i], dists[i], n, params.sorted); 247 indices_to_ids(indices[i], indices[i], n); 248 count += n; 249 } 250 } 251 } 252 else { 253 #pragma omp parallel num_threads(params.cores) 254 { 255 KNNResultSet<DistanceType> resultSet(knn); 256 #pragma omp for schedule(dynamic) reduction(+:count) 257 for (int i = 0; i < (int)queries.rows; i++) { 258 resultSet.clear(); 259 findNeighbors(resultSet, queries[i], params); 260 size_t n = std::min(resultSet.size(), knn); 261 resultSet.copy(indices[i], dists[i], n, params.sorted); 262 indices_to_ids(indices[i], indices[i], n); 263 count += n; 264 } 265 } 266 } 267 268 return count; 269 } 270 271 /** 272 * \brief Perform k-nearest neighbor search 273 * \param[in] queries The query points for which to find the nearest neighbors 274 * \param[out] indices The indices of the nearest neighbors found 275 * \param[out] dists Distances to the nearest neighbors found 276 * \param[in] knn Number of nearest neighbors to return 277 * \param[in] params Search parameters 278 */ knnSearch(const Matrix<ElementType> & queries,std::vector<std::vector<size_t>> & indices,std::vector<std::vector<DistanceType>> & dists,size_t knn,const SearchParams & params)279 int knnSearch(const Matrix<ElementType>& queries, 280 std::vector< std::vector<size_t> >& indices, 281 std::vector<std::vector<DistanceType> >& dists, 282 size_t knn, 283 const SearchParams& params) const 284 { 285 assert(queries.cols == veclen_); 286 if (indices.size() < queries.rows ) indices.resize(queries.rows); 287 if (dists.size() < queries.rows ) dists.resize(queries.rows); 288 289 int count = 0; 290 if (params.use_heap==FLANN_True) { 291 #pragma omp parallel num_threads(params.cores) 292 { 293 KNNUniqueResultSet<DistanceType> resultSet(knn); 294 #pragma omp for schedule(dynamic) reduction(+:count) 295 for (int i = 0; i < (int)queries.rows; i++) { 296 resultSet.clear(); 297 findNeighbors(resultSet, queries[i], params); 298 size_t n = std::min(resultSet.size(), knn); 299 indices[i].resize(n); 300 dists[i].resize(n); 301 if (n > 0) { 302 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted); 303 indices_to_ids(&indices[i][0], &indices[i][0], n); 304 } 305 count += n; 306 } 307 } 308 } 309 else { 310 #pragma omp parallel num_threads(params.cores) 311 { 312 KNNResultSet<DistanceType> resultSet(knn); 313 #pragma omp for schedule(dynamic) reduction(+:count) 314 for (int i = 0; i < (int)queries.rows; i++) { 315 resultSet.clear(); 316 findNeighbors(resultSet, queries[i], params); 317 size_t n = std::min(resultSet.size(), knn); 318 indices[i].resize(n); 319 dists[i].resize(n); 320 if (n > 0) { 321 resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted); 322 indices_to_ids(&indices[i][0], &indices[i][0], n); 323 } 324 count += n; 325 } 326 } 327 } 328 329 return count; 330 } 331 332 /** 333 * Find set of nearest neighbors to vec. Their indices are stored inside 334 * the result object. 335 * 336 * Params: 337 * result = the result object in which the indices of the nearest-neighbors are stored 338 * vec = the vector for which to search the nearest neighbors 339 * maxCheck = the maximum number of restarts (in a best-bin-first manner) 340 */ findNeighbors(ResultSet<DistanceType> & result,const ElementType * vec,const SearchParams &)341 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& /*searchParams*/) const 342 { 343 getNeighbors(vec, result); 344 } 345 346 protected: 347 348 /** 349 * Builds the index 350 */ buildIndexImpl()351 void buildIndexImpl() 352 { 353 tables_.resize(table_number_); 354 std::vector<std::pair<size_t,ElementType*> > features; 355 features.reserve(points_.size()); 356 for (size_t i=0;i<points_.size();++i) { 357 features.push_back(std::make_pair(i, points_[i])); 358 } 359 for (unsigned int i = 0; i < table_number_; ++i) { 360 lsh::LshTable<ElementType>& table = tables_[i]; 361 table = lsh::LshTable<ElementType>(veclen_, key_size_); 362 363 // Add the features to the table 364 table.add(features); 365 } 366 } 367 freeIndex()368 void freeIndex() 369 { 370 /* nothing to do here */ 371 } 372 373 374 private: 375 /** Defines the comparator on score and index 376 */ 377 typedef std::pair<float, unsigned int> ScoreIndexPair; 378 struct SortScoreIndexPairOnSecond 379 { operatorSortScoreIndexPairOnSecond380 bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const 381 { 382 return left.second < right.second; 383 } 384 }; 385 386 /** Fills the different xor masks to use when getting the neighbors in multi-probe LSH 387 * @param key the key we build neighbors from 388 * @param lowest_index the lowest index of the bit set 389 * @param level the multi-probe level we are at 390 * @param xor_masks all the xor mask 391 */ fill_xor_mask(lsh::BucketKey key,int lowest_index,unsigned int level,std::vector<lsh::BucketKey> & xor_masks)392 void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level, 393 std::vector<lsh::BucketKey>& xor_masks) 394 { 395 xor_masks.push_back(key); 396 if (level == 0) return; 397 for (int index = lowest_index - 1; index >= 0; --index) { 398 // Create a new key 399 lsh::BucketKey new_key = key | (lsh::BucketKey(1) << index); 400 fill_xor_mask(new_key, index, level - 1, xor_masks); 401 } 402 } 403 404 /** Performs the approximate nearest-neighbor search. 405 * @param vec the feature to analyze 406 * @param do_radius flag indicating if we check the radius too 407 * @param radius the radius if it is a radius search 408 * @param do_k flag indicating if we limit the number of nn 409 * @param k_nn the number of nearest neighbors 410 * @param checked_average used for debugging 411 */ getNeighbors(const ElementType * vec,bool do_radius,float radius,bool do_k,unsigned int k_nn,float & checked_average)412 void getNeighbors(const ElementType* vec, bool do_radius, float radius, bool do_k, unsigned int k_nn, 413 float& checked_average) 414 { 415 static std::vector<ScoreIndexPair> score_index_heap; 416 417 if (do_k) { 418 unsigned int worst_score = std::numeric_limits<unsigned int>::max(); 419 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin(); 420 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end(); 421 for (; table != table_end; ++table) { 422 size_t key = table->getKey(vec); 423 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin(); 424 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end(); 425 for (; xor_mask != xor_mask_end; ++xor_mask) { 426 size_t sub_key = key ^ (*xor_mask); 427 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key); 428 if (bucket == 0) continue; 429 430 // Go over each descriptor index 431 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin(); 432 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end(); 433 DistanceType hamming_distance; 434 435 // Process the rest of the candidates 436 for (; training_index < last_training_index; ++training_index) { 437 if (removed_ && removed_points_.test(*training_index)) continue; 438 hamming_distance = distance_(vec, points_[*training_index].point, veclen_); 439 440 if (hamming_distance < worst_score) { 441 // Insert the new element 442 score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index)); 443 std::push_heap(score_index_heap.begin(), score_index_heap.end()); 444 445 if (score_index_heap.size() > (unsigned int)k_nn) { 446 // Remove the highest distance value as we have too many elements 447 std::pop_heap(score_index_heap.begin(), score_index_heap.end()); 448 score_index_heap.pop_back(); 449 // Keep track of the worst score 450 worst_score = score_index_heap.front().first; 451 } 452 } 453 } 454 } 455 } 456 } 457 else { 458 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin(); 459 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end(); 460 for (; table != table_end; ++table) { 461 size_t key = table->getKey(vec); 462 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin(); 463 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end(); 464 for (; xor_mask != xor_mask_end; ++xor_mask) { 465 size_t sub_key = key ^ (*xor_mask); 466 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key); 467 if (bucket == 0) continue; 468 469 // Go over each descriptor index 470 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin(); 471 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end(); 472 DistanceType hamming_distance; 473 474 // Process the rest of the candidates 475 for (; training_index < last_training_index; ++training_index) { 476 if (removed_ && removed_points_.test(*training_index)) continue; 477 // Compute the Hamming distance 478 hamming_distance = distance_(vec, points_[*training_index].point, veclen_); 479 if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index)); 480 } 481 } 482 } 483 } 484 } 485 486 /** Performs the approximate nearest-neighbor search. 487 * This is a slower version than the above as it uses the ResultSet 488 * @param vec the feature to analyze 489 */ getNeighbors(const ElementType * vec,ResultSet<DistanceType> & result)490 void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result) const 491 { 492 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin(); 493 typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end(); 494 for (; table != table_end; ++table) { 495 size_t key = table->getKey(vec); 496 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin(); 497 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end(); 498 for (; xor_mask != xor_mask_end; ++xor_mask) { 499 size_t sub_key = key ^ (*xor_mask); 500 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key); 501 if (bucket == 0) continue; 502 503 // Go over each descriptor index 504 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin(); 505 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end(); 506 DistanceType hamming_distance; 507 508 // Process the rest of the candidates 509 for (; training_index < last_training_index; ++training_index) { 510 if (removed_ && removed_points_.test(*training_index)) continue; 511 // Compute the Hamming distance 512 hamming_distance = distance_(vec, points_[*training_index], veclen_); 513 result.addPoint(hamming_distance, *training_index); 514 } 515 } 516 } 517 } 518 519 swap(LshIndex & other)520 void swap(LshIndex& other) 521 { 522 BaseClass::swap(other); 523 std::swap(tables_, other.tables_); 524 std::swap(size_at_build_, other.size_at_build_); 525 std::swap(table_number_, other.table_number_); 526 std::swap(key_size_, other.key_size_); 527 std::swap(multi_probe_level_, other.multi_probe_level_); 528 std::swap(xor_masks_, other.xor_masks_); 529 } 530 531 /** The different hash tables */ 532 std::vector<lsh::LshTable<ElementType> > tables_; 533 534 /** table number */ 535 unsigned int table_number_; 536 /** key size */ 537 unsigned int key_size_; 538 /** How far should we look for neighbors in multi-probe LSH */ 539 unsigned int multi_probe_level_; 540 541 /** The XOR masks to apply to a key to get the neighboring buckets */ 542 std::vector<lsh::BucketKey> xor_masks_; 543 544 USING_BASECLASS_SYMBOLS 545 }; 546 } 547 548 #endif //FLANN_LSH_INDEX_H_ 549