1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30 
31 /***********************************************************************
32  * Author: Vincent Rabaud
33  *************************************************************************/
34 
35 #ifndef FLANN_LSH_INDEX_H_
36 #define FLANN_LSH_INDEX_H_
37 
38 #include <algorithm>
39 #include <cassert>
40 #include <cstring>
41 #include <map>
42 #include <vector>
43 
44 #include "FLANN/general.h"
45 #include "FLANN/algorithms/nn_index.h"
46 #include "FLANN/util/matrix.h"
47 #include "FLANN/util/result_set.h"
48 #include "FLANN/util/heap.h"
49 #include "FLANN/util/lsh_table.h"
50 #include "FLANN/util/allocator.h"
51 #include "FLANN/util/random.h"
52 #include "FLANN/util/saving.h"
53 
54 namespace flann
55 {
56 
57 struct LshIndexParams : public IndexParams
58 {
59     LshIndexParams(unsigned int table_number = 12, unsigned int key_size = 20, unsigned int multi_probe_level = 2)
60     {
61         (* this)["algorithm"] = FLANN_INDEX_LSH;
62         // The number of hash tables to use
63         (*this)["table_number"] = table_number;
64         // The length of the key in the hash tables
65         (*this)["key_size"] = key_size;
66         // Number of levels to use in multi-probe (0 for standard LSH)
67         (*this)["multi_probe_level"] = multi_probe_level;
68     }
69 };
70 
71 /**
72  * Locality-sensitive hashing  index
73  *
74  * Contains the tables and other information for indexing a set of points
75  * for nearest-neighbor matching.
76  */
77 template<typename Distance>
78 class LshIndex : public NNIndex<Distance>
79 {
80 public:
81     typedef typename Distance::ElementType ElementType;
82     typedef typename Distance::ResultType DistanceType;
83 
84     typedef NNIndex<Distance> BaseClass;
85 
86     /** Constructor
87      * @param params parameters passed to the LSH algorithm
88      * @param d the distance used
89      */
90     LshIndex(const IndexParams& params = LshIndexParams(), Distance d = Distance()) :
BaseClass(params,d)91     	BaseClass(params, d)
92     {
93         table_number_ = get_param<unsigned int>(index_params_,"table_number",12);
94         key_size_ = get_param<unsigned int>(index_params_,"key_size",20);
95         multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2);
96 
97         fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
98     }
99 
100 
101     /** Constructor
102      * @param input_data dataset with the input features
103      * @param params parameters passed to the LSH algorithm
104      * @param d the distance used
105      */
106     LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(), Distance d = Distance()) :
BaseClass(params,d)107     	BaseClass(params, d)
108     {
109         table_number_ = get_param<unsigned int>(index_params_,"table_number",12);
110         key_size_ = get_param<unsigned int>(index_params_,"key_size",20);
111         multi_probe_level_ = get_param<unsigned int>(index_params_,"multi_probe_level",2);
112 
113         fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
114 
115         setDataset(input_data);
116     }
117 
LshIndex(const LshIndex & other)118     LshIndex(const LshIndex& other) : BaseClass(other),
119     	tables_(other.tables_),
120     	table_number_(other.table_number_),
121     	key_size_(other.key_size_),
122     	multi_probe_level_(other.multi_probe_level_),
123     	xor_masks_(other.xor_masks_)
124     {
125     }
126 
127     LshIndex& operator=(LshIndex other)
128     {
129     	this->swap(other);
130     	return *this;
131     }
132 
~LshIndex()133     virtual ~LshIndex()
134     {
135     	freeIndex();
136     }
137 
138 
clone()139     BaseClass* clone() const
140     {
141     	return new LshIndex(*this);
142     }
143 
144     using BaseClass::buildIndex;
145 
146     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
147     {
148         assert(points.cols==veclen_);
149         size_t old_size = size_;
150 
151         extendDataset(points);
152 
153         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
154             buildIndex();
155         }
156         else {
157             for (unsigned int i = 0; i < table_number_; ++i) {
158                 lsh::LshTable<ElementType>& table = tables_[i];
159                 for (size_t i=old_size;i<size_;++i) {
160                     table.add(i, points_[i]);
161                 }
162             }
163         }
164     }
165 
166 
getType()167     flann_algorithm_t getType() const
168     {
169         return FLANN_INDEX_LSH;
170     }
171 
172 
173     template<typename Archive>
serialize(Archive & ar)174     void serialize(Archive& ar)
175     {
176     	ar.setObject(this);
177 
178     	ar & *static_cast<NNIndex<Distance>*>(this);
179 
180     	ar & table_number_;
181     	ar & key_size_;
182     	ar & multi_probe_level_;
183 
184     	ar & xor_masks_;
185     	ar & tables_;
186 
187     	if (Archive::is_loading::value) {
188             index_params_["algorithm"] = getType();
189             index_params_["table_number"] = table_number_;
190             index_params_["key_size"] = key_size_;
191             index_params_["multi_probe_level"] = multi_probe_level_;
192     	}
193     }
194 
saveIndex(FILE * stream)195     void saveIndex(FILE* stream)
196     {
197     	serialization::SaveArchive sa(stream);
198     	sa & *this;
199     }
200 
loadIndex(FILE * stream)201     void loadIndex(FILE* stream)
202     {
203     	serialization::LoadArchive la(stream);
204     	la & *this;
205     }
206 
207     /**
208      * Computes the index memory usage
209      * Returns: memory used by the index
210      */
usedMemory()211     int usedMemory() const
212     {
213         return size_ * sizeof(int);
214     }
215 
216     /**
217      * \brief Perform k-nearest neighbor search
218      * \param[in] queries The query points for which to find the nearest neighbors
219      * \param[out] indices The indices of the nearest neighbors found
220      * \param[out] dists Distances to the nearest neighbors found
221      * \param[in] knn Number of nearest neighbors to return
222      * \param[in] params Search parameters
223      */
knnSearch(const Matrix<ElementType> & queries,Matrix<size_t> & indices,Matrix<DistanceType> & dists,size_t knn,const SearchParams & params)224     int knnSearch(const Matrix<ElementType>& queries,
225     					Matrix<size_t>& indices,
226     					Matrix<DistanceType>& dists,
227     					size_t knn,
228     					const SearchParams& params) const
229     {
230         assert(queries.cols == veclen_);
231         assert(indices.rows >= queries.rows);
232         assert(dists.rows >= queries.rows);
233         assert(indices.cols >= knn);
234         assert(dists.cols >= knn);
235 
236         int count = 0;
237         if (params.use_heap==FLANN_True) {
238 #pragma omp parallel num_threads(params.cores)
239         	{
240         		KNNUniqueResultSet<DistanceType> resultSet(knn);
241 #pragma omp for schedule(dynamic) reduction(+:count)
242         		for (int i = 0; i < (int)queries.rows; i++) {
243         			resultSet.clear();
244         			findNeighbors(resultSet, queries[i], params);
245         			size_t n = std::min(resultSet.size(), knn);
246         			resultSet.copy(indices[i], dists[i], n, params.sorted);
247         			indices_to_ids(indices[i], indices[i], n);
248         			count += n;
249         		}
250         	}
251         }
252         else {
253 #pragma omp parallel num_threads(params.cores)
254         	{
255         		KNNResultSet<DistanceType> resultSet(knn);
256 #pragma omp for schedule(dynamic) reduction(+:count)
257         		for (int i = 0; i < (int)queries.rows; i++) {
258         			resultSet.clear();
259         			findNeighbors(resultSet, queries[i], params);
260         			size_t n = std::min(resultSet.size(), knn);
261         			resultSet.copy(indices[i], dists[i], n, params.sorted);
262         			indices_to_ids(indices[i], indices[i], n);
263         			count += n;
264         		}
265         	}
266         }
267 
268         return count;
269     }
270 
271     /**
272      * \brief Perform k-nearest neighbor search
273      * \param[in] queries The query points for which to find the nearest neighbors
274      * \param[out] indices The indices of the nearest neighbors found
275      * \param[out] dists Distances to the nearest neighbors found
276      * \param[in] knn Number of nearest neighbors to return
277      * \param[in] params Search parameters
278      */
knnSearch(const Matrix<ElementType> & queries,std::vector<std::vector<size_t>> & indices,std::vector<std::vector<DistanceType>> & dists,size_t knn,const SearchParams & params)279     int knnSearch(const Matrix<ElementType>& queries,
280 					std::vector< std::vector<size_t> >& indices,
281 					std::vector<std::vector<DistanceType> >& dists,
282     				size_t knn,
283     				const SearchParams& params) const
284     {
285         assert(queries.cols == veclen_);
286 		if (indices.size() < queries.rows ) indices.resize(queries.rows);
287 		if (dists.size() < queries.rows ) dists.resize(queries.rows);
288 
289 		int count = 0;
290 		if (params.use_heap==FLANN_True) {
291 #pragma omp parallel num_threads(params.cores)
292 			{
293 				KNNUniqueResultSet<DistanceType> resultSet(knn);
294 #pragma omp for schedule(dynamic) reduction(+:count)
295 				for (int i = 0; i < (int)queries.rows; i++) {
296 					resultSet.clear();
297 					findNeighbors(resultSet, queries[i], params);
298 					size_t n = std::min(resultSet.size(), knn);
299 					indices[i].resize(n);
300 					dists[i].resize(n);
301 					if (n > 0) {
302 						resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
303 						indices_to_ids(&indices[i][0], &indices[i][0], n);
304 					}
305 					count += n;
306 				}
307 			}
308 		}
309 		else {
310 #pragma omp parallel num_threads(params.cores)
311 			{
312 				KNNResultSet<DistanceType> resultSet(knn);
313 #pragma omp for schedule(dynamic) reduction(+:count)
314 				for (int i = 0; i < (int)queries.rows; i++) {
315 					resultSet.clear();
316 					findNeighbors(resultSet, queries[i], params);
317 					size_t n = std::min(resultSet.size(), knn);
318 					indices[i].resize(n);
319 					dists[i].resize(n);
320 					if (n > 0) {
321 						resultSet.copy(&indices[i][0], &dists[i][0], n, params.sorted);
322 						indices_to_ids(&indices[i][0], &indices[i][0], n);
323 					}
324 					count += n;
325 				}
326 			}
327 		}
328 
329 		return count;
330     }
331 
332     /**
333      * Find set of nearest neighbors to vec. Their indices are stored inside
334      * the result object.
335      *
336      * Params:
337      *     result = the result object in which the indices of the nearest-neighbors are stored
338      *     vec = the vector for which to search the nearest neighbors
339      *     maxCheck = the maximum number of restarts (in a best-bin-first manner)
340      */
findNeighbors(ResultSet<DistanceType> & result,const ElementType * vec,const SearchParams &)341     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& /*searchParams*/) const
342     {
343         getNeighbors(vec, result);
344     }
345 
346 protected:
347 
348     /**
349      * Builds the index
350      */
buildIndexImpl()351     void buildIndexImpl()
352     {
353         tables_.resize(table_number_);
354         std::vector<std::pair<size_t,ElementType*> > features;
355         features.reserve(points_.size());
356         for (size_t i=0;i<points_.size();++i) {
357         	features.push_back(std::make_pair(i, points_[i]));
358         }
359         for (unsigned int i = 0; i < table_number_; ++i) {
360             lsh::LshTable<ElementType>& table = tables_[i];
361             table = lsh::LshTable<ElementType>(veclen_, key_size_);
362 
363             // Add the features to the table
364             table.add(features);
365         }
366     }
367 
freeIndex()368     void freeIndex()
369     {
370         /* nothing to do here */
371     }
372 
373 
374 private:
375     /** Defines the comparator on score and index
376      */
377     typedef std::pair<float, unsigned int> ScoreIndexPair;
378     struct SortScoreIndexPairOnSecond
379     {
operatorSortScoreIndexPairOnSecond380         bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const
381         {
382             return left.second < right.second;
383         }
384     };
385 
386     /** Fills the different xor masks to use when getting the neighbors in multi-probe LSH
387      * @param key the key we build neighbors from
388      * @param lowest_index the lowest index of the bit set
389      * @param level the multi-probe level we are at
390      * @param xor_masks all the xor mask
391      */
fill_xor_mask(lsh::BucketKey key,int lowest_index,unsigned int level,std::vector<lsh::BucketKey> & xor_masks)392     void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level,
393                        std::vector<lsh::BucketKey>& xor_masks)
394     {
395         xor_masks.push_back(key);
396         if (level == 0) return;
397         for (int index = lowest_index - 1; index >= 0; --index) {
398             // Create a new key
399             lsh::BucketKey new_key = key | (lsh::BucketKey(1) << index);
400             fill_xor_mask(new_key, index, level - 1, xor_masks);
401         }
402     }
403 
404     /** Performs the approximate nearest-neighbor search.
405      * @param vec the feature to analyze
406      * @param do_radius flag indicating if we check the radius too
407      * @param radius the radius if it is a radius search
408      * @param do_k flag indicating if we limit the number of nn
409      * @param k_nn the number of nearest neighbors
410      * @param checked_average used for debugging
411      */
getNeighbors(const ElementType * vec,bool do_radius,float radius,bool do_k,unsigned int k_nn,float & checked_average)412     void getNeighbors(const ElementType* vec, bool do_radius, float radius, bool do_k, unsigned int k_nn,
413                       float& checked_average)
414     {
415         static std::vector<ScoreIndexPair> score_index_heap;
416 
417         if (do_k) {
418             unsigned int worst_score = std::numeric_limits<unsigned int>::max();
419             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
420             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
421             for (; table != table_end; ++table) {
422                 size_t key = table->getKey(vec);
423                 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
424                 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
425                 for (; xor_mask != xor_mask_end; ++xor_mask) {
426                     size_t sub_key = key ^ (*xor_mask);
427                     const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
428                     if (bucket == 0) continue;
429 
430                     // Go over each descriptor index
431                     std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
432                     std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
433                     DistanceType hamming_distance;
434 
435                     // Process the rest of the candidates
436                     for (; training_index < last_training_index; ++training_index) {
437                     	if (removed_ && removed_points_.test(*training_index)) continue;
438                         hamming_distance = distance_(vec, points_[*training_index].point, veclen_);
439 
440                         if (hamming_distance < worst_score) {
441                             // Insert the new element
442                             score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
443                             std::push_heap(score_index_heap.begin(), score_index_heap.end());
444 
445                             if (score_index_heap.size() > (unsigned int)k_nn) {
446                                 // Remove the highest distance value as we have too many elements
447                                 std::pop_heap(score_index_heap.begin(), score_index_heap.end());
448                                 score_index_heap.pop_back();
449                                 // Keep track of the worst score
450                                 worst_score = score_index_heap.front().first;
451                             }
452                         }
453                     }
454                 }
455             }
456         }
457         else {
458             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
459             typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
460             for (; table != table_end; ++table) {
461                 size_t key = table->getKey(vec);
462                 std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
463                 std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
464                 for (; xor_mask != xor_mask_end; ++xor_mask) {
465                     size_t sub_key = key ^ (*xor_mask);
466                     const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
467                     if (bucket == 0) continue;
468 
469                     // Go over each descriptor index
470                     std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
471                     std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
472                     DistanceType hamming_distance;
473 
474                     // Process the rest of the candidates
475                     for (; training_index < last_training_index; ++training_index) {
476                     	if (removed_ && removed_points_.test(*training_index)) continue;
477                         // Compute the Hamming distance
478                         hamming_distance = distance_(vec, points_[*training_index].point, veclen_);
479                         if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
480                     }
481                 }
482             }
483         }
484     }
485 
486     /** Performs the approximate nearest-neighbor search.
487      * This is a slower version than the above as it uses the ResultSet
488      * @param vec the feature to analyze
489      */
getNeighbors(const ElementType * vec,ResultSet<DistanceType> & result)490     void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result) const
491     {
492         typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
493         typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
494         for (; table != table_end; ++table) {
495             size_t key = table->getKey(vec);
496             std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
497             std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
498             for (; xor_mask != xor_mask_end; ++xor_mask) {
499                 size_t sub_key = key ^ (*xor_mask);
500                 const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
501                 if (bucket == 0) continue;
502 
503                 // Go over each descriptor index
504                 std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
505                 std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
506                 DistanceType hamming_distance;
507 
508                 // Process the rest of the candidates
509                 for (; training_index < last_training_index; ++training_index) {
510                 	if (removed_ && removed_points_.test(*training_index)) continue;
511                     // Compute the Hamming distance
512                     hamming_distance = distance_(vec, points_[*training_index], veclen_);
513                     result.addPoint(hamming_distance, *training_index);
514                 }
515             }
516         }
517     }
518 
519 
swap(LshIndex & other)520     void swap(LshIndex& other)
521     {
522     	BaseClass::swap(other);
523     	std::swap(tables_, other.tables_);
524     	std::swap(size_at_build_, other.size_at_build_);
525     	std::swap(table_number_, other.table_number_);
526     	std::swap(key_size_, other.key_size_);
527     	std::swap(multi_probe_level_, other.multi_probe_level_);
528     	std::swap(xor_masks_, other.xor_masks_);
529     }
530 
531     /** The different hash tables */
532     std::vector<lsh::LshTable<ElementType> > tables_;
533 
534     /** table number */
535     unsigned int table_number_;
536     /** key size */
537     unsigned int key_size_;
538     /** How far should we look for neighbors in multi-probe LSH */
539     unsigned int multi_probe_level_;
540 
541     /** The XOR masks to apply to a key to get the neighboring buckets */
542     std::vector<lsh::BucketKey> xor_masks_;
543 
544     USING_BASECLASS_SYMBOLS
545 };
546 }
547 
548 #endif //FLANN_LSH_INDEX_H_
549