1 /** 2 * @file tests/main_tests/hoeffding_tree_test.cpp 3 * @author Haritha Nair 4 * 5 * Test mlpackMain() of hoeffding_tree_main.cpp. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #define BINDING_TYPE BINDING_TYPE_TEST 13 14 #include <mlpack/core.hpp> 15 static const std::string testName = "HoeffdingTree"; 16 17 #include <mlpack/core/util/mlpack_main.hpp> 18 #include <mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp> 19 #include "test_helper.hpp" 20 21 #include "../catch.hpp" 22 #include "../test_catch_tools.hpp" 23 24 using namespace mlpack; 25 using namespace data; 26 27 struct HoeffdingTreeTestFixture 28 { 29 public: HoeffdingTreeTestFixtureHoeffdingTreeTestFixture30 HoeffdingTreeTestFixture() 31 { 32 // Cache in the options for this program. 33 IO::RestoreSettings(testName); 34 } 35 ~HoeffdingTreeTestFixtureHoeffdingTreeTestFixture36 ~HoeffdingTreeTestFixture() 37 { 38 // Clear the settings. 39 bindings::tests::CleanMemory(); 40 IO::ClearSettings(); 41 } 42 }; 43 44 /** 45 * Check that number of output points and 46 * number of input points are equal. 47 */ 48 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingTreeOutputDimensionTest", 49 "[HoeffdingTreeMainTest][BindingTest]") 50 { 51 arma::mat inputData; 52 DatasetInfo info; 53 if (!data::Load("vc2.csv", inputData, info)) 54 FAIL("Cannot load train dataset vc2.csv!"); 55 56 arma::Row<size_t> labels; 57 if (!data::Load("vc2_labels.txt", labels)) 58 FAIL("Cannot load labels for vc2_labels.txt"); 59 60 arma::mat testData; 61 if (!data::Load("vc2_test.csv", testData, info)) 62 FAIL("Cannot load test dataset vc2.csv!"); 63 64 size_t testSize = testData.n_cols; 65 66 // Input training data. 67 SetInputParam("training", std::make_tuple(info, inputData)); 68 SetInputParam("labels", std::move(labels)); 69 70 // Input test data. 71 SetInputParam("test", std::make_tuple(info, testData)); 72 73 mlpackMain(); 74 75 // Check that number of output points are equal to number of input points. 76 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 77 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize); 78 79 // Check number of output rows equals 1 for probabilities and predictions. 80 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 81 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1); 82 } 83 84 /** 85 * Check that number of output points and number 86 * of input points are equal for categorical dataset. 87 */ 88 TEST_CASE_METHOD(HoeffdingTreeTestFixture, 89 "HoeffdingTreeCategoricalOutputDimensionTest", 90 "[HoeffdingTreeMainTest][BindingTest]") 91 { 92 arma::mat inputData; 93 DatasetInfo info; 94 if (!data::Load("braziltourism.arff", inputData, info)) 95 FAIL("Cannot load train dataset braziltourism.arff!"); 96 97 arma::Row<size_t> labels; 98 if (!data::Load("braziltourism_labels.txt", labels)) 99 FAIL("Cannot load labels for braziltourism_labels.txt"); 100 101 arma::mat testData; 102 if (!data::Load("braziltourism_test.arff", testData, info)) 103 FAIL("Cannot load test dataset braziltourism_test.arff!"); 104 105 size_t testSize = testData.n_cols; 106 107 // Input training data. 108 SetInputParam("training", std::make_tuple(info, inputData)); 109 SetInputParam("labels", std::move(labels)); 110 111 // Input test data. 112 SetInputParam("test", std::make_tuple(info, testData)); 113 114 mlpackMain(); 115 116 // Check that number of output points are equal to number of input points. 117 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 118 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize); 119 120 // Check number of output rows equals 1 for probabilities and predictions. 121 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 122 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1); 123 } 124 125 /** 126 * Check whether providing labels explicitly and extracting from last 127 * dimension give the same output. 128 */ 129 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingTreeLabelLessTest", 130 "[HoeffdingTreeMainTest][BindingTest]") 131 { 132 arma::mat inputData; 133 DatasetInfo info; 134 if (!data::Load("vc2.csv", inputData, info)) 135 FAIL("Cannot load train dataset vc2.csv!"); 136 137 arma::Row<size_t> labels; 138 if (!data::Load("vc2_labels.txt", labels)) 139 FAIL("Cannot load labels for vc2_labels.txt"); 140 141 arma::mat testData; 142 if (!data::Load("vc2_test.csv", testData, info)) 143 FAIL("Cannot load test dataset vc2.csv!"); 144 145 // Append labels to the training set. 146 inputData.resize(inputData.n_rows+1, inputData.n_cols); 147 for (size_t i = 0; i < inputData.n_cols; ++i) 148 inputData(inputData.n_rows-1, i) = labels[i]; 149 150 size_t testSize = testData.n_cols; 151 152 // Input training data. 153 SetInputParam("training", std::make_tuple(info, inputData)); 154 155 // Input test data. 156 SetInputParam("test", std::make_tuple(info, testData)); 157 158 mlpackMain(); 159 160 // Check that number of output points are equal to number of input points. 161 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 162 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize); 163 164 // Check number of output rows equals number of classes in case of 165 // probabilities and 1 for predictions. 166 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 167 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1); 168 169 // Reset passed parameters. 170 IO::GetSingleton().Parameters()["training"].wasPassed = false; 171 IO::GetSingleton().Parameters()["test"].wasPassed = false; 172 173 arma::Row<size_t> predictions; 174 arma::mat probabilities; 175 predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions")); 176 probabilities = std::move(IO::GetParam<arma::mat>("probabilities")); 177 178 bindings::tests::CleanMemory(); 179 180 inputData.shed_row(inputData.n_rows - 1); 181 182 // Input training data. 183 SetInputParam("training", std::make_tuple(info, inputData)); 184 SetInputParam("test", std::make_tuple(info, testData)); 185 // Pass Labels. 186 SetInputParam("labels", std::move(labels)); 187 188 mlpackMain(); 189 190 // Check that number of output points are equal to number of input points. 191 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 192 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize); 193 194 // Check number of output rows equals 1 for probabilities and predictions. 195 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 196 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1); 197 198 // Check that initial and current predictions are same. 199 CheckMatrices( 200 predictions, IO::GetParam<arma::Row<size_t>>("predictions")); 201 CheckMatrices( 202 probabilities, IO::GetParam<arma::mat>("probabilities")); 203 } 204 205 /** 206 * Ensure that saved model can be used again. 207 */ 208 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingModelReuseTest", 209 "[HoeffdingTreeMainTest][BindingTest]") 210 { 211 arma::mat inputData; 212 DatasetInfo info; 213 if (!data::Load("vc2.csv", inputData, info)) 214 FAIL("Cannot load train dataset vc2.csv!"); 215 216 arma::Row<size_t> labels; 217 if (!data::Load("vc2_labels.txt", labels)) 218 FAIL("Cannot load labels for vc2_labels.txt"); 219 220 arma::mat testData; 221 if (!data::Load("vc2_test.csv", testData, info)) 222 FAIL("Cannot load test dataset vc2.csv!"); 223 224 size_t testSize = testData.n_cols; 225 226 // Input training data. 227 SetInputParam("training", std::make_tuple(info, inputData)); 228 SetInputParam("labels", std::move(labels)); 229 230 // Input test data. 231 SetInputParam("test", std::make_tuple(info, testData)); 232 233 mlpackMain(); 234 235 arma::Row<size_t> predictions; 236 arma::mat probabilities; 237 predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions")); 238 probabilities = std::move(IO::GetParam<arma::mat>("probabilities")); 239 240 // Reset passed parameters. 241 IO::GetSingleton().Parameters()["training"].wasPassed = false; 242 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 243 IO::GetSingleton().Parameters()["test"].wasPassed = false; 244 245 if (!data::Load("vc2_test.csv", testData, info)) 246 FAIL("Cannot load test dataset vc2.csv!"); 247 248 // Input trained model. 249 SetInputParam("test", std::make_tuple(info, testData)); 250 SetInputParam("input_model", 251 IO::GetParam<HoeffdingTreeModel*>("output_model")); 252 253 mlpackMain(); 254 255 // Check that number of output points are equal to number of input points. 256 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 257 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize); 258 259 // Check number of output rows equals 1 for probabilities and predictions. 260 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 261 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1); 262 263 // Check that initial predictions and predictions using saved model are same. 264 CheckMatrices( 265 predictions, IO::GetParam<arma::Row<size_t>>("predictions")); 266 CheckMatrices( 267 probabilities, IO::GetParam<arma::mat>("probabilities")); 268 } 269 270 /** 271 * Ensure that saved model trained on categorical dataset can be used again. 272 */ 273 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingModelCategoricalReuseTest", 274 "[HoeffdingTreeMainTest][BindingTest]") 275 { 276 arma::mat inputData; 277 DatasetInfo info; 278 if (!data::Load("braziltourism.arff", inputData, info)) 279 FAIL("Cannot load train dataset braziltourism.arff!"); 280 281 arma::Row<size_t> labels; 282 if (!data::Load("braziltourism_labels.txt", labels)) 283 FAIL("Cannot load labels for braziltourism_labels.txt"); 284 285 arma::mat testData; 286 if (!data::Load("braziltourism_test.arff", testData, info)) 287 FAIL("Cannot load test dataset braziltourism_test.arff!"); 288 289 size_t testSize = testData.n_cols; 290 291 // Input training data. 292 SetInputParam("training", std::make_tuple(info, inputData)); 293 SetInputParam("labels", std::move(labels)); 294 295 // Input test data. 296 SetInputParam("test", std::make_tuple(info, testData)); 297 298 mlpackMain(); 299 300 // Reset passed parameters. 301 IO::GetSingleton().Parameters()["training"].wasPassed = false; 302 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 303 IO::GetSingleton().Parameters()["test"].wasPassed = false; 304 305 arma::Row<size_t> predictions; 306 arma::mat probabilities; 307 predictions = std::move(IO::GetParam<arma::Row<size_t>>("predictions")); 308 probabilities = std::move(IO::GetParam<arma::mat>("probabilities")); 309 310 if (!data::Load("braziltourism_test.arff", testData, info)) 311 FAIL("Cannot load test dataset braziltourism_test.arff!"); 312 313 // Input trained model. 314 SetInputParam("test", std::make_tuple(info, testData)); 315 SetInputParam("input_model", 316 IO::GetParam<HoeffdingTreeModel*>("output_model")); 317 318 mlpackMain(); 319 320 // Check that number of output points are equal to number of input points. 321 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_cols == testSize); 322 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_cols == testSize); 323 324 // Check number of output rows equals 1 for probabilities and predictions. 325 REQUIRE(IO::GetParam<arma::Row<size_t>>("predictions").n_rows == 1); 326 REQUIRE(IO::GetParam<arma::mat>("probabilities").n_rows == 1); 327 328 // Check that initial predictions and predictions using saved model are same. 329 CheckMatrices( 330 predictions, IO::GetParam<arma::Row<size_t>>("predictions")); 331 CheckMatrices( 332 probabilities, IO::GetParam<arma::mat>("probabilities")); 333 } 334 335 /** 336 * Ensure that small min_samples creates larger model. 337 */ 338 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingMinSamplesTest", 339 "[HoeffdingTreeMainTest][BindingTest]") 340 { 341 arma::mat inputData; 342 DatasetInfo info; 343 int nodes; 344 if (!data::Load("vc2.csv", inputData, info)) 345 FAIL("Cannot load train dataset vc2.csv!"); 346 347 arma::Row<size_t> labels; 348 if (!data::Load("vc2_labels.txt", labels)) 349 FAIL("Cannot load labels for vc2_labels.txt"); 350 351 arma::mat testData; 352 if (!data::Load("vc2_test.csv", testData, info)) 353 FAIL("Cannot load test dataset vc2.csv!"); 354 355 // Input training data. 356 SetInputParam("training", std::make_tuple(info, inputData)); 357 SetInputParam("labels", std::move(labels)); 358 359 // Input test data. 360 SetInputParam("test", std::make_tuple(info, testData)); 361 362 SetInputParam("min_samples", 10); 363 SetInputParam("confidence", 0.25); 364 365 mlpackMain(); 366 367 // Reset passed parameters. 368 IO::GetSingleton().Parameters()["training"].wasPassed = false; 369 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 370 IO::GetSingleton().Parameters()["test"].wasPassed = false; 371 IO::GetSingleton().Parameters()["min_samples"].wasPassed = false; 372 IO::GetSingleton().Parameters()["confidence"].wasPassed = false; 373 374 nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes(); 375 376 bindings::tests::CleanMemory(); 377 378 if (!data::Load("vc2.csv", inputData, info)) 379 FAIL("Cannot load train dataset vc2.csv!"); 380 381 if (!data::Load("vc2_labels.txt", labels)) 382 FAIL("Cannot load labels for vc2_labels.txt"); 383 384 if (!data::Load("vc2_test.csv", testData, info)) 385 FAIL("Cannot load test dataset vc2.csv!"); 386 387 // Input training data. 388 SetInputParam("training", std::make_tuple(info, inputData)); 389 SetInputParam("labels", std::move(labels)); 390 391 // Input test data. 392 SetInputParam("test", std::make_tuple(info, testData)); 393 394 SetInputParam("min_samples", 2000); 395 SetInputParam("confidence", 0.25); 396 397 mlpackMain(); 398 399 // Check that small min_samples creates larger model. 400 REQUIRE((IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() < 401 nodes); 402 } 403 404 /** 405 * Ensure that large max_samples creates smaller model. 406 */ 407 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingMaxSamplesTest", 408 "[HoeffdingTreeMainTest][BindingTest]") 409 { 410 arma::mat inputData; 411 DatasetInfo info; 412 int nodes; 413 if (!data::Load("vc2.csv", inputData, info)) 414 FAIL("Cannot load train dataset vc2.csv!"); 415 416 arma::Row<size_t> labels; 417 if (!data::Load("vc2_labels.txt", labels)) 418 FAIL("Cannot load labels for vc2_labels.txt"); 419 420 arma::mat testData; 421 if (!data::Load("vc2_test.csv", testData, info)) 422 FAIL("Cannot load test dataset vc2.csv!"); 423 424 // Input training data. 425 SetInputParam("training", std::make_tuple(info, inputData)); 426 SetInputParam("labels", std::move(labels)); 427 428 // Input test data. 429 SetInputParam("test", std::make_tuple(info, testData)); 430 431 SetInputParam("max_samples", 50000); 432 SetInputParam("confidence", 0.95); 433 434 mlpackMain(); 435 436 // Reset passed parameters. 437 IO::GetSingleton().Parameters()["training"].wasPassed = false; 438 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 439 IO::GetSingleton().Parameters()["test"].wasPassed = false; 440 IO::GetSingleton().Parameters()["max_samples"].wasPassed = false; 441 IO::GetSingleton().Parameters()["confidence"].wasPassed = false; 442 443 nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes(); 444 445 bindings::tests::CleanMemory(); 446 447 if (!data::Load("vc2.csv", inputData, info)) 448 FAIL("Cannot load train dataset vc2.csv!"); 449 450 if (!data::Load("vc2_labels.txt", labels)) 451 FAIL("Cannot load labels for vc2_labels.txt"); 452 453 if (!data::Load("vc2_test.csv", testData, info)) 454 FAIL("Cannot load test dataset vc2.csv!"); 455 456 // Input training data. 457 SetInputParam("training", std::make_tuple(info, inputData)); 458 SetInputParam("labels", std::move(labels)); 459 460 // Input test data. 461 SetInputParam("test", std::make_tuple(info, testData)); 462 463 SetInputParam("max_samples", 5); 464 SetInputParam("confidence", 0.95); 465 466 mlpackMain(); 467 468 // Check that large max_samples creates smaller model. 469 REQUIRE(nodes < 470 (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes()); 471 } 472 473 /** 474 * Ensure that small confidence value creates larger model. 475 */ 476 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingConfidenceTest", 477 "[HoeffdingTreeMainTest][BindingTest]") 478 { 479 arma::mat inputData; 480 DatasetInfo info; 481 int nodes; 482 if (!data::Load("vc2.csv", inputData, info)) 483 FAIL("Cannot load train dataset vc2.csv!"); 484 485 arma::Row<size_t> labels; 486 if (!data::Load("vc2_labels.txt", labels)) 487 FAIL("Cannot load labels for vc2_labels.txt"); 488 489 arma::mat testData; 490 if (!data::Load("vc2_test.csv", testData, info)) 491 FAIL("Cannot load test dataset vc2.csv!"); 492 493 // Input training data. 494 SetInputParam("training", std::make_tuple(info, inputData)); 495 SetInputParam("labels", std::move(labels)); 496 497 // Input test data. 498 SetInputParam("test", std::make_tuple(info, testData)); 499 500 SetInputParam("confidence", 0.95); 501 502 mlpackMain(); 503 504 // Reset passed parameters. 505 IO::GetSingleton().Parameters()["training"].wasPassed = false; 506 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 507 IO::GetSingleton().Parameters()["test"].wasPassed = false; 508 IO::GetSingleton().Parameters()["confidence"].wasPassed = false; 509 510 // Model with high confidence. 511 nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes(); 512 513 bindings::tests::CleanMemory(); 514 515 if (!data::Load("vc2.csv", inputData, info)) 516 FAIL("Cannot load train dataset vc2.csv!"); 517 518 if (!data::Load("vc2_labels.txt", labels)) 519 FAIL("Cannot load labels for vc2_labels.txt"); 520 521 if (!data::Load("vc2_test.csv", testData, info)) 522 FAIL("Cannot load test dataset vc2.csv!"); 523 524 // Input training data. 525 SetInputParam("training", std::make_tuple(info, inputData)); 526 SetInputParam("labels", std::move(labels)); 527 528 // Input test data. 529 SetInputParam("test", std::make_tuple(info, testData)); 530 531 // Model with low confidence. 532 SetInputParam("confidence", 0.25); 533 534 mlpackMain(); 535 // Check that higher confidence creates smaller tree. 536 REQUIRE(nodes < 537 (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes()); 538 } 539 540 /** 541 * Ensure that large number of passes creates larger model. 542 */ 543 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingPassesTest", 544 "[HoeffdingTreeMainTest][BindingTest]") 545 { 546 arma::mat inputData; 547 DatasetInfo info; 548 int nodes; 549 if (!data::Load("vc2.csv", inputData, info)) 550 FAIL("Cannot load train dataset vc2.csv!"); 551 552 arma::Row<size_t> labels; 553 if (!data::Load("vc2_labels.txt", labels)) 554 FAIL("Cannot load labels for vc2_labels.txt"); 555 556 arma::mat testData; 557 if (!data::Load("vc2_test.csv", testData, info)) 558 FAIL("Cannot load test dataset vc2.csv!"); 559 560 // Input training data. 561 SetInputParam("training", std::make_tuple(info, inputData)); 562 SetInputParam("labels", std::move(labels)); 563 564 // Input test data. 565 SetInputParam("test", std::make_tuple(info, testData)); 566 567 SetInputParam("passes", 1); 568 569 mlpackMain(); 570 571 // Reset passed parameters. 572 IO::GetSingleton().Parameters()["training"].wasPassed = false; 573 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 574 IO::GetSingleton().Parameters()["test"].wasPassed = false; 575 IO::GetSingleton().Parameters()["passes"].wasPassed = false; 576 577 // Model with smaller number of passes. 578 nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes(); 579 580 bindings::tests::CleanMemory(); 581 582 if (!data::Load("vc2.csv", inputData, info)) 583 FAIL("Cannot load train dataset vc2.csv!"); 584 585 if (!data::Load("vc2_labels.txt", labels)) 586 FAIL("Cannot load labels for vc2_labels.txt"); 587 588 if (!data::Load("vc2_test.csv", testData, info)) 589 FAIL("Cannot load test dataset vc2.csv!"); 590 591 // Input training data. 592 SetInputParam("training", std::make_tuple(info, inputData)); 593 SetInputParam("labels", std::move(labels)); 594 595 // Input test data. 596 SetInputParam("test", std::make_tuple(info, testData)); 597 598 // Model with larger number of passes. 599 SetInputParam("passes", 100); 600 601 mlpackMain(); 602 603 // Check that model with larger number of passes has greater number of nodes. 604 REQUIRE(nodes < 605 (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes()); 606 } 607 608 /** 609 * Ensure that the root node has 2 children when splitting strategy is binary. 610 */ 611 TEST_CASE_METHOD(HoeffdingTreeTestFixture, 612 "HoeffdingBinarySplittingStrategyTest", 613 "[HoeffdingTreeMainTest][BindingTest]") 614 { 615 arma::mat inputData; 616 DatasetInfo info; 617 if (!data::Load("vc2.csv", inputData, info)) 618 FAIL("Cannot load train dataset vc2.csv!"); 619 620 arma::Row<size_t> labels; 621 if (!data::Load("vc2_labels.txt", labels)) 622 FAIL("Cannot load labels for vc2_labels.txt"); 623 624 arma::mat testData; 625 if (!data::Load("vc2_test.csv", testData, info)) 626 FAIL("Cannot load test dataset vc2.csv!"); 627 628 // Input training data. 629 SetInputParam("training", std::make_tuple(info, inputData)); 630 SetInputParam("labels", std::move(labels)); 631 632 // Input test data. 633 SetInputParam("test", std::make_tuple(info, testData)); 634 635 SetInputParam("numeric_split_strategy", (string) "binary"); 636 SetInputParam("max_samples", 50); 637 638 SetInputParam("confidence", 0.25); 639 640 mlpackMain(); 641 642 // Check that number of children is 2. 643 REQUIRE( 644 (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() - 1 == 2); 645 } 646 647 /** 648 * Ensure that the number of children varies with varying 'bins' in domingos. 649 */ 650 TEST_CASE_METHOD(HoeffdingTreeTestFixture, 651 "HoeffdingDomingosSplittingStrategyTest", 652 "[HoeffdingTreeMainTest][BindingTest]") 653 { 654 arma::mat inputData; 655 DatasetInfo info; 656 int nodes; 657 if (!data::Load("vc2.csv", inputData, info)) 658 FAIL("Cannot load train dataset vc2.csv!"); 659 660 arma::Row<size_t> labels; 661 if (!data::Load("vc2_labels.txt", labels)) 662 FAIL("Cannot load labels for vc2_labels.txt"); 663 664 arma::mat testData; 665 if (!data::Load("vc2_test.csv", testData, info)) 666 FAIL("Cannot load test dataset vc2.csv!"); 667 668 // Input training data. 669 SetInputParam("training", std::make_tuple(info, inputData)); 670 SetInputParam("labels", std::move(labels)); 671 672 // Input test data. 673 SetInputParam("test", std::make_tuple(info, testData)); 674 675 SetInputParam("numeric_split_strategy", (string) "domingos"); 676 SetInputParam("max_samples", 50); 677 SetInputParam("bins", 20); 678 679 mlpackMain(); 680 681 // Reset passed parameters. 682 IO::GetSingleton().Parameters()["training"].wasPassed = false; 683 IO::GetSingleton().Parameters()["labels"].wasPassed = false; 684 IO::GetSingleton().Parameters()["test"].wasPassed = false; 685 IO::GetSingleton().Parameters()["max_samples"].wasPassed = false; 686 IO::GetSingleton().Parameters()["numeric_split_strategy"].wasPassed = false; 687 IO::GetSingleton().Parameters()["bins"].wasPassed = false; 688 689 // Initial model. 690 nodes = (IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes(); 691 692 bindings::tests::CleanMemory(); 693 694 if (!data::Load("vc2.csv", inputData, info)) 695 FAIL("Cannot load train dataset vc2.csv!"); 696 697 if (!data::Load("vc2_labels.txt", labels)) 698 FAIL("Cannot load labels for vc2_labels.txt"); 699 700 if (!data::Load("vc2_test.csv", testData, info)) 701 FAIL("Cannot load test dataset vc2.csv!"); 702 703 // Input training data. 704 SetInputParam("training", std::make_tuple(info, inputData)); 705 SetInputParam("labels", std::move(labels)); 706 707 // Input test data. 708 SetInputParam("test", std::make_tuple(info, testData)); 709 710 SetInputParam("numeric_split_strategy", (string) "domingos"); 711 SetInputParam("max_samples", 50); 712 SetInputParam("bins", 10); 713 714 mlpackMain(); 715 716 // Check that both models have different number of nodes. 717 CHECK((IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() != 718 nodes); 719 } 720 721 /** 722 * Ensure that the model doesn't split if observations before binning 723 * is greater than total number of samples passed. 724 */ 725 TEST_CASE_METHOD(HoeffdingTreeTestFixture, "HoeffdingBinningTest", 726 "[HoeffdingTreeMainTest][BindingTests]") 727 { 728 arma::mat inputData; 729 arma::mat modData; 730 arma::Row<size_t> modLabels; 731 DatasetInfo info; 732 if (!data::Load("vc2.csv", inputData, info)) 733 FAIL("Cannot load train dataset vc2.csv!"); 734 735 arma::Row<size_t> labels; 736 if (!data::Load("vc2_labels.txt", labels)) 737 FAIL("Cannot load labels for vc2_labels.txt"); 738 739 arma::mat testData; 740 if (!data::Load("vc2_test.csv", testData, info)) 741 FAIL("Cannot load test dataset vc2.csv!"); 742 743 modData = inputData.cols(0, 49); 744 modLabels = labels.cols(0, 49); 745 746 // Input training data. 747 SetInputParam("training", std::make_tuple(info, modData)); 748 SetInputParam("labels", std::move(modLabels)); 749 750 // Input test data. 751 SetInputParam("test", std::make_tuple(info, testData)); 752 753 SetInputParam("numeric_split_strategy", (string) "domingos"); 754 SetInputParam("min_samples", 10); 755 756 // Set parameter to a value greater than number of samples. 757 SetInputParam("observations_before_binning", 100); 758 SetInputParam("confidence", 0.25); 759 760 mlpackMain(); 761 762 // Check that no splitting has happened. 763 REQUIRE((IO::GetParam<HoeffdingTreeModel*>("output_model"))->NumNodes() 764 == 1); 765 } 766