1 /** 2 * @file methods/lsh/lsh_search.hpp 3 * @author Parikshit Ram 4 * 5 * Defines the LSHSearch class, which performs an approximate 6 * nearest neighbor search for a queries in a query set 7 * over a given dataset using Locality-sensitive hashing 8 * with 2-stable distributions. 9 * 10 * The details of this method can be found in the following paper: 11 * 12 * @code 13 * @inproceedings{datar2004locality, 14 * title={Locality-sensitive hashing scheme based on p-stable distributions}, 15 * author={Datar, M. and Immorlica, N. and Indyk, P. and Mirrokni, V.S.}, 16 * booktitle= 17 * {Proceedings of the 12th Annual Symposium on Computational Geometry}, 18 * pages={253--262}, 19 * year={2004}, 20 * organization={ACM} 21 * } 22 * @endcode 23 * 24 * Additionally, the class implements Multiprobe LSH, which improves 25 * approximation results during the search for approximate nearest neighbors. 26 * The Multiprobe LSH algorithm was presented in the paper: 27 * 28 * @code 29 * @inproceedings{Lv2007multiprobe, 30 * tile={Multi-probe LSH: efficient indexing for high-dimensional similarity 31 * search}, 32 * author={Lv, Qin and Josephson, William and Wang, Zhe and Charikar, Moses and 33 * Li, Kai}, 34 * booktitle={Proceedings of the 33rd international conference on Very large 35 * data bases}, 36 * year={2007}, 37 * pages={950--961} 38 * } 39 * @endcode 40 * 41 * 42 * mlpack is free software; you may redistribute it and/or modify it under the 43 * terms of the 3-clause BSD license. You should have received a copy of the 44 * 3-clause BSD license along with mlpack. If not, see 45 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 46 */ 47 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 48 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 49 50 #include <mlpack/prereqs.hpp> 51 52 #include <mlpack/core/metrics/lmetric.hpp> 53 #include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp> 54 55 #include <queue> 56 57 namespace mlpack { 58 namespace neighbor { 59 60 /** 61 * The LSHSearch class; this class builds a hash on the reference set and uses 62 * this hash to compute the distance-approximate nearest-neighbors of the given 63 * queries. 64 * 65 * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort. 66 * @tparam MatType Type of matrix to use to store the data. 67 */ 68 template< 69 typename SortPolicy = NearestNeighborSort, 70 typename MatType = arma::mat 71 > 72 class LSHSearch 73 { 74 public: 75 /** 76 * This function initializes the LSH class. It builds the hash on the 77 * reference set with 2-stable distributions. See the individual functions 78 * performing the hashing for details on how the hashing is done. In order to 79 * avoid copying the reference set, it is suggested to pass that parameter 80 * with std::move(). 81 * 82 * @param referenceSet Set of reference points and the set of queries. 83 * @param projections Cube of projection tables. For a cube of size (a, b, c) 84 * we set numProj = a, numTables = c. b is the reference set 85 * dimensionality. 86 * @param hashWidth The width of hash for every table. If 0 (the default) is 87 * provided, then the hash width is automatically obtained by computing 88 * the average pairwise distance of 25 pairs. This should be a reasonable 89 * upper bound on the nearest-neighbor distance in general. 90 * @param secondHashSize The size of the second hash table. This should be a 91 * large prime number. 92 * @param bucketSize The size of the bucket in the second hash table. This is 93 * the maximum number of points that can be hashed into single bucket. A 94 * value of 0 indicates that there is no limit (so the second hash table 95 * can be arbitrarily large---be careful!). 96 */ 97 LSHSearch(MatType referenceSet, 98 const arma::cube& projections, 99 const double hashWidth = 0.0, 100 const size_t secondHashSize = 99901, 101 const size_t bucketSize = 500); 102 103 /** 104 * This function initializes the LSH class. It builds the hash one the 105 * reference set using the provided projections. See the individual functions 106 * performing the hashing for details on how the hashing is done. In order to 107 * avoid copying the reference set, consider passing the set with std::move(). 108 * 109 * @param referenceSet Set of reference points and the set of queries. 110 * @param numProj Number of projections in each hash table (anything between 111 * 10-50 might be a decent choice). 112 * @param numTables Total number of hash tables (anything between 10-20 113 * should suffice). 114 * @param hashWidth The width of hash for every table. If 0 (the default) is 115 * provided, then the hash width is automatically obtained by computing 116 * the average pairwise distance of 25 pairs. This should be a reasonable 117 * upper bound on the nearest-neighbor distance in general. 118 * @param secondHashSize The size of the second hash table. This should be a 119 * large prime number. 120 * @param bucketSize The size of the bucket in the second hash table. This is 121 * the maximum number of points that can be hashed into single bucket. A 122 * value of 0 indicates that there is no limit (so the second hash table 123 * can be arbitrarily large---be careful!). 124 */ 125 LSHSearch(MatType referenceSet, 126 const size_t numProj, 127 const size_t numTables, 128 const double hashWidth = 0.0, 129 const size_t secondHashSize = 99901, 130 const size_t bucketSize = 500); 131 132 /** 133 * Create an untrained LSH model. Be sure to call Train() before calling 134 * Search(); otherwise, an exception will be thrown when Search() is called. 135 */ 136 LSHSearch(); 137 138 /** 139 * Copy the given LSH model. 140 * 141 * @param other Other LSH model to copy. 142 */ 143 LSHSearch(const LSHSearch& other); 144 145 /** 146 * Take ownership of the given LSH model. 147 * 148 * @param other Other LSH model to take ownership of. 149 */ 150 LSHSearch(LSHSearch&& other); 151 152 /** 153 * Copy the given LSH model. 154 * 155 * @param other Other LSH model to copy. 156 */ 157 LSHSearch& operator=(const LSHSearch& other); 158 159 /** 160 * Take ownership of the given LSH model. 161 * 162 * @param other Other LSH model to take ownership of. 163 */ 164 LSHSearch& operator=(LSHSearch&& other); 165 166 /** 167 * Train the LSH model on the given dataset. If a correctly-sized projection 168 * cube is not provided, this means building new hash tables. Otherwise, we 169 * use the projections provided by the user. In order to avoid copying the 170 * reference set, consider passing that parameter with std::move(). 171 * 172 * @param referenceSet Set of reference points and the set of queries. 173 * @param numProj Number of projections in each hash table (anything between 174 * 10-50 might be a decent choice). 175 * @param numTables Total number of hash tables (anything between 10-20 176 * should suffice). 177 * @param hashWidth The width of hash for every table. If 0 (the default) is 178 * provided, then the hash width is automatically obtained by computing 179 * the average pairwise distance of 25 pairs. This should be a reasonable 180 * upper bound on the nearest-neighbor distance in general. 181 * @param secondHashSize The size of the second hash table. This should be a 182 * large prime number. 183 * @param bucketSize The size of the bucket in the second hash table. This is 184 * the maximum number of points that can be hashed into single bucket. A 185 * value of 0 indicates that there is no limit (so the second hash table 186 * can be arbitrarily large---be careful!). 187 * @param projection Cube of projection tables. For a cube of size (a, b, c) 188 * we set numProj = a, numTables = c. b is the reference set 189 * dimensionality. 190 */ 191 void Train(MatType referenceSet, 192 const size_t numProj, 193 const size_t numTables, 194 const double hashWidth = 0.0, 195 const size_t secondHashSize = 99901, 196 const size_t bucketSize = 500, 197 const arma::cube& projection = arma::cube()); 198 199 /** 200 * Compute the nearest neighbors of the points in the given query set and 201 * store the output in the given matrices. The matrices will be set to the 202 * size of n columns by k rows, where n is the number of points in the query 203 * dataset and k is the number of neighbors being searched for. 204 * 205 * @param querySet Set of query points. 206 * @param k Number of neighbors to search for. 207 * @param resultingNeighbors Matrix storing lists of neighbors for each query 208 * point. 209 * @param distances Matrix storing distances of neighbors for each query 210 * point. 211 * @param numTablesToSearch This parameter allows the user to have control 212 * over the number of hash tables to be searched. This allows 213 * the user to pick the number of tables it can afford for the time 214 * available without having to build hashing for every table size. 215 * By default, this is set to zero in which case all tables are 216 * considered. 217 * @param T The number of additional probing bins to examine with multiprobe 218 * LSH. If T = 0, classic single-probe LSH is run (default). 219 */ 220 void Search(const MatType& querySet, 221 const size_t k, 222 arma::Mat<size_t>& resultingNeighbors, 223 arma::mat& distances, 224 const size_t numTablesToSearch = 0, 225 const size_t T = 0); 226 227 /** 228 * Compute the nearest neighbors and store the output in the given matrices. 229 * The matrices will be set to the size of n columns by k rows, where n is 230 * the number of points in the query dataset and k is the number of neighbors 231 * being searched for. 232 * 233 * @param k Number of neighbors to search for. 234 * @param resultingNeighbors Matrix storing lists of neighbors for each query 235 * point. 236 * @param distances Matrix storing distances of neighbors for each query 237 * point. 238 * @param numTablesToSearch This parameter allows the user to have control 239 * over the number of hash tables to be searched. This allows 240 * the user to pick the number of tables it can afford for the time 241 * available without having to build hashing for every table size. 242 * By default, this is set to zero in which case all tables are 243 * considered. 244 * @param T Number of probing bins. 245 */ 246 void Search(const size_t k, 247 arma::Mat<size_t>& resultingNeighbors, 248 arma::mat& distances, 249 const size_t numTablesToSearch = 0, 250 size_t T = 0); 251 252 /** 253 * Compute the recall (% of neighbors found) given the neighbors returned by 254 * LSHSearch::Search and a "ground truth" set of neighbors. The recall 255 * returned will be in the range [0, 1]. 256 * 257 * @param foundNeighbors Set of neighbors to compute recall of. 258 * @param realNeighbors Set of "ground truth" neighbors to compute recall 259 * against. 260 */ 261 static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors, 262 const arma::Mat<size_t>& realNeighbors); 263 264 /** 265 * Serialize the LSH model. 266 * 267 * @param ar Archive to serialize to. 268 * @param version Version number. 269 */ 270 template<typename Archive> 271 void serialize(Archive& ar, const unsigned int version); 272 273 //! Return the number of distance evaluations performed. DistanceEvaluations() const274 size_t DistanceEvaluations() const { return distanceEvaluations; } 275 //! Modify the number of distance evaluations performed. DistanceEvaluations()276 size_t& DistanceEvaluations() { return distanceEvaluations; } 277 278 //! Return the reference dataset. ReferenceSet() const279 const MatType& ReferenceSet() const { return referenceSet; } 280 281 //! Get the number of projections. NumProjections() const282 size_t NumProjections() const { return projections.n_slices; } 283 284 //! Get the offsets 'b' for each of the projections. (One 'b' per column.) Offsets() const285 const arma::mat& Offsets() const { return offsets; } 286 287 //! Get the weights of the second hash. SecondHashWeights() const288 const arma::vec& SecondHashWeights() const { return secondHashWeights; } 289 290 //! Get the bucket size of the second hash. BucketSize() const291 size_t BucketSize() const { return bucketSize; } 292 293 //! Get the second hash table. SecondHashTable() const294 const std::vector<arma::Col<size_t>>& SecondHashTable() const 295 { return secondHashTable; } 296 297 //! Get the projection tables. Projections()298 const arma::cube& Projections() { return projections; } 299 300 //! Change the projection tables (this retrains the LSH model). Projections(const arma::cube & projTables)301 void Projections(const arma::cube& projTables) 302 { 303 // Simply call Train() with the given projection tables. 304 Train(referenceSet, numProj, numTables, hashWidth, secondHashSize, 305 bucketSize, projTables); 306 } 307 308 private: 309 /** 310 * This function takes a query and hashes it into each of the hash tables to 311 * get keys for the query and then the key is hashed to a bucket of the second 312 * hash table and all the points (if any) in those buckets are collected as 313 * the potential neighbor candidates. 314 * 315 * @param queryPoint The query point currently being processed. 316 * @param referenceIndices The list of neighbor candidates obtained from 317 * hashing the query into all the hash tables and eventually into 318 * multiple buckets of the second hash table. 319 * @param numTablesToSearch The number of tables to perform the search in. If 320 * 0, all tables are searched. 321 * @param T The number of additional probing bins for multiprobe LSH. If 0, 322 * single-probe is used. 323 */ 324 template<typename VecType> 325 void ReturnIndicesFromTable(const VecType& queryPoint, 326 arma::uvec& referenceIndices, 327 size_t numTablesToSearch, 328 const size_t T) const; 329 330 /** 331 * This is a helper function that computes the distance of the query to the 332 * neighbor candidates and appropriately stores the best 'k' candidates. This 333 * is specific to the monochromatic search case, where the query set is the 334 * reference set. 335 * 336 * @param queryIndex The index of the query in question 337 * @param referenceIndices The vector of indices of candidate neighbors for 338 * the query. 339 * @param k Number of neighbors to search for. 340 * @param neighbors Matrix holding output neighbors. 341 * @param distances Matrix holding output distances. 342 */ 343 void BaseCase(const size_t queryIndex, 344 const arma::uvec& referenceIndices, 345 const size_t k, 346 arma::Mat<size_t>& neighbors, 347 arma::mat& distances) const; 348 349 /** 350 * This is a helper function that computes the distance of the query to the 351 * neighbor candidates and appropriately stores the best 'k' candidates. This 352 * is specific to bichromatic search, where the query set is not the same as 353 * the reference set. 354 * 355 * @param queryIndex The index of the query in question 356 * @param referenceIndices The vector of indices of candidate neighbors for 357 * the query. 358 * @param k Number of neighbors to search for. 359 * @param querySet Set of query points. 360 * @param neighbors Matrix holding output neighbors. 361 * @param distances Matrix holding output distances. 362 */ 363 void BaseCase(const size_t queryIndex, 364 const arma::uvec& referenceIndices, 365 const size_t k, 366 const MatType& querySet, 367 arma::Mat<size_t>& neighbors, 368 arma::mat& distances) const; 369 370 /** 371 * This function implements the core idea behind Multiprobe LSH. It is called 372 * by ReturnIndicesFromTables when T > 0. Given a query's code and its 373 * projection location, GetAdditionalProbingBins will calculate the T most 374 * likely alternative bin codes (other than queryCode) where a query's 375 * neighbors might be found in. 376 * 377 * @param queryCode vector containing the numProj-dimensional query code. 378 * @param queryCodeNotFloored vector containing the projection location of the 379 * query. 380 * @param T number of additional probing bins. 381 * @param additionalProbingBins matrix. Each column will hold one additional 382 * bin. 383 */ 384 void GetAdditionalProbingBins(const arma::vec& queryCode, 385 const arma::vec& queryCodeNotFloored, 386 const size_t T, 387 arma::mat& additionalProbingBins) const; 388 389 /** 390 * Returns the score of a perturbation vector generated by perturbation set A. 391 * The score of a pertubation set (vector) is the sum of scores of the 392 * participating actions. 393 * @param A perturbation set to compute the score of. 394 * @param scores vector containing score of each perturbation. 395 */ 396 double PerturbationScore(const std::vector<bool>& A, 397 const arma::vec& scores) const; 398 399 /** 400 * Inline function used by GetAdditionalProbingBins. The vector shift 401 * operation replaces the largest element of a vector A with (largest element) 402 * + 1. Returns true if resulting vector is valid, otherwise false. 403 * 404 * @param A perturbation set to shift. 405 */ 406 bool PerturbationShift(std::vector<bool>& A) const; 407 408 /** 409 * Inline function used by GetAdditionalProbingBins. The vector expansion 410 * operation adds the element [1 + (largest_element)] to a vector A, where 411 * largest_element is the largest element of A. Returns true if resulting 412 * vector is valid, otherwise false. 413 * 414 * @param A perturbation set to expand. 415 */ 416 bool PerturbationExpand(std::vector<bool>& A) const; 417 418 /** 419 * Return true if perturbation set A is valid. A perturbation set is invalid 420 * if it contains two (or more) actions for the same dimension or dimensions 421 * that are larger than the queryCode's dimensions. 422 * 423 * @param A perturbation set to validate. 424 */ 425 bool PerturbationValid(const std::vector<bool>& A) const; 426 427 //! Reference dataset. 428 MatType referenceSet; 429 430 //! The number of projections. 431 size_t numProj; 432 //! The number of hash tables. 433 size_t numTables; 434 435 //! The arma::cube containing the projection matrix of each table. 436 arma::cube projections; // should be [numProj x dims] x numTables slices 437 438 //! The list of the offsets 'b' for each of the projection for each table. 439 arma::mat offsets; // should be numProj x numTables 440 441 //! The hash width. 442 double hashWidth; 443 444 //! The big prime representing the size of the second hash. 445 size_t secondHashSize; 446 447 //! The weights of the second hash. 448 arma::vec secondHashWeights; 449 450 //! The bucket size of the second hash. 451 size_t bucketSize; 452 453 //! The final hash table; should be (< secondHashSize) vectors each with 454 //! (<= bucketSize) elements. 455 std::vector<arma::Col<size_t>> secondHashTable; 456 457 //! The number of elements present in each hash bucket; should be 458 //! secondHashSize. 459 arma::Col<size_t> bucketContentSize; 460 461 //! For a particular hash value, points to the row in secondHashTable 462 //! corresponding to this value. Length secondHashSize. 463 arma::Col<size_t> bucketRowInHashTable; 464 465 //! The number of distance evaluations. 466 size_t distanceEvaluations; 467 468 //! Candidate represents a possible candidate neighbor (distance, index). 469 typedef std::pair<double, size_t> Candidate; 470 471 //! Compare two candidates based on the distance. 472 struct CandidateCmp { operator ()mlpack::neighbor::LSHSearch::CandidateCmp473 bool operator()(const Candidate& c1, const Candidate& c2) 474 { 475 return !SortPolicy::IsBetter(c2.first, c1.first); 476 }; 477 }; 478 479 //! Use a priority queue to represent the list of candidate neighbors. 480 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp> 481 CandidateList; 482 }; // class LSHSearch 483 484 } // namespace neighbor 485 } // namespace mlpack 486 487 //! Set the serialization version of the LSHSearch class. 488 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>, 489 mlpack::neighbor::LSHSearch<SortPolicy>, 1); 490 491 // Include implementation. 492 #include "lsh_search_impl.hpp" 493 494 #endif 495