1 /** 2 * @file function_test.cpp 3 * @author Ryan Curtin 4 * @author Shikhar Bhardwaj 5 * @author Marcus Edel 6 * @author Conrad Sanderson 7 * 8 * ensmallen is free software; you may redistribute it and/or modify it under 9 * the terms of the 3-clause BSD license. You should have received a copy of 10 * the 3-clause BSD license along with ensmallen. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 14 #include <ensmallen.hpp> 15 #include "catch.hpp" 16 17 using namespace ens; 18 using namespace ens::test; 19 using namespace ens::traits; 20 21 /** 22 * Utility class with no functions. 23 */ 24 class EmptyTestFunction { }; 25 26 /** 27 * Utility class with Evaluate() but no Evaluate(). 28 */ 29 class EvaluateTestFunction 30 { 31 public: Evaluate(const arma::mat & coordinates)32 double Evaluate(const arma::mat& coordinates) 33 { 34 return arma::accu(coordinates); 35 } 36 Evaluate(const arma::mat & coordinates,const size_t begin,const size_t batchSize)37 double Evaluate(const arma::mat& coordinates, 38 const size_t begin, 39 const size_t batchSize) 40 { 41 return arma::accu(coordinates) + begin + batchSize; 42 } 43 }; 44 45 /** 46 * Utility class with Gradient() but no Evaluate(). 47 */ 48 class GradientTestFunction 49 { 50 public: Gradient(const arma::mat & coordinates,arma::mat & gradient)51 void Gradient(const arma::mat& coordinates, arma::mat& gradient) 52 { 53 gradient.ones(coordinates.n_rows, coordinates.n_cols); 54 } 55 Gradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)56 void Gradient(const arma::mat& coordinates, 57 const size_t /* begin */, 58 arma::mat& gradient, 59 const size_t /* batchSize */) 60 { 61 gradient.ones(coordinates.n_rows, coordinates.n_cols); 62 } 63 }; 64 65 /** 66 * Utility class with Gradient() and Evaluate(). 67 */ 68 class EvaluateGradientTestFunction 69 { 70 public: Evaluate(const arma::mat & coordinates)71 double Evaluate(const arma::mat& coordinates) 72 { 73 return arma::accu(coordinates); 74 } 75 Evaluate(const arma::mat & coordinates,const size_t,const size_t)76 double Evaluate(const arma::mat& coordinates, 77 const size_t /* begin */, 78 const size_t /* batchSize */) 79 { 80 return arma::accu(coordinates); 81 } 82 Gradient(const arma::mat & coordinates,arma::mat & gradient)83 void Gradient(const arma::mat& coordinates, arma::mat& gradient) 84 { 85 gradient.ones(coordinates.n_rows, coordinates.n_cols); 86 } 87 Gradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)88 void Gradient(const arma::mat& coordinates, 89 const size_t /* begin */, 90 arma::mat& gradient, 91 const size_t /* batchSize */) 92 { 93 gradient.ones(coordinates.n_rows, coordinates.n_cols); 94 } 95 }; 96 97 /** 98 * Utility class with EvaluateWithGradient(). 99 */ 100 class EvaluateWithGradientTestFunction 101 { 102 public: EvaluateWithGradient(const arma::mat & coordinates,arma::mat & gradient)103 double EvaluateWithGradient(const arma::mat& coordinates, arma::mat& gradient) 104 { 105 gradient.ones(coordinates.n_rows, coordinates.n_cols); 106 return arma::accu(coordinates); 107 } 108 EvaluateWithGradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)109 double EvaluateWithGradient(const arma::mat& coordinates, 110 const size_t /* begin */, 111 arma::mat& gradient, 112 const size_t /* batchSize */) 113 { 114 gradient.ones(coordinates.n_rows, coordinates.n_cols); 115 return arma::accu(coordinates); 116 } 117 }; 118 119 /** 120 * Utility class with all three functions. 121 */ 122 class EvaluateAndWithGradientTestFunction 123 { 124 public: Evaluate(const arma::mat & coordinates)125 double Evaluate(const arma::mat& coordinates) 126 { 127 return arma::accu(coordinates); 128 } 129 Evaluate(const arma::mat & coordinates,const size_t begin,const size_t batchSize)130 double Evaluate(const arma::mat& coordinates, 131 const size_t begin, 132 const size_t batchSize) 133 { 134 return arma::accu(coordinates) + batchSize + begin; 135 } 136 Gradient(const arma::mat & coordinates,arma::mat & gradient)137 void Gradient(const arma::mat& coordinates, arma::mat& gradient) 138 { 139 gradient.ones(coordinates.n_rows, coordinates.n_cols); 140 } 141 Gradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)142 void Gradient(const arma::mat& coordinates, 143 const size_t /* begin */, 144 arma::mat& gradient, 145 const size_t /* batchSize */) 146 { 147 gradient.ones(coordinates.n_rows, coordinates.n_cols); 148 } 149 EvaluateWithGradient(const arma::mat & coordinates,arma::mat & gradient)150 double EvaluateWithGradient(const arma::mat& coordinates, arma::mat& gradient) 151 { 152 gradient.ones(coordinates.n_rows, coordinates.n_cols); 153 return arma::accu(coordinates); 154 } 155 EvaluateWithGradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)156 double EvaluateWithGradient(const arma::mat& coordinates, 157 const size_t /* begin */, 158 arma::mat& gradient, 159 const size_t /* batchSize */) 160 { 161 gradient.ones(coordinates.n_rows, coordinates.n_cols); 162 return arma::accu(coordinates); 163 } 164 }; 165 166 /** 167 * Utility class with const Evaluate() and non-const Gradient(). 168 */ 169 class EvaluateAndNonConstGradientTestFunction 170 { 171 public: Evaluate(const arma::mat & coordinates) const172 double Evaluate(const arma::mat& coordinates) const 173 { 174 return arma::accu(coordinates); 175 } 176 Gradient(const arma::mat & coordinates,arma::mat & gradient)177 void Gradient(const arma::mat& coordinates, arma::mat& gradient) 178 { 179 gradient.ones(coordinates.n_rows, coordinates.n_cols); 180 } 181 }; 182 183 /** 184 * Utility class with const Evaluate() and non-const Gradient(). 185 */ 186 class EvaluateAndStaticGradientTestFunction 187 { 188 public: Evaluate(const arma::mat & coordinates) const189 double Evaluate(const arma::mat& coordinates) const 190 { 191 return arma::accu(coordinates); 192 } 193 Gradient(const arma::mat & coordinates,arma::mat & gradient)194 static void Gradient(const arma::mat& coordinates, arma::mat& gradient) 195 { 196 gradient.ones(coordinates.n_rows, coordinates.n_cols); 197 } 198 }; 199 200 /** 201 * Make sure that an empty class doesn't have any methods added to it. 202 */ 203 TEST_CASE("AddEvaluateWithGradientEmptyTest", "[FunctionTest]") 204 { 205 const bool hasEvaluate = HasEvaluate< 206 Function<EmptyTestFunction, arma::mat, arma::mat>, 207 TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value; 208 const bool hasGradient = HasGradient< 209 Function<EmptyTestFunction, arma::mat, arma::mat>, 210 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 211 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 212 Function<EmptyTestFunction, arma::mat, arma::mat>, 213 TypedForms<arma::mat, arma::mat>::template 214 EvaluateWithGradientForm>::value; 215 216 REQUIRE(hasEvaluate == false); 217 REQUIRE(hasGradient == false); 218 REQUIRE(hasEvaluateWithGradient == false); 219 } 220 221 /** 222 * Make sure we don't add any functions if we only have Evaluate(). 223 */ 224 TEST_CASE("AddEvaluateWithGradientEvaluateOnlyTest", "[FunctionTest]") 225 { 226 const bool hasEvaluate = HasEvaluate< 227 Function<EvaluateTestFunction, arma::mat, arma::mat>, 228 TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value; 229 const bool hasGradient = HasGradient< 230 Function<EvaluateTestFunction, arma::mat, arma::mat>, 231 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 232 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 233 Function<EvaluateTestFunction, arma::mat, arma::mat>, 234 TypedForms<arma::mat, arma::mat>::template 235 EvaluateWithGradientForm>::value; 236 237 REQUIRE(hasEvaluate == true); 238 REQUIRE(hasGradient == false); 239 REQUIRE(hasEvaluateWithGradient == false); 240 } 241 242 /** 243 * Make sure we don't add any functions if we only have Gradient(). 244 */ 245 TEST_CASE("AddEvaluateWithGradientGradientOnlyTest", "[FunctionTest]") 246 { 247 const bool hasEvaluate = HasEvaluate< 248 Function<GradientTestFunction, arma::mat, arma::mat>, 249 TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value; 250 const bool hasGradient = HasGradient< 251 Function<GradientTestFunction, arma::mat, arma::mat>, 252 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 253 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 254 Function<GradientTestFunction, arma::mat, arma::mat>, 255 TypedForms<arma::mat, arma::mat>::template 256 EvaluateWithGradientForm>::value; 257 258 REQUIRE(hasEvaluate == false); 259 REQUIRE(hasGradient == true); 260 REQUIRE(hasEvaluateWithGradient == false); 261 } 262 263 /** 264 * Make sure we add EvaluateWithGradient() when we have both Evaluate() and 265 * Gradient(). 266 */ 267 TEST_CASE("AddEvaluateWithGradientBothTest", "[FunctionTest]") 268 { 269 const bool hasEvaluate = HasEvaluate< 270 Function<EvaluateGradientTestFunction, arma::mat, arma::mat>, 271 TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value; 272 const bool hasGradient = HasGradient< 273 Function<EvaluateGradientTestFunction, arma::mat, arma::mat>, 274 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 275 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 276 Function<EvaluateGradientTestFunction, arma::mat, arma::mat>, 277 TypedForms<arma::mat, arma::mat>::template 278 EvaluateWithGradientForm>::value; 279 280 REQUIRE(hasEvaluate == true); 281 REQUIRE(hasGradient == true); 282 REQUIRE(hasEvaluateWithGradient == true); 283 } 284 285 /** 286 * Make sure we add Evaluate() and Gradient() when we have only 287 * EvaluateWithGradient(). 288 */ 289 TEST_CASE("AddEvaluateWithGradientEvaluateWithGradientTest", "[FunctionTest]") 290 { 291 const bool hasEvaluate = HasEvaluate< 292 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>, 293 TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value; 294 const bool hasGradient = HasGradient< 295 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>, 296 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 297 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 298 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>, 299 TypedForms<arma::mat, arma::mat>::template 300 EvaluateWithGradientForm>::value; 301 302 REQUIRE(hasEvaluate == true); 303 REQUIRE(hasGradient == true); 304 REQUIRE(hasEvaluateWithGradient == true); 305 } 306 307 /** 308 * Make sure we add no methods when we already have all three. 309 */ 310 TEST_CASE("AddEvaluateWithGradientAllThreeTest", "[FunctionTest]") 311 { 312 const bool hasEvaluate = HasEvaluate< 313 Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>, 314 TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value; 315 const bool hasGradient = HasGradient< 316 Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>, 317 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 318 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 319 Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>, 320 TypedForms<arma::mat, arma::mat>::template 321 EvaluateWithGradientForm>::value; 322 323 REQUIRE(hasEvaluate == true); 324 REQUIRE(hasGradient == true); 325 REQUIRE(hasEvaluateWithGradient == true); 326 } 327 328 TEST_CASE("LogisticRegressionEvaluateWithGradientTest", "[FunctionTest]") 329 { 330 const bool hasEvaluate = HasEvaluate< 331 Function<LogisticRegressionFunction<>, arma::mat, arma::mat>, 332 TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value; 333 const bool hasGradient = HasGradient< 334 Function<LogisticRegressionFunction<>, arma::mat, arma::mat>, 335 TypedForms<arma::mat, arma::mat>::template GradientConstForm>::value; 336 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 337 Function<LogisticRegressionFunction<>, arma::mat, arma::mat>, 338 TypedForms<arma::mat, arma::mat>::template 339 EvaluateWithGradientConstForm>::value; 340 341 REQUIRE(hasEvaluate == true); 342 REQUIRE(hasGradient == true); 343 REQUIRE(hasEvaluateWithGradient == true); 344 } 345 346 TEST_CASE("SDPTest", "[FunctionTest]") 347 { 348 typedef AugLagrangianFunction<LRSDPFunction<SDP<arma::mat>>> FunctionType; 349 350 const bool hasEvaluate = HasEvaluate< 351 Function<FunctionType, arma::mat, arma::mat>, 352 TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value; 353 const bool hasGradient = HasGradient< 354 Function<FunctionType, arma::mat, arma::mat>, 355 TypedForms<arma::mat, arma::mat>::template GradientConstForm>::value; 356 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 357 Function<FunctionType, arma::mat, arma::mat>, 358 TypedForms<arma::mat, arma::mat>::template 359 EvaluateWithGradientConstForm>::value; 360 361 REQUIRE(hasEvaluate == true); 362 REQUIRE(hasGradient == true); 363 REQUIRE(hasEvaluateWithGradient == true); 364 } 365 366 /** 367 * Make sure that an empty class doesn't have any methods added to it. 368 */ 369 TEST_CASE("AddSeparableEvaluateWithGradientEmptyTest", "[FunctionTest]") 370 { 371 const bool hasEvaluate = HasEvaluate< 372 Function<EmptyTestFunction, arma::mat, arma::mat>, 373 TypedForms<arma::mat, arma::mat>::template 374 SeparableEvaluateForm>::value; 375 const bool hasGradient = HasGradient< 376 Function<EmptyTestFunction, arma::mat, arma::mat>, 377 TypedForms<arma::mat, arma::mat>::template 378 SeparableGradientForm>::value; 379 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 380 Function<EmptyTestFunction, arma::mat, arma::mat>, 381 TypedForms<arma::mat, arma::mat>::template 382 SeparableEvaluateWithGradientForm>::value; 383 384 REQUIRE(hasEvaluate == false); 385 REQUIRE(hasGradient == false); 386 REQUIRE(hasEvaluateWithGradient == false); 387 } 388 389 /** 390 * Make sure we don't add any functions if we only have Evaluate(). 391 */ 392 TEST_CASE("AddSeparableEvaluateWithGradientEvaluateOnlyTest", 393 "[FunctionTest]") 394 { 395 const bool hasEvaluate = HasEvaluate< 396 Function<EvaluateTestFunction, arma::mat, arma::mat>, 397 TypedForms<arma::mat, arma::mat>::template 398 SeparableEvaluateForm>::value; 399 const bool hasGradient = HasGradient< 400 Function<EvaluateTestFunction, arma::mat, arma::mat>, 401 TypedForms<arma::mat, arma::mat>::template 402 SeparableGradientForm>::value; 403 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 404 Function<EvaluateTestFunction, arma::mat, arma::mat>, 405 TypedForms<arma::mat, arma::mat>::template 406 SeparableEvaluateWithGradientForm>::value; 407 408 REQUIRE(hasEvaluate == true); 409 REQUIRE(hasGradient == false); 410 REQUIRE(hasEvaluateWithGradient == false); 411 } 412 413 /** 414 * Make sure we don't add any functions if we only have Gradient(). 415 */ 416 TEST_CASE("AddSeparableEvaluateWithGradientGradientOnlyTest", 417 "[FunctionTest]") 418 { 419 const bool hasEvaluate = HasEvaluate< 420 Function<GradientTestFunction, arma::mat, arma::mat>, 421 TypedForms<arma::mat, arma::mat>::template 422 SeparableEvaluateForm>::value; 423 const bool hasGradient = HasGradient< 424 Function<GradientTestFunction, arma::mat, arma::mat>, 425 TypedForms<arma::mat, arma::mat>::template 426 SeparableGradientForm>::value; 427 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 428 Function<GradientTestFunction, arma::mat, arma::mat>, 429 TypedForms<arma::mat, arma::mat>::template 430 SeparableEvaluateWithGradientForm>::value; 431 432 REQUIRE(hasEvaluate == false); 433 REQUIRE(hasGradient == true); 434 REQUIRE(hasEvaluateWithGradient == false); 435 } 436 437 /** 438 * Make sure we add EvaluateWithGradient() when we have both Evaluate() and 439 * Gradient(). 440 */ 441 TEST_CASE("AddSeparableEvaluateWithGradientBothTest", "[FunctionTest]") 442 { 443 const bool hasEvaluate = HasEvaluate< 444 Function<EvaluateGradientTestFunction, arma::mat, arma::mat>, 445 TypedForms<arma::mat, arma::mat>::template 446 SeparableEvaluateForm>::value; 447 const bool hasGradient = HasGradient< 448 Function<EvaluateGradientTestFunction, arma::mat, arma::mat>, 449 TypedForms<arma::mat, arma::mat>::template 450 SeparableGradientForm>::value; 451 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 452 Function<EvaluateGradientTestFunction, arma::mat, arma::mat>, 453 TypedForms<arma::mat, arma::mat>::template 454 SeparableEvaluateWithGradientForm>::value; 455 456 REQUIRE(hasEvaluate == true); 457 REQUIRE(hasGradient == true); 458 REQUIRE(hasEvaluateWithGradient == true); 459 } 460 461 /** 462 * Make sure we add Evaluate() and Gradient() when we have only 463 * EvaluateWithGradient(). 464 */ 465 TEST_CASE("AddSeparableEvaluateWGradientEvaluateWithGradientTest", 466 "[FunctionTest]") 467 { 468 const bool hasEvaluate = HasEvaluate< 469 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>, 470 TypedForms<arma::mat, arma::mat>::template 471 SeparableEvaluateForm>::value; 472 const bool hasGradient = HasGradient< 473 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>, 474 TypedForms<arma::mat, arma::mat>::template 475 SeparableGradientForm>::value; 476 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 477 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>, 478 TypedForms<arma::mat, arma::mat>::template 479 SeparableEvaluateWithGradientForm>::value; 480 481 Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat> f; 482 arma::mat coordinates(10, 10, arma::fill::ones); 483 arma::mat gradient; 484 f.Gradient(coordinates, 0, gradient, 5); 485 486 REQUIRE(hasEvaluate == true); 487 REQUIRE(hasGradient == true); 488 REQUIRE(hasEvaluateWithGradient == true); 489 } 490 491 /** 492 * Make sure we add no methods when we already have all three. 493 */ 494 TEST_CASE("AddSeparableEvaluateWithGradientAllThreeTest", "[FunctionTest]") 495 { 496 const bool hasEvaluate = HasEvaluate< 497 Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>, 498 TypedForms<arma::mat, arma::mat>::template 499 SeparableEvaluateForm>::value; 500 const bool hasGradient = HasGradient< 501 Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>, 502 TypedForms<arma::mat, arma::mat>::template 503 SeparableGradientForm>::value; 504 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 505 Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>, 506 TypedForms<arma::mat, arma::mat>::template 507 SeparableEvaluateWithGradientForm>::value; 508 509 REQUIRE(hasEvaluate == true); 510 REQUIRE(hasGradient == true); 511 REQUIRE(hasEvaluateWithGradient == true); 512 } 513 514 /** 515 * Make sure we can properly create EvaluateWithGradient() even when one of the 516 * functions is non-const. 517 */ 518 TEST_CASE("AddEvaluateWithGradientMixedTypesTest", "[FunctionTest]") 519 { 520 const bool hasEvaluate = HasEvaluate< 521 Function<EvaluateAndNonConstGradientTestFunction, arma::mat, arma::mat>, 522 TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value; 523 const bool hasGradient = HasGradient< 524 Function<EvaluateAndNonConstGradientTestFunction, arma::mat, arma::mat>, 525 TypedForms<arma::mat, arma::mat>::template GradientForm>::value; 526 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 527 Function<EvaluateAndNonConstGradientTestFunction, arma::mat, arma::mat>, 528 TypedForms<arma::mat, arma::mat>::template 529 EvaluateWithGradientForm>::value; 530 531 REQUIRE(hasEvaluate == true); 532 REQUIRE(hasGradient == true); 533 REQUIRE(hasEvaluateWithGradient == true); 534 } 535 536 /** 537 * Make sure we can properly create EvaluateWithGradient() even when one of the 538 * functions is static. 539 */ 540 TEST_CASE("AddEvaluateWithGradientMixedTypesStaticTest", "[FunctionTest]") 541 { 542 const bool hasEvaluate = HasEvaluate< 543 Function<EvaluateAndStaticGradientTestFunction, arma::mat, arma::mat>, 544 TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value; 545 const bool hasGradient = HasGradient< 546 Function<EvaluateAndStaticGradientTestFunction, arma::mat, arma::mat>, 547 TypedForms<arma::mat, arma::mat>::template GradientStaticForm>::value; 548 const bool hasEvaluateWithGradient = HasEvaluateWithGradient< 549 Function<EvaluateAndStaticGradientTestFunction, arma::mat, arma::mat>, 550 TypedForms<arma::mat, arma::mat>::template 551 EvaluateWithGradientConstForm>::value; 552 553 REQUIRE(hasEvaluate == true); 554 REQUIRE(hasGradient == true); 555 REQUIRE(hasEvaluateWithGradient == true); 556 } 557 558 class A 559 { 560 public: 561 size_t NumFunctions() const; 562 size_t NumFeatures() const; 563 double Evaluate(const arma::mat&, const size_t, const size_t) const; 564 void Gradient(const arma::mat&, const size_t, arma::mat&, const size_t) const; 565 void Gradient(const arma::mat&, const size_t, arma::sp_mat&, const size_t) 566 const; 567 void PartialGradient(const arma::mat&, const size_t, arma::sp_mat&) const; 568 }; 569 570 class B 571 { 572 public: 573 size_t NumFunctions(); 574 size_t NumFeatures(); 575 double Evaluate(const arma::mat&, const size_t, const size_t); 576 void Gradient(const arma::mat&, const size_t, arma::mat&, const size_t); 577 void Gradient(const arma::mat&, const size_t, arma::sp_mat&, const size_t); 578 void PartialGradient(const arma::mat&, const size_t, arma::sp_mat&); 579 }; 580 581 class C 582 { 583 public: 584 size_t NumConstraints() const; 585 double Evaluate(const arma::mat&) const; 586 void Gradient(const arma::mat&, arma::mat&) const; 587 double EvaluateConstraint(const size_t, const arma::mat&) const; 588 void GradientConstraint(const size_t, const arma::mat&, arma::mat&) const; 589 }; 590 591 class D 592 { 593 public: 594 size_t NumConstraints(); 595 double Evaluate(const arma::mat&); 596 void Gradient(const arma::mat&, arma::mat&); 597 double EvaluateConstraint(const size_t, const arma::mat&); 598 void GradientConstraint(const size_t, const arma::mat&, arma::mat&); 599 }; 600 601 602 /** 603 * Test the correctness of the static check for SeparableFunctionType API. 604 */ 605 TEST_CASE("SeparableFunctionTypeCheckTest", "[FunctionTest]") 606 { 607 static_assert(CheckNumFunctions<A, arma::mat, arma::mat>::value, 608 "CheckNumFunctions static check failed."); 609 static_assert(CheckNumFunctions<B, arma::mat, arma::mat>::value, 610 "CheckNumFunctions static check failed."); 611 static_assert(!CheckNumFunctions<C, arma::mat, arma::mat>::value, 612 "CheckNumFunctions static check failed."); 613 static_assert(!CheckNumFunctions<D, arma::mat, arma::mat>::value, 614 "CheckNumFunctions static check failed."); 615 616 static_assert(CheckSeparableEvaluate<A, arma::mat, arma::mat>::value, 617 "CheckSeparableEvaluate static check failed."); 618 static_assert(CheckSeparableEvaluate<B, arma::mat, arma::mat>::value, 619 "CheckSeparableEvaluate static check failed."); 620 static_assert(!CheckSeparableEvaluate<C, arma::mat, arma::mat>::value, 621 "CheckSeparableEvaluate static check failed."); 622 static_assert(!CheckSeparableEvaluate<D, arma::mat, arma::mat>::value, 623 "CheckSeparableEvaluate static check failed."); 624 625 static_assert(CheckSeparableGradient<A, arma::mat, arma::mat>::value, 626 "CheckSeparableGradient static check failed."); 627 static_assert(CheckSeparableGradient<B, arma::mat, arma::mat>::value, 628 "CheckSeparableGradient static check failed."); 629 static_assert(!CheckSeparableGradient<C, arma::mat, arma::mat>::value, 630 "CheckSeparableGradient static check failed."); 631 static_assert(!CheckSeparableGradient<D, arma::mat, arma::mat>::value, 632 "CheckSeparableGradient static check failed."); 633 } 634 635 /** 636 * Test the correctness of the static check for LagrangianFunctionType API. 637 */ 638 TEST_CASE("LagrangianFunctionTypeCheckTest", "[FunctionTest]") 639 { 640 static_assert(!CheckEvaluate<A, arma::mat, arma::mat>::value, 641 "CheckEvaluate static check failed."); 642 static_assert(!CheckEvaluate<B, arma::mat, arma::mat>::value, 643 "CheckEvaluate static check failed."); 644 static_assert(CheckEvaluate<C, arma::mat, arma::mat>::value, 645 "CheckEvaluate static check failed."); 646 static_assert(CheckEvaluate<D, arma::mat, arma::mat>::value, 647 "CheckEvaluate static check failed."); 648 649 static_assert(!CheckGradient<A, arma::mat, arma::mat>::value, 650 "CheckGradient static check failed."); 651 static_assert(!CheckGradient<B, arma::mat, arma::mat>::value, 652 "CheckGradient static check failed."); 653 static_assert(CheckGradient<C, arma::mat, arma::mat>::value, 654 "CheckGradient static check failed."); 655 static_assert(CheckGradient<D, arma::mat, arma::mat>::value, 656 "CheckGradient static check failed."); 657 658 static_assert(!CheckNumConstraints<A, arma::mat, arma::mat>::value, 659 "CheckNumConstraints static check failed."); 660 static_assert(!CheckNumConstraints<B, arma::mat, arma::mat>::value, 661 "CheckNumConstraints static check failed."); 662 static_assert(CheckNumConstraints<C, arma::mat, arma::mat>::value, 663 "CheckNumConstraints static check failed."); 664 static_assert(CheckNumConstraints<D, arma::mat, arma::mat>::value, 665 "CheckNumConstraints static check failed."); 666 667 static_assert(!CheckEvaluateConstraint<A, arma::mat, arma::mat>::value, 668 "CheckEvaluateConstraint static check failed."); 669 static_assert(!CheckEvaluateConstraint<B, arma::mat, arma::mat>::value, 670 "CheckEvaluateConstraint static check failed."); 671 static_assert(CheckEvaluateConstraint<C, arma::mat, arma::mat>::value, 672 "CheckEvaluateConstraint static check failed."); 673 static_assert(CheckEvaluateConstraint<D, arma::mat, arma::mat>::value, 674 "CheckEvaluateConstraint static check failed."); 675 676 static_assert(!CheckGradientConstraint<A, arma::mat, arma::mat>::value, 677 "CheckGradientConstraint static check failed."); 678 static_assert(!CheckGradientConstraint<B, arma::mat, arma::mat>::value, 679 "CheckGradientConstraint static check failed."); 680 static_assert(CheckGradientConstraint<C, arma::mat, arma::mat>::value, 681 "CheckGradientConstraint static check failed."); 682 static_assert(CheckGradientConstraint<D, arma::mat, arma::mat>::value, 683 "CheckGradientConstraint static check failed."); 684 } 685 686 /** 687 * Test the correctness of the static check for SparseFunctionType API. 688 */ 689 TEST_CASE("SparseFunctionTypeCheckTest", "[FunctionTest]") 690 { 691 static_assert(CheckSparseGradient<A, arma::mat, arma::mat>::value, 692 "CheckSparseGradient static check failed."); 693 static_assert(CheckSparseGradient<B, arma::mat, arma::mat>::value, 694 "CheckSparseGradient static check failed."); 695 static_assert(!CheckSparseGradient<C, arma::mat, arma::mat>::value, 696 "CheckSparseGradient static check failed."); 697 static_assert(!CheckSparseGradient<D, arma::mat, arma::mat>::value, 698 "CheckSparseGradient static check failed."); 699 } 700 701 /** 702 * Test the correctness of the static check for SparseFunctionType API. 703 */ 704 TEST_CASE("ResolvableFunctionTypeCheckTest", "[FunctionTest]") 705 { 706 static_assert(CheckNumFeatures<A, arma::mat, arma::mat>::value, 707 "CheckNumFeatures static check failed."); 708 static_assert(CheckNumFeatures<B, arma::mat, arma::mat>::value, 709 "CheckNumFeatures static check failed."); 710 static_assert(!CheckNumFeatures<C, arma::mat, arma::mat>::value, 711 "CheckNumFeatures static check failed."); 712 static_assert(!CheckNumFeatures<D, arma::mat, arma::mat>::value, 713 "CheckNumFeatures static check failed."); 714 715 static_assert(CheckPartialGradient<A, arma::mat, arma::sp_mat>::value, 716 "CheckPartialGradient static check failed."); 717 static_assert(CheckPartialGradient<B, arma::mat, arma::sp_mat>::value, 718 "CheckPartialGradient static check failed."); 719 static_assert(!CheckPartialGradient<C, arma::mat, arma::sp_mat>::value, 720 "CheckPartialGradient static check failed."); 721 static_assert(!CheckPartialGradient<D, arma::mat, arma::sp_mat>::value, 722 "CheckPartialGradient static check failed."); 723 } 724