1 /**
2 * @file tests/main_tests/lsh_test.cpp
3 * @author Manish Kumar
4 *
5 * Test mlpackMain() of lsh_main.cpp.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #include <string>
13
14 #define BINDING_TYPE BINDING_TYPE_TEST
15 static const std::string testName = "LSH";
16
17 #include <mlpack/core.hpp>
18 #include <mlpack/core/util/mlpack_main.hpp>
19 #include "test_helper.hpp"
20 #include <mlpack/methods/lsh/lsh_main.cpp>
21
22 #include <boost/test/unit_test.hpp>
23 #include "../test_tools.hpp"
24
25 using namespace mlpack;
26
27 struct LSHTestFixture
28 {
29 public:
LSHTestFixtureLSHTestFixture30 LSHTestFixture()
31 {
32 // Cache in the options for this program.
33 IO::RestoreSettings(testName);
34 }
35
~LSHTestFixtureLSHTestFixture36 ~LSHTestFixture()
37 {
38 // Clear the settings.
39 bindings::tests::CleanMemory();
40 IO::ClearSettings();
41 }
42 };
43
44 BOOST_FIXTURE_TEST_SUITE(LSHMainTest, LSHTestFixture);
45
46 /**
47 * Check that output neighbors and distances have valid dimensions.
48 */
BOOST_AUTO_TEST_CASE(LSHOutputDimensionTest)49 BOOST_AUTO_TEST_CASE(LSHOutputDimensionTest)
50 {
51 arma::mat reference = arma::randu<arma::mat>(5, 100);
52
53 SetInputParam("reference", std::move(reference));
54 SetInputParam("k", (int) 6);
55
56 mlpackMain();
57
58 // Check the neighbors matrix has 6 points for each of the 100 input points.
59 BOOST_REQUIRE_EQUAL(IO::GetParam<arma::Mat<size_t>>("neighbors").n_rows, 6);
60 BOOST_REQUIRE_EQUAL(IO::GetParam<arma::Mat<size_t>>("neighbors").n_cols,
61 100);
62
63 // Check the distances matrix has 6 points for each of the 100 input points.
64 BOOST_REQUIRE_EQUAL(IO::GetParam<arma::mat>("distances").n_rows, 6);
65 BOOST_REQUIRE_EQUAL(IO::GetParam<arma::mat>("distances").n_cols, 100);
66 }
67
68 /**
69 * Ensure that bucket_size, second_hash_size & number of nearest neighbors
70 * are always positive.
71 */
BOOST_AUTO_TEST_CASE(LSHParamValidityTest)72 BOOST_AUTO_TEST_CASE(LSHParamValidityTest)
73 {
74 arma::mat reference = arma::randu<arma::mat>(5, 100);
75
76 // Test for bucket_size.
77
78 SetInputParam("reference", reference);
79 SetInputParam("k", (int) 6);
80 SetInputParam("bucket_size", (int) -1.0);
81
82 Log::Fatal.ignoreInput = true;
83 BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error);
84 Log::Fatal.ignoreInput = false;
85
86 bindings::tests::CleanMemory();
87
88 // Test for second_hash_size.
89
90 SetInputParam("reference", reference);
91 SetInputParam("k", (int) 6);
92 SetInputParam("second_hash_size", (int) -1.0);
93
94 Log::Fatal.ignoreInput = true;
95 BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error);
96 Log::Fatal.ignoreInput = false;
97
98 bindings::tests::CleanMemory();
99
100 // Test for number of nearest neighbors.
101
102 SetInputParam("reference", std::move(reference));
103 SetInputParam("k", (int) -2);
104
105 Log::Fatal.ignoreInput = true;
106 BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error);
107 Log::Fatal.ignoreInput = false;
108 }
109
110 /**
111 * Make sure only one of reference data or pre-trained model is passed.
112 */
BOOST_AUTO_TEST_CASE(LSHModelValidityTest)113 BOOST_AUTO_TEST_CASE(LSHModelValidityTest)
114 {
115 arma::mat reference = arma::randu<arma::mat>(5, 100);
116
117 SetInputParam("reference", std::move(reference));
118 SetInputParam("k", (int) 6);
119
120 mlpackMain();
121
122 SetInputParam("input_model", IO::GetParam<LSHSearch<>*>("output_model"));
123
124 Log::Fatal.ignoreInput = true;
125 BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error);
126 Log::Fatal.ignoreInput = false;
127 }
128
129 /**
130 * Check learning process using different tables.
131 */
BOOST_AUTO_TEST_CASE(LSHDiffTablesTest)132 BOOST_AUTO_TEST_CASE(LSHDiffTablesTest)
133 {
134 arma::mat reference = arma::randu<arma::mat>(5, 100);
135
136 SetInputParam("reference", reference);
137 SetInputParam("k", (int) 6);
138
139 mlpack::math::FixedRandomSeed();
140 mlpackMain();
141
142 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
143 arma::mat distances = IO::GetParam<arma::mat>("distances");
144
145 bindings::tests::CleanMemory();
146
147 // Train model using tables equals to 40.
148
149 SetInputParam("reference", std::move(reference));
150 SetInputParam("k", (int) 6);
151 SetInputParam("tables", (int) 40);
152
153 mlpack::math::FixedRandomSeed();
154 mlpackMain();
155
156 // Check that initial outputs and final outputs using two models are
157 // different.
158 BOOST_REQUIRE_LT(arma::accu(neighbors ==
159 IO::GetParam<arma::Mat<size_t>>("neighbors")), neighbors.n_elem);
160 BOOST_REQUIRE_LT(arma::accu(distances ==
161 IO::GetParam<arma::mat>("distances")), distances.n_elem);
162 }
163
164 /**
165 * Check learning process using different projections.
166 */
BOOST_AUTO_TEST_CASE(LSHDiffProjectionsTest)167 BOOST_AUTO_TEST_CASE(LSHDiffProjectionsTest)
168 {
169 arma::mat reference = arma::randu<arma::mat>(5, 100);
170
171 SetInputParam("reference", reference);
172 SetInputParam("k", (int) 6);
173
174 mlpack::math::FixedRandomSeed();
175 mlpackMain();
176
177 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
178 arma::mat distances = IO::GetParam<arma::mat>("distances");
179
180 bindings::tests::CleanMemory();
181
182 // Train model using projections equals to 30.
183
184 SetInputParam("reference", std::move(reference));
185 SetInputParam("k", (int) 6);
186 SetInputParam("projections", (int) 30);
187
188 mlpack::math::FixedRandomSeed();
189 mlpackMain();
190
191 // Check that initial outputs and final outputs using two models are
192 // different.
193 BOOST_REQUIRE_LT(arma::accu(neighbors ==
194 IO::GetParam<arma::Mat<size_t>>("neighbors")), neighbors.n_elem);
195 BOOST_REQUIRE_LT(arma::accu(distances ==
196 IO::GetParam<arma::mat>("distances")), distances.n_elem);
197 }
198
199 /**
200 * Check learning process using different hash_width.
201 */
BOOST_AUTO_TEST_CASE(LSHDiffHashWidthTest)202 BOOST_AUTO_TEST_CASE(LSHDiffHashWidthTest)
203 {
204 arma::mat reference = arma::randu<arma::mat>(5, 100);
205
206 SetInputParam("reference", reference);
207 SetInputParam("k", (int) 6);
208
209 mlpack::math::FixedRandomSeed();
210 mlpackMain();
211
212 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
213 arma::mat distances = IO::GetParam<arma::mat>("distances");
214
215 bindings::tests::CleanMemory();
216
217 // Train model using hash_width equals to 0.5.
218
219 SetInputParam("reference", std::move(reference));
220 SetInputParam("k", (int) 6);
221 SetInputParam("hash_width", (double) 0.5);
222
223 mlpack::math::FixedRandomSeed();
224 mlpackMain();
225
226 // Check that initial outputs and final outputs using two models are
227 // different.
228 BOOST_REQUIRE_LT(arma::accu(neighbors ==
229 IO::GetParam<arma::Mat<size_t>>("neighbors")), neighbors.n_elem);
230 BOOST_REQUIRE_LT(arma::accu(distances ==
231 IO::GetParam<arma::mat>("distances")), distances.n_elem);
232 }
233
234 /**
235 * Check learning process using different num_probes.
236 */
BOOST_AUTO_TEST_CASE(LSHDiffNumProbesTest)237 BOOST_AUTO_TEST_CASE(LSHDiffNumProbesTest)
238 {
239 arma::mat reference = arma::randu<arma::mat>(5, 100);
240 arma::mat query = arma::randu<arma::mat>(5, 40);
241
242 SetInputParam("reference", std::move(reference));
243 SetInputParam("query", query);
244 SetInputParam("k", (int) 6);
245
246 mlpackMain();
247
248 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
249 arma::mat distances = IO::GetParam<arma::mat>("distances");
250
251 IO::GetSingleton().Parameters()["reference"].wasPassed = false;
252
253 // Train model using num_probes equals to 5.
254
255 SetInputParam("input_model", IO::GetParam<LSHSearch<>*>("output_model"));
256 SetInputParam("query", std::move(query));
257 SetInputParam("num_probes", (int) 5);
258
259 mlpackMain();
260
261 // Check that initial outputs and final outputs using two models are
262 // different.
263 BOOST_REQUIRE_LT(arma::accu(neighbors ==
264 IO::GetParam<arma::Mat<size_t>>("neighbors")), neighbors.n_elem);
265 BOOST_REQUIRE_LT(arma::accu(distances ==
266 IO::GetParam<arma::mat>("distances")), distances.n_elem);
267 }
268
269 /**
270 * Check learning process using different second_hash_size.
271 */
BOOST_AUTO_TEST_CASE(LSHDiffSecondHashSizeTest)272 BOOST_AUTO_TEST_CASE(LSHDiffSecondHashSizeTest)
273 {
274 arma::mat reference = arma::randu<arma::mat>(5, 100);
275
276 SetInputParam("reference", reference);
277 SetInputParam("k", (int) 6);
278
279 mlpack::math::FixedRandomSeed();
280 mlpackMain();
281
282 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
283 arma::mat distances = IO::GetParam<arma::mat>("distances");
284
285 bindings::tests::CleanMemory();
286
287 // Train model using second_hash_size equals to 5000.
288
289 SetInputParam("reference", std::move(reference));
290 SetInputParam("k", (int) 6);
291 SetInputParam("second_hash_size", (int) 5000);
292
293 mlpack::math::FixedRandomSeed();
294 mlpackMain();
295
296 // Check that initial outputs and final outputs using two models are
297 // different.
298 BOOST_REQUIRE_LT(arma::accu(neighbors ==
299 IO::GetParam<arma::Mat<size_t>>("neighbors")), neighbors.n_elem);
300 BOOST_REQUIRE_LT(arma::accu(distances ==
301 IO::GetParam<arma::mat>("distances")), distances.n_elem);
302 }
303
304 /**
305 * Check learning process using different bucket sizes.
306 */
BOOST_AUTO_TEST_CASE(LSHDiffBucketSizeTest)307 BOOST_AUTO_TEST_CASE(LSHDiffBucketSizeTest)
308 {
309 arma::mat reference = arma::randu<arma::mat>(5, 100);
310
311 SetInputParam("reference", reference);
312 SetInputParam("k", (int) 6);
313
314 mlpack::math::FixedRandomSeed();
315 mlpackMain();
316
317 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
318 arma::mat distances = IO::GetParam<arma::mat>("distances");
319
320 bindings::tests::CleanMemory();
321
322 // Train model using bucket_size equals to 1000.
323
324 SetInputParam("reference", std::move(reference));
325 SetInputParam("k", (int) 6);
326 SetInputParam("bucket_size", (int) 1);
327
328 mlpack::math::FixedRandomSeed();
329 mlpackMain();
330
331 // Check that initial outputs and final outputs using the two models are
332 // different.
333 BOOST_REQUIRE_LT(arma::accu(neighbors ==
334 IO::GetParam<arma::Mat<size_t>>("neighbors")), neighbors.n_elem);
335 BOOST_REQUIRE_LT(arma::accu(distances ==
336 IO::GetParam<arma::mat>("distances")), distances.n_elem);
337 }
338
339 /**
340 * Check that saved model can be reused again.
341 */
BOOST_AUTO_TEST_CASE(LSHModelReuseTest)342 BOOST_AUTO_TEST_CASE(LSHModelReuseTest)
343 {
344 arma::mat reference = arma::randu<arma::mat>(5, 100);
345 arma::mat query = arma::randu<arma::mat>(5, 40);
346
347 SetInputParam("reference", std::move(reference));
348 SetInputParam("query", query);
349 SetInputParam("k", (int) 6);
350
351 mlpackMain();
352
353 arma::Mat<size_t> neighbors = IO::GetParam<arma::Mat<size_t>>("neighbors");
354 arma::mat distances = IO::GetParam<arma::mat>("distances");
355
356 IO::GetSingleton().Parameters()["reference"].wasPassed = false;
357
358 SetInputParam("input_model", IO::GetParam<LSHSearch<>*>("output_model"));
359 SetInputParam("query", std::move(query));
360
361 mlpackMain();
362
363 // Check that initial query outputs and final outputs using saved model are
364 // same.
365 CheckMatrices(neighbors, IO::GetParam<arma::Mat<size_t>>("neighbors"));
366 CheckMatrices(distances, IO::GetParam<arma::mat>("distances"));
367 }
368
369 /**
370 * Make sure true_neighbors have valid dimensions.
371 */
BOOST_AUTO_TEST_CASE(LSHModelTrueNighborsDimTest)372 BOOST_AUTO_TEST_CASE(LSHModelTrueNighborsDimTest)
373 {
374 arma::mat reference = arma::randu<arma::mat>(5, 100);
375
376 // Initalize trueNeighbors with invalid dimensions.
377 arma::Mat<size_t> trueNeighbors = arma::randu<arma::Mat<size_t>>(7, 100);
378
379 SetInputParam("reference", std::move(reference));
380 SetInputParam("true_neighbors", std::move(trueNeighbors));
381 SetInputParam("k", (int) 6);
382
383 Log::Fatal.ignoreInput = true;
384 BOOST_REQUIRE_THROW(mlpackMain(), std::runtime_error);
385 Log::Fatal.ignoreInput = false;
386 }
387
388 BOOST_AUTO_TEST_SUITE_END();
389