1 /** 2 * @file tests/main_tests/lmnn_test.cpp 3 * @author Manish Kumar 4 * 5 * Test mlpackMain() of lmnn_main.cpp. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 13 #include <string> 14 15 #define BINDING_TYPE BINDING_TYPE_TEST 16 static const std::string testName = "LMNN"; 17 18 #include <mlpack/core.hpp> 19 #include <mlpack/core/util/mlpack_main.hpp> 20 #include <mlpack/core/metrics/lmetric.hpp> 21 22 #include "test_helper.hpp" 23 #include <mlpack/methods/lmnn/lmnn_main.cpp> 24 25 #include "../test_catch_tools.hpp" 26 #include "../catch.hpp" 27 28 using namespace mlpack; 29 30 struct LMNNTestFixture 31 { 32 public: LMNNTestFixtureLMNNTestFixture33 LMNNTestFixture() 34 { 35 // Cache in the options for this program. 36 IO::RestoreSettings(testName); 37 } 38 ~LMNNTestFixtureLMNNTestFixture39 ~LMNNTestFixture() 40 { 41 // Clear the settings. 42 bindings::tests::CleanMemory(); 43 IO::ClearSettings(); 44 } 45 }; 46 47 /** 48 * Ensure that, when labels are implicitily given with input, 49 * the last column is treated as labels and that we get the 50 * desired shape of output. 51 */ 52 TEST_CASE_METHOD(LMNNTestFixture, "LMNNExplicitImplicitLabelsTest", 53 "[LMNNMainTest][BindingTests]") 54 { 55 // Dataset containing labels as last column. 56 arma::mat inputData; 57 if (!data::Load("iris_train.csv", inputData)) 58 FAIL("Cannot load iris.csv!"); 59 60 SetInputParam("input", inputData); 61 62 mlpackMain(); 63 64 // Check that final output has expected number of rows and colums. 65 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 66 inputData.n_rows - 1); 67 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 68 inputData.n_rows - 1); 69 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 70 inputData.n_rows - 1); 71 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 72 inputData.n_cols); 73 74 // Reset Settings. 75 IO::ClearSettings(); 76 IO::RestoreSettings(testName); 77 78 // Now check that when labels are explicitely given, the last column 79 // of input is not treated as labels. 80 if (!data::Load("iris.csv", inputData)) 81 FAIL("Cannot load iris.csv!"); 82 83 arma::Row<size_t> labels; 84 if (!data::Load("iris_labels.txt", labels)) 85 FAIL("Cannot load iris_labels.txt!"); 86 87 SetInputParam("input", inputData); 88 SetInputParam("labels", std::move(labels)); 89 90 mlpackMain(); 91 92 // Check that final output has expected number of rows and colums. 93 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 94 inputData.n_rows); 95 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 96 inputData.n_rows); 97 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 98 inputData.n_rows); 99 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 100 inputData.n_cols); 101 } 102 103 /** 104 * Ensure that when we pass optimizer of type lbfgs, we also get the desired 105 * shape of output. 106 */ 107 TEST_CASE_METHOD(LMNNTestFixture, "LMNNOptimizerTest", 108 "[LMNNMainTest][BindingTests]") 109 { 110 arma::mat inputData; 111 if (!data::Load("iris.csv", inputData)) 112 FAIL("Cannot load iris.csv!"); 113 114 arma::Row<size_t> labels; 115 if (!data::Load("iris_labels.txt", labels)) 116 FAIL("Cannot load iris_labels.txt!"); 117 118 // Input random data points. 119 SetInputParam("input", inputData); 120 SetInputParam("labels", labels); 121 // TODO: set back to bbsgd---this was done for #1490 and should be reverted 122 // when that is fixed. 123 SetInputParam("optimizer", std::string("amsgrad")); 124 125 mlpackMain(); 126 127 // Check that final output has expected number of rows and colums. 128 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 129 inputData.n_rows); 130 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 131 inputData.n_rows); 132 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 133 inputData.n_rows); 134 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 135 inputData.n_cols); 136 137 // Reset rettings. 138 IO::ClearSettings(); 139 IO::RestoreSettings(testName); 140 141 // Input random data points. 142 SetInputParam("input", inputData); 143 SetInputParam("labels", labels); 144 SetInputParam("optimizer", std::string("sgd")); 145 146 mlpackMain(); 147 148 // Check that final output has expected number of rows and colums. 149 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 150 inputData.n_rows); 151 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 152 inputData.n_rows); 153 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 154 inputData.n_rows); 155 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 156 inputData.n_cols); 157 158 // Reset rettings. 159 IO::ClearSettings(); 160 IO::RestoreSettings(testName); 161 162 // Input random data points. 163 SetInputParam("input", inputData); 164 SetInputParam("labels", std::move(labels)); 165 SetInputParam("optimizer", std::string("lbfgs")); 166 167 mlpackMain(); 168 169 // Check that final output has expected number of rows and colums. 170 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 171 inputData.n_rows); 172 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 173 inputData.n_rows); 174 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 175 inputData.n_rows); 176 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 177 inputData.n_cols); 178 } 179 180 /** 181 * Ensure that when we pass a valid initial learning point, we get 182 * output of the same dimensions. 183 */ 184 TEST_CASE_METHOD(LMNNTestFixture, "LMNNValidDistanceTest", 185 "[LMNNMainTest][BindingTests]") 186 { 187 arma::mat inputData; 188 if (!data::Load("iris.csv", inputData)) 189 FAIL("Cannot load iris.csv!"); 190 191 arma::Row<size_t> labels; 192 if (!data::Load("iris_labels.txt", labels)) 193 FAIL("Cannot load iris_labels.txt!"); 194 195 // Initial learning point. 196 arma::mat distance; 197 distance.randu(inputData.n_rows - 1, inputData.n_rows); 198 199 // Input random data points. 200 SetInputParam("input", inputData); 201 SetInputParam("labels", std::move(labels)); 202 SetInputParam("distance", std::move(distance)); 203 204 mlpackMain(); 205 206 // Check that final output has expected number of rows and colums. 207 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 208 inputData.n_rows - 1); 209 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 210 inputData.n_rows); 211 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 212 inputData.n_rows - 1); 213 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 214 inputData.n_cols); 215 } 216 217 /** 218 * Ensure that when we pass a valid initial square matrix as the learning 219 * point, we get output of the same dimensions. 220 */ 221 TEST_CASE_METHOD(LMNNTestFixture, "LMNNValidDistanceTest2", 222 "[LMNNMainTest][BindingTests]") 223 { 224 arma::mat inputData; 225 if (!data::Load("iris.csv", inputData)) 226 FAIL("Cannot load iris.csv!"); 227 228 arma::Row<size_t> labels; 229 if (!data::Load("iris_labels.txt", labels)) 230 FAIL("Cannot load iris_labels.txt!"); 231 232 // Initial learning point (square matrix). 233 arma::mat distance; 234 distance.randu(inputData.n_rows, inputData.n_rows); 235 236 // Input random data points. 237 SetInputParam("input", inputData); 238 SetInputParam("labels", std::move(labels)); 239 SetInputParam("distance", std::move(distance)); 240 241 mlpackMain(); 242 243 // Check that final output has expected number of rows and colums. 244 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 245 inputData.n_rows); 246 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 247 inputData.n_rows); 248 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 249 inputData.n_rows); 250 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 251 inputData.n_cols); 252 } 253 254 /** 255 * Ensure that when we pass an invalid initial learning point, we get 256 * output as the square matrix. 257 */ 258 TEST_CASE_METHOD(LMNNTestFixture, "LMNNInvalidDistanceTest", 259 "[LMNNMainTest][BindingTests]") 260 { 261 arma::mat inputData; 262 if (!data::Load("iris.csv", inputData)) 263 FAIL("Cannot load iris.csv!"); 264 265 arma::Row<size_t> labels; 266 if (!data::Load("iris_labels.txt", labels)) 267 FAIL("Cannot load iris_labels.txt!"); 268 269 // Initial learning point. 270 arma::mat distance; 271 distance.randu(inputData.n_rows + 1, inputData.n_rows); 272 273 // Input random data points. 274 SetInputParam("input", inputData); 275 SetInputParam("labels", std::move(labels)); 276 SetInputParam("distance", std::move(distance)); 277 278 mlpackMain(); 279 280 // Check that final output has expected number of rows and colums. 281 REQUIRE(IO::GetParam<arma::mat>("output").n_rows == 282 inputData.n_rows); 283 REQUIRE(IO::GetParam<arma::mat>("output").n_cols == 284 inputData.n_rows); 285 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_rows == 286 inputData.n_rows); 287 REQUIRE(IO::GetParam<arma::mat>("transformed_data").n_cols == 288 inputData.n_cols); 289 } 290 291 /** 292 * Ensure that if number of available labels in a class is less than 293 * the number of targets, an error occurs. 294 */ 295 TEST_CASE_METHOD(LMNNTestFixture, "LMNNNumTargetsTest", 296 "[LMNNMainTest][BindingTests]") 297 { 298 // Input Dataset 299 arma::mat inputData = "-0.1 -0.1 -0.1 0.1 0.1 0.1;" 300 " 1.0 0.0 -1.0 1.0 0.0 -1.0 "; 301 arma::Row<size_t> labels = " 0 0 0 1 1 1"; 302 303 SetInputParam("input", std::move(inputData)); 304 SetInputParam("labels", std::move(labels)); 305 SetInputParam("k", (int) 5); 306 307 // Check that an error is thrown. 308 Log::Fatal.ignoreInput = true; 309 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 310 Log::Fatal.ignoreInput = false; 311 } 312 313 /** 314 * Ensure that setting normalize as true results in a 315 * different output matrix then when set to false. 316 */ 317 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffNormalizationTest", 318 "[LMNNMainTest][BindingTests]") 319 { 320 arma::mat inputData; 321 if (!data::Load("iris.csv", inputData)) 322 FAIL("Cannot load iris.csv!"); 323 324 arma::Row<size_t> labels; 325 if (!data::Load("iris_labels.txt", labels)) 326 FAIL("Cannot load iris_labels.txt!"); 327 328 // Set parameters and set normalize to true. 329 SetInputParam("input", inputData); 330 SetInputParam("labels", labels); 331 SetInputParam("linear_scan", true); 332 SetInputParam("tolerance", 0.01); 333 334 mlpackMain(); 335 336 arma::mat output = IO::GetParam<arma::mat>("output"); 337 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 338 339 // Reset rettings. 340 IO::ClearSettings(); 341 IO::RestoreSettings(testName); 342 343 // Use the same input but set normalize to false. 344 SetInputParam("input", std::move(inputData)); 345 SetInputParam("labels", std::move(labels)); 346 SetInputParam("normalize", true); 347 SetInputParam("linear_scan", true); 348 SetInputParam("tolerance", 0.01); 349 350 mlpackMain(); 351 352 // Check that the output matrices are different. 353 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 354 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 355 transformedData) > 0); 356 } 357 358 /** 359 * Ensure that output is different when step_size is different. 360 */ 361 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffStepSizeTest", 362 "[LMNNMainTest][BindingTests]") 363 { 364 arma::mat inputData; 365 if (!data::Load("iris.csv", inputData)) 366 FAIL("Cannot load iris.csv!"); 367 368 arma::Row<size_t> labels; 369 if (!data::Load("iris_labels.txt", labels)) 370 FAIL("Cannot load iris_labels.txt!"); 371 372 // Set parameters with a small step_size. 373 SetInputParam("input", inputData); 374 SetInputParam("labels", labels); 375 SetInputParam("step_size", (double) 0.01); 376 SetInputParam("linear_scan", (bool) true); 377 378 mlpackMain(); 379 380 arma::mat output = IO::GetParam<arma::mat>("output"); 381 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 382 383 // Reset settings. 384 IO::ClearSettings(); 385 IO::RestoreSettings(testName); 386 387 // Set parameters using the same input but with a larger step_size. 388 SetInputParam("input", std::move(inputData)); 389 SetInputParam("labels", std::move(labels)); 390 SetInputParam("step_size", (double) 20.5); 391 SetInputParam("linear_scan", (bool) true); 392 393 mlpackMain(); 394 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 395 transformedData) > 0); 396 // Check that the output matrices are different. 397 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 398 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 399 transformedData) > 0); 400 } 401 402 /** 403 * Ensure that output is different when the tolerance is different. 404 */ 405 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffToleranceTest", 406 "[LMNNMainTest][BindingTests]") 407 { 408 arma::mat inputData; 409 if (!data::Load("iris.csv", inputData)) 410 FAIL("Cannot load iris.csv!"); 411 412 arma::Row<size_t> labels; 413 if (!data::Load("iris_labels.txt", labels)) 414 FAIL("Cannot load iris_labels.txt!"); 415 416 // Set parameters with a small tolerance. 417 SetInputParam("input", inputData); 418 SetInputParam("tolerance", (double) 1e-6); 419 SetInputParam("linear_scan", (bool) true); 420 421 mlpackMain(); 422 423 arma::mat output = IO::GetParam<arma::mat>("output"); 424 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 425 426 // Reset settings. 427 IO::ClearSettings(); 428 IO::RestoreSettings(testName); 429 430 // Set parameters using the same input but with a larger tolerance. 431 SetInputParam("input", std::move(inputData)); 432 SetInputParam("tolerance", (double) 0.3); 433 SetInputParam("linear_scan", (bool) true); 434 435 mlpackMain(); 436 437 // Check that the output matrices are different. 438 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 439 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 440 transformedData) > 0); 441 } 442 443 /** 444 * Ensure that output is different when batch_size is different. 445 */ 446 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffBatchSizeTest", 447 "[LMNNMainTest][BindingTests]") 448 { 449 arma::mat inputData; 450 if (!data::Load("iris.csv", inputData)) 451 FAIL("Cannot load iris.csv!"); 452 453 arma::Row<size_t> labels; 454 if (!data::Load("iris_labels.txt", labels)) 455 FAIL("Cannot load iris_labels.txt!"); 456 457 // Set parameters with a small batch_size. 458 SetInputParam("input", inputData); 459 SetInputParam("labels", labels); 460 SetInputParam("batch_size", (int) 20); 461 SetInputParam("linear_scan", (bool) true); 462 463 mlpackMain(); 464 465 arma::mat output = IO::GetParam<arma::mat>("output"); 466 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 467 468 // Reset settings. 469 IO::ClearSettings(); 470 IO::RestoreSettings(testName); 471 472 // Set parameters using the same input but with a larger batch_size. 473 SetInputParam("input", std::move(inputData)); 474 SetInputParam("labels", std::move(labels)); 475 SetInputParam("batch_size", (int) 30); 476 SetInputParam("linear_scan", (bool) true); 477 478 mlpackMain(); 479 480 // Check that the output matrices are different. 481 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 482 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 483 transformedData) > 0); 484 } 485 486 /** 487 * Ensure that different value of number of targets results in a 488 * different output matrix. 489 */ 490 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffNumTargetsTest", 491 "[LMNNMainTest][BindingTests]") 492 { 493 arma::mat inputData; 494 if (!data::Load("iris.csv", inputData)) 495 FAIL("Cannot load iris.csv!"); 496 497 arma::Row<size_t> labels; 498 if (!data::Load("iris_labels.txt", labels)) 499 FAIL("Cannot load iris_labels.txt!"); 500 501 // Set parameters. 502 SetInputParam("input", inputData); 503 SetInputParam("labels", labels); 504 SetInputParam("k", 1); 505 SetInputParam("linear_scan", (bool) true); 506 507 mlpackMain(); 508 509 arma::mat output = IO::GetParam<arma::mat>("output"); 510 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 511 512 // Reset settings. 513 IO::ClearSettings(); 514 IO::RestoreSettings(testName); 515 516 // Set different parameters. 517 SetInputParam("input", std::move(inputData)); 518 SetInputParam("labels", std::move(labels)); 519 SetInputParam("k", 5); 520 SetInputParam("linear_scan", (bool) true); 521 522 mlpackMain(); 523 524 // Check that the output matrices are different. 525 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 526 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 527 transformedData) > 0); 528 } 529 530 /** 531 * Ensure that different value of regularization results in a 532 * different output matrix. 533 */ 534 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffRegularizationTest", 535 "[LMNNMainTest][BindingTests]") 536 { 537 arma::mat inputData; 538 if (!data::Load("iris.csv", inputData)) 539 FAIL("Cannot load iris.csv!"); 540 541 arma::Row<size_t> labels; 542 if (!data::Load("iris_labels.txt", labels)) 543 FAIL("Cannot load iris_labels.txt!"); 544 545 // Set parameters. 546 SetInputParam("input", inputData); 547 SetInputParam("labels", labels); 548 SetInputParam("linear_scan", (bool) true); 549 SetInputParam("regularization", 1.0); 550 551 mlpackMain(); 552 553 arma::mat output = IO::GetParam<arma::mat>("output"); 554 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 555 556 // Reset settings. 557 IO::ClearSettings(); 558 IO::RestoreSettings(testName); 559 560 // Set different parameters. 561 SetInputParam("input", std::move(inputData)); 562 SetInputParam("labels", std::move(labels)); 563 SetInputParam("linear_scan", (bool) true); 564 SetInputParam("regularization", 0.1); 565 566 mlpackMain(); 567 568 // Check that the output matrices are different. 569 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 570 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 571 transformedData) > 0); 572 } 573 574 /** 575 * Ensure that different value of range results in a 576 * different output matrix. 577 */ 578 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffRangeTest", 579 "[LMNNMainTest][BindingTests]") 580 { 581 arma::mat inputData; 582 if (!data::Load("iris.csv", inputData)) 583 FAIL("Cannot load iris.csv!"); 584 585 arma::Row<size_t> labels; 586 if (!data::Load("iris_labels.txt", labels)) 587 FAIL("Cannot load iris_labels.txt!"); 588 589 // Set parameters. 590 SetInputParam("input", inputData); 591 SetInputParam("labels", labels); 592 SetInputParam("linear_scan", (bool) true); 593 594 mlpackMain(); 595 596 arma::mat output = IO::GetParam<arma::mat>("output"); 597 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 598 599 // Reset settings. 600 IO::ClearSettings(); 601 IO::RestoreSettings(testName); 602 603 // Set different parameters. 604 SetInputParam("input", std::move(inputData)); 605 SetInputParam("labels", std::move(labels)); 606 SetInputParam("linear_scan", (bool) true); 607 SetInputParam("range", 100); 608 609 mlpackMain(); 610 611 // Check that the output matrices are different. 612 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 613 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 614 transformedData) > 0); 615 } 616 617 /** 618 * Ensure that using a different value of max_iteration 619 * results in a different output matrix. 620 */ 621 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffMaxIterationTest", 622 "[LMNNMainTest][BindingTests]") 623 { 624 arma::mat inputData; 625 if (!data::Load("iris.csv", inputData)) 626 FAIL("Cannot load iris.csv!"); 627 628 arma::Row<size_t> labels; 629 if (!data::Load("iris_labels.txt", labels)) 630 FAIL("Cannot load iris_labels.txt!"); 631 632 // Set parameters with a small max_iterations. 633 SetInputParam("input", inputData); 634 SetInputParam("labels", labels); 635 SetInputParam("linear_scan", (bool) true); 636 SetInputParam("optimizer", std::string("lbfgs")); 637 SetInputParam("k", 5); 638 SetInputParam("max_iterations", (int) 2); 639 640 mlpackMain(); 641 642 arma::mat output = IO::GetParam<arma::mat>("output"); 643 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 644 645 // Reset settings. 646 IO::ClearSettings(); 647 IO::RestoreSettings(testName); 648 649 // Set parameters using the same input but with a larger max_iterations. 650 SetInputParam("input", std::move(inputData)); 651 SetInputParam("labels", labels); 652 SetInputParam("linear_scan", (bool) true); 653 SetInputParam("optimizer", std::string("lbfgs")); 654 SetInputParam("k", 5); 655 SetInputParam("max_iterations", (int) 500); 656 657 mlpackMain(); 658 659 // Check that the output matrices are different. 660 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 661 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 662 transformedData) > 0); 663 } 664 665 /** 666 * Ensure that using a different value of passes 667 * results in a different output matrix. 668 */ 669 TEST_CASE_METHOD(LMNNTestFixture, "LMNNDiffPassesTest", 670 "[LMNNMainTest][BindingTests]") 671 { 672 arma::mat inputData; 673 if (!data::Load("iris.csv", inputData)) 674 FAIL("Cannot load iris.csv!"); 675 676 arma::Row<size_t> labels; 677 if (!data::Load("iris_labels.txt", labels)) 678 FAIL("Cannot load iris_labels.txt!"); 679 680 // Set parameters with a small passes. 681 SetInputParam("input", inputData); 682 SetInputParam("labels", labels); 683 SetInputParam("linear_scan", (bool) true); 684 SetInputParam("passes", (int) 2); 685 686 mlpackMain(); 687 688 arma::mat output = IO::GetParam<arma::mat>("output"); 689 arma::mat transformedData = IO::GetParam<arma::mat>("transformed_data"); 690 691 // Reset settings. 692 IO::ClearSettings(); 693 IO::RestoreSettings(testName); 694 695 // Set parameters using the same input but with a larger passes. 696 SetInputParam("input", std::move(inputData)); 697 SetInputParam("labels", labels); 698 SetInputParam("linear_scan", (bool) true); 699 SetInputParam("passes", (int) 6); 700 701 mlpackMain(); 702 703 // Check that the output matrices are different. 704 REQUIRE(arma::accu(IO::GetParam<arma::mat>("output") != output) > 0); 705 REQUIRE(arma::accu(IO::GetParam<arma::mat>("transformed_data") != 706 transformedData) > 0); 707 } 708 709 /** 710 * Ensure that number of targets, range, batch size must be always positive 711 * and regularization, step size, max iterations, rank, passes & tolerance are 712 * always non-negative 713 */ 714 TEST_CASE_METHOD(LMNNTestFixture, "LMNNBoundsTest", 715 "[LMNNMainTest][BindingTests]") 716 { 717 arma::mat inputData; 718 if (!data::Load("iris.csv", inputData)) 719 FAIL("Cannot load iris.csv!"); 720 721 arma::Row<size_t> labels; 722 if (!data::Load("iris_labels.txt", labels)) 723 FAIL("Cannot load iris_labels.txt!"); 724 725 // Test for number of targets value. 726 727 // Input training data. 728 SetInputParam("input", inputData); 729 SetInputParam("labels", labels); 730 SetInputParam("k", (int) 0); 731 732 Log::Fatal.ignoreInput = true; 733 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 734 Log::Fatal.ignoreInput = false; 735 736 // Reset settings. 737 IO::ClearSettings(); 738 IO::RestoreSettings(testName); 739 740 // Test for range value. 741 742 // Input training data. 743 SetInputParam("input", inputData); 744 SetInputParam("labels", labels); 745 SetInputParam("range", (int) 0); 746 747 Log::Fatal.ignoreInput = true; 748 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 749 Log::Fatal.ignoreInput = false; 750 751 // Reset settings. 752 IO::ClearSettings(); 753 IO::RestoreSettings(testName); 754 755 // Test for batch size value. 756 757 // Input training data. 758 SetInputParam("input", inputData); 759 SetInputParam("labels", labels); 760 SetInputParam("batch_size", (int) 0); 761 762 Log::Fatal.ignoreInput = true; 763 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 764 Log::Fatal.ignoreInput = false; 765 766 // Reset settings. 767 IO::ClearSettings(); 768 IO::RestoreSettings(testName); 769 770 // Test for regularization value. 771 772 // Input training data. 773 SetInputParam("input", inputData); 774 SetInputParam("labels", labels); 775 SetInputParam("regularization", (double) -1.0); 776 777 Log::Fatal.ignoreInput = true; 778 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 779 Log::Fatal.ignoreInput = false; 780 781 // Reset settings. 782 IO::ClearSettings(); 783 IO::RestoreSettings(testName); 784 785 // Test for step size value. 786 787 // Input training data. 788 SetInputParam("input", inputData); 789 SetInputParam("labels", labels); 790 SetInputParam("step_size", (double) -1.0); 791 792 Log::Fatal.ignoreInput = true; 793 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 794 Log::Fatal.ignoreInput = false; 795 796 // Reset settings. 797 IO::ClearSettings(); 798 IO::RestoreSettings(testName); 799 800 // Test for max iterations value. 801 802 // Input training data. 803 SetInputParam("input", inputData); 804 SetInputParam("labels", labels); 805 SetInputParam("max_iterations", (int) -1.0); 806 807 Log::Fatal.ignoreInput = true; 808 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 809 Log::Fatal.ignoreInput = false; 810 811 // Reset settings. 812 IO::ClearSettings(); 813 IO::RestoreSettings(testName); 814 815 // Test for passes value. 816 817 // Input training data. 818 SetInputParam("input", inputData); 819 SetInputParam("labels", labels); 820 SetInputParam("passes", (int) -1.0); 821 822 Log::Fatal.ignoreInput = true; 823 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 824 Log::Fatal.ignoreInput = false; 825 826 // Reset settings. 827 IO::ClearSettings(); 828 IO::RestoreSettings(testName); 829 830 // Test for max iterations value. 831 832 // Input training data. 833 SetInputParam("input", inputData); 834 SetInputParam("labels", labels); 835 SetInputParam("rank", (int) -1.0); 836 837 Log::Fatal.ignoreInput = true; 838 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 839 Log::Fatal.ignoreInput = false; 840 841 // Reset settings. 842 IO::ClearSettings(); 843 IO::RestoreSettings(testName); 844 845 // Test for tolerance value. 846 847 // Input training data. 848 SetInputParam("input", std::move(inputData)); 849 SetInputParam("labels", std::move(labels)); 850 SetInputParam("tolerance", (double) -1.0); 851 852 Log::Fatal.ignoreInput = true; 853 REQUIRE_THROWS_AS(mlpackMain(), std::runtime_error); 854 Log::Fatal.ignoreInput = false; 855 } 856