1 /**
2 * @file methods/lsh/lsh_search_impl.hpp
3 * @author Parikshit Ram
4 *
5 * Implementation of the LSHSearch class.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
13 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_IMPL_HPP
14
15 #include <mlpack/prereqs.hpp>
16 #include <mlpack/core/math/random.hpp>
17
18 namespace mlpack {
19 namespace neighbor {
20
21 // Construct the object with random tables
22 template<typename SortPolicy, typename MatType>
23 LSHSearch<SortPolicy, MatType>::
LSHSearch(MatType referenceSet,const size_t numProj,const size_t numTables,const double hashWidthIn,const size_t secondHashSize,const size_t bucketSize)24 LSHSearch(MatType referenceSet,
25 const size_t numProj,
26 const size_t numTables,
27 const double hashWidthIn,
28 const size_t secondHashSize,
29 const size_t bucketSize) :
30 numProj(numProj),
31 numTables(numTables),
32 hashWidth(hashWidthIn),
33 secondHashSize(secondHashSize),
34 bucketSize(bucketSize),
35 distanceEvaluations(0)
36 {
37 // Pass work to training function.
38 Train(std::move(referenceSet), numProj, numTables, hashWidthIn,
39 secondHashSize, bucketSize);
40 }
41
42 // Construct the object with given tables
43 template<typename SortPolicy, typename MatType>
44 LSHSearch<SortPolicy, MatType>::
LSHSearch(MatType referenceSet,const arma::cube & projections,const double hashWidthIn,const size_t secondHashSize,const size_t bucketSize)45 LSHSearch(MatType referenceSet,
46 const arma::cube& projections,
47 const double hashWidthIn,
48 const size_t secondHashSize,
49 const size_t bucketSize) :
50 numProj(projections.n_cols),
51 numTables(projections.n_slices),
52 hashWidth(hashWidthIn),
53 secondHashSize(secondHashSize),
54 bucketSize(bucketSize),
55 distanceEvaluations(0)
56 {
57 // Pass work to training function.
58 Train(std::move(referenceSet), numProj, numTables, hashWidthIn,
59 secondHashSize, bucketSize, projections);
60 }
61
62 // Empty constructor.
63 template<typename SortPolicy, typename MatType>
LSHSearch()64 LSHSearch<SortPolicy, MatType>::LSHSearch() :
65 numProj(0),
66 numTables(0),
67 hashWidth(0),
68 secondHashSize(99901),
69 bucketSize(500),
70 distanceEvaluations(0)
71 {
72 }
73
74 // Copy constructor.
75 template<typename SortPolicy, typename MatType>
LSHSearch(const LSHSearch & other)76 LSHSearch<SortPolicy, MatType>::LSHSearch(const LSHSearch& other) :
77 referenceSet(other.referenceSet), // Copy the other set.
78 numProj(other.numProj),
79 numTables(other.numTables),
80 projections(other.projections),
81 offsets(other.offsets),
82 hashWidth(other.hashWidth),
83 secondHashSize(other.secondHashSize),
84 secondHashWeights(other.secondHashWeights),
85 bucketSize(other.bucketSize),
86 secondHashTable(other.secondHashTable),
87 bucketContentSize(other.bucketContentSize),
88 bucketRowInHashTable(other.bucketRowInHashTable),
89 distanceEvaluations(other.distanceEvaluations)
90 {
91 // Nothing to do.
92 }
93
94 // Move constructor.
95 template<typename SortPolicy, typename MatType>
LSHSearch(LSHSearch && other)96 LSHSearch<SortPolicy, MatType>::LSHSearch(LSHSearch&& other) :
97 referenceSet(std::move(other.referenceSet)),
98 numProj(other.numProj),
99 numTables(other.numTables),
100 projections(std::move(other.projections)),
101 offsets(std::move(other.offsets)),
102 hashWidth(other.hashWidth),
103 secondHashSize(other.secondHashSize),
104 secondHashWeights(std::move(other.secondHashWeights)),
105 bucketSize(other.bucketSize),
106 secondHashTable(std::move(other.secondHashTable)),
107 bucketContentSize(std::move(other.bucketContentSize)),
108 bucketRowInHashTable(std::move(other.bucketRowInHashTable)),
109 distanceEvaluations(other.distanceEvaluations)
110 {
111 // Reset other model to defaults.
112 other.numProj = 0;
113 other.numTables = 0;
114 other.hashWidth = 0;
115 other.secondHashSize = 99901;
116 other.bucketSize = 500;
117 other.distanceEvaluations = 0;
118 }
119
120 // Copy operator.
121 template<typename SortPolicy, typename MatType>
operator =(const LSHSearch & other)122 LSHSearch<SortPolicy, MatType>& LSHSearch<SortPolicy, MatType>::operator=(
123 const LSHSearch& other)
124 {
125 referenceSet = other.referenceSet;
126 numProj = other.numProj;
127 numTables = other.numTables;
128 projections = other.projections;
129 offsets = other.offsets;
130 hashWidth = other.hashWidth;
131 secondHashSize = other.secondHashSize;
132 secondHashWeights = other.secondHashWeights;
133 bucketSize = other.bucketSize;
134 secondHashTable = other.secondHashTable;
135 bucketContentSize = other.bucketContentSize;
136 bucketRowInHashTable = other.bucketRowInHashTable;
137 distanceEvaluations = other.distanceEvaluations;
138
139 return *this;
140 }
141
142 // Move operator.
143 template<typename SortPolicy, typename MatType>
operator =(LSHSearch && other)144 LSHSearch<SortPolicy, MatType>& LSHSearch<SortPolicy, MatType>::operator=(
145 LSHSearch&& other)
146 {
147 referenceSet = std::move(other.referenceSet);
148 numProj = other.numProj;
149 numTables = other.numTables;
150 projections = std::move(other.projections);
151 offsets = std::move(other.offsets);
152 hashWidth = other.hashWidth;
153 secondHashSize = other.secondHashSize;
154 secondHashWeights = std::move(other.secondHashWeights);
155 bucketSize = other.bucketSize;
156 secondHashTable = std::move(other.secondHashTable);
157 bucketContentSize = std::move(other.bucketContentSize);
158 bucketRowInHashTable = std::move(other.bucketRowInHashTable);
159 distanceEvaluations = other.distanceEvaluations;
160
161 // Reset other model to defaults.
162 other.numProj = 0;
163 other.numTables = 0;
164 other.hashWidth = 0;
165 other.secondHashSize = 99901;
166 other.bucketSize = 500;
167 other.distanceEvaluations = 0;
168
169 return *this;
170 }
171
172 // Train on a new reference set.
173 template<typename SortPolicy, typename MatType>
Train(MatType referenceSet,const size_t numProj,const size_t numTables,const double hashWidthIn,const size_t secondHashSize,const size_t bucketSize,const arma::cube & projection)174 void LSHSearch<SortPolicy, MatType>::Train(MatType referenceSet,
175 const size_t numProj,
176 const size_t numTables,
177 const double hashWidthIn,
178 const size_t secondHashSize,
179 const size_t bucketSize,
180 const arma::cube& projection)
181 {
182 // Set new reference set.
183 this->referenceSet = std::move(referenceSet);
184
185 // Set new parameters.
186 this->numProj = numProj;
187 this->numTables = numTables;
188 this->hashWidth = hashWidthIn;
189 this->secondHashSize = secondHashSize;
190 this->bucketSize = bucketSize;
191
192 if (hashWidth == 0.0) // The user has not provided any value.
193 {
194 const size_t numSamples = 25;
195 // Compute a heuristic hash width from the data.
196 for (size_t i = 0; i < numSamples; ++i)
197 {
198 size_t p1 = (size_t) math::RandInt(this->referenceSet.n_cols);
199 size_t p2 = (size_t) math::RandInt(this->referenceSet.n_cols);
200
201 hashWidth += std::sqrt(metric::EuclideanDistance::Evaluate(
202 this->referenceSet.col(p1),
203 this->referenceSet.col(p2)));
204 }
205
206 hashWidth /= numSamples;
207 }
208
209 Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
210
211 // Hash building procedure:
212 // The first level hash for a single table outputs a 'numProj'-dimensional
213 // integer key for each point in the set -- (key, pointID). The key creation
214 // details are presented below.
215
216 // Step I: Prepare the second level hash.
217
218 // Obtain the weights for the second hash.
219 secondHashWeights = arma::floor(arma::randu(numProj) *
220 (double) secondHashSize);
221
222 // Instead of putting the points in the row corresponding to the bucket, we
223 // chose the next empty row and keep track of the row in which the bucket
224 // lies. This allows us to stack together and slice out the empty buckets at
225 // the end of the hashing.
226 bucketRowInHashTable.set_size(secondHashSize);
227 bucketRowInHashTable.fill(secondHashSize);
228
229 // Step II: The offsets for all projections in all tables.
230 // Since the 'offsets' are in [0, hashWidth], we obtain the 'offsets'
231 // as randu(numProj, numTables) * hashWidth.
232 offsets.randu(numProj, numTables);
233 offsets *= hashWidth;
234
235 // Step III: Obtain the 'numProj' projections for each table.
236 projections.clear(); // Reset projections vector.
237
238 if (projection.n_slices == 0) // Randomly generate the tables.
239 {
240 // For L2 metric, 2-stable distributions are used, and the normal Z ~ N(0,
241 // 1) is a 2-stable distribution.
242
243 // Build numTables random tables arranged in a cube.
244 projections.randn(this->referenceSet.n_rows, numProj, numTables);
245 }
246 else if (projection.n_slices == numTables) // Take user-defined tables.
247 {
248 projections = projection;
249 }
250 else // The user gave something wrong.
251 {
252 throw std::invalid_argument("LSHSearch::Train(): number of projection "
253 "tables provided must be equal to numProj");
254 }
255
256 // We will store the second hash vectors in this matrix; the second hash
257 // vector for table i will be held in row i. We have to use int and not
258 // size_t, otherwise negative numbers are cast to 0.
259 arma::Mat<size_t> secondHashVectors(numTables, this->referenceSet.n_cols);
260
261 for (size_t i = 0; i < numTables; ++i)
262 {
263 // Step IV: create the 'numProj'-dimensional key for each point in each
264 // table.
265
266 // The following code performs the task of hashing each point to a
267 // 'numProj'-dimensional integer key. Hence you get a ('numProj' x
268 // 'referenceSet.n_cols') key matrix.
269 //
270 // For a single table, let the 'numProj' projections be denoted by 'proj_i'
271 // and the corresponding offset be 'offset_i'. Then the key of a single
272 // point is obtained as:
273 // key = { floor((<proj_i, point> + offset_i) / 'hashWidth') forall i }
274 arma::mat offsetMat = arma::repmat(offsets.unsafe_col(i), 1,
275 this->referenceSet.n_cols);
276 arma::mat hashMat = projections.slice(i).t() * (this->referenceSet);
277 hashMat += offsetMat;
278 hashMat /= hashWidth;
279
280 // Step V: Putting the points in the 'secondHashTable' by hashing the key.
281 // Now we hash every key, point ID to its corresponding bucket. We must
282 // also normalize the hashes to the range [0, secondHashSize).
283 arma::rowvec unmodVector = secondHashWeights.t() * arma::floor(hashMat);
284 for (size_t j = 0; j < unmodVector.n_elem; ++j)
285 {
286 double shs = (double) secondHashSize; // Convenience cast.
287 if (unmodVector[j] >= 0.0)
288 {
289 const size_t key = size_t(fmod(unmodVector[j], shs));
290 secondHashVectors(i, j) = key;
291 }
292 else
293 {
294 const double mod = fmod(-unmodVector[j], shs);
295 const size_t key = (mod < 1.0) ? 0 : secondHashSize - size_t(mod);
296 secondHashVectors(i, j) = key;
297 }
298 }
299 }
300
301 // Now, using the hash vectors for each table, count the number of rows we
302 // have in the second hash table.
303 arma::Row<size_t> secondHashBinCounts(secondHashSize, arma::fill::zeros);
304 for (size_t i = 0; i < secondHashVectors.n_elem; ++i)
305 secondHashBinCounts[secondHashVectors[i]]++;
306
307 // Enforce the maximum bucket size.
308 const size_t effectiveBucketSize = (bucketSize == 0) ? SIZE_MAX : bucketSize;
309 secondHashBinCounts.transform([effectiveBucketSize](size_t val)
310 { return std::min(val, effectiveBucketSize); });
311
312 const size_t numRowsInTable = arma::accu(secondHashBinCounts > 0);
313 bucketContentSize.zeros(numRowsInTable);
314 secondHashTable.resize(numRowsInTable);
315
316 // Next we must assign each point in each table to the right second hash
317 // table.
318 size_t currentRow = 0;
319 for (size_t i = 0; i < numTables; ++i)
320 {
321 // Insert the point in the corresponding row to its bucket in the
322 // 'secondHashTable'.
323 for (size_t j = 0; j < secondHashVectors.n_cols; ++j)
324 {
325 // This is the bucket number.
326 size_t hashInd = (size_t) secondHashVectors(i, j);
327 // The point ID is 'j'.
328
329 // If this is currently an empty bucket, start a new row keep track of
330 // which row corresponds to the bucket.
331 const size_t maxSize = secondHashBinCounts[hashInd];
332 if (bucketRowInHashTable[hashInd] == secondHashSize)
333 {
334 bucketRowInHashTable[hashInd] = currentRow;
335 secondHashTable[currentRow].set_size(maxSize);
336 currentRow++;
337 }
338
339 // If this vector in the hash table is not full, add the point.
340 const size_t index = bucketRowInHashTable[hashInd];
341 if (bucketContentSize[index] < maxSize)
342 secondHashTable[index](bucketContentSize[index]++) = j;
343 } // Loop over all points in the reference set.
344 } // Loop over tables.
345
346 Log::Info << "Final hash table size: " << numRowsInTable << " rows, with a "
347 << "maximum length of " << arma::max(secondHashBinCounts) << ", "
348 << "totaling " << arma::accu(secondHashBinCounts) << " elements."
349 << std::endl;
350 }
351
352 // Base case where the query set is the reference set. (So, we can't return
353 // ourselves as the nearest neighbor.)
354 template<typename SortPolicy, typename MatType>
355 inline force_inline
BaseCase(const size_t queryIndex,const arma::uvec & referenceIndices,const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances) const356 void LSHSearch<SortPolicy, MatType>::BaseCase(
357 const size_t queryIndex,
358 const arma::uvec& referenceIndices,
359 const size_t k,
360 arma::Mat<size_t>& neighbors,
361 arma::mat& distances) const
362 {
363 // Let's build the list of candidate neighbors for the given query point.
364 // It will be initialized with k candidates:
365 // (WorstDistance, referenceSet.n_cols)
366 const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
367 referenceSet.n_cols);
368 std::vector<Candidate> vect(k, def);
369 CandidateList pqueue(CandidateCmp(), std::move(vect));
370
371 for (size_t j = 0; j < referenceIndices.n_elem; ++j)
372 {
373 const size_t referenceIndex = referenceIndices[j];
374 // If the points are the same, skip this point.
375 if (queryIndex == referenceIndex)
376 continue;
377
378 const double distance = metric::EuclideanDistance::Evaluate(
379 referenceSet.col(queryIndex),
380 referenceSet.col(referenceIndex));
381
382 Candidate c = std::make_pair(distance, referenceIndex);
383 // If this distance is better than the worst candidate, let's insert it.
384 if (CandidateCmp()(c, pqueue.top()))
385 {
386 pqueue.pop();
387 pqueue.push(c);
388 }
389 }
390
391 for (size_t j = 1; j <= k; ++j)
392 {
393 neighbors(k - j, queryIndex) = pqueue.top().second;
394 distances(k - j, queryIndex) = pqueue.top().first;
395 pqueue.pop();
396 }
397 }
398
399 // Base case for bichromatic search.
400 template<typename SortPolicy, typename MatType>
401 inline force_inline
BaseCase(const size_t queryIndex,const arma::uvec & referenceIndices,const size_t k,const MatType & querySet,arma::Mat<size_t> & neighbors,arma::mat & distances) const402 void LSHSearch<SortPolicy, MatType>::BaseCase(
403 const size_t queryIndex,
404 const arma::uvec& referenceIndices,
405 const size_t k,
406 const MatType& querySet,
407 arma::Mat<size_t>& neighbors,
408 arma::mat& distances) const
409 {
410 // Let's build the list of candidate neighbors for the given query point.
411 // It will be initialized with k candidates:
412 // (WorstDistance, referenceSet.n_cols)
413 const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
414 referenceSet.n_cols);
415 std::vector<Candidate> vect(k, def);
416 CandidateList pqueue(CandidateCmp(), std::move(vect));
417
418 for (size_t j = 0; j < referenceIndices.n_elem; ++j)
419 {
420 const size_t referenceIndex = referenceIndices[j];
421 const double distance = metric::EuclideanDistance::Evaluate(
422 querySet.col(queryIndex),
423 referenceSet.col(referenceIndex));
424
425 Candidate c = std::make_pair(distance, referenceIndex);
426 // If this distance is better than the worst candidate, let's insert it.
427 if (CandidateCmp()(c, pqueue.top()))
428 {
429 pqueue.pop();
430 pqueue.push(c);
431 }
432 }
433
434 for (size_t j = 1; j <= k; ++j)
435 {
436 neighbors(k - j, queryIndex) = pqueue.top().second;
437 distances(k - j, queryIndex) = pqueue.top().first;
438 pqueue.pop();
439 }
440 }
441
442 template<typename SortPolicy, typename MatType>
443 inline force_inline
PerturbationScore(const std::vector<bool> & A,const arma::vec & scores) const444 double LSHSearch<SortPolicy, MatType>::PerturbationScore(
445 const std::vector<bool>& A,
446 const arma::vec& scores) const
447 {
448 double score = 0.0;
449 for (size_t i = 0; i < A.size(); ++i)
450 if (A[i])
451 score += scores(i); // add scores of non-zero indices
452 return score;
453 }
454
455 template<typename SortPolicy, typename MatType>
456 inline force_inline
PerturbationShift(std::vector<bool> & A) const457 bool LSHSearch<SortPolicy, MatType>::PerturbationShift(
458 std::vector<bool>& A) const
459 {
460 size_t maxPos = 0;
461 for (size_t i = 0; i < A.size(); ++i)
462 if (A[i] == 1) // Marked true.
463 maxPos = i;
464
465 if (maxPos + 1 < A.size()) // Otherwise, this is an invalid vector.
466 {
467 A[maxPos] = 0;
468 A[maxPos + 1] = 1;
469 return true; // valid
470 }
471 return false; // invalid
472 }
473
474 template<typename SortPolicy, typename MatType>
475 inline force_inline
PerturbationExpand(std::vector<bool> & A) const476 bool LSHSearch<SortPolicy, MatType>::PerturbationExpand(
477 std::vector<bool>& A) const
478 {
479 // Find the last '1' in A.
480 size_t maxPos = 0;
481 for (size_t i = 0; i < A.size(); ++i)
482 if (A[i]) // Marked true.
483 maxPos = i;
484
485 if (maxPos + 1 < A.size()) // Otherwise, this is an invalid vector.
486 {
487 A[maxPos + 1] = 1;
488 return true;
489 }
490 return false;
491 }
492
493 template<typename SortPolicy, typename MatType>
494 inline force_inline
PerturbationValid(const std::vector<bool> & A) const495 bool LSHSearch<SortPolicy, MatType>::PerturbationValid(
496 const std::vector<bool>& A) const
497 {
498 // Use check to mark dimensions we have seen before in A. If a dimension is
499 // seen twice (or more), A is not a valid perturbation.
500 std::vector<bool> check(numProj);
501
502 if (A.size() > 2 * numProj)
503 return false; // This should never happen.
504
505 // Check that we only see each dimension once. If not, vector is not valid.
506 for (size_t i = 0; i < A.size(); ++i)
507 {
508 // Only check dimensions that were included.
509 if (!A[i])
510 continue;
511
512 // If dimesnion is unseen thus far, mark it as seen.
513 if (check[i % numProj] == false)
514 check[i % numProj] = true;
515 else
516 return false; // If dimension was seen before, set is not valid.
517 }
518 // If we didn't fail, set is valid.
519 return true;
520 }
521
522 // Compute additional probing bins for a query
523 template<typename SortPolicy, typename MatType>
GetAdditionalProbingBins(const arma::vec & queryCode,const arma::vec & queryCodeNotFloored,const size_t T,arma::mat & additionalProbingBins) const524 void LSHSearch<SortPolicy, MatType>::GetAdditionalProbingBins(
525 const arma::vec& queryCode,
526 const arma::vec& queryCodeNotFloored,
527 const size_t T,
528 arma::mat& additionalProbingBins) const
529 {
530 // No additional bins requested. Our work is done.
531 if (T == 0)
532 return;
533
534 // Each column of additionalProbingBins is the code of a bin.
535 additionalProbingBins.set_size(numProj, T);
536
537 // Copy the query's code, then in the end we will add/subtract according
538 // to perturbations we calculated.
539 for (size_t c = 0; c < T; ++c)
540 additionalProbingBins.col(c) = queryCode;
541
542
543 // Calculate query point's projection position.
544 arma::mat projection = queryCodeNotFloored;
545
546 // Use projection to calculate query's distance from hash limits.
547 arma::vec limLow = projection - queryCode * hashWidth;
548 arma::vec limHigh = hashWidth - limLow;
549
550 // Calculate scores. score = distance^2.
551 arma::vec scores(2 * numProj);
552 scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
553 scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
554
555 // Actions vector describes what perturbation (-1/+1) corresponds to a score.
556 arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
557 actions.rows(0, numProj - 1) = // First numProj rows.
558 -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
559 actions.rows(numProj, (2 * numProj) - 1) = // Last numProj rows.
560 arma::ones< arma::Col<short int> > (numProj); // 1s
561
562
563 // Acting dimension vector shows which coordinate to transform according to
564 // actions (actions are described by actions vector above).
565 arma::Col<size_t> positions(2 * numProj); // Will be [0 1 2 ... 0 1 2 ...].
566 positions.rows(0, numProj - 1) =
567 arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
568 positions.rows(numProj, 2 * numProj - 1) =
569 arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
570
571 // Special case: No need to create heap for 1 or 2 codes.
572 if (T <= 2)
573 {
574 // First, find location of minimum score, generate 1 perturbation vector,
575 // and add its code to additionalProbingBins column 0.
576
577 // Find location and value of smallest element of scores vector.
578 double minscore = scores[0];
579 size_t minloc = 0;
580 for (size_t s = 1; s < (2 * numProj); ++s)
581 {
582 if (minscore > scores[s])
583 {
584 minscore = scores[s];
585 minloc = s;
586 }
587 }
588
589 // Add or subtract 1 to dimension corresponding to minimum score.
590 additionalProbingBins(positions[minloc], 0) += actions[minloc];
591 if (T == 1)
592 return; // Done if asked for only 1 code.
593
594 // Now, find location of second smallest score and generate one more vector.
595 // The second perturbation vector still can't comprise of more than one
596 // change in the bin codes, because of the way perturbation vectors
597 // are generated: First we create the one with the smallest score (Ao) and
598 // then we either add 1 extra dimension to it (Ae) or shift it by one (As).
599 // Since As contains the second smallest score, and Ae contains both the
600 // smallest and the second smallest, it's obvious that score(Ae) >
601 // score(As). Therefore the second perturbation vector is ALWAYS the vector
602 // containing only the second-lowest scoring perturbation.
603 double minscore2 = scores[0];
604 size_t minloc2 = 0;
605 for (size_t s = 0; s < (2 * numProj); ++s) // Here we can't start from 1.
606 {
607 if (minscore2 > scores[s] && s != minloc) // Second smallest.
608 {
609 minscore2 = scores[s];
610 minloc2 = s;
611 }
612 }
613
614 // Add or subtract 1 to create second-lowest scoring vector.
615 additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
616 return;
617 }
618
619 // General case: more than 2 perturbation vectors require use of minheap.
620 // Sort everything in increasing order.
621 arma::uvec sortidx = arma::sort_index(scores);
622 scores = scores(sortidx);
623 actions = actions(sortidx);
624 positions = positions(sortidx);
625
626 // Theory:
627 // A probing sequence is a sequence of T probing bins where a query's
628 // neighbors are most likely to be. Likelihood is dependent only on a bin's
629 // score, which is the sum of scores of all dimension-action pairs, so we
630 // need to calculate the T smallest sums of scores that are not conflicting.
631 //
632 // Method:
633 // Store each perturbation set (pair of (dimension, action)) in a
634 // std::vector. Create a minheap of scores, with each node pointing to its
635 // relevant perturbation set. Each perturbation set popped from the minheap
636 // is the next most likely perturbation set.
637 // Transform perturbation set to perturbation vector by setting the
638 // dimensions specified by the set to queryCode+action (action is {-1, 1}).
639
640 // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
641 // included in a given perturbation vector. Other spaces are 0.
642 std::vector<bool> Ao(2 * numProj);
643 Ao[0] = 1; // Smallest vector includes only smallest score.
644
645 std::vector< std::vector<bool> > perturbationSets;
646 perturbationSets.push_back(Ao); // Storage of perturbation sets.
647
648 std::priority_queue<
649 std::pair<double, size_t>, // contents: pairs of (score, index)
650 std::vector< // container: vector of pairs
651 std::pair<double, size_t>
652 >,
653 std::greater< std::pair<double, size_t> > // comparator of pairs
654 > minHeap; // our minheap
655
656 // Start by adding the lowest scoring set to the minheap.
657 minHeap.push(std::make_pair(PerturbationScore(Ao, scores), 0));
658
659 // Loop invariable: after pvec iterations, additionalProbingBins contains pvec
660 // valid codes of the lowest-scoring bins (bins most likely to contain
661 // neighbors of the query).
662 for (size_t pvec = 0; pvec < T; ++pvec)
663 {
664 std::vector<bool> Ai;
665 do
666 {
667 // Get the perturbation set corresponding to the minimum score.
668 Ai = perturbationSets[ minHeap.top().second ];
669 minHeap.pop(); // .top() returns, .pop() removes
670
671 // Shift operation on Ai (replace max with max+1).
672 std::vector<bool> As = Ai;
673
674 // Don't add invalid sets.
675 if (PerturbationShift(As) && PerturbationValid(As))
676 {
677 perturbationSets.push_back(As); // add shifted set to sets
678 minHeap.push(
679 std::make_pair(PerturbationScore(As, scores),
680 perturbationSets.size() - 1));
681 }
682
683 // Expand operation on Ai (add max+1 to set).
684 std::vector<bool> Ae = Ai;
685
686 // Don't add invalid sets.
687 if (PerturbationExpand(Ae) && PerturbationValid(Ae))
688 {
689 perturbationSets.push_back(Ae); // add expanded set to sets
690 minHeap.push(
691 std::make_pair(PerturbationScore(Ae, scores),
692 perturbationSets.size() - 1));
693 }
694 } while (!PerturbationValid(Ai)); // Discard invalid perturbations
695
696 // Found valid perturbation set Ai. Construct perturbation vector from set.
697 for (size_t pos = 0; pos < Ai.size(); ++pos)
698 {
699 // If Ai[pos] is marked, add action to probing vector.
700 additionalProbingBins(positions(pos), pvec) += Ai[pos] ? actions(pos) : 0;
701 }
702 }
703 }
704
705 template<typename SortPolicy, typename MatType>
706 template<typename VecType>
ReturnIndicesFromTable(const VecType & queryPoint,arma::uvec & referenceIndices,size_t numTablesToSearch,const size_t T) const707 void LSHSearch<SortPolicy, MatType>::ReturnIndicesFromTable(
708 const VecType& queryPoint,
709 arma::uvec& referenceIndices,
710 size_t numTablesToSearch,
711 const size_t T) const
712 {
713 // Decide on the number of tables to look into.
714 if (numTablesToSearch == 0) // If no user input is given, search all.
715 numTablesToSearch = numTables;
716
717 // Sanity check to make sure that the existing number of tables is not
718 // exceeded.
719 if (numTablesToSearch > numTables)
720 numTablesToSearch = numTables;
721
722 // Hash the query in each of the 'numTablesToSearch' hash tables using the
723 // 'numProj' projections for each table. This gives us 'numTablesToSearch'
724 // keys for the query where each key is a 'numProj' dimensional integer
725 // vector.
726
727 // Compute the projection of the query in each table.
728 arma::mat allProjInTables(numProj, numTablesToSearch);
729 arma::mat queryCodesNotFloored(numProj, numTablesToSearch);
730 for (size_t i = 0; i < numTablesToSearch; ++i)
731 queryCodesNotFloored.unsafe_col(i) = projections.slice(i).t() * queryPoint;
732
733 queryCodesNotFloored += offsets.cols(0, numTablesToSearch - 1);
734 allProjInTables = arma::floor(queryCodesNotFloored / hashWidth);
735
736 // Use hashMat to store the primary probing codes and any additional codes
737 // from multiprobe LSH.
738 arma::Mat<size_t> hashMat;
739 hashMat.set_size(T + 1, numTablesToSearch);
740
741 // Compute the primary hash value of each key of the query into a bucket of
742 // the secondHashTable using the secondHashWeights.
743 hashMat.row(0) = arma::conv_to<arma::Row<size_t>> // Floor by typecasting
744 ::from(secondHashWeights.t() * allProjInTables);
745 // Mod to compute 2nd-level codes.
746 for (size_t i = 0; i < numTablesToSearch; ++i)
747 hashMat(0, i) = (hashMat(0, i) % secondHashSize);
748
749 // Compute hash codes of additional probing bins.
750 if (T > 0)
751 {
752 for (size_t i = 0; i < numTablesToSearch; ++i)
753 {
754 // Construct this table's probing sequence of length T.
755 arma::mat additionalProbingBins;
756 GetAdditionalProbingBins(allProjInTables.unsafe_col(i),
757 queryCodesNotFloored.unsafe_col(i),
758 T,
759 additionalProbingBins);
760
761 // Map each probing bin to a bin in secondHashTable (just like we did for
762 // the primary hash table).
763 hashMat(arma::span(1, T), i) = // Compute code of rows 1:end of column i
764 arma::conv_to< arma::Col<size_t> >:: // floor by typecasting to size_t
765 from(secondHashWeights.t() * additionalProbingBins);
766 for (size_t p = 1; p < T + 1; ++p)
767 hashMat(p, i) = (hashMat(p, i) % secondHashSize);
768 }
769 }
770
771 // Count number of points hashed in the same bucket as the query.
772 size_t maxNumPoints = 0;
773 for (size_t i = 0; i < numTablesToSearch; ++i)
774 {
775 for (size_t p = 0; p < T + 1; ++p)
776 {
777 const size_t hashInd = hashMat(p, i); // find query's bucket
778 const size_t tableRow = bucketRowInHashTable[hashInd];
779 if (tableRow < secondHashSize)
780 maxNumPoints += bucketContentSize[tableRow]; // count bucket contents
781 }
782 }
783
784 // There are two ways to proceed here:
785 // Either allocate a maxNumPoints-size vector, place all candidates, and run
786 // unique on the vector to discard duplicates.
787 // Or allocate a referenceSet.n_cols size vector (i.e. number of reference
788 // points) of zeros, and mark found indices as 1.
789 // Option 1 runs faster for small maxNumPoints but worse for larger values, so
790 // we choose based on a heuristic.
791 const float cutoff = 0.1;
792 const float selectivity = static_cast<float>(maxNumPoints) /
793 static_cast<float>(referenceSet.n_cols);
794
795 if (selectivity > cutoff)
796 {
797 // Heuristic: larger maxNumPoints means we should use find() because it
798 // should be faster.
799 // Reference points hashed in the same bucket as the query are set to >0.
800 arma::Col<size_t> refPointsConsidered;
801 refPointsConsidered.zeros(referenceSet.n_cols);
802
803 for (size_t i = 0; i < numTablesToSearch; ++i) // for all tables
804 {
805 for (size_t p = 0; p < T + 1; ++p) // For entire probing sequence.
806 {
807 // get the sequence code
808 size_t hashInd = hashMat(p, i);
809 size_t tableRow = bucketRowInHashTable[hashInd];
810
811 if (tableRow < secondHashSize && bucketContentSize[tableRow] > 0)
812 {
813 // Pick the indices in the bucket corresponding to hashInd.
814 for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
815 refPointsConsidered[ secondHashTable[tableRow](j) ]++;
816 }
817 }
818 }
819
820 // Only keep reference points found in at least one bucket.
821 referenceIndices = arma::find(refPointsConsidered > 0);
822 return;
823 }
824 else
825 {
826 // Heuristic: smaller maxNumPoints means we should use unique() because it
827 // should be faster.
828 // Allocate space for the query's potential neighbors.
829 arma::uvec refPointsConsideredSmall;
830 refPointsConsideredSmall.zeros(maxNumPoints);
831
832 // Retrieve candidates.
833 size_t start = 0;
834
835 for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
836 {
837 for (size_t p = 0; p < T + 1; ++p)
838 {
839 const size_t hashInd = hashMat(p, i); // Find the query's bucket.
840 const size_t tableRow = bucketRowInHashTable[hashInd];
841
842 if (tableRow < secondHashSize)
843 {
844 // Store all secondHashTable points in the candidates set.
845 for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
846 refPointsConsideredSmall(start++) = secondHashTable[tableRow](j);
847 }
848 }
849 }
850
851 // Keep only one copy of each candidate.
852 referenceIndices = arma::unique(refPointsConsideredSmall);
853 return;
854 }
855 }
856
857 // Search for nearest neighbors in a given query set.
858 template<typename SortPolicy, typename MatType>
Search(const MatType & querySet,const size_t k,arma::Mat<size_t> & resultingNeighbors,arma::mat & distances,const size_t numTablesToSearch,const size_t T)859 void LSHSearch<SortPolicy, MatType>::Search(
860 const MatType& querySet,
861 const size_t k,
862 arma::Mat<size_t>& resultingNeighbors,
863 arma::mat& distances,
864 const size_t numTablesToSearch,
865 const size_t T)
866 {
867 // Ensure the dimensionality of the query set is correct.
868 if (querySet.n_rows != referenceSet.n_rows)
869 {
870 std::ostringstream oss;
871 oss << "LSHSearch::Search(): dimensionality of query set ("
872 << querySet.n_rows << ") is not equal to the dimensionality the model "
873 << "was trained on (" << referenceSet.n_rows << ")!" << std::endl;
874 throw std::invalid_argument(oss.str());
875 }
876
877 if (k > referenceSet.n_cols)
878 {
879 std::ostringstream oss;
880 oss << "LSHSearch::Search(): requested " << k << " approximate nearest "
881 << "neighbors, but reference set has " << referenceSet.n_cols
882 << " points!" << std::endl;
883 throw std::invalid_argument(oss.str());
884 }
885
886 // Set the size of the neighbor and distance matrices.
887 resultingNeighbors.set_size(k, querySet.n_cols);
888 distances.set_size(k, querySet.n_cols);
889
890 // If the user asked for 0 nearest neighbors... uh... we're done.
891 if (k == 0)
892 return;
893
894 // If the user requested more than the available number of additional probing
895 // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
896 size_t Teffective = T;
897 if (T > ((size_t) ((1 << numProj) - 1)))
898 {
899 Teffective = (1 << numProj) - 1;
900 Log::Warn << "Requested " << T << " additional bins are more than "
901 << "theoretical maximum. Using " << Teffective << " instead."
902 << std::endl;
903 }
904
905 // If the user set multiprobe, log it
906 if (Teffective > 0)
907 Log::Info << "Running multiprobe LSH with " << Teffective
908 <<" additional probing bins per table per query." << std::endl;
909
910 size_t avgIndicesReturned = 0;
911
912 Timer::Start("computing_neighbors");
913
914 // Parallelization to process more than one query at a time.
915 #pragma omp parallel for \
916 shared(resultingNeighbors, distances) \
917 schedule(dynamic)\
918 reduction(+:avgIndicesReturned)
919 for (omp_size_t i = 0; i < (omp_size_t) querySet.n_cols; ++i)
920 {
921 // Go through every query point.
922 // Hash every query into every hash table and eventually into the
923 // 'secondHashTable' to obtain the neighbor candidates.
924 arma::uvec refIndices;
925 ReturnIndicesFromTable(querySet.col(i), refIndices, numTablesToSearch,
926 Teffective);
927
928 // An informative book-keeping for the number of neighbor candidates
929 // returned on average.
930 // Make atomic to avoid race conditions when multiple threads are running
931 // #pragma omp atomic
932 avgIndicesReturned = avgIndicesReturned + refIndices.n_elem;
933
934 // Sequentially go through all the candidates and save the best 'k'
935 // candidates.
936 BaseCase(i, refIndices, k, querySet, resultingNeighbors, distances);
937 }
938
939 Timer::Stop("computing_neighbors");
940
941 distanceEvaluations += avgIndicesReturned;
942 avgIndicesReturned /= querySet.n_cols;
943 Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
944 std::endl;
945 }
946
947 // Search for approximate neighbors of the reference set.
948 template<typename SortPolicy, typename MatType>
949 void LSHSearch<SortPolicy, MatType>::
Search(const size_t k,arma::Mat<size_t> & resultingNeighbors,arma::mat & distances,const size_t numTablesToSearch,size_t T)950 Search(const size_t k,
951 arma::Mat<size_t>& resultingNeighbors,
952 arma::mat& distances,
953 const size_t numTablesToSearch,
954 size_t T)
955 {
956 // This is monochromatic search; the query set is the reference set.
957 resultingNeighbors.set_size(k, referenceSet.n_cols);
958 distances.set_size(k, referenceSet.n_cols);
959
960 // If the user requested more than the available number of additional probing
961 // bins, set Teffective to maximum T. Maximum T is 2^numProj - 1
962 size_t Teffective = T;
963 if (T > ((size_t) ((1 << numProj) - 1)))
964 {
965 Teffective = (1 << numProj) - 1;
966 Log::Warn << "Requested " << T << " additional bins are more than "
967 << "theoretical maximum. Using " << Teffective << " instead."
968 << std::endl;
969 }
970
971 // If the user set multiprobe, log it
972 if (T > 0)
973 Log::Info << "Running multiprobe LSH with " << Teffective <<
974 " additional probing bins per table per query."<< std::endl;
975
976 size_t avgIndicesReturned = 0;
977
978 Timer::Start("computing_neighbors");
979
980 // Parallelization to process more than one query at a time.
981 #pragma omp parallel for \
982 shared(resultingNeighbors, distances) \
983 schedule(dynamic)\
984 reduction(+:avgIndicesReturned)
985 for (omp_size_t i = 0; i < (omp_size_t) referenceSet.n_cols; ++i)
986 {
987 // Go through every query point.
988 // Hash every query into every hash table and eventually into the
989 // 'secondHashTable' to obtain the neighbor candidates.
990 arma::uvec refIndices;
991 ReturnIndicesFromTable(referenceSet.col(i), refIndices, numTablesToSearch,
992 Teffective);
993
994 // An informative book-keeping for the number of neighbor candidates
995 // returned on average.
996 // Make atomic to avoid race conditions when multiple threads are running.
997 // #pragma omp atomic
998 avgIndicesReturned += refIndices.n_elem;
999
1000 // Sequentially go through all the candidates and save the best 'k'
1001 // candidates.
1002 BaseCase(i, refIndices, k, resultingNeighbors, distances);
1003 }
1004
1005 Timer::Stop("computing_neighbors");
1006
1007 distanceEvaluations += avgIndicesReturned;
1008 avgIndicesReturned /= referenceSet.n_cols;
1009 Log::Info << avgIndicesReturned << " distinct indices returned on average." <<
1010 std::endl;
1011 }
1012
1013 template<typename SortPolicy, typename MatType>
ComputeRecall(const arma::Mat<size_t> & foundNeighbors,const arma::Mat<size_t> & realNeighbors)1014 double LSHSearch<SortPolicy, MatType>::ComputeRecall(
1015 const arma::Mat<size_t>& foundNeighbors,
1016 const arma::Mat<size_t>& realNeighbors)
1017 {
1018 if (foundNeighbors.n_rows != realNeighbors.n_rows ||
1019 foundNeighbors.n_cols != realNeighbors.n_cols)
1020 throw std::invalid_argument("LSHSearch::ComputeRecall(): matrices provided"
1021 " must have equal size");
1022
1023 const size_t queries = foundNeighbors.n_cols;
1024 const size_t neighbors = foundNeighbors.n_rows; // Should be equal to k.
1025
1026 // The recall is the set intersection of found and real neighbors.
1027 size_t found = 0;
1028 for (size_t col = 0; col < queries; ++col)
1029 for (size_t row = 0; row < neighbors; ++row)
1030 for (size_t nei = 0; nei < realNeighbors.n_rows; ++nei)
1031 if (realNeighbors(row, col) == foundNeighbors(nei, col))
1032 {
1033 found++;
1034 break;
1035 }
1036
1037 return ((double) found) / realNeighbors.n_elem;
1038 }
1039
1040 template<typename SortPolicy, typename MatType>
1041 template<typename Archive>
serialize(Archive & ar,const unsigned int version)1042 void LSHSearch<SortPolicy, MatType>::serialize(Archive& ar,
1043 const unsigned int version)
1044 {
1045 ar & BOOST_SERIALIZATION_NVP(referenceSet);
1046 ar & BOOST_SERIALIZATION_NVP(numProj);
1047 ar & BOOST_SERIALIZATION_NVP(numTables);
1048
1049 // Delete existing projections, if necessary.
1050 if (Archive::is_loading::value)
1051 projections.reset();
1052
1053 // Backward compatibility: older versions of LSHSearch stored the projection
1054 // tables in a std::vector<arma::mat>.
1055 if (version == 0)
1056 {
1057 std::vector<arma::mat> tmpProj;
1058 ar & BOOST_SERIALIZATION_NVP(tmpProj);
1059
1060 projections.set_size(tmpProj[0].n_rows, tmpProj[0].n_cols, tmpProj.size());
1061 for (size_t i = 0; i < tmpProj.size(); ++i)
1062 projections.slice(i) = tmpProj[i];
1063 }
1064 else
1065 {
1066 ar & BOOST_SERIALIZATION_NVP(projections);
1067 }
1068
1069 ar & BOOST_SERIALIZATION_NVP(offsets);
1070 ar & BOOST_SERIALIZATION_NVP(hashWidth);
1071 ar & BOOST_SERIALIZATION_NVP(secondHashSize);
1072 ar & BOOST_SERIALIZATION_NVP(secondHashWeights);
1073 ar & BOOST_SERIALIZATION_NVP(bucketSize);
1074 // needs specific handling for new version
1075
1076 // Backward compatibility: in older versions of LSHSearch, the secondHashTable
1077 // was stored as an arma::Mat<size_t>. So we need to properly load that, then
1078 // prune it down to size.
1079 if (version == 0)
1080 {
1081 arma::Mat<size_t> tmpSecondHashTable;
1082 ar & BOOST_SERIALIZATION_NVP(tmpSecondHashTable);
1083
1084 // The old secondHashTable was stored in row-major format, so we transpose
1085 // it.
1086 tmpSecondHashTable = tmpSecondHashTable.t();
1087
1088 secondHashTable.resize(tmpSecondHashTable.n_cols);
1089 for (size_t i = 0; i < tmpSecondHashTable.n_cols; ++i)
1090 {
1091 // Find length of each column. We know we are at the end of the list when
1092 // the value referenceSet.n_cols is seen.
1093
1094 size_t len = 0;
1095 for (; len < tmpSecondHashTable.n_rows; ++len)
1096 if (tmpSecondHashTable(len, i) == referenceSet.n_cols)
1097 break;
1098
1099 // Set the size of the new column correctly.
1100 secondHashTable[i].set_size(len);
1101 for (size_t j = 0; j < len; ++j)
1102 secondHashTable[i](j) = tmpSecondHashTable(j, i);
1103 }
1104 }
1105 else
1106 {
1107 size_t tables;
1108 if (Archive::is_saving::value)
1109 tables = secondHashTable.size();
1110 ar & BOOST_SERIALIZATION_NVP(tables);
1111
1112 // Set size of second hash table if needed.
1113 if (Archive::is_loading::value)
1114 {
1115 secondHashTable.clear();
1116 secondHashTable.resize(tables);
1117 }
1118
1119 ar & BOOST_SERIALIZATION_NVP(secondHashTable);
1120 }
1121
1122 // Backward compatibility: old versions of LSHSearch held bucketContentSize
1123 // for all possible buckets (of size secondHashSize), but now we hold a
1124 // compressed representation.
1125 if (version == 0)
1126 {
1127 // The vector was stored in the old uncompressed form. So we need to shrink
1128 // it. But we can't do that until we have bucketRowInHashTable, so we also
1129 // have to load that.
1130 arma::Col<size_t> tmpBucketContentSize;
1131 ar & BOOST_SERIALIZATION_NVP(tmpBucketContentSize);
1132 ar & BOOST_SERIALIZATION_NVP(bucketRowInHashTable);
1133
1134 // Compress into a smaller vector by just dropping all of the zeros.
1135 bucketContentSize.set_size(secondHashTable.size());
1136 for (size_t i = 0; i < tmpBucketContentSize.n_elem; ++i)
1137 if (tmpBucketContentSize[i] > 0)
1138 bucketContentSize[bucketRowInHashTable[i]] = tmpBucketContentSize[i];
1139 }
1140 else
1141 {
1142 ar & BOOST_SERIALIZATION_NVP(bucketContentSize);
1143 ar & BOOST_SERIALIZATION_NVP(bucketRowInHashTable);
1144 }
1145
1146 ar & BOOST_SERIALIZATION_NVP(distanceEvaluations);
1147 }
1148
1149 } // namespace neighbor
1150 } // namespace mlpack
1151
1152 #endif
1153