1 /**
2  * @file function_test.cpp
3  * @author Ryan Curtin
4  * @author Shikhar Bhardwaj
5  * @author Marcus Edel
6  * @author Conrad Sanderson
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 
14 #include <ensmallen.hpp>
15 #include "catch.hpp"
16 
17 using namespace ens;
18 using namespace ens::test;
19 using namespace ens::traits;
20 
21 /**
22  * Utility class with no functions.
23  */
24 class EmptyTestFunction { };
25 
26 /**
27  * Utility class with Evaluate() but no Evaluate().
28  */
29 class EvaluateTestFunction
30 {
31  public:
Evaluate(const arma::mat & coordinates)32   double Evaluate(const arma::mat& coordinates)
33   {
34     return arma::accu(coordinates);
35   }
36 
Evaluate(const arma::mat & coordinates,const size_t begin,const size_t batchSize)37   double Evaluate(const arma::mat& coordinates,
38                   const size_t begin,
39                   const size_t batchSize)
40   {
41     return arma::accu(coordinates) + begin + batchSize;
42   }
43 };
44 
45 /**
46  * Utility class with Gradient() but no Evaluate().
47  */
48 class GradientTestFunction
49 {
50  public:
Gradient(const arma::mat & coordinates,arma::mat & gradient)51   void Gradient(const arma::mat& coordinates, arma::mat& gradient)
52   {
53     gradient.ones(coordinates.n_rows, coordinates.n_cols);
54   }
55 
Gradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)56   void Gradient(const arma::mat& coordinates,
57                 const size_t /* begin */,
58                 arma::mat& gradient,
59                 const size_t /* batchSize */)
60   {
61     gradient.ones(coordinates.n_rows, coordinates.n_cols);
62   }
63 };
64 
65 /**
66  * Utility class with Gradient() and Evaluate().
67  */
68 class EvaluateGradientTestFunction
69 {
70  public:
Evaluate(const arma::mat & coordinates)71   double Evaluate(const arma::mat& coordinates)
72   {
73     return arma::accu(coordinates);
74   }
75 
Evaluate(const arma::mat & coordinates,const size_t,const size_t)76   double Evaluate(const arma::mat& coordinates,
77                   const size_t /* begin */,
78                   const size_t /* batchSize */)
79   {
80     return arma::accu(coordinates);
81   }
82 
Gradient(const arma::mat & coordinates,arma::mat & gradient)83   void Gradient(const arma::mat& coordinates, arma::mat& gradient)
84   {
85     gradient.ones(coordinates.n_rows, coordinates.n_cols);
86   }
87 
Gradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)88   void Gradient(const arma::mat& coordinates,
89                 const size_t /* begin */,
90                 arma::mat& gradient,
91                 const size_t /* batchSize */)
92   {
93     gradient.ones(coordinates.n_rows, coordinates.n_cols);
94   }
95 };
96 
97 /**
98  * Utility class with EvaluateWithGradient().
99  */
100 class EvaluateWithGradientTestFunction
101 {
102  public:
EvaluateWithGradient(const arma::mat & coordinates,arma::mat & gradient)103   double EvaluateWithGradient(const arma::mat& coordinates, arma::mat& gradient)
104   {
105     gradient.ones(coordinates.n_rows, coordinates.n_cols);
106     return arma::accu(coordinates);
107   }
108 
EvaluateWithGradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)109   double EvaluateWithGradient(const arma::mat& coordinates,
110                               const size_t /* begin */,
111                               arma::mat& gradient,
112                               const size_t /* batchSize */)
113   {
114     gradient.ones(coordinates.n_rows, coordinates.n_cols);
115     return arma::accu(coordinates);
116   }
117 };
118 
119 /**
120  * Utility class with all three functions.
121  */
122 class EvaluateAndWithGradientTestFunction
123 {
124  public:
Evaluate(const arma::mat & coordinates)125   double Evaluate(const arma::mat& coordinates)
126   {
127     return arma::accu(coordinates);
128   }
129 
Evaluate(const arma::mat & coordinates,const size_t begin,const size_t batchSize)130   double Evaluate(const arma::mat& coordinates,
131                   const size_t begin,
132                   const size_t batchSize)
133   {
134     return arma::accu(coordinates) + batchSize + begin;
135   }
136 
Gradient(const arma::mat & coordinates,arma::mat & gradient)137   void Gradient(const arma::mat& coordinates, arma::mat& gradient)
138   {
139     gradient.ones(coordinates.n_rows, coordinates.n_cols);
140   }
141 
Gradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)142   void Gradient(const arma::mat& coordinates,
143                 const size_t /* begin */,
144                 arma::mat& gradient,
145                 const size_t /* batchSize */)
146   {
147     gradient.ones(coordinates.n_rows, coordinates.n_cols);
148   }
149 
EvaluateWithGradient(const arma::mat & coordinates,arma::mat & gradient)150   double EvaluateWithGradient(const arma::mat& coordinates, arma::mat& gradient)
151   {
152     gradient.ones(coordinates.n_rows, coordinates.n_cols);
153     return arma::accu(coordinates);
154   }
155 
EvaluateWithGradient(const arma::mat & coordinates,const size_t,arma::mat & gradient,const size_t)156   double EvaluateWithGradient(const arma::mat& coordinates,
157                               const size_t /* begin */,
158                               arma::mat& gradient,
159                               const size_t /* batchSize */)
160   {
161     gradient.ones(coordinates.n_rows, coordinates.n_cols);
162     return arma::accu(coordinates);
163   }
164 };
165 
166 /**
167  * Utility class with const Evaluate() and non-const Gradient().
168  */
169 class EvaluateAndNonConstGradientTestFunction
170 {
171  public:
Evaluate(const arma::mat & coordinates) const172   double Evaluate(const arma::mat& coordinates) const
173   {
174     return arma::accu(coordinates);
175   }
176 
Gradient(const arma::mat & coordinates,arma::mat & gradient)177   void Gradient(const arma::mat& coordinates, arma::mat& gradient)
178   {
179     gradient.ones(coordinates.n_rows, coordinates.n_cols);
180   }
181 };
182 
183 /**
184  * Utility class with const Evaluate() and non-const Gradient().
185  */
186 class EvaluateAndStaticGradientTestFunction
187 {
188  public:
Evaluate(const arma::mat & coordinates) const189   double Evaluate(const arma::mat& coordinates) const
190   {
191     return arma::accu(coordinates);
192   }
193 
Gradient(const arma::mat & coordinates,arma::mat & gradient)194   static void Gradient(const arma::mat& coordinates, arma::mat& gradient)
195   {
196     gradient.ones(coordinates.n_rows, coordinates.n_cols);
197   }
198 };
199 
200 /**
201  * Make sure that an empty class doesn't have any methods added to it.
202  */
203 TEST_CASE("AddEvaluateWithGradientEmptyTest", "[FunctionTest]")
204 {
205   const bool hasEvaluate = HasEvaluate<
206       Function<EmptyTestFunction, arma::mat, arma::mat>,
207       TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value;
208   const bool hasGradient = HasGradient<
209       Function<EmptyTestFunction, arma::mat, arma::mat>,
210       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
211   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
212       Function<EmptyTestFunction, arma::mat, arma::mat>,
213       TypedForms<arma::mat, arma::mat>::template
214           EvaluateWithGradientForm>::value;
215 
216   REQUIRE(hasEvaluate == false);
217   REQUIRE(hasGradient == false);
218   REQUIRE(hasEvaluateWithGradient == false);
219 }
220 
221 /**
222  * Make sure we don't add any functions if we only have Evaluate().
223  */
224 TEST_CASE("AddEvaluateWithGradientEvaluateOnlyTest", "[FunctionTest]")
225 {
226   const bool hasEvaluate = HasEvaluate<
227       Function<EvaluateTestFunction, arma::mat, arma::mat>,
228       TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value;
229   const bool hasGradient = HasGradient<
230       Function<EvaluateTestFunction, arma::mat, arma::mat>,
231       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
232   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
233       Function<EvaluateTestFunction, arma::mat, arma::mat>,
234       TypedForms<arma::mat, arma::mat>::template
235           EvaluateWithGradientForm>::value;
236 
237   REQUIRE(hasEvaluate == true);
238   REQUIRE(hasGradient == false);
239   REQUIRE(hasEvaluateWithGradient == false);
240 }
241 
242 /**
243  * Make sure we don't add any functions if we only have Gradient().
244  */
245 TEST_CASE("AddEvaluateWithGradientGradientOnlyTest", "[FunctionTest]")
246 {
247   const bool hasEvaluate = HasEvaluate<
248       Function<GradientTestFunction, arma::mat, arma::mat>,
249       TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value;
250   const bool hasGradient = HasGradient<
251       Function<GradientTestFunction, arma::mat, arma::mat>,
252       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
253   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
254       Function<GradientTestFunction, arma::mat, arma::mat>,
255       TypedForms<arma::mat, arma::mat>::template
256           EvaluateWithGradientForm>::value;
257 
258   REQUIRE(hasEvaluate == false);
259   REQUIRE(hasGradient == true);
260   REQUIRE(hasEvaluateWithGradient == false);
261 }
262 
263 /**
264  * Make sure we add EvaluateWithGradient() when we have both Evaluate() and
265  * Gradient().
266  */
267 TEST_CASE("AddEvaluateWithGradientBothTest", "[FunctionTest]")
268 {
269   const bool hasEvaluate = HasEvaluate<
270       Function<EvaluateGradientTestFunction, arma::mat, arma::mat>,
271       TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value;
272   const bool hasGradient = HasGradient<
273       Function<EvaluateGradientTestFunction, arma::mat, arma::mat>,
274       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
275   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
276       Function<EvaluateGradientTestFunction, arma::mat, arma::mat>,
277       TypedForms<arma::mat, arma::mat>::template
278           EvaluateWithGradientForm>::value;
279 
280   REQUIRE(hasEvaluate == true);
281   REQUIRE(hasGradient == true);
282   REQUIRE(hasEvaluateWithGradient == true);
283 }
284 
285 /**
286  * Make sure we add Evaluate() and Gradient() when we have only
287  * EvaluateWithGradient().
288  */
289 TEST_CASE("AddEvaluateWithGradientEvaluateWithGradientTest", "[FunctionTest]")
290 {
291   const bool hasEvaluate = HasEvaluate<
292       Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>,
293       TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value;
294   const bool hasGradient = HasGradient<
295       Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>,
296       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
297   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
298       Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>,
299       TypedForms<arma::mat, arma::mat>::template
300           EvaluateWithGradientForm>::value;
301 
302   REQUIRE(hasEvaluate == true);
303   REQUIRE(hasGradient == true);
304   REQUIRE(hasEvaluateWithGradient == true);
305 }
306 
307 /**
308  * Make sure we add no methods when we already have all three.
309  */
310 TEST_CASE("AddEvaluateWithGradientAllThreeTest", "[FunctionTest]")
311 {
312   const bool hasEvaluate = HasEvaluate<
313       Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>,
314       TypedForms<arma::mat, arma::mat>::template EvaluateForm>::value;
315   const bool hasGradient = HasGradient<
316       Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>,
317       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
318   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
319       Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>,
320       TypedForms<arma::mat, arma::mat>::template
321           EvaluateWithGradientForm>::value;
322 
323   REQUIRE(hasEvaluate == true);
324   REQUIRE(hasGradient == true);
325   REQUIRE(hasEvaluateWithGradient == true);
326 }
327 
328 TEST_CASE("LogisticRegressionEvaluateWithGradientTest", "[FunctionTest]")
329 {
330   const bool hasEvaluate = HasEvaluate<
331       Function<LogisticRegressionFunction<>, arma::mat, arma::mat>,
332       TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value;
333   const bool hasGradient = HasGradient<
334       Function<LogisticRegressionFunction<>, arma::mat, arma::mat>,
335       TypedForms<arma::mat, arma::mat>::template GradientConstForm>::value;
336   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
337       Function<LogisticRegressionFunction<>, arma::mat, arma::mat>,
338       TypedForms<arma::mat, arma::mat>::template
339           EvaluateWithGradientConstForm>::value;
340 
341   REQUIRE(hasEvaluate == true);
342   REQUIRE(hasGradient == true);
343   REQUIRE(hasEvaluateWithGradient == true);
344 }
345 
346 TEST_CASE("SDPTest", "[FunctionTest]")
347 {
348   typedef AugLagrangianFunction<LRSDPFunction<SDP<arma::mat>>> FunctionType;
349 
350   const bool hasEvaluate = HasEvaluate<
351       Function<FunctionType, arma::mat, arma::mat>,
352       TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value;
353   const bool hasGradient = HasGradient<
354       Function<FunctionType, arma::mat, arma::mat>,
355       TypedForms<arma::mat, arma::mat>::template GradientConstForm>::value;
356   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
357       Function<FunctionType, arma::mat, arma::mat>,
358       TypedForms<arma::mat, arma::mat>::template
359           EvaluateWithGradientConstForm>::value;
360 
361   REQUIRE(hasEvaluate == true);
362   REQUIRE(hasGradient == true);
363   REQUIRE(hasEvaluateWithGradient == true);
364 }
365 
366 /**
367  * Make sure that an empty class doesn't have any methods added to it.
368  */
369 TEST_CASE("AddSeparableEvaluateWithGradientEmptyTest", "[FunctionTest]")
370 {
371   const bool hasEvaluate = HasEvaluate<
372       Function<EmptyTestFunction, arma::mat, arma::mat>,
373       TypedForms<arma::mat, arma::mat>::template
374           SeparableEvaluateForm>::value;
375   const bool hasGradient = HasGradient<
376       Function<EmptyTestFunction, arma::mat, arma::mat>,
377       TypedForms<arma::mat, arma::mat>::template
378           SeparableGradientForm>::value;
379   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
380       Function<EmptyTestFunction, arma::mat, arma::mat>,
381       TypedForms<arma::mat, arma::mat>::template
382           SeparableEvaluateWithGradientForm>::value;
383 
384   REQUIRE(hasEvaluate == false);
385   REQUIRE(hasGradient == false);
386   REQUIRE(hasEvaluateWithGradient == false);
387 }
388 
389 /**
390  * Make sure we don't add any functions if we only have Evaluate().
391  */
392 TEST_CASE("AddSeparableEvaluateWithGradientEvaluateOnlyTest",
393           "[FunctionTest]")
394 {
395   const bool hasEvaluate = HasEvaluate<
396       Function<EvaluateTestFunction, arma::mat, arma::mat>,
397       TypedForms<arma::mat, arma::mat>::template
398           SeparableEvaluateForm>::value;
399   const bool hasGradient = HasGradient<
400       Function<EvaluateTestFunction, arma::mat, arma::mat>,
401       TypedForms<arma::mat, arma::mat>::template
402           SeparableGradientForm>::value;
403   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
404       Function<EvaluateTestFunction, arma::mat, arma::mat>,
405       TypedForms<arma::mat, arma::mat>::template
406           SeparableEvaluateWithGradientForm>::value;
407 
408   REQUIRE(hasEvaluate == true);
409   REQUIRE(hasGradient == false);
410   REQUIRE(hasEvaluateWithGradient == false);
411 }
412 
413 /**
414  * Make sure we don't add any functions if we only have Gradient().
415  */
416 TEST_CASE("AddSeparableEvaluateWithGradientGradientOnlyTest",
417           "[FunctionTest]")
418 {
419   const bool hasEvaluate = HasEvaluate<
420       Function<GradientTestFunction, arma::mat, arma::mat>,
421       TypedForms<arma::mat, arma::mat>::template
422           SeparableEvaluateForm>::value;
423   const bool hasGradient = HasGradient<
424       Function<GradientTestFunction, arma::mat, arma::mat>,
425       TypedForms<arma::mat, arma::mat>::template
426           SeparableGradientForm>::value;
427   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
428       Function<GradientTestFunction, arma::mat, arma::mat>,
429       TypedForms<arma::mat, arma::mat>::template
430           SeparableEvaluateWithGradientForm>::value;
431 
432   REQUIRE(hasEvaluate == false);
433   REQUIRE(hasGradient == true);
434   REQUIRE(hasEvaluateWithGradient == false);
435 }
436 
437 /**
438  * Make sure we add EvaluateWithGradient() when we have both Evaluate() and
439  * Gradient().
440  */
441 TEST_CASE("AddSeparableEvaluateWithGradientBothTest", "[FunctionTest]")
442 {
443   const bool hasEvaluate = HasEvaluate<
444       Function<EvaluateGradientTestFunction, arma::mat, arma::mat>,
445       TypedForms<arma::mat, arma::mat>::template
446           SeparableEvaluateForm>::value;
447   const bool hasGradient = HasGradient<
448       Function<EvaluateGradientTestFunction, arma::mat, arma::mat>,
449       TypedForms<arma::mat, arma::mat>::template
450           SeparableGradientForm>::value;
451   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
452       Function<EvaluateGradientTestFunction, arma::mat, arma::mat>,
453       TypedForms<arma::mat, arma::mat>::template
454           SeparableEvaluateWithGradientForm>::value;
455 
456   REQUIRE(hasEvaluate == true);
457   REQUIRE(hasGradient == true);
458   REQUIRE(hasEvaluateWithGradient == true);
459 }
460 
461 /**
462  * Make sure we add Evaluate() and Gradient() when we have only
463  * EvaluateWithGradient().
464  */
465 TEST_CASE("AddSeparableEvaluateWGradientEvaluateWithGradientTest",
466           "[FunctionTest]")
467 {
468   const bool hasEvaluate = HasEvaluate<
469       Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>,
470       TypedForms<arma::mat, arma::mat>::template
471           SeparableEvaluateForm>::value;
472   const bool hasGradient = HasGradient<
473       Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>,
474       TypedForms<arma::mat, arma::mat>::template
475           SeparableGradientForm>::value;
476   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
477       Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat>,
478       TypedForms<arma::mat, arma::mat>::template
479           SeparableEvaluateWithGradientForm>::value;
480 
481   Function<EvaluateWithGradientTestFunction, arma::mat, arma::mat> f;
482   arma::mat coordinates(10, 10, arma::fill::ones);
483   arma::mat gradient;
484   f.Gradient(coordinates, 0, gradient, 5);
485 
486   REQUIRE(hasEvaluate == true);
487   REQUIRE(hasGradient == true);
488   REQUIRE(hasEvaluateWithGradient == true);
489 }
490 
491 /**
492  * Make sure we add no methods when we already have all three.
493  */
494 TEST_CASE("AddSeparableEvaluateWithGradientAllThreeTest", "[FunctionTest]")
495 {
496   const bool hasEvaluate = HasEvaluate<
497       Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>,
498       TypedForms<arma::mat, arma::mat>::template
499           SeparableEvaluateForm>::value;
500   const bool hasGradient = HasGradient<
501       Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>,
502       TypedForms<arma::mat, arma::mat>::template
503           SeparableGradientForm>::value;
504   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
505       Function<EvaluateAndWithGradientTestFunction, arma::mat, arma::mat>,
506       TypedForms<arma::mat, arma::mat>::template
507           SeparableEvaluateWithGradientForm>::value;
508 
509   REQUIRE(hasEvaluate == true);
510   REQUIRE(hasGradient == true);
511   REQUIRE(hasEvaluateWithGradient == true);
512 }
513 
514 /**
515  * Make sure we can properly create EvaluateWithGradient() even when one of the
516  * functions is non-const.
517  */
518 TEST_CASE("AddEvaluateWithGradientMixedTypesTest", "[FunctionTest]")
519 {
520   const bool hasEvaluate = HasEvaluate<
521       Function<EvaluateAndNonConstGradientTestFunction, arma::mat, arma::mat>,
522       TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value;
523   const bool hasGradient = HasGradient<
524       Function<EvaluateAndNonConstGradientTestFunction, arma::mat, arma::mat>,
525       TypedForms<arma::mat, arma::mat>::template GradientForm>::value;
526   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
527       Function<EvaluateAndNonConstGradientTestFunction, arma::mat, arma::mat>,
528       TypedForms<arma::mat, arma::mat>::template
529           EvaluateWithGradientForm>::value;
530 
531   REQUIRE(hasEvaluate == true);
532   REQUIRE(hasGradient == true);
533   REQUIRE(hasEvaluateWithGradient == true);
534 }
535 
536 /**
537  * Make sure we can properly create EvaluateWithGradient() even when one of the
538  * functions is static.
539  */
540 TEST_CASE("AddEvaluateWithGradientMixedTypesStaticTest", "[FunctionTest]")
541 {
542   const bool hasEvaluate = HasEvaluate<
543       Function<EvaluateAndStaticGradientTestFunction, arma::mat, arma::mat>,
544       TypedForms<arma::mat, arma::mat>::template EvaluateConstForm>::value;
545   const bool hasGradient = HasGradient<
546       Function<EvaluateAndStaticGradientTestFunction, arma::mat, arma::mat>,
547       TypedForms<arma::mat, arma::mat>::template GradientStaticForm>::value;
548   const bool hasEvaluateWithGradient = HasEvaluateWithGradient<
549       Function<EvaluateAndStaticGradientTestFunction, arma::mat, arma::mat>,
550       TypedForms<arma::mat, arma::mat>::template
551           EvaluateWithGradientConstForm>::value;
552 
553   REQUIRE(hasEvaluate == true);
554   REQUIRE(hasGradient == true);
555   REQUIRE(hasEvaluateWithGradient == true);
556 }
557 
558 class A
559 {
560  public:
561   size_t NumFunctions() const;
562   size_t NumFeatures() const;
563   double Evaluate(const arma::mat&, const size_t, const size_t) const;
564   void Gradient(const arma::mat&, const size_t, arma::mat&, const size_t) const;
565   void Gradient(const arma::mat&, const size_t, arma::sp_mat&, const size_t)
566       const;
567   void PartialGradient(const arma::mat&, const size_t, arma::sp_mat&) const;
568 };
569 
570 class B
571 {
572  public:
573   size_t NumFunctions();
574   size_t NumFeatures();
575   double Evaluate(const arma::mat&, const size_t, const size_t);
576   void Gradient(const arma::mat&, const size_t, arma::mat&, const size_t);
577   void Gradient(const arma::mat&, const size_t, arma::sp_mat&, const size_t);
578   void PartialGradient(const arma::mat&, const size_t, arma::sp_mat&);
579 };
580 
581 class C
582 {
583  public:
584   size_t NumConstraints() const;
585   double Evaluate(const arma::mat&) const;
586   void Gradient(const arma::mat&, arma::mat&) const;
587   double EvaluateConstraint(const size_t, const arma::mat&) const;
588   void GradientConstraint(const size_t, const arma::mat&, arma::mat&) const;
589 };
590 
591 class D
592 {
593  public:
594   size_t NumConstraints();
595   double Evaluate(const arma::mat&);
596   void Gradient(const arma::mat&, arma::mat&);
597   double EvaluateConstraint(const size_t, const arma::mat&);
598   void GradientConstraint(const size_t, const arma::mat&, arma::mat&);
599 };
600 
601 
602 /**
603  * Test the correctness of the static check for SeparableFunctionType API.
604  */
605 TEST_CASE("SeparableFunctionTypeCheckTest", "[FunctionTest]")
606 {
607   static_assert(CheckNumFunctions<A, arma::mat, arma::mat>::value,
608       "CheckNumFunctions static check failed.");
609   static_assert(CheckNumFunctions<B, arma::mat, arma::mat>::value,
610       "CheckNumFunctions static check failed.");
611   static_assert(!CheckNumFunctions<C, arma::mat, arma::mat>::value,
612       "CheckNumFunctions static check failed.");
613   static_assert(!CheckNumFunctions<D, arma::mat, arma::mat>::value,
614       "CheckNumFunctions static check failed.");
615 
616   static_assert(CheckSeparableEvaluate<A, arma::mat, arma::mat>::value,
617       "CheckSeparableEvaluate static check failed.");
618   static_assert(CheckSeparableEvaluate<B, arma::mat, arma::mat>::value,
619       "CheckSeparableEvaluate static check failed.");
620   static_assert(!CheckSeparableEvaluate<C, arma::mat, arma::mat>::value,
621       "CheckSeparableEvaluate static check failed.");
622   static_assert(!CheckSeparableEvaluate<D, arma::mat, arma::mat>::value,
623       "CheckSeparableEvaluate static check failed.");
624 
625   static_assert(CheckSeparableGradient<A, arma::mat, arma::mat>::value,
626       "CheckSeparableGradient static check failed.");
627   static_assert(CheckSeparableGradient<B, arma::mat, arma::mat>::value,
628       "CheckSeparableGradient static check failed.");
629   static_assert(!CheckSeparableGradient<C, arma::mat, arma::mat>::value,
630       "CheckSeparableGradient static check failed.");
631   static_assert(!CheckSeparableGradient<D, arma::mat, arma::mat>::value,
632       "CheckSeparableGradient static check failed.");
633 }
634 
635 /**
636  * Test the correctness of the static check for LagrangianFunctionType API.
637  */
638 TEST_CASE("LagrangianFunctionTypeCheckTest", "[FunctionTest]")
639 {
640   static_assert(!CheckEvaluate<A, arma::mat, arma::mat>::value,
641       "CheckEvaluate static check failed.");
642   static_assert(!CheckEvaluate<B, arma::mat, arma::mat>::value,
643       "CheckEvaluate static check failed.");
644   static_assert(CheckEvaluate<C, arma::mat, arma::mat>::value,
645       "CheckEvaluate static check failed.");
646   static_assert(CheckEvaluate<D, arma::mat, arma::mat>::value,
647       "CheckEvaluate static check failed.");
648 
649   static_assert(!CheckGradient<A, arma::mat, arma::mat>::value,
650       "CheckGradient static check failed.");
651   static_assert(!CheckGradient<B, arma::mat, arma::mat>::value,
652       "CheckGradient static check failed.");
653   static_assert(CheckGradient<C, arma::mat, arma::mat>::value,
654       "CheckGradient static check failed.");
655   static_assert(CheckGradient<D, arma::mat, arma::mat>::value,
656       "CheckGradient static check failed.");
657 
658   static_assert(!CheckNumConstraints<A, arma::mat, arma::mat>::value,
659       "CheckNumConstraints static check failed.");
660   static_assert(!CheckNumConstraints<B, arma::mat, arma::mat>::value,
661       "CheckNumConstraints static check failed.");
662   static_assert(CheckNumConstraints<C, arma::mat, arma::mat>::value,
663       "CheckNumConstraints static check failed.");
664   static_assert(CheckNumConstraints<D, arma::mat, arma::mat>::value,
665       "CheckNumConstraints static check failed.");
666 
667   static_assert(!CheckEvaluateConstraint<A, arma::mat, arma::mat>::value,
668       "CheckEvaluateConstraint static check failed.");
669   static_assert(!CheckEvaluateConstraint<B, arma::mat, arma::mat>::value,
670       "CheckEvaluateConstraint static check failed.");
671   static_assert(CheckEvaluateConstraint<C, arma::mat, arma::mat>::value,
672       "CheckEvaluateConstraint static check failed.");
673   static_assert(CheckEvaluateConstraint<D, arma::mat, arma::mat>::value,
674       "CheckEvaluateConstraint static check failed.");
675 
676   static_assert(!CheckGradientConstraint<A, arma::mat, arma::mat>::value,
677       "CheckGradientConstraint static check failed.");
678   static_assert(!CheckGradientConstraint<B, arma::mat, arma::mat>::value,
679       "CheckGradientConstraint static check failed.");
680   static_assert(CheckGradientConstraint<C, arma::mat, arma::mat>::value,
681       "CheckGradientConstraint static check failed.");
682   static_assert(CheckGradientConstraint<D, arma::mat, arma::mat>::value,
683       "CheckGradientConstraint static check failed.");
684 }
685 
686 /**
687  * Test the correctness of the static check for SparseFunctionType API.
688  */
689 TEST_CASE("SparseFunctionTypeCheckTest", "[FunctionTest]")
690 {
691   static_assert(CheckSparseGradient<A, arma::mat, arma::mat>::value,
692       "CheckSparseGradient static check failed.");
693   static_assert(CheckSparseGradient<B, arma::mat, arma::mat>::value,
694       "CheckSparseGradient static check failed.");
695   static_assert(!CheckSparseGradient<C, arma::mat, arma::mat>::value,
696       "CheckSparseGradient static check failed.");
697   static_assert(!CheckSparseGradient<D, arma::mat, arma::mat>::value,
698       "CheckSparseGradient static check failed.");
699 }
700 
701 /**
702  * Test the correctness of the static check for SparseFunctionType API.
703  */
704 TEST_CASE("ResolvableFunctionTypeCheckTest", "[FunctionTest]")
705 {
706   static_assert(CheckNumFeatures<A, arma::mat, arma::mat>::value,
707       "CheckNumFeatures static check failed.");
708   static_assert(CheckNumFeatures<B, arma::mat, arma::mat>::value,
709       "CheckNumFeatures static check failed.");
710   static_assert(!CheckNumFeatures<C, arma::mat, arma::mat>::value,
711       "CheckNumFeatures static check failed.");
712   static_assert(!CheckNumFeatures<D, arma::mat, arma::mat>::value,
713       "CheckNumFeatures static check failed.");
714 
715   static_assert(CheckPartialGradient<A, arma::mat, arma::sp_mat>::value,
716       "CheckPartialGradient static check failed.");
717   static_assert(CheckPartialGradient<B, arma::mat, arma::sp_mat>::value,
718       "CheckPartialGradient static check failed.");
719   static_assert(!CheckPartialGradient<C, arma::mat, arma::sp_mat>::value,
720       "CheckPartialGradient static check failed.");
721   static_assert(!CheckPartialGradient<D, arma::mat, arma::sp_mat>::value,
722       "CheckPartialGradient static check failed.");
723 }
724