1 /**
2  * @file methods/lsh/lsh_main.cpp
3  * @author Parikshit Ram
4  *
5  * This file computes the approximate nearest-neighbors using 2-stable
6  * Locality-sensitive Hashing.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #include <mlpack/prereqs.hpp>
14 #include <mlpack/core/util/io.hpp>
15 #include <mlpack/core/util/mlpack_main.hpp>
16 
17 #include <mlpack/core/metrics/lmetric.hpp>
18 
19 #include "lsh_search.hpp"
20 
21 using namespace std;
22 using namespace mlpack;
23 using namespace mlpack::neighbor;
24 using namespace mlpack::util;
25 
26 // Program Name.
27 BINDING_NAME("K-Approximate-Nearest-Neighbor Search with LSH");
28 
29 // Short description.
30 BINDING_SHORT_DESC(
31     "An implementation of approximate k-nearest-neighbor search with "
32     "locality-sensitive hashing (LSH).  Given a set of reference points and a "
33     "set of query points, this will compute the k approximate nearest neighbors"
34     " of each query point in the reference set; models can be saved for future "
35     "use.");
36 
37 // Long description.
38 BINDING_LONG_DESC(
39     "This program will calculate the k approximate-nearest-neighbors of a set "
40     "of points using locality-sensitive hashing. You may specify a separate set"
41     " of reference points and query points, or just a reference set which will "
42     "be used as both the reference and query set. ");
43 
44 // Example.
45 BINDING_EXAMPLE(
46     "For example, the following will return 5 neighbors from the data for each "
47     "point in " + PRINT_DATASET("input") + " and store the distances in " +
48     PRINT_DATASET("distances") + " and the neighbors in " +
49     PRINT_DATASET("neighbors") + ":"
50     "\n\n" +
51     PRINT_CALL("lsh", "k", 5, "reference", "input", "distances", "distances",
52         "neighbors", "neighbors") +
53     "\n\n"
54     "The output is organized such that row i and column j in the neighbors "
55     "output corresponds to the index of the point in the reference set which "
56     "is the j'th nearest neighbor from the point in the query set with index "
57     "i.  Row j and column i in the distances output file corresponds to the "
58     "distance between those two points."
59     "\n\n"
60     "Because this is approximate-nearest-neighbors search, results may be "
61     "different from run to run.  Thus, the " + PRINT_PARAM_STRING("seed") +
62     " parameter can be specified to set the random seed."
63     "\n\n"
64     "This program also has many other parameters to control its functionality;"
65     " see the parameter-specific documentation for more information.");
66 
67 // See also...
68 BINDING_SEE_ALSO("@knn", "#knn");
69 BINDING_SEE_ALSO("@krann", "#krann");
70 BINDING_SEE_ALSO("Locality-sensitive hashing on Wikipedia",
71         "https://en.wikipedia.org/wiki/Locality-sensitive_hashing");
72 BINDING_SEE_ALSO("Locality-sensitive hashing scheme based on p-stable"
73         "  distributions(pdf)", "http://mlpack.org/papers/lsh.pdf");
74 BINDING_SEE_ALSO("mlpack::neighbor::LSHSearch C++ class documentation",
75         "@doxygen/classmlpack_1_1neighbor_1_1LSHSearch.html");
76 
77 // Define our input parameters that this program will take.
78 PARAM_MATRIX_IN("reference", "Matrix containing the reference dataset.", "r");
79 PARAM_MATRIX_OUT("distances", "Matrix to output distances into.", "d");
80 PARAM_UMATRIX_OUT("neighbors", "Matrix to output neighbors into.", "n");
81 
82 // We can load or save models.
83 PARAM_MODEL_IN(LSHSearch<>, "input_model", "Input LSH model.", "m");
84 PARAM_MODEL_OUT(LSHSearch<>, "output_model", "Output for trained LSH model.",
85     "M");
86 
87 // For testing recall.
88 PARAM_UMATRIX_IN("true_neighbors", "Matrix of true neighbors to compute "
89     "recall with (the recall is printed when -v is specified).", "t");
90 
91 PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
92 PARAM_MATRIX_IN("query", "Matrix containing query points (optional).", "q");
93 
94 PARAM_INT_IN("projections", "The number of hash functions for each table", "K",
95     10);
96 PARAM_INT_IN("tables", "The number of hash tables to be used.", "L", 30);
97 PARAM_DOUBLE_IN("hash_width", "The hash width for the first-level hashing in "
98     "the LSH preprocessing. By default, the LSH class automatically estimates "
99     "a hash width for its use.", "H", 0.0);
100 PARAM_INT_IN("num_probes", "Number of additional probes for multiprobe LSH; if "
101     "0, traditional LSH is used.", "T", 0);
102 PARAM_INT_IN("second_hash_size", "The size of the second level hash table.",
103     "S", 99901);
104 PARAM_INT_IN("bucket_size", "The size of a bucket in the second level hash.",
105     "B", 500);
106 PARAM_INT_IN("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
107 
mlpackMain()108 static void mlpackMain()
109 {
110   if (IO::GetParam<int>("seed") != 0)
111     math::RandomSeed((size_t) IO::GetParam<int>("seed"));
112   else
113     math::RandomSeed((size_t) time(NULL));
114 
115   // Get all the parameters after checking them.
116   if (IO::HasParam("k"))
117   {
118     RequireParamValue<int>("k", [](int x) { return x > 0; }, true,
119         "k must be greater than 0");
120   }
121   RequireParamValue<int>("second_hash_size", [](int x) { return x > 0; }, true,
122       "second hash size must be greater than 0");
123   RequireParamValue<int>("bucket_size", [](int x) { return x > 0; }, true,
124       "bucket size must be greater than 0");
125 
126   size_t k = IO::GetParam<int>("k");
127   size_t secondHashSize = IO::GetParam<int>("second_hash_size");
128   size_t bucketSize = IO::GetParam<int>("bucket_size");
129 
130   RequireOnlyOnePassed({ "input_model", "reference" }, true);
131   RequireAtLeastOnePassed({ "neighbors", "distances", "output_model" }, false,
132       "no results will be saved");
133   if (IO::HasParam("k"))
134   {
135     RequireAtLeastOnePassed({ "query", "reference", "input_model" }, true,
136         "must pass set to search");
137   }
138 
139   if (IO::HasParam("input_model") && IO::HasParam("k") &&
140       !IO::HasParam("query"))
141   {
142     Log::Info << "Performing LSH-based approximate nearest neighbor search on "
143         << "the reference dataset in the model stored in '"
144         << IO::GetPrintableParam<LSHSearch<>>("input_model") << "'." << endl;
145   }
146 
147   ReportIgnoredParam({{ "k", false }}, "neighbors");
148   ReportIgnoredParam({{ "k", false }}, "distances");
149 
150   ReportIgnoredParam({{ "reference", false }}, "bucket_size");
151   ReportIgnoredParam({{ "reference", false }}, "second_hash_size");
152   ReportIgnoredParam({{ "reference", false }}, "hash_width");
153 
154   if (IO::HasParam("input_model") && !IO::HasParam("k"))
155   {
156     Log::Warn << PRINT_PARAM_STRING("k") << " not passed; no search will be "
157         << "performed!" << std::endl;
158   }
159 
160   // These declarations are here so that the matrices don't go out of scope.
161   arma::mat referenceData;
162   arma::mat queryData;
163 
164   // Pick up the LSH-specific parameters.
165   const size_t numProj = IO::GetParam<int>("projections");
166   const size_t numTables = IO::GetParam<int>("tables");
167   const double hashWidth = IO::GetParam<double>("hash_width");
168   const size_t numProbes = (size_t) IO::GetParam<int>("num_probes");
169 
170   arma::Mat<size_t> neighbors;
171   arma::mat distances;
172 
173   if (hashWidth == 0.0)
174     Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
175         numTables << " tables (L) with default hash width." << endl;
176   else
177     Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
178         numTables << " tables (L) with hash width (r): " << hashWidth << endl;
179 
180   LSHSearch<>* allkann;
181   if (IO::HasParam("reference"))
182   {
183     allkann = new LSHSearch<>();
184     Log::Info << "Using reference data from "
185         << IO::GetPrintableParam<arma::mat>("reference") << "." << endl;
186     referenceData = std::move(IO::GetParam<arma::mat>("reference"));
187 
188     Timer::Start("hash_building");
189     allkann->Train(std::move(referenceData), numProj, numTables, hashWidth,
190         secondHashSize, bucketSize);
191     Timer::Stop("hash_building");
192   }
193   else // We must have an input model.
194   {
195     allkann = IO::GetParam<LSHSearch<>*>("input_model");
196   }
197 
198   if (IO::HasParam("k"))
199   {
200     Log::Info << "Computing " << k << " distance approximate nearest neighbors."
201         << endl;
202     if (IO::HasParam("query"))
203     {
204       Log::Info << "Loaded query data from "
205           << IO::GetPrintableParam<arma::mat>("query") << "." << endl;
206       queryData = std::move(IO::GetParam<arma::mat>("query"));
207 
208       allkann->Search(queryData, k, neighbors, distances, 0, numProbes);
209     }
210     else
211     {
212       allkann->Search(k, neighbors, distances, 0, numProbes);
213     }
214 
215     Log::Info << "Neighbors computed." << endl;
216   }
217 
218   // Compute recall, if desired.
219   if (IO::HasParam("true_neighbors"))
220   {
221     Log::Info << "Using true neighbor indices from '"
222         << IO::GetPrintableParam<arma::Mat<size_t>>("true_neighbors") << "'."
223         << endl;
224 
225     // Load the true neighbors.
226     arma::Mat<size_t> trueNeighbors =
227         std::move(IO::GetParam<arma::Mat<size_t>>("true_neighbors"));
228 
229     if (trueNeighbors.n_rows != neighbors.n_rows ||
230         trueNeighbors.n_cols != neighbors.n_cols)
231     {
232       // Delete the model if needed.
233       if (IO::HasParam("reference"))
234         delete allkann;
235       Log::Fatal << "The true neighbors file must have the same number of "
236           << "values as the set of neighbors being queried!" << endl;
237     }
238 
239     // Compute recall and print it.
240     double recallPercentage = 100 * allkann->ComputeRecall(neighbors,
241         trueNeighbors);
242 
243     Log::Info << "Recall: " << recallPercentage << endl;
244   }
245 
246   // Save output, if we did a search..
247   if (IO::HasParam("k"))
248   {
249     IO::GetParam<arma::mat>("distances") = std::move(distances);
250     IO::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors);
251   }
252   IO::GetParam<LSHSearch<>*>("output_model") = allkann;
253 }
254