1 /**
2 * @file methods/lsh/lsh_main.cpp
3 * @author Parikshit Ram
4 *
5 * This file computes the approximate nearest-neighbors using 2-stable
6 * Locality-sensitive Hashing.
7 *
8 * mlpack is free software; you may redistribute it and/or modify it under the
9 * terms of the 3-clause BSD license. You should have received a copy of the
10 * 3-clause BSD license along with mlpack. If not, see
11 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12 */
13 #include <mlpack/prereqs.hpp>
14 #include <mlpack/core/util/io.hpp>
15 #include <mlpack/core/util/mlpack_main.hpp>
16
17 #include <mlpack/core/metrics/lmetric.hpp>
18
19 #include "lsh_search.hpp"
20
21 using namespace std;
22 using namespace mlpack;
23 using namespace mlpack::neighbor;
24 using namespace mlpack::util;
25
26 // Program Name.
27 BINDING_NAME("K-Approximate-Nearest-Neighbor Search with LSH");
28
29 // Short description.
30 BINDING_SHORT_DESC(
31 "An implementation of approximate k-nearest-neighbor search with "
32 "locality-sensitive hashing (LSH). Given a set of reference points and a "
33 "set of query points, this will compute the k approximate nearest neighbors"
34 " of each query point in the reference set; models can be saved for future "
35 "use.");
36
37 // Long description.
38 BINDING_LONG_DESC(
39 "This program will calculate the k approximate-nearest-neighbors of a set "
40 "of points using locality-sensitive hashing. You may specify a separate set"
41 " of reference points and query points, or just a reference set which will "
42 "be used as both the reference and query set. ");
43
44 // Example.
45 BINDING_EXAMPLE(
46 "For example, the following will return 5 neighbors from the data for each "
47 "point in " + PRINT_DATASET("input") + " and store the distances in " +
48 PRINT_DATASET("distances") + " and the neighbors in " +
49 PRINT_DATASET("neighbors") + ":"
50 "\n\n" +
51 PRINT_CALL("lsh", "k", 5, "reference", "input", "distances", "distances",
52 "neighbors", "neighbors") +
53 "\n\n"
54 "The output is organized such that row i and column j in the neighbors "
55 "output corresponds to the index of the point in the reference set which "
56 "is the j'th nearest neighbor from the point in the query set with index "
57 "i. Row j and column i in the distances output file corresponds to the "
58 "distance between those two points."
59 "\n\n"
60 "Because this is approximate-nearest-neighbors search, results may be "
61 "different from run to run. Thus, the " + PRINT_PARAM_STRING("seed") +
62 " parameter can be specified to set the random seed."
63 "\n\n"
64 "This program also has many other parameters to control its functionality;"
65 " see the parameter-specific documentation for more information.");
66
67 // See also...
68 BINDING_SEE_ALSO("@knn", "#knn");
69 BINDING_SEE_ALSO("@krann", "#krann");
70 BINDING_SEE_ALSO("Locality-sensitive hashing on Wikipedia",
71 "https://en.wikipedia.org/wiki/Locality-sensitive_hashing");
72 BINDING_SEE_ALSO("Locality-sensitive hashing scheme based on p-stable"
73 " distributions(pdf)", "http://mlpack.org/papers/lsh.pdf");
74 BINDING_SEE_ALSO("mlpack::neighbor::LSHSearch C++ class documentation",
75 "@doxygen/classmlpack_1_1neighbor_1_1LSHSearch.html");
76
77 // Define our input parameters that this program will take.
78 PARAM_MATRIX_IN("reference", "Matrix containing the reference dataset.", "r");
79 PARAM_MATRIX_OUT("distances", "Matrix to output distances into.", "d");
80 PARAM_UMATRIX_OUT("neighbors", "Matrix to output neighbors into.", "n");
81
82 // We can load or save models.
83 PARAM_MODEL_IN(LSHSearch<>, "input_model", "Input LSH model.", "m");
84 PARAM_MODEL_OUT(LSHSearch<>, "output_model", "Output for trained LSH model.",
85 "M");
86
87 // For testing recall.
88 PARAM_UMATRIX_IN("true_neighbors", "Matrix of true neighbors to compute "
89 "recall with (the recall is printed when -v is specified).", "t");
90
91 PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
92 PARAM_MATRIX_IN("query", "Matrix containing query points (optional).", "q");
93
94 PARAM_INT_IN("projections", "The number of hash functions for each table", "K",
95 10);
96 PARAM_INT_IN("tables", "The number of hash tables to be used.", "L", 30);
97 PARAM_DOUBLE_IN("hash_width", "The hash width for the first-level hashing in "
98 "the LSH preprocessing. By default, the LSH class automatically estimates "
99 "a hash width for its use.", "H", 0.0);
100 PARAM_INT_IN("num_probes", "Number of additional probes for multiprobe LSH; if "
101 "0, traditional LSH is used.", "T", 0);
102 PARAM_INT_IN("second_hash_size", "The size of the second level hash table.",
103 "S", 99901);
104 PARAM_INT_IN("bucket_size", "The size of a bucket in the second level hash.",
105 "B", 500);
106 PARAM_INT_IN("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
107
mlpackMain()108 static void mlpackMain()
109 {
110 if (IO::GetParam<int>("seed") != 0)
111 math::RandomSeed((size_t) IO::GetParam<int>("seed"));
112 else
113 math::RandomSeed((size_t) time(NULL));
114
115 // Get all the parameters after checking them.
116 if (IO::HasParam("k"))
117 {
118 RequireParamValue<int>("k", [](int x) { return x > 0; }, true,
119 "k must be greater than 0");
120 }
121 RequireParamValue<int>("second_hash_size", [](int x) { return x > 0; }, true,
122 "second hash size must be greater than 0");
123 RequireParamValue<int>("bucket_size", [](int x) { return x > 0; }, true,
124 "bucket size must be greater than 0");
125
126 size_t k = IO::GetParam<int>("k");
127 size_t secondHashSize = IO::GetParam<int>("second_hash_size");
128 size_t bucketSize = IO::GetParam<int>("bucket_size");
129
130 RequireOnlyOnePassed({ "input_model", "reference" }, true);
131 RequireAtLeastOnePassed({ "neighbors", "distances", "output_model" }, false,
132 "no results will be saved");
133 if (IO::HasParam("k"))
134 {
135 RequireAtLeastOnePassed({ "query", "reference", "input_model" }, true,
136 "must pass set to search");
137 }
138
139 if (IO::HasParam("input_model") && IO::HasParam("k") &&
140 !IO::HasParam("query"))
141 {
142 Log::Info << "Performing LSH-based approximate nearest neighbor search on "
143 << "the reference dataset in the model stored in '"
144 << IO::GetPrintableParam<LSHSearch<>>("input_model") << "'." << endl;
145 }
146
147 ReportIgnoredParam({{ "k", false }}, "neighbors");
148 ReportIgnoredParam({{ "k", false }}, "distances");
149
150 ReportIgnoredParam({{ "reference", false }}, "bucket_size");
151 ReportIgnoredParam({{ "reference", false }}, "second_hash_size");
152 ReportIgnoredParam({{ "reference", false }}, "hash_width");
153
154 if (IO::HasParam("input_model") && !IO::HasParam("k"))
155 {
156 Log::Warn << PRINT_PARAM_STRING("k") << " not passed; no search will be "
157 << "performed!" << std::endl;
158 }
159
160 // These declarations are here so that the matrices don't go out of scope.
161 arma::mat referenceData;
162 arma::mat queryData;
163
164 // Pick up the LSH-specific parameters.
165 const size_t numProj = IO::GetParam<int>("projections");
166 const size_t numTables = IO::GetParam<int>("tables");
167 const double hashWidth = IO::GetParam<double>("hash_width");
168 const size_t numProbes = (size_t) IO::GetParam<int>("num_probes");
169
170 arma::Mat<size_t> neighbors;
171 arma::mat distances;
172
173 if (hashWidth == 0.0)
174 Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
175 numTables << " tables (L) with default hash width." << endl;
176 else
177 Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
178 numTables << " tables (L) with hash width (r): " << hashWidth << endl;
179
180 LSHSearch<>* allkann;
181 if (IO::HasParam("reference"))
182 {
183 allkann = new LSHSearch<>();
184 Log::Info << "Using reference data from "
185 << IO::GetPrintableParam<arma::mat>("reference") << "." << endl;
186 referenceData = std::move(IO::GetParam<arma::mat>("reference"));
187
188 Timer::Start("hash_building");
189 allkann->Train(std::move(referenceData), numProj, numTables, hashWidth,
190 secondHashSize, bucketSize);
191 Timer::Stop("hash_building");
192 }
193 else // We must have an input model.
194 {
195 allkann = IO::GetParam<LSHSearch<>*>("input_model");
196 }
197
198 if (IO::HasParam("k"))
199 {
200 Log::Info << "Computing " << k << " distance approximate nearest neighbors."
201 << endl;
202 if (IO::HasParam("query"))
203 {
204 Log::Info << "Loaded query data from "
205 << IO::GetPrintableParam<arma::mat>("query") << "." << endl;
206 queryData = std::move(IO::GetParam<arma::mat>("query"));
207
208 allkann->Search(queryData, k, neighbors, distances, 0, numProbes);
209 }
210 else
211 {
212 allkann->Search(k, neighbors, distances, 0, numProbes);
213 }
214
215 Log::Info << "Neighbors computed." << endl;
216 }
217
218 // Compute recall, if desired.
219 if (IO::HasParam("true_neighbors"))
220 {
221 Log::Info << "Using true neighbor indices from '"
222 << IO::GetPrintableParam<arma::Mat<size_t>>("true_neighbors") << "'."
223 << endl;
224
225 // Load the true neighbors.
226 arma::Mat<size_t> trueNeighbors =
227 std::move(IO::GetParam<arma::Mat<size_t>>("true_neighbors"));
228
229 if (trueNeighbors.n_rows != neighbors.n_rows ||
230 trueNeighbors.n_cols != neighbors.n_cols)
231 {
232 // Delete the model if needed.
233 if (IO::HasParam("reference"))
234 delete allkann;
235 Log::Fatal << "The true neighbors file must have the same number of "
236 << "values as the set of neighbors being queried!" << endl;
237 }
238
239 // Compute recall and print it.
240 double recallPercentage = 100 * allkann->ComputeRecall(neighbors,
241 trueNeighbors);
242
243 Log::Info << "Recall: " << recallPercentage << endl;
244 }
245
246 // Save output, if we did a search..
247 if (IO::HasParam("k"))
248 {
249 IO::GetParam<arma::mat>("distances") = std::move(distances);
250 IO::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors);
251 }
252 IO::GetParam<LSHSearch<>*>("output_model") = allkann;
253 }
254