1 /**
2  * @file traits.hpp
3  * @author Ryan Curtin
4  *
5  * This file provides metaprogramming utilities for detecting certain members of
6  * FunctionType classes.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_FUNCTION_TRAITS_HPP
14 #define ENSMALLEN_FUNCTION_TRAITS_HPP
15 
16 #include "sfinae_utility.hpp"
17 #include "arma_traits.hpp"
18 
19 namespace ens {
20 namespace traits {
21 
22 //! Detect an Evaluate() method.
23 ENS_HAS_EXACT_METHOD_FORM(Evaluate, HasEvaluate)
24 //! Detect a Gradient() method.
25 ENS_HAS_EXACT_METHOD_FORM(Gradient, HasGradient)
26 //! Detect an EvaluateWithGradient() method.
27 ENS_HAS_EXACT_METHOD_FORM(EvaluateWithGradient, HasEvaluateWithGradient)
28 //! Detect a NumFunctions() method.
29 ENS_HAS_EXACT_METHOD_FORM(NumFunctions, HasNumFunctions)
30 //! Detect a Shuffle() method.
31 ENS_HAS_EXACT_METHOD_FORM(Shuffle, HasShuffle)
32 //! Detect a NumConstraints() method.
33 ENS_HAS_EXACT_METHOD_FORM(NumConstraints, HasNumConstraints)
34 //! Detect an EvaluateConstraint() method.
35 ENS_HAS_EXACT_METHOD_FORM(EvaluateConstraint, HasEvaluateConstraint)
36 //! Detect a GradientConstraint() method.
37 ENS_HAS_EXACT_METHOD_FORM(GradientConstraint, HasGradientConstraint)
38 //! Detect a NumFeatures() method.
39 ENS_HAS_EXACT_METHOD_FORM(NumFeatures, HasNumFeatures)
40 //! Detect a PartialGradient() method.
41 ENS_HAS_EXACT_METHOD_FORM(PartialGradient, HasPartialGradient)
42 //! Detect an MaxIterations() method.
43 ENS_HAS_EXACT_METHOD_FORM(MaxIterations, HasMaxIterations)
44 //! Detect an ResetPolicy() method.
45 ENS_HAS_EXACT_METHOD_FORM(ResetPolicy, HasResetPolicy)
46 //! Detect an BatchSize() method.
47 ENS_HAS_EXACT_METHOD_FORM(BatchSize, HasBatchSize)
48 //! Detect an StepSize() method.
49 ENS_HAS_EXACT_METHOD_FORM(StepSize, HasStepSize)
50 
51 template<typename MatType, typename GradType>
52 struct TypedForms
53 {
54   typedef typename MatTypeTraits<MatType>::BaseMatType BaseMatType;
55   typedef typename MatTypeTraits<GradType>::BaseMatType BaseGradType;
56 
57   //! This is the form of a non-const Evaluate() method.
58   template<typename FunctionType>
59   using EvaluateForm =
60       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&);
61 
62   //! This is the form of a const Evaluate() method.
63   template<typename FunctionType>
64   using EvaluateConstForm = typename BaseMatType::elem_type(FunctionType::*)(
65       const BaseMatType&) const;
66 
67   //! This is the form of a static Evaluate() method.
68   template<typename FunctionType>
69   using EvaluateStaticForm = typename BaseMatType::elem_type(*)(
70       const BaseMatType&);
71 
72   //! This is the form of a non-const Gradient() method.
73   template<typename FunctionType>
74   using GradientForm = void(FunctionType::*)(const BaseMatType&, BaseGradType&);
75 
76   //! This is the form of a const Gradient() method.
77   template<typename FunctionType>
78   using GradientConstForm =
79       void(FunctionType::*)(const BaseMatType&, BaseGradType&) const;
80 
81   //! This is the form of a static Gradient() method.
82   template<typename FunctionType>
83   using GradientStaticForm = void(*)(const BaseMatType&, BaseGradType&);
84 
85   //! This is the form of a non-const EvaluateWithGradient() method.
86   template<typename FunctionType>
87   using EvaluateWithGradientForm =
88       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&,
89                                                        BaseGradType&);
90 
91   //! This is the form of a const EvaluateWithGradient() method.
92   template<typename FunctionType>
93   using EvaluateWithGradientConstForm =
94       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&,
95                                                        BaseGradType&) const;
96 
97   //! This is the form of a static EvaluateWithGradient() method.
98   template<typename FunctionType>
99   using EvaluateWithGradientStaticForm = typename BaseMatType::elem_type(*)(
100       const BaseMatType&, BaseGradType&);
101 
102   //! This is the form of a non-const NumFunctions() method.
103   template <typename FunctionType>
104   using NumFunctionsForm = size_t(FunctionType::*)();
105 
106   //! This is the form of a const NumFunctions() method.
107   template <typename FunctionType>
108   using NumFunctionsConstForm = size_t(FunctionType::*)() const;
109 
110   //! This is the form of a static NumFunctions() method.
111   template<typename FunctionType>
112   using NumFunctionsStaticForm = size_t(*)();
113 
114   //! This is the form of a non-const Shuffle() method.
115   template<typename FunctionType>
116   using ShuffleForm = void(FunctionType::*)();
117 
118   //! This is the form of a const Shuffle() method.
119   template<typename FunctionType>
120   using ShuffleConstForm = void(FunctionType::*)() const;
121 
122   //! This is the form of a static Shuffle() method.
123   template<typename FunctionType>
124   using ShuffleStaticForm = void(*)();
125 
126   //! This is the form of a separable Evaluate() method.
127   template<typename FunctionType>
128   using SeparableEvaluateForm =
129       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&,
130                                                        const size_t,
131                                                        const size_t);
132 
133   //! This is the form of a separable const Evaluate() method.
134   template<typename FunctionType>
135   using SeparableEvaluateConstForm =
136       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&,
137                                                        const size_t,
138                                                        const size_t) const;
139 
140   //! This is the form of a separable static Evaluate() method.
141   template<typename FunctionType>
142   using SeparableEvaluateStaticForm = typename BaseMatType::elem_type(*)(
143         const BaseMatType&, const size_t, const size_t);
144 
145   //! This is the form of a separable non-const Gradient() method.
146   template<typename FunctionType>
147   using SeparableGradientForm = void(FunctionType::*)(
148       const BaseMatType&, const size_t, BaseGradType&, const size_t);
149 
150   //! This the form of a separable const Gradient() method.
151   template<typename FunctionType>
152   using SeparableGradientConstForm = void(FunctionType::*)(
153       const BaseMatType&, const size_t, BaseGradType&, const size_t) const;
154 
155   //! This is the form of a separable static Gradient() method.
156   template<typename FunctionType>
157   using SeparableGradientStaticForm = void(*)(
158       const BaseMatType&, const size_t, BaseGradType&, const size_t);
159 
160   //! This is the form of a separable non-const EvaluateWithGradient()
161   //! method.
162   template<typename FunctionType>
163   using SeparableEvaluateWithGradientForm =
164       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&,
165                                                        const size_t,
166                                                        BaseGradType&,
167                                                        const size_t);
168 
169   //! This is the form of a separable const EvaluateWithGradient() method.
170   template<typename FunctionType>
171   using SeparableEvaluateWithGradientConstForm =
172       typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&,
173                                                        const size_t,
174                                                        BaseGradType&,
175                                                        const size_t) const;
176 
177   //! This is the form of a separable static EvaluateWithGradient() method.
178   template<typename FunctionType>
179   using SeparableEvaluateWithGradientStaticForm =
180       typename BaseMatType::elem_type(*)(const BaseMatType&,
181                                          const size_t,
182                                          BaseGradType&,
183                                          const size_t);
184 
185   //! This is the form of a non-const NumConstraints() method.
186   template<typename FunctionType>
187   using NumConstraintsForm = size_t(FunctionType::*)();
188 
189   //! This is the form of a const NumConstraints() method.
190   template<typename FunctionType>
191   using NumConstraintsConstForm = size_t(FunctionType::*)() const;
192 
193   //! This is the form of a static NumConstraints() method.
194   template<typename FunctionType>
195   using NumConstraintsStaticForm = size_t(*)();
196 
197   //! This is the form of a non-const EvaluateConstraint() method.
198   template <typename FunctionType>
199   using EvaluateConstraintForm =
200       typename BaseMatType::elem_type(FunctionType::*)(const size_t,
201                                                        const BaseMatType&);
202 
203   //! This is the form of a const EvaluateConstraint() method.
204   template<typename FunctionType>
205   using EvaluateConstraintConstForm =
206       typename BaseMatType::elem_type(FunctionType::*)(const size_t,
207                                                        const BaseMatType&)
208           const;
209 
210   //! This is the form of a static EvaluateConstraint() method.
211   template<typename FunctionType>
212   using EvaluateConstraintStaticForm = typename BaseMatType::elem_type(*)(
213       const size_t, const BaseMatType&);
214 
215   //! This is the form of a non-const GradientConstraint() method.
216   template <typename FunctionType>
217   using GradientConstraintForm = void(FunctionType::*)(
218       const size_t, const BaseMatType&, BaseGradType&);
219 
220   //! This is the form of a const GradientConstraint() method.
221   template<typename FunctionType>
222   using GradientConstraintConstForm = void(FunctionType::*)(
223       const size_t, const BaseMatType&, BaseGradType&) const;
224 
225   //! This is the form of a static GradientConstraint() method.
226   template<typename Class, typename... Ts>
227   using GradientConstraintStaticForm = void(*)(
228       const size_t, const BaseMatType&, BaseGradType&);
229 
230   //! This is the form of a non-const sparse Gradient() method.
231   //! This check isn't particularly useful---the user needs to specify a sparse
232   //! gradient type...
233   template<typename FunctionType>
234   using SparseGradientForm = void(FunctionType::*)(
235       const BaseMatType&, const size_t, BaseGradType&, const size_t);
236 
237   //! This is the form of a const sparse Gradient() method.
238   //! This check isn't particularly useful---the user needs to specify a sparse
239   //! gradient type...
240   template<typename FunctionType>
241   using SparseGradientConstForm = void(FunctionType::*)(
242       const BaseMatType&, const size_t, BaseGradType&, const size_t) const;
243 
244   //! This is the form of a static sparse Gradient() method.
245   //! This check isn't particularly useful---the user needs to specify a sparse
246   //! gradient type...
247   template<typename FunctionType>
248   using SparseGradientStaticForm = void(*)(
249       const BaseMatType&, const size_t, BaseGradType&, const size_t);
250 
251   //! This is the form of a non-const NumFeatures() method.
252   template<typename FunctionType>
253   using NumFeaturesForm = size_t(FunctionType::*)();
254 
255   //! This is the form of a const NumFeatures() method.
256   template<typename FunctionType>
257   using NumFeaturesConstForm = size_t(FunctionType::*)() const;
258 
259   //! This is the form of a static NumFeatures() method.
260   template<typename FunctionType>
261   using NumFeaturesStaticForm = size_t(*)();
262 
263   //! This is the form of a non-const PartialGradient() method.
264   template<typename FunctionType>
265   using PartialGradientForm = void(FunctionType::*)(
266       const BaseMatType&, const size_t, BaseGradType&);
267 
268   //! This is the form of a const PartialGradient() method.
269   template<typename FunctionType>
270   using PartialGradientConstForm = void(FunctionType::*)(
271       const BaseMatType&, const size_t, BaseGradType&) const;
272 
273   //! This is the form of a static PartialGradient() method.
274   template<typename FunctionType>
275   using PartialGradientStaticForm = void(*)(
276       const BaseMatType&, const size_t, BaseGradType&);
277 
278   //! This is a utility struct that will match any non-const form.
279   template<typename FunctionType, typename... Ts>
280   using OtherForm = typename BaseMatType::elem_type(FunctionType::*)(Ts...);
281 
282   //! This is a utility struct that will match any const form.
283   template<typename FunctionType, typename... Ts>
284   using OtherConstForm = typename BaseMatType::elem_type(FunctionType::*)(Ts...)
285       const;
286 
287   //! This is a utility struct that will match any static form.
288   template<typename FunctionType, typename... Ts>
289   using OtherStaticForm = typename BaseMatType::elem_type(*)(Ts...);
290 };
291 
292 /**
293  * This is a utility type used to provide unusable overloads from each of the
294  * mixin classes.  If you are seeing an error mentioning this class, the most
295  * likely issue is that you have not implemented the right methods for your
296  * FunctionType class.
297  */
298 struct UnconstructableType
299 {
300  private:
UnconstructableTypeens::traits::UnconstructableType301   UnconstructableType() { }
302 };
303 
304 /**
305  * Utility struct: sometimes we want to know if we have two functions available,
306  * and that at least one of them is non-const and non-static.  If the
307  * corresponding checkers (from ENS_HAS_METHOD_FORM()) are given as CheckerA and
308  * CheckerB, and the corresponding non-const, const, and static function
309  * signatures are given as SignatureA, ConstSignatureA, StaticSignatureA,
310  * SignatureB, ConstSignatureB, and StaticSignatureB, then 'value' will be true
311  * if methods with the correct names exist in the given ClassType and at least
312  * one of those two methods is non-const and non-static.
313  */
314 template<typename ClassType,
315          template<typename, template<typename...> class, size_t> class CheckerA,
316          template<typename...> class SignatureA,
317          template<typename...> class ConstSignatureA,
318          template<typename...> class StaticSignatureA,
319          template<typename, template<typename...> class, size_t> class CheckerB,
320          template<typename...> class SignatureB,
321          template<typename...> class ConstSignatureB,
322          template<typename...> class StaticSignatureB>
323 struct HasNonConstSignatures
324 {
325   // Check if any const or static version of method A exists.
326   const static bool HasAnyFormA =
327       CheckerA<ClassType, SignatureA, 0>::value ||
328       CheckerA<ClassType, ConstSignatureA, 0>::value ||
329       CheckerA<ClassType, StaticSignatureA, 0>::value;
330   // Check if any const or static versino of method B exists.
331   const static bool HasAnyFormB =
332       CheckerB<ClassType, SignatureB, 0>::value ||
333       CheckerB<ClassType, ConstSignatureB, 0>::value ||
334       CheckerB<ClassType, StaticSignatureB, 0>::value;
335 
336   // Make sure at least one const version exists.
337   const static bool HasEitherNonConstForm =
338       CheckerA<ClassType, SignatureA, 0>::value ||
339       CheckerB<ClassType, SignatureB, 0>::value;
340 
341   const static bool value = HasEitherNonConstForm && HasAnyFormA && HasAnyFormB;
342 };
343 
344 /**
345  * Utility struct: sometimes we want to know if we have two functions available,
346  * and that at least one of them is const and both of them are not non-const and
347  * non-static.  If the corresponding checkers (from ENS_HAS_METHOD_FORM()) are
348  * given as CheckerA and CheckerB, and the corresponding const and static
349  * function signatures are given as ConstSignatureA, StaticSignatureA,
350  * ConstSignatureB, and StaticSignatureB, then 'value' will be true if methods
351  * with the correct names exist in the given ClassType and at least one of those
352  * two methods is const, and neither method is non-const and non-static.
353  */
354 template<typename ClassType,
355          template<typename, template<typename...> class, size_t> class CheckerA,
356          template<typename...> class ConstSignatureA,
357          template<typename...> class StaticSignatureA,
358          template<typename, template<typename...> class, size_t> class CheckerB,
359          template<typename...> class ConstSignatureB,
360          template<typename...> class StaticSignatureB>
361 struct HasConstSignatures
362 {
363   // Check if any const or static version of method A exists.
364   const static bool HasAnyFormA =
365       CheckerA<ClassType, ConstSignatureA, 0>::value ||
366       CheckerA<ClassType, StaticSignatureA, 0>::value;
367   // Check if any const or static version of method B exists.
368   const static bool HasAnyFormB =
369       CheckerB<ClassType, ConstSignatureB, 0>::value ||
370       CheckerB<ClassType, StaticSignatureB, 0>::value;
371 
372   // Make sure at least one const version exists.
373   const static bool HasEitherConstForm =
374       CheckerA<ClassType, ConstSignatureA, 0>::value ||
375       CheckerB<ClassType, ConstSignatureB, 0>::value;
376 
377   const static bool value = HasEitherConstForm && HasAnyFormA && HasAnyFormB;
378 };
379 
380 //! Utility struct, check if size_t BatchSize() const or size_t BatchSize()
381 //! exists.
382 template<typename OptimizerType>
383 struct HasBatchSizeSignature
384 {
385   template<typename C>
386   using BatchSizeConstForm = size_t(C::*)(void) const;
387 
388   template<typename C>
389   using BatchSizeForm = size_t(C::*)(void);
390 
391   const static bool value =
392       HasBatchSize<OptimizerType, BatchSizeForm>::value ||
393       HasBatchSize<OptimizerType, BatchSizeConstForm>::value;
394 };
395 
396 //! Utility struct, check if size_t StepSize() const or size_t StepSize()
397 //! exists.
398 template<typename OptimizerType>
399 struct HasStepSizeSignature
400 {
401   template<typename C>
402   using StepSizeConstForm = double(C::*)(void) const;
403 
404   template<typename C>
405   using StepSizeForm = double(C::*)(void);
406 
407   const static bool value =
408       HasStepSize<OptimizerType, StepSizeForm>::value ||
409       HasStepSize<OptimizerType, StepSizeConstForm>::value;
410 };
411 
412 //! Utility struct, check if size_t MaxIterations() const exists.
413 template<typename OptimizerType>
414 struct HasMaxIterationsSignature
415 {
416   template<typename C>
417   using HasMaxIterationsForm = size_t(C::*)(void) const;
418 
419   const static bool value =
420       HasMaxIterations<OptimizerType, HasMaxIterationsForm>::value;
421 };
422 
423 //! Utility struct, check if size_t NumFunctions() const or
424 //! size_t NumFunctions() exists.
425 template<typename OptimizerType>
426 struct HasNumFunctionsSignature
427 {
428   template<typename C>
429   using NumFunctionsConstForm = size_t(C::*)(void) const;
430 
431   template<typename C>
432   using NumFunctionsForm = size_t(C::*)(void);
433 
434   const static bool value =
435       HasNumFunctions<OptimizerType, NumFunctionsForm>::value ||
436       HasNumFunctions<OptimizerType, NumFunctionsConstForm>::value;
437 };
438 
439 //! Utility struct, check if bool ResetPolicy() exists.
440 template<typename OptimizerType>
441 struct HasResetPolicySignature
442 {
443   template<typename C>
444   using HasResetPolicyForm = bool&(C::*)(void);
445 
446   const static bool value =
447       HasResetPolicy<OptimizerType, HasResetPolicyForm>::value;
448 };
449 
450 } // namespace traits
451 } // namespace ens
452 
453 #endif
454