1 /**
2 * @file tests/main_tests/kmeans_test.cpp
3 * @author Prabhat Sharma
4 *
5 * Test mlpackMain() of kmeans_main.cpp
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #include <string>
13
14 #define BINDING_TYPE BINDING_TYPE_TEST
15 static const std::string testName = "Kmeans";
16
17 #include <mlpack/core.hpp>
18 #include <mlpack/core/util/mlpack_main.hpp>
19 #include "test_helper.hpp"
20 #include <mlpack/methods/kmeans/kmeans_main.cpp>
21
22 #include "../catch.hpp"
23 #include "../test_catch_tools.hpp"
24
25 using namespace mlpack;
26
27 struct KmTestFixture
28 {
29 public:
KmTestFixtureKmTestFixture30 KmTestFixture()
31 {
32 // Cache in the options for this program.
33 IO::RestoreSettings(testName);
34 }
35
~KmTestFixtureKmTestFixture36 ~KmTestFixture()
37 {
38 // Clear the settings.
39 IO::ClearSettings();
40 }
41 };
42
ResetKmSettings()43 void ResetKmSettings()
44 {
45 IO::ClearSettings();
46 IO::RestoreSettings(testName);
47 }
48
49 /**
50 * Checking that number of Clusters are non negative
51 */
52 TEST_CASE_METHOD(KmTestFixture, "NonNegativeClustersTest",
53 "[KmeansMainTest][BindingTests]")
54 {
55 arma::mat inputData;
56 if (!data::Load("vc2.csv", inputData))
57 FAIL("Unable to load train dataset vc2.csv!");
58
59 SetInputParam("input", std::move(inputData));
60 SetInputParam("clusters", (int) -1); // Invalid
61
62 Log::Fatal.ignoreInput = true;
63 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
64 Log::Fatal.ignoreInput = false;
65 }
66
67
68 /**
69 * Checking that initial centroids are provided if clusters are to be auto detected
70 */
71 TEST_CASE_METHOD(KmTestFixture, "AutoDetectClusterTest",
72 "[KmeansMainTest][BindingTests]")
73 {
74 constexpr int N = 10;
75 constexpr int D = 4;
76
77 arma::mat inputData = arma::randu<arma::mat>(N, D);
78
79 SetInputParam("input", std::move(inputData));
80 SetInputParam("clusters", (int) 0); // Invalid
81
82 Log::Fatal.ignoreInput = true;
83 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
84 Log::Fatal.ignoreInput = false;
85 }
86
87
88 /**
89 * Checking that percentage is between 0 and 1 when --refined_start is specified
90 */
91 TEST_CASE_METHOD(KmTestFixture, "RefinedStartPercentageTest",
92 "[KmeansMainTest][BindingTests]")
93 {
94 int c = 2;
95 double P = 2.0;
96 arma::mat inputData;
97 if (!data::Load("vc2.csv", inputData))
98 FAIL("Unable to load train dataset vc2.csv!");
99
100 SetInputParam("input", std::move(inputData));
101 SetInputParam("refined_start", true);
102 SetInputParam("clusters", c);
103 SetInputParam("percentage", std::move(P)); // Invalid
104
105 Log::Fatal.ignoreInput = true;
106 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
107 Log::Fatal.ignoreInput = false;
108 }
109
110
111 /**
112 * Checking percentage is non-negative when --refined_start is specified
113 */
114 TEST_CASE_METHOD(KmTestFixture, "NonNegativePercentageTest",
115 "[KmeansMainTest][BindingTests]")
116 {
117 int c = 2;
118 double P = -1.0;
119 arma::mat inputData;
120 if (!data::Load("vc2.csv", inputData))
121 FAIL("Unable to load train dataset vc2.csv!");
122
123 SetInputParam("input", std::move(inputData));
124 SetInputParam("refined_start", true);
125 SetInputParam("clusters", c);
126 SetInputParam("percentage", P); // Invalid
127
128 Log::Fatal.ignoreInput = true;
129 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
130 Log::Fatal.ignoreInput = false;
131 }
132
133
134 /**
135 * Checking that size and dimensionality of prediction is correct.
136 */
137 TEST_CASE_METHOD(KmTestFixture, "KmClusteringSizeCheck",
138 "[KmeansMainTest][BindingTests]")
139 {
140 int c = 2;
141 arma::mat inputData;
142 if (!data::Load("vc2.csv", inputData))
143 FAIL("Unable to load train dataset vc2.csv!");
144
145 size_t col = inputData.n_cols;
146 size_t row = inputData.n_rows;
147
148 SetInputParam("input", std::move(inputData));
149 SetInputParam("clusters", c);
150
151 mlpackMain();
152
153 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == row+1);
154 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == col);
155 REQUIRE(IO::GetParam<arma::mat>("centroid").n_rows == row);
156 REQUIRE(IO::GetParam<arma::mat>("centroid").n_cols == c);
157 }
158
159 /**
160 * Checking that size and dimensionality of prediction is correct when --labels_only is specified
161 */
162 TEST_CASE_METHOD(KmTestFixture, "KmClusteringSizeCheckLabelOnly",
163 "[KmeansMainTest][BindingTests]")
164 {
165 int c = 2;
166
167 arma::mat inputData;
168 if (!data::Load("vc2.csv", inputData))
169 FAIL("Unable to load train dataset vc2.csv!");
170 size_t col = inputData.n_cols;
171 size_t row = inputData.n_rows;
172
173 SetInputParam("input", std::move(inputData));
174 SetInputParam("clusters", c);
175 SetInputParam("labels_only", true);
176
177 mlpackMain();
178
179 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 1);
180 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == col);
181 REQUIRE(IO::GetParam<arma::mat>("centroid").n_rows == row);
182 REQUIRE(IO::GetParam<arma::mat>("centroid").n_cols == c);
183 }
184
185
186 /**
187 * Checking that predictions are not same when --allow_empty_clusters or kill_empty_clusters are specified
188 */
189 TEST_CASE_METHOD(KmTestFixture, "KmClusteringEmptyClustersCheck",
190 "[KmeansMainTest][BindingTests]")
191 {
192 int c = 400;
193 int iterations = 100;
194
195 arma::mat inputData;
196 if (!data::Load("test_data_3_1000.csv", inputData))
197 FAIL("Unable to load train dataset test_data_3_1000.csv!");
198 arma::mat initCentroid = arma::randu<arma::mat>(inputData.n_rows, c);
199
200 SetInputParam("input", inputData);
201 SetInputParam("clusters", c);
202 SetInputParam("labels_only", true);
203 SetInputParam("max_iterations", iterations);
204 SetInputParam("initial_centroids", initCentroid);
205
206 mlpackMain();
207
208 arma::mat normalOutput;
209 normalOutput = std::move(IO::GetParam<arma::mat>("centroid"));
210
211 ResetKmSettings();
212
213 SetInputParam("input", inputData);
214 SetInputParam("clusters", c);
215 SetInputParam("labels_only", true);
216 SetInputParam("allow_empty_clusters", true);
217 SetInputParam("max_iterations", iterations);
218 SetInputParam("initial_centroids", initCentroid);
219
220 mlpackMain();
221
222 arma::mat allowEmptyOutput;
223 allowEmptyOutput = std::move(IO::GetParam<arma::mat>("centroid"));
224
225 ResetKmSettings();
226
227 SetInputParam("input", inputData);
228 SetInputParam("clusters", c);
229 SetInputParam("labels_only", true);
230 SetInputParam("kill_empty_clusters", true);
231 SetInputParam("max_iterations", iterations);
232 SetInputParam("initial_centroids", initCentroid);
233
234 mlpackMain();
235
236 arma::mat killEmptyOutput;
237 killEmptyOutput = std::move(IO::GetParam<arma::mat>("centroid"));
238
239 ResetKmSettings();
240
241 if (killEmptyOutput.n_elem == allowEmptyOutput.n_elem)
242 {
243 REQUIRE(arma::accu(killEmptyOutput != allowEmptyOutput) > 1);
244 REQUIRE(arma::accu(killEmptyOutput != normalOutput) > 1);
245 }
246 REQUIRE(arma::accu(normalOutput != allowEmptyOutput) > 1);
247 }
248
249 /**
250 * Checking that that size and dimensionality of Final Input File is correct
251 * when flag --in_place is specified
252 */
253 TEST_CASE_METHOD(KmTestFixture, "KmClusteringResultSizeCheck",
254 "[KmeansMainTest][BindingTests]")
255 {
256 int c = 2;
257
258 arma::mat inputData;
259 if (!data::Load("vc2.csv", inputData))
260 FAIL("Unable to load train dataset vc2.csv!");
261
262 size_t row = inputData.n_rows;
263 size_t col = inputData.n_cols;
264
265 SetInputParam("input", inputData);
266 SetInputParam("clusters", c);
267 SetInputParam("in_place", true);
268
269 mlpackMain();
270 arma::mat processedInput = IO::GetParam<arma::mat>("output");
271 // here input is actually accessed through output
272 // due to a little trick in kmeans_main
273
274 REQUIRE(processedInput.n_cols == col);
275 REQUIRE(processedInput.n_rows == row+1);
276 }
277
278 /**
279 * Ensuring that absence of Number of Clusters is checked.
280 */
281 TEST_CASE_METHOD(KmTestFixture, "KmClustersNotDefined",
282 "[KmeansMainTest][BindingTests]")
283 {
284 arma::mat inputData;
285 if (!data::Load("vc2.csv", inputData))
286 FAIL("Unable to load train dataset vc2.csv!");
287
288 SetInputParam("input", std::move(inputData));
289
290 Log::Fatal.ignoreInput = true;
291 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error);
292 Log::Fatal.ignoreInput = false;
293 }
294
295 /**
296 * Checking that all the algorithms yield same results
297 */
298 TEST_CASE_METHOD(KmTestFixture, "AlgorithmsSimilarTest",
299 "[KmeansMainTest][BindingTests]")
300 {
301 int c = 5;
302 arma::mat inputData(10, 1000);
303 inputData.randu();
304
305 arma::mat initCentroids(10, 5);
306 initCentroids.randu();
307
308 arma::mat initCentroid = arma::randu<arma::mat>(inputData.n_rows, c);
309 std::string algo = "naive";
310
311 SetInputParam("input", inputData);
312 SetInputParam("clusters", c);
313 SetInputParam("algorithm", std::move(algo));
314 SetInputParam("labels_only", true);
315 SetInputParam("initial_centroids", initCentroid);
316
317 mlpackMain();
318
319 arma::mat naiveOutput;
320 arma::mat naiveCentroid;
321 naiveOutput = std::move(IO::GetParam<arma::mat>("output"));
322 naiveCentroid = std::move(IO::GetParam<arma::mat>("centroid"));
323
324 ResetKmSettings();
325
326 algo = "elkan";
327
328 SetInputParam("input", inputData);
329 SetInputParam("clusters", c);
330 SetInputParam("algorithm", std::move(algo));
331 SetInputParam("labels_only", true);
332 SetInputParam("initial_centroids", initCentroid);
333
334 mlpackMain();
335
336 arma::mat elkanOutput;
337 arma::mat elkanCentroid;
338 elkanOutput = std::move(IO::GetParam<arma::mat>("output"));
339 elkanCentroid = std::move(IO::GetParam<arma::mat>("centroid"));
340
341 ResetKmSettings();
342
343 algo = "hamerly";
344
345 SetInputParam("input", inputData);
346 SetInputParam("clusters", c);
347 SetInputParam("algorithm", std::move(algo));
348 SetInputParam("labels_only", true);
349 SetInputParam("initial_centroids", initCentroid);
350
351 mlpackMain();
352
353 arma::mat hamerlyOutput;
354 arma::mat hamerlyCentroid;
355 hamerlyOutput = std::move(IO::GetParam<arma::mat>("output"));
356 hamerlyCentroid = std::move(IO::GetParam<arma::mat>("centroid"));
357
358 ResetKmSettings();
359
360 algo = "dualtree";
361
362 SetInputParam("input", inputData);
363 SetInputParam("clusters", c);
364 SetInputParam("algorithm", std::move(algo));
365 SetInputParam("labels_only", true);
366 SetInputParam("initial_centroids", initCentroid);
367
368 mlpackMain();
369
370 arma::mat dualTreeOutput;
371 arma::mat dualTreeCentroid;
372 dualTreeOutput = std::move(IO::GetParam<arma::mat>("output"));
373 dualTreeCentroid = std::move(IO::GetParam<arma::mat>("centroid"));
374
375 ResetKmSettings();
376
377 algo = "dualtree-covertree";
378
379 SetInputParam("input", std::move(inputData));
380 SetInputParam("clusters", c);
381 SetInputParam("algorithm", std::move(algo));
382 SetInputParam("labels_only", true);
383 SetInputParam("initial_centroids", std::move(initCentroid));
384
385 mlpackMain();
386
387 arma::mat dualCoverTreeOutput;
388 arma::mat dualCoverTreeCentroid;
389 dualCoverTreeOutput = std::move(IO::GetParam<arma::mat>("output"));
390 dualCoverTreeCentroid = std::move(IO::GetParam<arma::mat>("centroid"));
391
392 // Checking all the algorithms return same assignments
393 CheckMatrices(naiveOutput, hamerlyOutput);
394 CheckMatrices(naiveOutput, elkanOutput);
395 CheckMatrices(naiveOutput, dualTreeOutput);
396 CheckMatrices(naiveOutput, dualCoverTreeOutput);
397
398 // Checking all the algorithms return almost same centroid
399 CheckMatrices(naiveCentroid, hamerlyCentroid);
400 CheckMatrices(naiveCentroid, elkanCentroid);
401 CheckMatrices(naiveCentroid, dualTreeCentroid);
402 CheckMatrices(naiveCentroid, dualCoverTreeCentroid);
403 }
404