1 /**
2  * @file methods/lsh/lsh_search.hpp
3  * @author Parikshit Ram
4  *
5  * Defines the LSHSearch class, which performs an approximate
6  * nearest neighbor search for a queries in a query set
7  * over a given dataset using Locality-sensitive hashing
8  * with 2-stable distributions.
9  *
10  * The details of this method can be found in the following paper:
11  *
12  * @code
13  * @inproceedings{datar2004locality,
14  *  title={Locality-sensitive hashing scheme based on p-stable distributions},
15  *  author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.},
16  *  booktitle=
17  *      {Proceedings of the 12th Annual Symposium on Computational Geometry},
18  *  pages={253--262},
19  *  year={2004},
20  *  organization={ACM}
21  * }
22  * @endcode
23  *
24  * Additionally, the class implements Multiprobe LSH, which improves
25  * approximation results during the search for approximate nearest neighbors.
26  * The Multiprobe LSH algorithm was presented in the paper:
27  *
28  * @code
29  * @inproceedings{Lv2007multiprobe,
30  *  tile={Multi-probe LSH: efficient indexing for high-dimensional similarity
31  *  search},
32  *  author={Lv, Qin and Josephson, William and Wang, Zhe and Charikar, Moses and
33  *  Li, Kai},
34  *  booktitle={Proceedings of the 33rd international conference on Very large
35  *  data bases},
36  *  year={2007},
37  *  pages={950--961}
38  * }
39  * @endcode
40  *
41  *
42  * mlpack is free software; you may redistribute it and/or modify it under the
43  * terms of the 3-clause BSD license.  You should have received a copy of the
44  * 3-clause BSD license along with mlpack.  If not, see
45  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
46  */
47 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
48 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
49 
50 #include <mlpack/prereqs.hpp>
51 
52 #include <mlpack/core/metrics/lmetric.hpp>
53 #include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
54 
55 #include <queue>
56 
57 namespace mlpack {
58 namespace neighbor {
59 
60 /**
61  * The LSHSearch class; this class builds a hash on the reference set and uses
62  * this hash to compute the distance-approximate nearest-neighbors of the given
63  * queries.
64  *
65  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
66  * @tparam MatType Type of matrix to use to store the data.
67  */
68 template<
69     typename SortPolicy = NearestNeighborSort,
70     typename MatType = arma::mat
71 >
72 class LSHSearch
73 {
74  public:
75   /**
76    * This function initializes the LSH class. It builds the hash on the
77    * reference set with 2-stable distributions. See the individual functions
78    * performing the hashing for details on how the hashing is done.  In order to
79    * avoid copying the reference set, it is suggested to pass that parameter
80    * with std::move().
81    *
82    * @param referenceSet Set of reference points and the set of queries.
83    * @param projections Cube of projection tables. For a cube of size (a, b, c)
84    *     we set numProj = a, numTables = c. b is the reference set
85    *     dimensionality.
86    * @param hashWidth The width of hash for every table. If 0 (the default) is
87    *     provided, then the hash width is automatically obtained by computing
88    *     the average pairwise distance of 25 pairs.  This should be a reasonable
89    *     upper bound on the nearest-neighbor distance in general.
90    * @param secondHashSize The size of the second hash table. This should be a
91    *     large prime number.
92    * @param bucketSize The size of the bucket in the second hash table. This is
93    *     the maximum number of points that can be hashed into single bucket.  A
94    *     value of 0 indicates that there is no limit (so the second hash table
95    *     can be arbitrarily large---be careful!).
96    */
97   LSHSearch(MatType referenceSet,
98             const arma::cube& projections,
99             const double hashWidth = 0.0,
100             const size_t secondHashSize = 99901,
101             const size_t bucketSize = 500);
102 
103   /**
104    * This function initializes the LSH class. It builds the hash one the
105    * reference set using the provided projections. See the individual functions
106    * performing the hashing for details on how the hashing is done.  In order to
107    * avoid copying the reference set, consider passing the set with std::move().
108    *
109    * @param referenceSet Set of reference points and the set of queries.
110    * @param numProj Number of projections in each hash table (anything between
111    *     10-50 might be a decent choice).
112    * @param numTables Total number of hash tables (anything between 10-20
113    *     should suffice).
114    * @param hashWidth The width of hash for every table. If 0 (the default) is
115    *     provided, then the hash width is automatically obtained by computing
116    *     the average pairwise distance of 25 pairs.  This should be a reasonable
117    *     upper bound on the nearest-neighbor distance in general.
118    * @param secondHashSize The size of the second hash table. This should be a
119    *     large prime number.
120    * @param bucketSize The size of the bucket in the second hash table. This is
121    *     the maximum number of points that can be hashed into single bucket.  A
122    *     value of 0 indicates that there is no limit (so the second hash table
123    *     can be arbitrarily large---be careful!).
124    */
125   LSHSearch(MatType referenceSet,
126             const size_t numProj,
127             const size_t numTables,
128             const double hashWidth = 0.0,
129             const size_t secondHashSize = 99901,
130             const size_t bucketSize = 500);
131 
132   /**
133    * Create an untrained LSH model.  Be sure to call Train() before calling
134    * Search(); otherwise, an exception will be thrown when Search() is called.
135    */
136   LSHSearch();
137 
138   /**
139    * Copy the given LSH model.
140    *
141    * @param other Other LSH model to copy.
142    */
143   LSHSearch(const LSHSearch& other);
144 
145   /**
146    * Take ownership of the given LSH model.
147    *
148    * @param other Other LSH model to take ownership of.
149    */
150   LSHSearch(LSHSearch&& other);
151 
152   /**
153    * Copy the given LSH model.
154    *
155    * @param other Other LSH model to copy.
156    */
157   LSHSearch& operator=(const LSHSearch& other);
158 
159   /**
160    * Take ownership of the given LSH model.
161    *
162    * @param other Other LSH model to take ownership of.
163    */
164   LSHSearch& operator=(LSHSearch&& other);
165 
166   /**
167    * Train the LSH model on the given dataset.  If a correctly-sized projection
168    * cube is not provided, this means building new hash tables. Otherwise, we
169    * use the projections provided by the user.  In order to avoid copying the
170    * reference set, consider passing that parameter with std::move().
171    *
172    * @param referenceSet Set of reference points and the set of queries.
173    * @param numProj Number of projections in each hash table (anything between
174    *     10-50 might be a decent choice).
175    * @param numTables Total number of hash tables (anything between 10-20
176    *     should suffice).
177    * @param hashWidth The width of hash for every table. If 0 (the default) is
178    *     provided, then the hash width is automatically obtained by computing
179    *     the average pairwise distance of 25 pairs.  This should be a reasonable
180    *     upper bound on the nearest-neighbor distance in general.
181    * @param secondHashSize The size of the second hash table. This should be a
182    *     large prime number.
183    * @param bucketSize The size of the bucket in the second hash table. This is
184    *     the maximum number of points that can be hashed into single bucket.  A
185    *     value of 0 indicates that there is no limit (so the second hash table
186    *     can be arbitrarily large---be careful!).
187    * @param projection Cube of projection tables. For a cube of size (a, b, c)
188    *     we set numProj = a, numTables = c. b is the reference set
189    *     dimensionality.
190    */
191   void Train(MatType referenceSet,
192              const size_t numProj,
193              const size_t numTables,
194              const double hashWidth = 0.0,
195              const size_t secondHashSize = 99901,
196              const size_t bucketSize = 500,
197              const arma::cube& projection = arma::cube());
198 
199   /**
200    * Compute the nearest neighbors of the points in the given query set and
201    * store the output in the given matrices.  The matrices will be set to the
202    * size of n columns by k rows, where n is the number of points in the query
203    * dataset and k is the number of neighbors being searched for.
204    *
205    * @param querySet Set of query points.
206    * @param k Number of neighbors to search for.
207    * @param resultingNeighbors Matrix storing lists of neighbors for each query
208    *     point.
209    * @param distances Matrix storing distances of neighbors for each query
210    *     point.
211    * @param numTablesToSearch This parameter allows the user to have control
212    *     over the number of hash tables to be searched. This allows
213    *     the user to pick the number of tables it can afford for the time
214    *     available without having to build hashing for every table size.
215    *     By default, this is set to zero in which case all tables are
216    *     considered.
217    * @param T The number of additional probing bins to examine with multiprobe
218    *     LSH. If T = 0, classic single-probe LSH is run (default).
219    */
220   void Search(const MatType& querySet,
221               const size_t k,
222               arma::Mat<size_t>& resultingNeighbors,
223               arma::mat& distances,
224               const size_t numTablesToSearch = 0,
225               const size_t T = 0);
226 
227   /**
228    * Compute the nearest neighbors and store the output in the given matrices.
229    * The matrices will be set to the size of n columns by k rows, where n is
230    * the number of points in the query dataset and k is the number of neighbors
231    * being searched for.
232    *
233    * @param k Number of neighbors to search for.
234    * @param resultingNeighbors Matrix storing lists of neighbors for each query
235    *     point.
236    * @param distances Matrix storing distances of neighbors for each query
237    *     point.
238    * @param numTablesToSearch This parameter allows the user to have control
239    *     over the number of hash tables to be searched. This allows
240    *     the user to pick the number of tables it can afford for the time
241    *     available without having to build hashing for every table size.
242    *     By default, this is set to zero in which case all tables are
243    *     considered.
244    * @param T Number of probing bins.
245    */
246   void Search(const size_t k,
247               arma::Mat<size_t>& resultingNeighbors,
248               arma::mat& distances,
249               const size_t numTablesToSearch = 0,
250               size_t T = 0);
251 
252   /**
253    * Compute the recall (% of neighbors found) given the neighbors returned by
254    * LSHSearch::Search and a "ground truth" set of neighbors.  The recall
255    * returned will be in the range [0, 1].
256    *
257    * @param foundNeighbors Set of neighbors to compute recall of.
258    * @param realNeighbors Set of "ground truth" neighbors to compute recall
259    *     against.
260    */
261   static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
262                               const arma::Mat<size_t>& realNeighbors);
263 
264   /**
265    * Serialize the LSH model.
266    *
267    * @param ar Archive to serialize to.
268    * @param version Version number.
269    */
270   template<typename Archive>
271   void serialize(Archive& ar, const unsigned int version);
272 
273   //! Return the number of distance evaluations performed.
DistanceEvaluations() const274   size_t DistanceEvaluations() const { return distanceEvaluations; }
275   //! Modify the number of distance evaluations performed.
DistanceEvaluations()276   size_t& DistanceEvaluations() { return distanceEvaluations; }
277 
278   //! Return the reference dataset.
ReferenceSet() const279   const MatType& ReferenceSet() const { return referenceSet; }
280 
281   //! Get the number of projections.
NumProjections() const282   size_t NumProjections() const { return projections.n_slices; }
283 
284   //! Get the offsets 'b' for each of the projections.  (One 'b' per column.)
Offsets() const285   const arma::mat& Offsets() const { return offsets; }
286 
287   //! Get the weights of the second hash.
SecondHashWeights() const288   const arma::vec& SecondHashWeights() const { return secondHashWeights; }
289 
290   //! Get the bucket size of the second hash.
BucketSize() const291   size_t BucketSize() const { return bucketSize; }
292 
293   //! Get the second hash table.
SecondHashTable() const294   const std::vector<arma::Col<size_t>>& SecondHashTable() const
295       { return secondHashTable; }
296 
297   //! Get the projection tables.
Projections()298   const arma::cube& Projections() { return projections; }
299 
300   //! Change the projection tables (this retrains the LSH model).
Projections(const arma::cube & projTables)301   void Projections(const arma::cube& projTables)
302   {
303     // Simply call Train() with the given projection tables.
304     Train(referenceSet, numProj, numTables, hashWidth, secondHashSize,
305         bucketSize, projTables);
306   }
307 
308  private:
309   /**
310    * This function takes a query and hashes it into each of the hash tables to
311    * get keys for the query and then the key is hashed to a bucket of the second
312    * hash table and all the points (if any) in those buckets are collected as
313    * the potential neighbor candidates.
314    *
315    * @param queryPoint The query point currently being processed.
316    * @param referenceIndices The list of neighbor candidates obtained from
317    *    hashing the query into all the hash tables and eventually into
318    *    multiple buckets of the second hash table.
319    * @param numTablesToSearch The number of tables to perform the search in. If
320    *    0, all tables are searched.
321    * @param T The number of additional probing bins for multiprobe LSH. If 0,
322    *    single-probe is used.
323    */
324   template<typename VecType>
325   void ReturnIndicesFromTable(const VecType& queryPoint,
326                               arma::uvec& referenceIndices,
327                               size_t numTablesToSearch,
328                               const size_t T) const;
329 
330   /**
331    * This is a helper function that computes the distance of the query to the
332    * neighbor candidates and appropriately stores the best 'k' candidates.  This
333    * is specific to the monochromatic search case, where the query set is the
334    * reference set.
335    *
336    * @param queryIndex The index of the query in question
337    * @param referenceIndices The vector of indices of candidate neighbors for
338    *    the query.
339    * @param k Number of neighbors to search for.
340    * @param neighbors Matrix holding output neighbors.
341    * @param distances Matrix holding output distances.
342    */
343   void BaseCase(const size_t queryIndex,
344                 const arma::uvec& referenceIndices,
345                 const size_t k,
346                 arma::Mat<size_t>& neighbors,
347                 arma::mat& distances) const;
348 
349   /**
350    * This is a helper function that computes the distance of the query to the
351    * neighbor candidates and appropriately stores the best 'k' candidates.  This
352    * is specific to bichromatic search, where the query set is not the same as
353    * the reference set.
354    *
355    * @param queryIndex The index of the query in question
356    * @param referenceIndices The vector of indices of candidate neighbors for
357    *    the query.
358    * @param k Number of neighbors to search for.
359    * @param querySet Set of query points.
360    * @param neighbors Matrix holding output neighbors.
361    * @param distances Matrix holding output distances.
362    */
363   void BaseCase(const size_t queryIndex,
364                 const arma::uvec& referenceIndices,
365                 const size_t k,
366                 const MatType& querySet,
367                 arma::Mat<size_t>& neighbors,
368                 arma::mat& distances) const;
369 
370   /**
371    * This function implements the core idea behind Multiprobe LSH. It is called
372    * by ReturnIndicesFromTables when T > 0. Given a query's code and its
373    * projection location, GetAdditionalProbingBins will calculate the T most
374    * likely alternative bin codes (other than queryCode) where a query's
375    * neighbors might be found in.
376    *
377    * @param queryCode vector containing the numProj-dimensional query code.
378    * @param queryCodeNotFloored vector containing the projection location of the
379    *    query.
380    * @param T number of additional probing bins.
381    * @param additionalProbingBins matrix. Each column will hold one additional
382    *    bin.
383   */
384   void GetAdditionalProbingBins(const arma::vec& queryCode,
385                                 const arma::vec& queryCodeNotFloored,
386                                 const size_t T,
387                                 arma::mat& additionalProbingBins) const;
388 
389   /**
390    * Returns the score of a perturbation vector generated by perturbation set A.
391    * The score of a pertubation set (vector) is the sum of scores of the
392    * participating actions.
393    * @param A perturbation set to compute the score of.
394    * @param scores vector containing score of each perturbation.
395   */
396   double PerturbationScore(const std::vector<bool>& A,
397                            const arma::vec& scores) const;
398 
399   /**
400    * Inline function used by GetAdditionalProbingBins. The vector shift
401    * operation replaces the largest element of a vector A with (largest element)
402    * + 1.  Returns true if resulting vector is valid, otherwise false.
403    *
404    * @param A perturbation set to shift.
405    */
406   bool PerturbationShift(std::vector<bool>& A) const;
407 
408   /**
409    * Inline function used by GetAdditionalProbingBins. The vector expansion
410    * operation adds the element [1 + (largest_element)] to a vector A, where
411    * largest_element is the largest element of A. Returns true if resulting
412    * vector is valid, otherwise false.
413    *
414    * @param A perturbation set to expand.
415    */
416   bool PerturbationExpand(std::vector<bool>& A) const;
417 
418   /**
419    * Return true if perturbation set A is valid. A perturbation set is invalid
420    * if it contains two (or more) actions for the same dimension or dimensions
421    * that are larger than the queryCode's dimensions.
422    *
423    * @param A perturbation set to validate.
424    */
425   bool PerturbationValid(const std::vector<bool>& A) const;
426 
427   //! Reference dataset.
428   MatType referenceSet;
429 
430   //! The number of projections.
431   size_t numProj;
432   //! The number of hash tables.
433   size_t numTables;
434 
435   //! The arma::cube containing the projection matrix of each table.
436   arma::cube projections; // should be [numProj x dims] x numTables slices
437 
438   //! The list of the offsets 'b' for each of the projection for each table.
439   arma::mat offsets; // should be numProj x numTables
440 
441   //! The hash width.
442   double hashWidth;
443 
444   //! The big prime representing the size of the second hash.
445   size_t secondHashSize;
446 
447   //! The weights of the second hash.
448   arma::vec secondHashWeights;
449 
450   //! The bucket size of the second hash.
451   size_t bucketSize;
452 
453   //! The final hash table; should be (< secondHashSize) vectors each with
454   //! (<= bucketSize) elements.
455   std::vector<arma::Col<size_t>> secondHashTable;
456 
457   //! The number of elements present in each hash bucket; should be
458   //! secondHashSize.
459   arma::Col<size_t> bucketContentSize;
460 
461   //! For a particular hash value, points to the row in secondHashTable
462   //! corresponding to this value. Length secondHashSize.
463   arma::Col<size_t> bucketRowInHashTable;
464 
465   //! The number of distance evaluations.
466   size_t distanceEvaluations;
467 
468   //! Candidate represents a possible candidate neighbor (distance, index).
469   typedef std::pair<double, size_t> Candidate;
470 
471   //! Compare two candidates based on the distance.
472   struct CandidateCmp {
operator ()mlpack::neighbor::LSHSearch::CandidateCmp473     bool operator()(const Candidate& c1, const Candidate& c2)
474     {
475       return !SortPolicy::IsBetter(c2.first, c1.first);
476     };
477   };
478 
479   //! Use a priority queue to represent the list of candidate neighbors.
480   typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
481       CandidateList;
482 }; // class LSHSearch
483 
484 } // namespace neighbor
485 } // namespace mlpack
486 
487 //! Set the serialization version of the LSHSearch class.
488 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
489     mlpack::neighbor::LSHSearch<SortPolicy>, 1);
490 
491 // Include implementation.
492 #include "lsh_search_impl.hpp"
493 
494 #endif
495