1 /**
2  * @file methods/lsh/lsh_search_impl.hpp
3  * @author Parikshit Ram
4  *
5  * Implementation of the LSHSearch class.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
13 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <mlpack/core/math/random.hpp>
17 
18 namespace mlpack {
19 namespace neighbor {
20 
21 // Construct the object with random tables
22 template<typename SortPolicy, typename MatType>
23 LSHSearch<SortPolicy, MatType>::
LSHSearch(MatType referenceSet,const size_t numProj,const size_t numTables,const double hashWidthIn,const size_t secondHashSize,const size_t bucketSize)24 LSHSearch(MatType referenceSet,
25           const size_t numProj,
26           const size_t numTables,
27           const double hashWidthIn,
28           const size_t secondHashSize,
29           const size_t bucketSize) :
30   numProj(numProj),
31   numTables(numTables),
32   hashWidth(hashWidthIn),
33   secondHashSize(secondHashSize),
34   bucketSize(bucketSize),
35   distanceEvaluations(0)
36 {
37   // Pass work to training function.
38   Train(std::move(referenceSet), numProj, numTables, hashWidthIn,
39       secondHashSize, bucketSize);
40 }
41 
42 // Construct the object with given tables
43 template<typename SortPolicy, typename MatType>
44 LSHSearch<SortPolicy, MatType>::
LSHSearch(MatType referenceSet,const arma::cube & projections,const double hashWidthIn,const size_t secondHashSize,const size_t bucketSize)45 LSHSearch(MatType referenceSet,
46           const arma::cube& projections,
47           const double hashWidthIn,
48           const size_t secondHashSize,
49           const size_t bucketSize) :
50   numProj(projections.n_cols),
51   numTables(projections.n_slices),
52   hashWidth(hashWidthIn),
53   secondHashSize(secondHashSize),
54   bucketSize(bucketSize),
55   distanceEvaluations(0)
56 {
57   // Pass work to training function.
58   Train(std::move(referenceSet), numProj, numTables, hashWidthIn,
59       secondHashSize, bucketSize, projections);
60 }
61 
62 // Empty constructor.
63 template<typename SortPolicy, typename MatType>
LSHSearch()64 LSHSearch<SortPolicy, MatType>::LSHSearch() :
65     numProj(0),
66     numTables(0),
67     hashWidth(0),
68     secondHashSize(99901),
69     bucketSize(500),
70     distanceEvaluations(0)
71 {
72 }
73 
74 // Copy constructor.
75 template<typename SortPolicy, typename MatType>
LSHSearch(const LSHSearch & other)76 LSHSearch<SortPolicy, MatType>::LSHSearch(const LSHSearch& other) :
77     referenceSet(other.referenceSet), // Copy the other set.
78     numProj(other.numProj),
79     numTables(other.numTables),
80     projections(other.projections),
81     offsets(other.offsets),
82     hashWidth(other.hashWidth),
83     secondHashSize(other.secondHashSize),
84     secondHashWeights(other.secondHashWeights),
85     bucketSize(other.bucketSize),
86     secondHashTable(other.secondHashTable),
87     bucketContentSize(other.bucketContentSize),
88     bucketRowInHashTable(other.bucketRowInHashTable),
89     distanceEvaluations(other.distanceEvaluations)
90 {
91   // Nothing to do.
92 }
93 
94 // Move constructor.
95 template<typename SortPolicy, typename MatType>
LSHSearch(LSHSearch && other)96 LSHSearch<SortPolicy, MatType>::LSHSearch(LSHSearch&& other) :
97     referenceSet(std::move(other.referenceSet)),
98     numProj(other.numProj),
99     numTables(other.numTables),
100     projections(std::move(other.projections)),
101     offsets(std::move(other.offsets)),
102     hashWidth(other.hashWidth),
103     secondHashSize(other.secondHashSize),
104     secondHashWeights(std::move(other.secondHashWeights)),
105     bucketSize(other.bucketSize),
106     secondHashTable(std::move(other.secondHashTable)),
107     bucketContentSize(std::move(other.bucketContentSize)),
108     bucketRowInHashTable(std::move(other.bucketRowInHashTable)),
109     distanceEvaluations(other.distanceEvaluations)
110 {
111   // Reset other model to defaults.
112   other.numProj = 0;
113   other.numTables = 0;
114   other.hashWidth = 0;
115   other.secondHashSize = 99901;
116   other.bucketSize = 500;
117   other.distanceEvaluations = 0;
118 }
119 
120 // Copy operator.
121 template<typename SortPolicy, typename MatType>
operator =(const LSHSearch & other)122 LSHSearch<SortPolicy, MatType>& LSHSearch<SortPolicy, MatType>::operator=(
123     const LSHSearch& other)
124 {
125   referenceSet = other.referenceSet;
126   numProj = other.numProj;
127   numTables = other.numTables;
128   projections = other.projections;
129   offsets = other.offsets;
130   hashWidth = other.hashWidth;
131   secondHashSize = other.secondHashSize;
132   secondHashWeights = other.secondHashWeights;
133   bucketSize = other.bucketSize;
134   secondHashTable = other.secondHashTable;
135   bucketContentSize = other.bucketContentSize;
136   bucketRowInHashTable = other.bucketRowInHashTable;
137   distanceEvaluations = other.distanceEvaluations;
138 
139   return *this;
140 }
141 
142 // Move operator.
143 template<typename SortPolicy, typename MatType>
operator =(LSHSearch && other)144 LSHSearch<SortPolicy, MatType>& LSHSearch<SortPolicy, MatType>::operator=(
145     LSHSearch&& other)
146 {
147   referenceSet = std::move(other.referenceSet);
148   numProj = other.numProj;
149   numTables = other.numTables;
150   projections = std::move(other.projections);
151   offsets = std::move(other.offsets);
152   hashWidth = other.hashWidth;
153   secondHashSize = other.secondHashSize;
154   secondHashWeights = std::move(other.secondHashWeights);
155   bucketSize = other.bucketSize;
156   secondHashTable = std::move(other.secondHashTable);
157   bucketContentSize = std::move(other.bucketContentSize);
158   bucketRowInHashTable = std::move(other.bucketRowInHashTable);
159   distanceEvaluations = other.distanceEvaluations;
160 
161   // Reset other model to defaults.
162   other.numProj = 0;
163   other.numTables = 0;
164   other.hashWidth = 0;
165   other.secondHashSize = 99901;
166   other.bucketSize = 500;
167   other.distanceEvaluations = 0;
168 
169   return *this;
170 }
171 
172 // Train on a new reference set.
173 template<typename SortPolicy, typename MatType>
Train(MatType referenceSet,const size_t numProj,const size_t numTables,const double hashWidthIn,const size_t secondHashSize,const size_t bucketSize,const arma::cube & projection)174 void LSHSearch<SortPolicy, MatType>::Train(MatType referenceSet,
175                                            const size_t numProj,
176                                            const size_t numTables,
177                                            const double hashWidthIn,
178                                            const size_t secondHashSize,
179                                            const size_t bucketSize,
180                                            const arma::cube& projection)
181 {
182   // Set new reference set.
183   this->referenceSet = std::move(referenceSet);
184 
185   // Set new parameters.
186   this->numProj = numProj;
187   this->numTables = numTables;
188   this->hashWidth = hashWidthIn;
189   this->secondHashSize = secondHashSize;
190   this->bucketSize = bucketSize;
191 
192   if (hashWidth == 0.0) // The user has not provided any value.
193   {
194     const size_t numSamples = 25;
195     // Compute a heuristic hash width from the data.
196     for (size_t i = 0; i < numSamples; ++i)
197     {
198       size_t p1 = (size_t) math::RandInt(this->referenceSet.n_cols);
199       size_t p2 = (size_t) math::RandInt(this->referenceSet.n_cols);
200 
201       hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
202           this->referenceSet.col(p1),
203           this->referenceSet.col(p2)));
204     }
205 
206     hashWidth /= numSamples;
207   }
208 
209   Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
210 
211   // Hash building procedure:
212   // The first level hash for a single table outputs a 'numProj'-dimensional
213   // integer key for each point in the set -- (key, pointID).  The key creation
214   // details are presented below.
215 
216   // Step I: Prepare the second level hash.
217 
218   // Obtain the weights for the second hash.
219   secondHashWeights = arma::floor(arma::randu(numProj) *
220                                   (double) secondHashSize);
221 
222   // Instead of putting the points in the row corresponding to the bucket, we
223   // chose the next empty row and keep track of the row in which the bucket
224   // lies. This allows us to stack together and slice out the empty buckets at
225   // the end of the hashing.
226   bucketRowInHashTable.set_size(secondHashSize);
227   bucketRowInHashTable.fill(secondHashSize);
228 
229   // Step II: The offsets for all projections in all tables.
230   // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
231   // as randu(numProj, numTables) * hashWidth.
232   offsets.randu(numProj, numTables);
233   offsets *= hashWidth;
234 
235   // Step III: Obtain the 'numProj' projections for each table.
236   projections.clear(); // Reset projections vector.
237 
238   if (projection.n_slices == 0) // Randomly generate the tables.
239   {
240     // For L2 metric, 2-stable distributions are used, and the normal Z ~ N(0,
241     // 1) is a 2-stable distribution.
242 
243     // Build numTables random tables arranged in a cube.
244     projections.randn(this->referenceSet.n_rows, numProj, numTables);
245   }
246   else if (projection.n_slices == numTables) // Take user-defined tables.
247   {
248     projections = projection;
249   }
250   else // The user gave something wrong.
251   {
252     throw std::invalid_argument("LSHSearch::Train(): number of projection "
253         "tables provided must be equal to numProj");
254   }
255 
256   // We will store the second hash vectors in this matrix; the second hash
257   // vector for table i will be held in row i.  We have to use int and not
258   // size_t, otherwise negative numbers are cast to 0.
259   arma::Mat<size_t> secondHashVectors(numTables, this->referenceSet.n_cols);
260 
261   for (size_t i = 0; i < numTables; ++i)
262   {
263     // Step IV: create the 'numProj'-dimensional key for each point in each
264     // table.
265 
266     // The following code performs the task of hashing each point to a
267     // 'numProj'-dimensional integer key.  Hence you get a ('numProj' x
268     // 'referenceSet.n_cols') key matrix.
269     //
270     // For a single table, let the 'numProj' projections be denoted by 'proj_i'
271     // and the corresponding offset be 'offset_i'.  Then the key of a single
272     // point is obtained as:
273     // key = { floor((<proj_i, point> + offset_i) / 'hashWidth') forall i }
274     arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
275                                        this->referenceSet.n_cols);
276     arma::mat hashMat = projections.slice(i).t() * (this->referenceSet);
277     hashMat += offsetMat;
278     hashMat /= hashWidth;
279 
280     // Step V: Putting the points in the 'secondHashTable' by hashing the key.
281     // Now we hash every key, point ID to its corresponding bucket.  We must
282     // also normalize the hashes to the range [0, secondHashSize).
283     arma::rowvec unmodVector = secondHashWeights.t() * arma::floor(hashMat);
284     for (size_t j = 0; j < unmodVector.n_elem; ++j)
285     {
286       double shs = (double) secondHashSize; // Convenience cast.
287       if (unmodVector[j] >= 0.0)
288       {
289         const size_t key = size_t(fmod(unmodVector[j], shs));
290         secondHashVectors(i, j) = key;
291       }
292       else
293       {
294         const double mod = fmod(-unmodVector[j], shs);
295         const size_t key = (mod < 1.0) ? 0 : secondHashSize - size_t(mod);
296         secondHashVectors(i, j) = key;
297       }
298     }
299   }
300 
301   // Now, using the hash vectors for each table, count the number of rows we
302   // have in the second hash table.
303   arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
304   for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
305     secondHashBinCounts[secondHashVectors[i]]++;
306 
307   // Enforce the maximum bucket size.
308   const size_t effectiveBucketSize = (bucketSize == 0) ? SIZE_MAX : bucketSize;
309   secondHashBinCounts.transform([effectiveBucketSize](size_t val)
310       { return std::min(val, effectiveBucketSize); });
311 
312   const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
313   bucketContentSize.zeros(numRowsInTable);
314   secondHashTable.resize(numRowsInTable);
315 
316   // Next we must assign each point in each table to the right second hash
317   // table.
318   size_t currentRow = 0;
319   for (size_t i = 0; i < numTables; ++i)
320   {
321     // Insert the point in the corresponding row to its bucket in the
322     // 'secondHashTable'.
323     for (size_t j = 0; j < secondHashVectors.n_cols; ++j)
324     {
325       // This is the bucket number.
326       size_t hashInd = (size_t) secondHashVectors(i, j);
327       // The point ID is 'j'.
328 
329       // If this is currently an empty bucket, start a new row keep track of
330       // which row corresponds to the bucket.
331       const size_t maxSize = secondHashBinCounts[hashInd];
332       if (bucketRowInHashTable[hashInd] == secondHashSize)
333       {
334         bucketRowInHashTable[hashInd] = currentRow;
335         secondHashTable[currentRow].set_size(maxSize);
336         currentRow++;
337       }
338 
339       // If this vector in the hash table is not full, add the point.
340       const size_t index = bucketRowInHashTable[hashInd];
341       if (bucketContentSize[index] < maxSize)
342         secondHashTable[index](bucketContentSize[index]++) = j;
343     } // Loop over all points in the reference set.
344   } // Loop over tables.
345 
346   Log::Info << "Final hash table size: " << numRowsInTable << " rows, with a "
347             << "maximum length of " << arma::max(secondHashBinCounts) << ", "
348             << "totaling " << arma::accu(secondHashBinCounts) << " elements."
349             << std::endl;
350 }
351 
352 // Base case where the query set is the reference set.  (So, we can't return
353 // ourselves as the nearest neighbor.)
354 template<typename SortPolicy, typename MatType>
355 inline force_inline
BaseCase(const size_t queryIndex,const arma::uvec & referenceIndices,const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances) const356 void LSHSearch<SortPolicy, MatType>::BaseCase(
357     const size_t queryIndex,
358     const arma::uvec& referenceIndices,
359     const size_t k,
360     arma::Mat<size_t>& neighbors,
361     arma::mat& distances) const
362 {
363   // Let's build the list of candidate neighbors for the given query point.
364   // It will be initialized with k candidates:
365   // (WorstDistance, referenceSet.n_cols)
366   const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
367       referenceSet.n_cols);
368   std::vector<Candidate> vect(k, def);
369   CandidateList pqueue(CandidateCmp(), std::move(vect));
370 
371   for (size_t j = 0; j < referenceIndices.n_elem; ++j)
372   {
373     const size_t referenceIndex = referenceIndices[j];
374     // If the points are the same, skip this point.
375     if (queryIndex == referenceIndex)
376       continue;
377 
378     const double distance = metric::EuclideanDistance::Evaluate(
379         referenceSet.col(queryIndex),
380         referenceSet.col(referenceIndex));
381 
382     Candidate c = std::make_pair(distance, referenceIndex);
383     // If this distance is better than the worst candidate, let's insert it.
384     if (CandidateCmp()(c, pqueue.top()))
385     {
386       pqueue.pop();
387       pqueue.push(c);
388     }
389   }
390 
391   for (size_t j = 1; j <= k; ++j)
392   {
393     neighbors(k - j, queryIndex) = pqueue.top().second;
394     distances(k - j, queryIndex) = pqueue.top().first;
395     pqueue.pop();
396   }
397 }
398 
399 // Base case for bichromatic search.
400 template<typename SortPolicy, typename MatType>
401 inline force_inline
BaseCase(const size_t queryIndex,const arma::uvec & referenceIndices,const size_t k,const MatType & querySet,arma::Mat<size_t> & neighbors,arma::mat & distances) const402 void LSHSearch<SortPolicy, MatType>::BaseCase(
403     const size_t queryIndex,
404     const arma::uvec& referenceIndices,
405     const size_t k,
406     const MatType& querySet,
407     arma::Mat<size_t>& neighbors,
408     arma::mat& distances) const
409 {
410   // Let's build the list of candidate neighbors for the given query point.
411   // It will be initialized with k candidates:
412   // (WorstDistance, referenceSet.n_cols)
413   const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
414       referenceSet.n_cols);
415   std::vector<Candidate> vect(k, def);
416   CandidateList pqueue(CandidateCmp(), std::move(vect));
417 
418   for (size_t j = 0; j < referenceIndices.n_elem; ++j)
419   {
420     const size_t referenceIndex = referenceIndices[j];
421     const double distance = metric::EuclideanDistance::Evaluate(
422         querySet.col(queryIndex),
423         referenceSet.col(referenceIndex));
424 
425     Candidate c = std::make_pair(distance, referenceIndex);
426     // If this distance is better than the worst candidate, let's insert it.
427     if (CandidateCmp()(c, pqueue.top()))
428     {
429       pqueue.pop();
430       pqueue.push(c);
431     }
432   }
433 
434   for (size_t j = 1; j <= k; ++j)
435   {
436     neighbors(k - j, queryIndex) = pqueue.top().second;
437     distances(k - j, queryIndex) = pqueue.top().first;
438     pqueue.pop();
439   }
440 }
441 
442 template<typename SortPolicy, typename MatType>
443 inline force_inline
PerturbationScore(const std::vector<bool> & A,const arma::vec & scores) const444 double LSHSearch<SortPolicy, MatType>::PerturbationScore(
445     const std::vector<bool>& A,
446     const arma::vec& scores) const
447 {
448   double score = 0.0;
449   for (size_t i = 0; i < A.size(); ++i)
450     if (A[i])
451       score += scores(i); // add scores of non-zero indices
452   return score;
453 }
454 
455 template<typename SortPolicy, typename MatType>
456 inline force_inline
PerturbationShift(std::vector<bool> & A) const457 bool LSHSearch<SortPolicy, MatType>::PerturbationShift(
458     std::vector<bool>& A) const
459 {
460   size_t maxPos = 0;
461   for (size_t i = 0; i < A.size(); ++i)
462     if (A[i] == 1) // Marked true.
463       maxPos = i;
464 
465   if (maxPos + 1 < A.size()) // Otherwise, this is an invalid vector.
466   {
467     A[maxPos] = 0;
468     A[maxPos + 1] = 1;
469     return true; // valid
470   }
471   return false; // invalid
472 }
473 
474 template<typename SortPolicy, typename MatType>
475 inline force_inline
PerturbationExpand(std::vector<bool> & A) const476 bool LSHSearch<SortPolicy, MatType>::PerturbationExpand(
477     std::vector<bool>& A) const
478 {
479   // Find the last '1' in A.
480   size_t maxPos = 0;
481   for (size_t i = 0; i < A.size(); ++i)
482     if (A[i]) // Marked true.
483       maxPos = i;
484 
485   if (maxPos + 1 < A.size()) // Otherwise, this is an invalid vector.
486   {
487     A[maxPos + 1] = 1;
488     return true;
489   }
490   return false;
491 }
492 
493 template<typename SortPolicy, typename MatType>
494 inline force_inline
PerturbationValid(const std::vector<bool> & A) const495 bool LSHSearch<SortPolicy, MatType>::PerturbationValid(
496     const std::vector<bool>& A) const
497 {
498   // Use check to mark dimensions we have seen before in A. If a dimension is
499   // seen twice (or more), A is not a valid perturbation.
500   std::vector<bool> check(numProj);
501 
502   if (A.size() > 2 * numProj)
503     return false; // This should never happen.
504 
505   // Check that we only see each dimension once. If not, vector is not valid.
506   for (size_t i = 0; i < A.size(); ++i)
507   {
508     // Only check dimensions that were included.
509     if (!A[i])
510       continue;
511 
512     // If dimesnion is unseen thus far, mark it as seen.
513     if (check[i % numProj] == false)
514       check[i % numProj] = true;
515     else
516       return false; // If dimension was seen before, set is not valid.
517   }
518   // If we didn't fail, set is valid.
519   return true;
520 }
521 
522 // Compute additional probing bins for a query
523 template<typename SortPolicy, typename MatType>
GetAdditionalProbingBins(const arma::vec & queryCode,const arma::vec & queryCodeNotFloored,const size_t T,arma::mat & additionalProbingBins) const524 void LSHSearch<SortPolicy, MatType>::GetAdditionalProbingBins(
525     const arma::vec& queryCode,
526     const arma::vec& queryCodeNotFloored,
527     const size_t T,
528     arma::mat& additionalProbingBins) const
529 {
530   // No additional bins requested. Our work is done.
531   if (T == 0)
532     return;
533 
534   // Each column of additionalProbingBins is the code of a bin.
535   additionalProbingBins.set_size(numProj, T);
536 
537   // Copy the query's code, then in the end we will  add/subtract according
538   // to perturbations we calculated.
539   for (size_t c = 0; c < T; ++c)
540     additionalProbingBins.col(c) = queryCode;
541 
542 
543   // Calculate query point's projection position.
544   arma::mat projection = queryCodeNotFloored;
545 
546   // Use projection to calculate query's distance from hash limits.
547   arma::vec limLow = projection - queryCode * hashWidth;
548   arma::vec limHigh = hashWidth - limLow;
549 
550   // Calculate scores. score = distance^2.
551   arma::vec scores(2 * numProj);
552   scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
553   scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
554 
555   // Actions vector describes what perturbation (-1/+1) corresponds to a score.
556   arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
557   actions.rows(0, numProj - 1) = // First numProj rows.
558     -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
559   actions.rows(numProj, (2 * numProj) - 1) = // Last numProj rows.
560     arma::ones< arma::Col<short int> > (numProj); // 1s
561 
562 
563   // Acting dimension vector shows which coordinate to transform according to
564   // actions (actions are described by actions vector above).
565   arma::Col<size_t> positions(2 * numProj); // Will be [0 1 2 ... 0 1 2 ...].
566   positions.rows(0, numProj - 1) =
567     arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
568   positions.rows(numProj, 2 * numProj - 1) =
569     arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
570 
571   // Special case: No need to create heap for 1 or 2 codes.
572   if (T <= 2)
573   {
574     // First, find location of minimum score, generate 1 perturbation vector,
575     // and add its code to additionalProbingBins column 0.
576 
577     // Find location and value of smallest element of scores vector.
578     double minscore = scores[0];
579     size_t minloc = 0;
580     for (size_t s = 1; s < (2 * numProj); ++s)
581     {
582       if (minscore > scores[s])
583       {
584         minscore = scores[s];
585         minloc = s;
586       }
587     }
588 
589     // Add or subtract 1 to dimension corresponding to minimum score.
590     additionalProbingBins(positions[minloc], 0) += actions[minloc];
591     if (T == 1)
592       return; // Done if asked for only 1 code.
593 
594     // Now, find location of second smallest score and generate one more vector.
595     // The second perturbation vector still can't comprise of more than one
596     // change in the bin codes, because of the way perturbation vectors
597     // are generated: First we create the one with the smallest score (Ao) and
598     // then we either add 1 extra dimension to it (Ae) or shift it by one (As).
599     // Since As contains the second smallest score, and Ae contains both the
600     // smallest and the second smallest, it's obvious that score(Ae) >
601     // score(As). Therefore the second perturbation vector is ALWAYS the vector
602     // containing only the second-lowest scoring perturbation.
603     double minscore2 = scores[0];
604     size_t minloc2 = 0;
605     for (size_t s = 0; s < (2 * numProj); ++s) // Here we can't start from 1.
606     {
607       if (minscore2 > scores[s] && s != minloc) // Second smallest.
608       {
609         minscore2 = scores[s];
610         minloc2 = s;
611       }
612     }
613 
614     // Add or subtract 1 to create second-lowest scoring vector.
615     additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
616     return;
617   }
618 
619   // General case: more than 2 perturbation vectors require use of minheap.
620   // Sort everything in increasing order.
621   arma::uvec sortidx = arma::sort_index(scores);
622   scores = scores(sortidx);
623   actions = actions(sortidx);
624   positions = positions(sortidx);
625 
626   // Theory:
627   // A probing sequence is a sequence of T probing bins where a query's
628   // neighbors are most likely to be. Likelihood is dependent only on a bin's
629   // score, which is the sum of scores of all dimension-action pairs, so we
630   // need to calculate the T smallest sums of scores that are not conflicting.
631   //
632   // Method:
633   // Store each perturbation set (pair of (dimension, action)) in a
634   // std::vector. Create a minheap of scores, with each node pointing to its
635   // relevant perturbation set. Each perturbation set popped from the minheap
636   // is the next most likely perturbation set.
637   // Transform perturbation set to perturbation vector by setting the
638   // dimensions specified by the set to queryCode+action (action is {-1, 1}).
639 
640   // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
641   // included in a given perturbation vector. Other spaces are 0.
642   std::vector<bool> Ao(2 * numProj);
643   Ao[0] = 1; // Smallest vector includes only smallest score.
644 
645   std::vector< std::vector<bool> > perturbationSets;
646   perturbationSets.push_back(Ao); // Storage of perturbation sets.
647 
648   std::priority_queue<
649     std::pair<double, size_t>,        // contents: pairs of (score, index)
650     std::vector<                      // container: vector of pairs
651       std::pair<double, size_t>
652       >,
653     std::greater< std::pair<double, size_t> > // comparator of pairs
654   > minHeap; // our minheap
655 
656   // Start by adding the lowest scoring set to the minheap.
657   minHeap.push(std::make_pair(PerturbationScore(Ao, scores), 0));
658 
659   // Loop invariable: after pvec iterations, additionalProbingBins contains pvec
660   // valid codes of the lowest-scoring bins (bins most likely to contain
661   // neighbors of the query).
662   for (size_t pvec = 0; pvec < T; ++pvec)
663   {
664     std::vector<bool> Ai;
665     do
666     {
667       // Get the perturbation set corresponding to the minimum score.
668       Ai = perturbationSets[ minHeap.top().second ];
669       minHeap.pop(); // .top() returns, .pop() removes
670 
671       // Shift operation on Ai (replace max with max+1).
672       std::vector<bool> As = Ai;
673 
674       // Don't add invalid sets.
675       if (PerturbationShift(As) && PerturbationValid(As))
676       {
677         perturbationSets.push_back(As); // add shifted set to sets
678         minHeap.push(
679             std::make_pair(PerturbationScore(As, scores),
680             perturbationSets.size() - 1));
681       }
682 
683       // Expand operation on Ai (add max+1 to set).
684       std::vector<bool> Ae = Ai;
685 
686       // Don't add invalid sets.
687       if (PerturbationExpand(Ae) && PerturbationValid(Ae))
688       {
689         perturbationSets.push_back(Ae); // add expanded set to sets
690         minHeap.push(
691             std::make_pair(PerturbationScore(Ae, scores),
692             perturbationSets.size() - 1));
693       }
694     } while (!PerturbationValid(Ai)); // Discard invalid perturbations
695 
696     // Found valid perturbation set Ai. Construct perturbation vector from set.
697     for (size_t pos = 0; pos < Ai.size(); ++pos)
698     {
699       // If Ai[pos] is marked, add action to probing vector.
700       additionalProbingBins(positions(pos), pvec) += Ai[pos] ? actions(pos) : 0;
701     }
702   }
703 }
704 
705 template<typename SortPolicy, typename MatType>
706 template<typename VecType>
ReturnIndicesFromTable(const VecType & queryPoint,arma::uvec & referenceIndices,size_t numTablesToSearch,const size_t T) const707 void LSHSearch<SortPolicy, MatType>::ReturnIndicesFromTable(
708     const VecType& queryPoint,
709     arma::uvec& referenceIndices,
710     size_t numTablesToSearch,
711     const size_t T) const
712 {
713   // Decide on the number of tables to look into.
714   if (numTablesToSearch == 0) // If no user input is given, search all.
715     numTablesToSearch = numTables;
716 
717   // Sanity check to make sure that the existing number of tables is not
718   // exceeded.
719   if (numTablesToSearch > numTables)
720     numTablesToSearch = numTables;
721 
722   // Hash the query in each of the 'numTablesToSearch' hash tables using the
723   // 'numProj' projections for each table. This gives us 'numTablesToSearch'
724   // keys for the query where each key is a 'numProj' dimensional integer
725   // vector.
726 
727   // Compute the projection of the query in each table.
728   arma::mat allProjInTables(numProj, numTablesToSearch);
729   arma::mat queryCodesNotFloored(numProj, numTablesToSearch);
730   for (size_t i = 0; i < numTablesToSearch; ++i)
731     queryCodesNotFloored.unsafe_col(i) = projections.slice(i).t() * queryPoint;
732 
733   queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
734   allProjInTables = arma::floor(queryCodesNotFloored / hashWidth);
735 
736   // Use hashMat to store the primary probing codes and any additional codes
737   // from multiprobe LSH.
738   arma::Mat<size_t> hashMat;
739   hashMat.set_size(T + 1, numTablesToSearch);
740 
741   // Compute the primary hash value of each key of the query into a bucket of
742   // the secondHashTable using the secondHashWeights.
743   hashMat.row(0) = arma::conv_to<arma::Row<size_t>> // Floor by typecasting
744       ::from(secondHashWeights.t() * allProjInTables);
745   // Mod to compute 2nd-level codes.
746   for (size_t i = 0; i < numTablesToSearch; ++i)
747     hashMat(0, i) = (hashMat(0, i) % secondHashSize);
748 
749   // Compute hash codes of additional probing bins.
750   if (T > 0)
751   {
752     for (size_t i = 0; i < numTablesToSearch; ++i)
753     {
754       // Construct this table's probing sequence of length T.
755       arma::mat additionalProbingBins;
756       GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
757                                 queryCodesNotFloored.unsafe_col(i),
758                                 T,
759                                 additionalProbingBins);
760 
761       // Map each probing bin to a bin in secondHashTable (just like we did for
762       // the primary hash table).
763       hashMat(arma::span(1, T), i) = // Compute code of rows 1:end of column i
764         arma::conv_to< arma::Col<size_t> >:: // floor by typecasting to size_t
765         from(secondHashWeights.t() * additionalProbingBins);
766       for (size_t p = 1; p < T + 1; ++p)
767         hashMat(p, i) = (hashMat(p, i) % secondHashSize);
768     }
769   }
770 
771   // Count number of points hashed in the same bucket as the query.
772   size_t maxNumPoints = 0;
773   for (size_t i = 0; i < numTablesToSearch; ++i)
774   {
775     for (size_t p = 0; p < T + 1; ++p)
776     {
777       const size_t hashInd = hashMat(p, i); // find query's bucket
778       const size_t tableRow = bucketRowInHashTable[hashInd];
779       if (tableRow < secondHashSize)
780         maxNumPoints += bucketContentSize[tableRow]; // count bucket contents
781     }
782   }
783 
784   // There are two ways to proceed here:
785   // Either allocate a maxNumPoints-size vector, place all candidates, and run
786   // unique on the vector to discard duplicates.
787   // Or allocate a referenceSet.n_cols size vector (i.e. number of reference
788   // points) of zeros, and mark found indices as 1.
789   // Option 1 runs faster for small maxNumPoints but worse for larger values, so
790   // we choose based on a heuristic.
791   const float cutoff = 0.1;
792   const float selectivity = static_cast<float>(maxNumPoints) /
793       static_cast<float>(referenceSet.n_cols);
794 
795   if (selectivity > cutoff)
796   {
797     // Heuristic: larger maxNumPoints means we should use find() because it
798     // should be faster.
799     // Reference points hashed in the same bucket as the query are set to >0.
800     arma::Col<size_t> refPointsConsidered;
801     refPointsConsidered.zeros(referenceSet.n_cols);
802 
803     for (size_t i = 0; i < numTablesToSearch; ++i) // for all tables
804     {
805       for (size_t p = 0; p < T + 1; ++p) // For entire probing sequence.
806       {
807         // get the sequence code
808         size_t hashInd = hashMat(p, i);
809         size_t tableRow = bucketRowInHashTable[hashInd];
810 
811         if (tableRow < secondHashSize && bucketContentSize[tableRow] > 0)
812         {
813           // Pick the indices in the bucket corresponding to hashInd.
814           for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
815             refPointsConsidered[ secondHashTable[tableRow](j) ]++;
816         }
817       }
818     }
819 
820     // Only keep reference points found in at least one bucket.
821     referenceIndices = arma::find(refPointsConsidered > 0);
822     return;
823   }
824   else
825   {
826     // Heuristic: smaller maxNumPoints means we should use unique() because it
827     // should be faster.
828     // Allocate space for the query's potential neighbors.
829     arma::uvec refPointsConsideredSmall;
830     refPointsConsideredSmall.zeros(maxNumPoints);
831 
832     // Retrieve candidates.
833     size_t start = 0;
834 
835     for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
836     {
837       for (size_t p = 0; p < T + 1; ++p)
838       {
839         const size_t hashInd =  hashMat(p, i); // Find the query's bucket.
840         const size_t tableRow = bucketRowInHashTable[hashInd];
841 
842         if (tableRow < secondHashSize)
843         {
844           // Store all secondHashTable points in the candidates set.
845           for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
846             refPointsConsideredSmall(start++) = secondHashTable[tableRow](j);
847        }
848       }
849     }
850 
851     // Keep only one copy of each candidate.
852     referenceIndices = arma::unique(refPointsConsideredSmall);
853     return;
854   }
855 }
856 
857 // Search for nearest neighbors in a given query set.
858 template<typename SortPolicy, typename MatType>
Search(const MatType & querySet,const size_t k,arma::Mat<size_t> & resultingNeighbors,arma::mat & distances,const size_t numTablesToSearch,const size_t T)859 void LSHSearch<SortPolicy, MatType>::Search(
860     const MatType& querySet,
861     const size_t k,
862     arma::Mat<size_t>& resultingNeighbors,
863     arma::mat& distances,
864     const size_t numTablesToSearch,
865     const size_t T)
866 {
867   // Ensure the dimensionality of the query set is correct.
868   if (querySet.n_rows != referenceSet.n_rows)
869   {
870     std::ostringstream oss;
871     oss << "LSHSearch::Search(): dimensionality of query set ("
872         << querySet.n_rows << ") is not equal to the dimensionality the model "
873         << "was trained on (" << referenceSet.n_rows << ")!" << std::endl;
874     throw std::invalid_argument(oss.str());
875   }
876 
877   if (k > referenceSet.n_cols)
878   {
879     std::ostringstream oss;
880     oss << "LSHSearch::Search(): requested " << k << " approximate nearest "
881         << "neighbors, but reference set has " << referenceSet.n_cols
882         << " points!" << std::endl;
883     throw std::invalid_argument(oss.str());
884   }
885 
886   // Set the size of the neighbor and distance matrices.
887   resultingNeighbors.set_size(k, querySet.n_cols);
888   distances.set_size(k, querySet.n_cols);
889 
890   // If the user asked for 0 nearest neighbors... uh... we're done.
891   if (k == 0)
892     return;
893 
894   // If the user requested more than the available number of additional probing
895   // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
896   size_t Teffective = T;
897   if (T > ((size_t) ((1 << numProj) - 1)))
898   {
899     Teffective = (1 << numProj) - 1;
900     Log::Warn << "Requested " << T << " additional bins are more than "
901         << "theoretical maximum. Using " << Teffective << " instead."
902         << std::endl;
903   }
904 
905   // If the user set multiprobe, log it
906   if (Teffective > 0)
907     Log::Info << "Running multiprobe LSH with " << Teffective
908         <<" additional probing bins per table per query." << std::endl;
909 
910   size_t avgIndicesReturned = 0;
911 
912   Timer::Start("computing_neighbors");
913 
914   // Parallelization to process more than one query at a time.
915   #pragma omp parallel for \
916       shared(resultingNeighbors, distances) \
917       schedule(dynamic)\
918       reduction(+:avgIndicesReturned)
919   for (omp_size_t i = 0; i < (omp_size_t) querySet.n_cols; ++i)
920   {
921     // Go through every query point.
922     // Hash every query into every hash table and eventually into the
923     // 'secondHashTable' to obtain the neighbor candidates.
924     arma::uvec refIndices;
925     ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch,
926         Teffective);
927 
928     // An informative book-keeping for the number of neighbor candidates
929     // returned on average.
930     // Make atomic to avoid race conditions when multiple threads are running
931     // #pragma omp atomic
932     avgIndicesReturned = avgIndicesReturned + refIndices.n_elem;
933 
934     // Sequentially go through all the candidates and save the best 'k'
935     // candidates.
936     BaseCase(i, refIndices, k, querySet, resultingNeighbors, distances);
937   }
938 
939   Timer::Stop("computing_neighbors");
940 
941   distanceEvaluations += avgIndicesReturned;
942   avgIndicesReturned /= querySet.n_cols;
943   Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
944       std::endl;
945 }
946 
947 // Search for approximate neighbors of the reference set.
948 template<typename SortPolicy, typename MatType>
949 void LSHSearch<SortPolicy, MatType>::
Search(const size_t k,arma::Mat<size_t> & resultingNeighbors,arma::mat & distances,const size_t numTablesToSearch,size_t T)950 Search(const size_t k,
951        arma::Mat<size_t>& resultingNeighbors,
952        arma::mat& distances,
953        const size_t numTablesToSearch,
954        size_t T)
955 {
956   // This is monochromatic search; the query set is the reference set.
957   resultingNeighbors.set_size(k, referenceSet.n_cols);
958   distances.set_size(k, referenceSet.n_cols);
959 
960   // If the user requested more than the available number of additional probing
961   // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
962   size_t Teffective = T;
963   if (T > ((size_t) ((1 << numProj) - 1)))
964   {
965     Teffective = (1 << numProj) - 1;
966     Log::Warn << "Requested " << T << " additional bins are more than "
967         << "theoretical maximum. Using " << Teffective << " instead."
968         << std::endl;
969   }
970 
971   // If the user set multiprobe, log it
972   if (T > 0)
973     Log::Info << "Running multiprobe LSH with " << Teffective <<
974       " additional probing bins per table per query."<< std::endl;
975 
976   size_t avgIndicesReturned = 0;
977 
978   Timer::Start("computing_neighbors");
979 
980   // Parallelization to process more than one query at a time.
981   #pragma omp parallel for \
982       shared(resultingNeighbors, distances) \
983       schedule(dynamic)\
984       reduction(+:avgIndicesReturned)
985   for (omp_size_t i = 0; i < (omp_size_t) referenceSet.n_cols; ++i)
986   {
987     // Go through every query point.
988     // Hash every query into every hash table and eventually into the
989     // 'secondHashTable' to obtain the neighbor candidates.
990     arma::uvec refIndices;
991     ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch,
992         Teffective);
993 
994     // An informative book-keeping for the number of neighbor candidates
995     // returned on average.
996     // Make atomic to avoid race conditions when multiple threads are running.
997     // #pragma omp atomic
998     avgIndicesReturned += refIndices.n_elem;
999 
1000     // Sequentially go through all the candidates and save the best 'k'
1001     // candidates.
1002     BaseCase(i, refIndices, k, resultingNeighbors, distances);
1003   }
1004 
1005   Timer::Stop("computing_neighbors");
1006 
1007   distanceEvaluations += avgIndicesReturned;
1008   avgIndicesReturned /= referenceSet.n_cols;
1009   Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
1010       std::endl;
1011 }
1012 
1013 template<typename SortPolicy, typename MatType>
ComputeRecall(const arma::Mat<size_t> & foundNeighbors,const arma::Mat<size_t> & realNeighbors)1014 double LSHSearch<SortPolicy, MatType>::ComputeRecall(
1015     const arma::Mat<size_t>& foundNeighbors,
1016     const arma::Mat<size_t>& realNeighbors)
1017 {
1018   if (foundNeighbors.n_rows != realNeighbors.n_rows ||
1019       foundNeighbors.n_cols != realNeighbors.n_cols)
1020     throw std::invalid_argument("LSHSearch::ComputeRecall(): matrices provided"
1021         " must have equal size");
1022 
1023   const size_t queries = foundNeighbors.n_cols;
1024   const size_t neighbors = foundNeighbors.n_rows; // Should be equal to k.
1025 
1026   // The recall is the set intersection of found and real neighbors.
1027   size_t found = 0;
1028   for (size_t col = 0; col < queries; ++col)
1029     for (size_t row = 0; row < neighbors; ++row)
1030       for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
1031         if (realNeighbors(row, col) == foundNeighbors(nei, col))
1032         {
1033           found++;
1034           break;
1035         }
1036 
1037   return ((double) found) / realNeighbors.n_elem;
1038 }
1039 
1040 template<typename SortPolicy, typename MatType>
1041 template<typename Archive>
serialize(Archive & ar,const unsigned int version)1042 void LSHSearch<SortPolicy, MatType>::serialize(Archive& ar,
1043                                                const unsigned int version)
1044 {
1045   ar & BOOST_SERIALIZATION_NVP(referenceSet);
1046   ar & BOOST_SERIALIZATION_NVP(numProj);
1047   ar & BOOST_SERIALIZATION_NVP(numTables);
1048 
1049   // Delete existing projections, if necessary.
1050   if (Archive::is_loading::value)
1051     projections.reset();
1052 
1053   // Backward compatibility: older versions of LSHSearch stored the projection
1054   // tables in a std::vector<arma::mat>.
1055   if (version == 0)
1056   {
1057     std::vector<arma::mat> tmpProj;
1058     ar & BOOST_SERIALIZATION_NVP(tmpProj);
1059 
1060     projections.set_size(tmpProj[0].n_rows, tmpProj[0].n_cols, tmpProj.size());
1061     for (size_t i = 0; i < tmpProj.size(); ++i)
1062       projections.slice(i) = tmpProj[i];
1063   }
1064   else
1065   {
1066     ar & BOOST_SERIALIZATION_NVP(projections);
1067   }
1068 
1069   ar & BOOST_SERIALIZATION_NVP(offsets);
1070   ar & BOOST_SERIALIZATION_NVP(hashWidth);
1071   ar & BOOST_SERIALIZATION_NVP(secondHashSize);
1072   ar & BOOST_SERIALIZATION_NVP(secondHashWeights);
1073   ar & BOOST_SERIALIZATION_NVP(bucketSize);
1074   // needs specific handling for new version
1075 
1076   // Backward compatibility: in older versions of LSHSearch, the secondHashTable
1077   // was stored as an arma::Mat<size_t>.  So we need to properly load that, then
1078   // prune it down to size.
1079   if (version == 0)
1080   {
1081     arma::Mat<size_t> tmpSecondHashTable;
1082     ar & BOOST_SERIALIZATION_NVP(tmpSecondHashTable);
1083 
1084     // The old secondHashTable was stored in row-major format, so we transpose
1085     // it.
1086     tmpSecondHashTable = tmpSecondHashTable.t();
1087 
1088     secondHashTable.resize(tmpSecondHashTable.n_cols);
1089     for (size_t i = 0; i < tmpSecondHashTable.n_cols; ++i)
1090     {
1091       // Find length of each column.  We know we are at the end of the list when
1092       // the value referenceSet.n_cols is seen.
1093 
1094       size_t len = 0;
1095       for (; len < tmpSecondHashTable.n_rows; ++len)
1096         if (tmpSecondHashTable(len, i) == referenceSet.n_cols)
1097           break;
1098 
1099       // Set the size of the new column correctly.
1100       secondHashTable[i].set_size(len);
1101       for (size_t j = 0; j < len; ++j)
1102         secondHashTable[i](j) = tmpSecondHashTable(j, i);
1103     }
1104   }
1105   else
1106   {
1107     size_t tables;
1108     if (Archive::is_saving::value)
1109       tables = secondHashTable.size();
1110     ar & BOOST_SERIALIZATION_NVP(tables);
1111 
1112     // Set size of second hash table if needed.
1113     if (Archive::is_loading::value)
1114     {
1115       secondHashTable.clear();
1116       secondHashTable.resize(tables);
1117     }
1118 
1119     ar & BOOST_SERIALIZATION_NVP(secondHashTable);
1120   }
1121 
1122   // Backward compatibility: old versions of LSHSearch held bucketContentSize
1123   // for all possible buckets (of size secondHashSize), but now we hold a
1124   // compressed representation.
1125   if (version == 0)
1126   {
1127     // The vector was stored in the old uncompressed form.  So we need to shrink
1128     // it.  But we can't do that until we have bucketRowInHashTable, so we also
1129     // have to load that.
1130     arma::Col<size_t> tmpBucketContentSize;
1131     ar & BOOST_SERIALIZATION_NVP(tmpBucketContentSize);
1132     ar & BOOST_SERIALIZATION_NVP(bucketRowInHashTable);
1133 
1134     // Compress into a smaller vector by just dropping all of the zeros.
1135     bucketContentSize.set_size(secondHashTable.size());
1136     for (size_t i = 0; i < tmpBucketContentSize.n_elem; ++i)
1137       if (tmpBucketContentSize[i] > 0)
1138         bucketContentSize[bucketRowInHashTable[i]] = tmpBucketContentSize[i];
1139   }
1140   else
1141   {
1142     ar & BOOST_SERIALIZATION_NVP(bucketContentSize);
1143     ar & BOOST_SERIALIZATION_NVP(bucketRowInHashTable);
1144   }
1145 
1146   ar & BOOST_SERIALIZATION_NVP(distanceEvaluations);
1147 }
1148 
1149 } // namespace neighbor
1150 } // namespace mlpack
1151 
1152 #endif
1153