1 /**
2  * @file tests/main_tests/hoeffding_tree_test.cpp
3  * @author Haritha Nair
4  *
5  * Test mlpackMain() of hoeffding_tree_main.cpp.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #define BINDING_TYPE BINDING_TYPE_TEST
13 
14 #include <mlpack/core.hpp>
15 static const std::string testName = "HoeffdingTree";
16 
17 #include <mlpack/core/util/mlpack_main.hpp>
18 #include <mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp>
19 #include "test_helper.hpp"
20 
21 #include "../catch.hpp"
22 #include "../test_catch_tools.hpp"
23 
24 using namespace mlpack;
25 using namespace data;
26 
27 struct HoeffdingTreeTestFixture
28 {
29  public:
HoeffdingTreeTestFixtureHoeffdingTreeTestFixture30   HoeffdingTreeTestFixture()
31   {
32     // Cache in the options for this program.
33     IO::RestoreSettings(testName);
34   }
35 
~HoeffdingTreeTestFixtureHoeffdingTreeTestFixture36   ~HoeffdingTreeTestFixture()
37   {
38     // Clear the settings.
39     bindings::tests::CleanMemory();
40     IO::ClearSettings();
41   }
42 };
43 
44 /**
45  * Check that number of output points and
46  * number of input points are equal.
47  */
48 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingTreeOutputDimensionTest",
49                  "[HoeffdingTreeMainTest][BindingTest]")
50 {
51   arma::mat inputData;
52   DatasetInfo info;
53   if (!data::Load("vc2.csv", inputData, info))
54     FAIL("Cannot load train dataset vc2.csv!");
55 
56   arma::Row<size_t> labels;
57   if (!data::Load("vc2_labels.txt", labels))
58     FAIL("Cannot load labels for vc2_labels.txt");
59 
60   arma::mat testData;
61   if (!data::Load("vc2_test.csv", testData, info))
62     FAIL("Cannot load test dataset vc2.csv!");
63 
64   size_t testSize = testData.n_cols;
65 
66   // Input training data.
67   SetInputParam("training", std::make_tuple(info, inputData));
68   SetInputParam("labels", std::move(labels));
69 
70   // Input test data.
71   SetInputParam("test", std::make_tuple(info, testData));
72 
73   mlpackMain();
74 
75   // Check that number of output points are equal to number of input points.
76   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
77   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize);
78 
79   // Check number of output rows equals 1 for probabilities and predictions.
80   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
81   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1);
82 }
83 
84 /**
85  * Check that number of output points and number
86  * of input points are equal for categorical dataset.
87  */
88 TEST_CASE_METHOD(HoeffdingTreeTestFixture,
89                  "HoeffdingTreeCategoricalOutputDimensionTest",
90                  "[HoeffdingTreeMainTest][BindingTest]")
91 {
92   arma::mat inputData;
93   DatasetInfo info;
94   if (!data::Load("braziltourism.arff", inputData, info))
95     FAIL("Cannot load train dataset braziltourism.arff!");
96 
97   arma::Row<size_t> labels;
98   if (!data::Load("braziltourism_labels.txt", labels))
99     FAIL("Cannot load labels for braziltourism_labels.txt");
100 
101   arma::mat testData;
102   if (!data::Load("braziltourism_test.arff", testData, info))
103     FAIL("Cannot load test dataset braziltourism_test.arff!");
104 
105   size_t testSize = testData.n_cols;
106 
107   // Input training data.
108   SetInputParam("training", std::make_tuple(info, inputData));
109   SetInputParam("labels", std::move(labels));
110 
111   // Input test data.
112   SetInputParam("test", std::make_tuple(info, testData));
113 
114   mlpackMain();
115 
116   // Check that number of output points are equal to number of input points.
117   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
118   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize);
119 
120   // Check number of output rows equals 1 for probabilities and predictions.
121   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
122   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1);
123 }
124 
125 /**
126  * Check whether providing labels explicitly and extracting from last
127  * dimension give the same output.
128  */
129 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingTreeLabelLessTest",
130                  "[HoeffdingTreeMainTest][BindingTest]")
131 {
132   arma::mat inputData;
133   DatasetInfo info;
134   if (!data::Load("vc2.csv", inputData, info))
135     FAIL("Cannot load train dataset vc2.csv!");
136 
137   arma::Row<size_t> labels;
138   if (!data::Load("vc2_labels.txt", labels))
139     FAIL("Cannot load labels for vc2_labels.txt");
140 
141   arma::mat testData;
142   if (!data::Load("vc2_test.csv", testData, info))
143     FAIL("Cannot load test dataset vc2.csv!");
144 
145   // Append labels to the training set.
146   inputData.resize(inputData.n_rows+1, inputData.n_cols);
147   for (size_t i = 0; i < inputData.n_cols; ++i)
148     inputData(inputData.n_rows-1, i) = labels[i];
149 
150   size_t testSize = testData.n_cols;
151 
152   // Input training data.
153   SetInputParam("training", std::make_tuple(info, inputData));
154 
155   // Input test data.
156   SetInputParam("test", std::make_tuple(info, testData));
157 
158   mlpackMain();
159 
160   // Check that number of output points are equal to number of input points.
161   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
162   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize);
163 
164   // Check number of output rows equals number of classes in case of
165   // probabilities and 1 for predictions.
166   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
167   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1);
168 
169   // Reset passed parameters.
170   IO::GetSingleton().Parameters()["training"].wasPassed = false;
171   IO::GetSingleton().Parameters()["test"].wasPassed = false;
172 
173   arma::Row<size_t> predictions;
174   arma::mat probabilities;
175   predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions"));
176   probabilities = std::move(IO::GetParam<arma::mat>("probabilities"));
177 
178   bindings::tests::CleanMemory();
179 
180   inputData.shed_row(inputData.n_rows - 1);
181 
182   // Input training data.
183   SetInputParam("training", std::make_tuple(info, inputData));
184   SetInputParam("test", std::make_tuple(info, testData));
185   // Pass Labels.
186   SetInputParam("labels", std::move(labels));
187 
188   mlpackMain();
189 
190   // Check that number of output points are equal to number of input points.
191   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
192   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize);
193 
194   // Check number of output rows equals 1 for probabilities and predictions.
195   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
196   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1);
197 
198   // Check that initial and current predictions are same.
199   CheckMatrices(
200       predictions, IO::GetParam<arma::Row<size_t>>("predictions"));
201   CheckMatrices(
202       probabilities, IO::GetParam<arma::mat>("probabilities"));
203 }
204 
205 /**
206  * Ensure that saved model can be used again.
207  */
208 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingModelReuseTest",
209                  "[HoeffdingTreeMainTest][BindingTest]")
210 {
211   arma::mat inputData;
212   DatasetInfo info;
213   if (!data::Load("vc2.csv", inputData, info))
214     FAIL("Cannot load train dataset vc2.csv!");
215 
216   arma::Row<size_t> labels;
217   if (!data::Load("vc2_labels.txt", labels))
218     FAIL("Cannot load labels for vc2_labels.txt");
219 
220   arma::mat testData;
221   if (!data::Load("vc2_test.csv", testData, info))
222     FAIL("Cannot load test dataset vc2.csv!");
223 
224   size_t testSize = testData.n_cols;
225 
226   // Input training data.
227   SetInputParam("training", std::make_tuple(info, inputData));
228   SetInputParam("labels", std::move(labels));
229 
230   // Input test data.
231   SetInputParam("test", std::make_tuple(info, testData));
232 
233   mlpackMain();
234 
235   arma::Row<size_t> predictions;
236   arma::mat probabilities;
237   predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions"));
238   probabilities = std::move(IO::GetParam<arma::mat>("probabilities"));
239 
240   // Reset passed parameters.
241   IO::GetSingleton().Parameters()["training"].wasPassed = false;
242   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
243   IO::GetSingleton().Parameters()["test"].wasPassed = false;
244 
245   if (!data::Load("vc2_test.csv", testData, info))
246     FAIL("Cannot load test dataset vc2.csv!");
247 
248   // Input trained model.
249   SetInputParam("test", std::make_tuple(info, testData));
250   SetInputParam("input_model",
251       IO::GetParam<HoeffdingTreeModel*>("output_model"));
252 
253   mlpackMain();
254 
255   // Check that number of output points are equal to number of input points.
256   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
257   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize);
258 
259   // Check number of output rows equals 1 for probabilities and predictions.
260   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
261   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1);
262 
263   // Check that initial predictions and predictions using saved model are same.
264   CheckMatrices(
265       predictions, IO::GetParam<arma::Row<size_t>>("predictions"));
266   CheckMatrices(
267       probabilities, IO::GetParam<arma::mat>("probabilities"));
268 }
269 
270 /**
271  * Ensure that saved model trained on categorical dataset can be used again.
272  */
273 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingModelCategoricalReuseTest",
274                  "[HoeffdingTreeMainTest][BindingTest]")
275 {
276   arma::mat inputData;
277   DatasetInfo info;
278   if (!data::Load("braziltourism.arff", inputData, info))
279     FAIL("Cannot load train dataset braziltourism.arff!");
280 
281   arma::Row<size_t> labels;
282   if (!data::Load("braziltourism_labels.txt", labels))
283     FAIL("Cannot load labels for braziltourism_labels.txt");
284 
285   arma::mat testData;
286   if (!data::Load("braziltourism_test.arff", testData, info))
287     FAIL("Cannot load test dataset braziltourism_test.arff!");
288 
289   size_t testSize = testData.n_cols;
290 
291   // Input training data.
292   SetInputParam("training", std::make_tuple(info, inputData));
293   SetInputParam("labels", std::move(labels));
294 
295   // Input test data.
296   SetInputParam("test", std::make_tuple(info, testData));
297 
298   mlpackMain();
299 
300   // Reset passed parameters.
301   IO::GetSingleton().Parameters()["training"].wasPassed = false;
302   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
303   IO::GetSingleton().Parameters()["test"].wasPassed = false;
304 
305   arma::Row<size_t> predictions;
306   arma::mat probabilities;
307   predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions"));
308   probabilities = std::move(IO::GetParam<arma::mat>("probabilities"));
309 
310   if (!data::Load("braziltourism_test.arff", testData, info))
311     FAIL("Cannot load test dataset braziltourism_test.arff!");
312 
313   // Input trained model.
314   SetInputParam("test", std::make_tuple(info, testData));
315   SetInputParam("input_model",
316       IO::GetParam<HoeffdingTreeModel*>("output_model"));
317 
318   mlpackMain();
319 
320   // Check that number of output points are equal to number of input points.
321   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize);
322   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize);
323 
324   // Check number of output rows equals 1 for probabilities and predictions.
325   REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1);
326   REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1);
327 
328   // Check that initial predictions and predictions using saved model are same.
329   CheckMatrices(
330       predictions, IO::GetParam<arma::Row<size_t>>("predictions"));
331   CheckMatrices(
332       probabilities, IO::GetParam<arma::mat>("probabilities"));
333 }
334 
335 /**
336  * Ensure that small min_samples creates larger model.
337  */
338 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingMinSamplesTest",
339                  "[HoeffdingTreeMainTest][BindingTest]")
340 {
341   arma::mat inputData;
342   DatasetInfo info;
343   int nodes;
344   if (!data::Load("vc2.csv", inputData, info))
345     FAIL("Cannot load train dataset vc2.csv!");
346 
347   arma::Row<size_t> labels;
348   if (!data::Load("vc2_labels.txt", labels))
349     FAIL("Cannot load labels for vc2_labels.txt");
350 
351   arma::mat testData;
352   if (!data::Load("vc2_test.csv", testData, info))
353     FAIL("Cannot load test dataset vc2.csv!");
354 
355   // Input training data.
356   SetInputParam("training", std::make_tuple(info, inputData));
357   SetInputParam("labels", std::move(labels));
358 
359   // Input test data.
360   SetInputParam("test", std::make_tuple(info, testData));
361 
362   SetInputParam("min_samples", 10);
363   SetInputParam("confidence", 0.25);
364 
365   mlpackMain();
366 
367   // Reset passed parameters.
368   IO::GetSingleton().Parameters()["training"].wasPassed = false;
369   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
370   IO::GetSingleton().Parameters()["test"].wasPassed = false;
371   IO::GetSingleton().Parameters()["min_samples"].wasPassed = false;
372   IO::GetSingleton().Parameters()["confidence"].wasPassed = false;
373 
374   nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes();
375 
376   bindings::tests::CleanMemory();
377 
378   if (!data::Load("vc2.csv", inputData, info))
379     FAIL("Cannot load train dataset vc2.csv!");
380 
381   if (!data::Load("vc2_labels.txt", labels))
382     FAIL("Cannot load labels for vc2_labels.txt");
383 
384   if (!data::Load("vc2_test.csv", testData, info))
385     FAIL("Cannot load test dataset vc2.csv!");
386 
387   // Input training data.
388   SetInputParam("training", std::make_tuple(info, inputData));
389   SetInputParam("labels", std::move(labels));
390 
391   // Input test data.
392   SetInputParam("test", std::make_tuple(info, testData));
393 
394   SetInputParam("min_samples", 2000);
395   SetInputParam("confidence", 0.25);
396 
397   mlpackMain();
398 
399   // Check that small min_samples creates larger model.
400   REQUIRE((IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() <
401       nodes);
402 }
403 
404 /**
405  * Ensure that large max_samples creates smaller model.
406  */
407 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingMaxSamplesTest",
408                  "[HoeffdingTreeMainTest][BindingTest]")
409 {
410   arma::mat inputData;
411   DatasetInfo info;
412   int nodes;
413   if (!data::Load("vc2.csv", inputData, info))
414     FAIL("Cannot load train dataset vc2.csv!");
415 
416   arma::Row<size_t> labels;
417   if (!data::Load("vc2_labels.txt", labels))
418     FAIL("Cannot load labels for vc2_labels.txt");
419 
420   arma::mat testData;
421   if (!data::Load("vc2_test.csv", testData, info))
422     FAIL("Cannot load test dataset vc2.csv!");
423 
424   // Input training data.
425   SetInputParam("training", std::make_tuple(info, inputData));
426   SetInputParam("labels", std::move(labels));
427 
428   // Input test data.
429   SetInputParam("test", std::make_tuple(info, testData));
430 
431   SetInputParam("max_samples", 50000);
432   SetInputParam("confidence", 0.95);
433 
434   mlpackMain();
435 
436   // Reset passed parameters.
437   IO::GetSingleton().Parameters()["training"].wasPassed = false;
438   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
439   IO::GetSingleton().Parameters()["test"].wasPassed = false;
440   IO::GetSingleton().Parameters()["max_samples"].wasPassed = false;
441   IO::GetSingleton().Parameters()["confidence"].wasPassed = false;
442 
443   nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes();
444 
445   bindings::tests::CleanMemory();
446 
447   if (!data::Load("vc2.csv", inputData, info))
448     FAIL("Cannot load train dataset vc2.csv!");
449 
450   if (!data::Load("vc2_labels.txt", labels))
451     FAIL("Cannot load labels for vc2_labels.txt");
452 
453   if (!data::Load("vc2_test.csv", testData, info))
454     FAIL("Cannot load test dataset vc2.csv!");
455 
456   // Input training data.
457   SetInputParam("training", std::make_tuple(info, inputData));
458   SetInputParam("labels", std::move(labels));
459 
460   // Input test data.
461   SetInputParam("test", std::make_tuple(info, testData));
462 
463   SetInputParam("max_samples", 5);
464   SetInputParam("confidence", 0.95);
465 
466   mlpackMain();
467 
468   // Check that large max_samples creates smaller model.
469   REQUIRE(nodes <
470       (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes());
471 }
472 
473 /**
474  * Ensure that small confidence value creates larger model.
475  */
476 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingConfidenceTest",
477                  "[HoeffdingTreeMainTest][BindingTest]")
478 {
479   arma::mat inputData;
480   DatasetInfo info;
481   int nodes;
482   if (!data::Load("vc2.csv", inputData, info))
483     FAIL("Cannot load train dataset vc2.csv!");
484 
485   arma::Row<size_t> labels;
486   if (!data::Load("vc2_labels.txt", labels))
487     FAIL("Cannot load labels for vc2_labels.txt");
488 
489   arma::mat testData;
490   if (!data::Load("vc2_test.csv", testData, info))
491     FAIL("Cannot load test dataset vc2.csv!");
492 
493   // Input training data.
494   SetInputParam("training", std::make_tuple(info, inputData));
495   SetInputParam("labels", std::move(labels));
496 
497   // Input test data.
498   SetInputParam("test", std::make_tuple(info, testData));
499 
500   SetInputParam("confidence", 0.95);
501 
502   mlpackMain();
503 
504   // Reset passed parameters.
505   IO::GetSingleton().Parameters()["training"].wasPassed = false;
506   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
507   IO::GetSingleton().Parameters()["test"].wasPassed = false;
508   IO::GetSingleton().Parameters()["confidence"].wasPassed = false;
509 
510   // Model with high confidence.
511   nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes();
512 
513   bindings::tests::CleanMemory();
514 
515   if (!data::Load("vc2.csv", inputData, info))
516     FAIL("Cannot load train dataset vc2.csv!");
517 
518   if (!data::Load("vc2_labels.txt", labels))
519     FAIL("Cannot load labels for vc2_labels.txt");
520 
521   if (!data::Load("vc2_test.csv", testData, info))
522     FAIL("Cannot load test dataset vc2.csv!");
523 
524   // Input training data.
525   SetInputParam("training", std::make_tuple(info, inputData));
526   SetInputParam("labels", std::move(labels));
527 
528   // Input test data.
529   SetInputParam("test", std::make_tuple(info, testData));
530 
531   // Model with low confidence.
532   SetInputParam("confidence", 0.25);
533 
534   mlpackMain();
535   // Check that higher confidence creates smaller tree.
536   REQUIRE(nodes <
537       (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes());
538 }
539 
540 /**
541  * Ensure that large number of passes creates larger model.
542  */
543 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingPassesTest",
544                  "[HoeffdingTreeMainTest][BindingTest]")
545 {
546   arma::mat inputData;
547   DatasetInfo info;
548   int nodes;
549   if (!data::Load("vc2.csv", inputData, info))
550     FAIL("Cannot load train dataset vc2.csv!");
551 
552   arma::Row<size_t> labels;
553   if (!data::Load("vc2_labels.txt", labels))
554     FAIL("Cannot load labels for vc2_labels.txt");
555 
556   arma::mat testData;
557   if (!data::Load("vc2_test.csv", testData, info))
558     FAIL("Cannot load test dataset vc2.csv!");
559 
560   // Input training data.
561   SetInputParam("training", std::make_tuple(info, inputData));
562   SetInputParam("labels", std::move(labels));
563 
564   // Input test data.
565   SetInputParam("test", std::make_tuple(info, testData));
566 
567   SetInputParam("passes", 1);
568 
569   mlpackMain();
570 
571   // Reset passed parameters.
572   IO::GetSingleton().Parameters()["training"].wasPassed = false;
573   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
574   IO::GetSingleton().Parameters()["test"].wasPassed = false;
575   IO::GetSingleton().Parameters()["passes"].wasPassed = false;
576 
577   // Model with smaller number of passes.
578   nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes();
579 
580   bindings::tests::CleanMemory();
581 
582   if (!data::Load("vc2.csv", inputData, info))
583     FAIL("Cannot load train dataset vc2.csv!");
584 
585   if (!data::Load("vc2_labels.txt", labels))
586     FAIL("Cannot load labels for vc2_labels.txt");
587 
588   if (!data::Load("vc2_test.csv", testData, info))
589     FAIL("Cannot load test dataset vc2.csv!");
590 
591   // Input training data.
592   SetInputParam("training", std::make_tuple(info, inputData));
593   SetInputParam("labels", std::move(labels));
594 
595   // Input test data.
596   SetInputParam("test", std::make_tuple(info, testData));
597 
598   // Model with larger number of passes.
599   SetInputParam("passes", 100);
600 
601   mlpackMain();
602 
603   // Check that model with larger number of passes has greater number of nodes.
604   REQUIRE(nodes <
605       (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes());
606 }
607 
608 /**
609  * Ensure that the root node has 2 children when splitting strategy is binary.
610  */
611 TEST_CASE_METHOD(HoeffdingTreeTestFixture,
612                  "HoeffdingBinarySplittingStrategyTest",
613                  "[HoeffdingTreeMainTest][BindingTest]")
614 {
615   arma::mat inputData;
616   DatasetInfo info;
617   if (!data::Load("vc2.csv", inputData, info))
618     FAIL("Cannot load train dataset vc2.csv!");
619 
620   arma::Row<size_t> labels;
621   if (!data::Load("vc2_labels.txt", labels))
622     FAIL("Cannot load labels for vc2_labels.txt");
623 
624   arma::mat testData;
625   if (!data::Load("vc2_test.csv", testData, info))
626     FAIL("Cannot load test dataset vc2.csv!");
627 
628   // Input training data.
629   SetInputParam("training", std::make_tuple(info, inputData));
630   SetInputParam("labels", std::move(labels));
631 
632   // Input test data.
633   SetInputParam("test", std::make_tuple(info, testData));
634 
635   SetInputParam("numeric_split_strategy", (string) "binary");
636   SetInputParam("max_samples", 50);
637 
638   SetInputParam("confidence", 0.25);
639 
640   mlpackMain();
641 
642   // Check that number of children is 2.
643   REQUIRE(
644       (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() - 1 == 2);
645 }
646 
647 /**
648  * Ensure that the number of children varies with varying 'bins' in domingos.
649  */
650 TEST_CASE_METHOD(HoeffdingTreeTestFixture,
651                  "HoeffdingDomingosSplittingStrategyTest",
652                  "[HoeffdingTreeMainTest][BindingTest]")
653 {
654   arma::mat inputData;
655   DatasetInfo info;
656   int nodes;
657   if (!data::Load("vc2.csv", inputData, info))
658     FAIL("Cannot load train dataset vc2.csv!");
659 
660   arma::Row<size_t> labels;
661   if (!data::Load("vc2_labels.txt", labels))
662     FAIL("Cannot load labels for vc2_labels.txt");
663 
664   arma::mat testData;
665   if (!data::Load("vc2_test.csv", testData, info))
666     FAIL("Cannot load test dataset vc2.csv!");
667 
668   // Input training data.
669   SetInputParam("training", std::make_tuple(info, inputData));
670   SetInputParam("labels", std::move(labels));
671 
672   // Input test data.
673   SetInputParam("test", std::make_tuple(info, testData));
674 
675   SetInputParam("numeric_split_strategy", (string) "domingos");
676   SetInputParam("max_samples", 50);
677   SetInputParam("bins", 20);
678 
679   mlpackMain();
680 
681   // Reset passed parameters.
682   IO::GetSingleton().Parameters()["training"].wasPassed = false;
683   IO::GetSingleton().Parameters()["labels"].wasPassed = false;
684   IO::GetSingleton().Parameters()["test"].wasPassed = false;
685   IO::GetSingleton().Parameters()["max_samples"].wasPassed = false;
686   IO::GetSingleton().Parameters()["numeric_split_strategy"].wasPassed = false;
687   IO::GetSingleton().Parameters()["bins"].wasPassed = false;
688 
689   // Initial model.
690   nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes();
691 
692   bindings::tests::CleanMemory();
693 
694   if (!data::Load("vc2.csv", inputData, info))
695     FAIL("Cannot load train dataset vc2.csv!");
696 
697   if (!data::Load("vc2_labels.txt", labels))
698     FAIL("Cannot load labels for vc2_labels.txt");
699 
700   if (!data::Load("vc2_test.csv", testData, info))
701     FAIL("Cannot load test dataset vc2.csv!");
702 
703   // Input training data.
704   SetInputParam("training", std::make_tuple(info, inputData));
705   SetInputParam("labels", std::move(labels));
706 
707   // Input test data.
708   SetInputParam("test", std::make_tuple(info, testData));
709 
710   SetInputParam("numeric_split_strategy", (string) "domingos");
711   SetInputParam("max_samples", 50);
712   SetInputParam("bins", 10);
713 
714   mlpackMain();
715 
716   // Check that both models have different number of nodes.
717   CHECK((IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() !=
718       nodes);
719 }
720 
721 /**
722  * Ensure that the model doesn't split if observations before binning
723  * is greater than total number of samples passed.
724  */
725 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingBinningTest",
726                  "[HoeffdingTreeMainTest][BindingTests]")
727 {
728   arma::mat inputData;
729   arma::mat modData;
730   arma::Row<size_t> modLabels;
731   DatasetInfo info;
732   if (!data::Load("vc2.csv", inputData, info))
733     FAIL("Cannot load train dataset vc2.csv!");
734 
735   arma::Row<size_t> labels;
736   if (!data::Load("vc2_labels.txt", labels))
737     FAIL("Cannot load labels for vc2_labels.txt");
738 
739   arma::mat testData;
740   if (!data::Load("vc2_test.csv", testData, info))
741     FAIL("Cannot load test dataset vc2.csv!");
742 
743   modData = inputData.cols(0, 49);
744   modLabels = labels.cols(0, 49);
745 
746   // Input training data.
747   SetInputParam("training", std::make_tuple(info, modData));
748   SetInputParam("labels", std::move(modLabels));
749 
750   // Input test data.
751   SetInputParam("test", std::make_tuple(info, testData));
752 
753   SetInputParam("numeric_split_strategy", (string) "domingos");
754   SetInputParam("min_samples", 10);
755 
756   // Set parameter to a value greater than number of samples.
757   SetInputParam("observations_before_binning", 100);
758   SetInputParam("confidence", 0.25);
759 
760   mlpackMain();
761 
762   // Check that no splitting has happened.
763   REQUIRE((IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes()
764       == 1);
765 }
766