1 /**
2  * @file add_evaluate_with_gradient.hpp
3  * @author Ryan Curtin
4  *
5  * This file defines a mixin for the Function class that will ensure that the
6  * EvaluateWithGradient() function is available if possible.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_FUNCTION_ADD_EVALUATE_WITH_GRADIENT_HPP
14 #define ENSMALLEN_FUNCTION_ADD_EVALUATE_WITH_GRADIENT_HPP
15 
16 #include "sfinae_utility.hpp"
17 #include "traits.hpp"
18 
19 namespace ens {
20 
21 /**
22  * The AddEvaluateWithGradient mixin class will provide an
23  * EvaluateWithGradient() method if the given FunctionType has both Evaluate()
24  * and Gradient(), or it will provide nothing otherwise.
25  */
26 template<typename FunctionType,
27          typename MatType,
28          typename GradType,
29          // Check if there is at least one non-const Evaluate() or Gradient().
30          bool HasEvaluateGradient = traits::HasNonConstSignatures<
31              FunctionType,
32              traits::HasEvaluate,
33              traits::TypedForms<MatType, GradType>::template EvaluateForm,
34              traits::TypedForms<MatType, GradType>::template EvaluateConstForm,
35              traits::TypedForms<MatType, GradType>::template EvaluateStaticForm,
36              traits::HasGradient,
37              traits::TypedForms<MatType, GradType>::template GradientForm,
38              traits::TypedForms<MatType, GradType>::template GradientConstForm,
39              traits::TypedForms<MatType, GradType>::template GradientStaticForm
40          >::value,
41          bool HasEvaluateWithGradient = traits::HasEvaluateWithGradient<
42              FunctionType,
43              traits::TypedForms<MatType, GradType>::template
44                  EvaluateWithGradientForm>::value>
45 class AddEvaluateWithGradient
46 {
47  public:
48   // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this
49   // object.
50   typename MatType::elem_type EvaluateWithGradient(
51       traits::UnconstructableType&);
52 };
53 
54 /**
55  * Reflect the existing EvaluateWithGradient().
56  */
57 template<typename FunctionType,
58          typename MatType,
59          typename GradType,
60          bool HasEvaluateGradient>
61 class AddEvaluateWithGradient<FunctionType,
62                               MatType,
63                               GradType,
64                               HasEvaluateGradient,
65                               true>
66 {
67  public:
68   // Reflect the existing EvaluateWithGradient().
EvaluateWithGradient(const MatType & coordinates,GradType & gradient)69   typename MatType::elem_type EvaluateWithGradient(
70       const MatType& coordinates, GradType& gradient)
71   {
72     return static_cast<FunctionType*>(
73         static_cast<Function<FunctionType,
74                              MatType,
75                              GradType>*>(this))->EvaluateWithGradient(
76         coordinates, gradient);
77   }
78 };
79 
80 /**
81  * If the FunctionType has Evaluate() and Gradient(), provide
82  * EvaluateWithGradient().
83  */
84 template<typename FunctionType, typename MatType, typename GradType>
85 class AddEvaluateWithGradient<FunctionType, MatType, GradType, true, false>
86 {
87  public:
88   /**
89    * Return both the evaluated objective function and its gradient, storing the
90    * gradient in the given matrix.
91    *
92    * @param coordinates Coordinates to evaluate the function at.
93    * @param gradient Matrix to store the gradient into.
94    */
EvaluateWithGradient(const MatType & coordinates,GradType & gradient)95   typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates,
96                                                    GradType& gradient)
97   {
98     const typename MatType::elem_type objective =
99         static_cast<Function<FunctionType,
100                              MatType, GradType>*>(this)->Evaluate(coordinates);
101     static_cast<Function<FunctionType,
102                          MatType,
103                          GradType>*>(this)->Gradient(coordinates, gradient);
104     return objective;
105   }
106 };
107 
108 /**
109  * The AddEvaluateWithGradient mixin class will provide an
110  * EvaluateWithGradient() const method if the given FunctionType has both
111  * Evaluate() const and Gradient() const, or it will provide nothing otherwise.
112  */
113 template<typename FunctionType,
114          typename MatType,
115          typename GradType,
116          // Check if there is at least one const Evaluate() or Gradient().
117          bool HasEvaluateGradient = traits::HasConstSignatures<
118              FunctionType,
119              traits::HasEvaluate,
120              traits::TypedForms<MatType, GradType>::template EvaluateConstForm,
121              traits::TypedForms<MatType, GradType>::template EvaluateStaticForm,
122              traits::HasGradient,
123              traits::TypedForms<MatType, GradType>::template GradientConstForm,
124              traits::TypedForms<MatType, GradType>::template GradientStaticForm
125          >::value,
126          bool HasEvaluateWithGradient = traits::HasEvaluateWithGradient<
127              FunctionType,
128              traits::TypedForms<
129                  MatType, GradType
130              >::template EvaluateWithGradientConstForm>::value>
131 class AddEvaluateWithGradientConst
132 {
133  public:
134   // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this
135   // object.
136   typename MatType::elem_type EvaluateWithGradient(
137       traits::UnconstructableType&) const;
138 };
139 
140 /**
141  * Reflect the existing EvaluateWithGradient().
142  */
143 template<typename FunctionType,
144          typename MatType,
145          typename GradType,
146          bool HasEvaluateGradient>
147 class AddEvaluateWithGradientConst<FunctionType,
148                                    MatType,
149                                    GradType,
150                                    HasEvaluateGradient,
151                                    true>
152 {
153  public:
154   // Reflect the existing EvaluateWithGradient().
EvaluateWithGradient(const MatType & coordinates,GradType & gradient) const155   typename MatType::elem_type EvaluateWithGradient(
156       const MatType& coordinates, GradType& gradient) const
157   {
158     return static_cast<const FunctionType*>(
159         static_cast<const Function<FunctionType,
160                                    MatType,
161                                    GradType>*>(this))->EvaluateWithGradient(
162         coordinates, gradient);
163   }
164 };
165 
166 /**
167  * If the FunctionType has Evaluate() const and Gradient() const, provide
168  * EvaluateWithGradient() const.
169  */
170 template<typename FunctionType, typename MatType, typename GradType>
171 class AddEvaluateWithGradientConst<FunctionType, MatType, GradType, true, false>
172 {
173  public:
174   /**
175    * Return both the evaluated objective function and its gradient, storing the
176    * gradient in the given matrix.
177    *
178    * @param coordinates Coordinates to evaluate the function at.
179    * @param gradient Matrix to store the gradient into.
180    */
EvaluateWithGradient(const MatType & coordinates,GradType & gradient) const181   typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates,
182                                                    GradType& gradient) const
183   {
184     const typename MatType::elem_type objective =
185         static_cast<const Function<FunctionType,
186                                    MatType,
187                                    GradType>*>(this)->Evaluate(coordinates);
188     static_cast<const Function<FunctionType,
189                                MatType,
190                                GradType>*>(this)->Gradient(coordinates,
191                                                            gradient);
192     return objective;
193   }
194 };
195 
196 /**
197  * The AddEvaluateWithGradientStatic mixin class will provide a
198  * static EvaluateWithGradient() method if the given FunctionType has both
199  * static Evaluate() and static Gradient(), or it will provide nothing
200  * otherwise.
201  */
202 template<typename FunctionType,
203          typename MatType,
204          typename GradType,
205          bool HasEvaluateGradient =
206              traits::HasEvaluate<FunctionType,
207                  traits::TypedForms<MatType, GradType>::template
208                      EvaluateStaticForm
209              >::value &&
210              traits::HasGradient<FunctionType,
211                  traits::TypedForms<MatType, GradType>::template
212                      GradientStaticForm
213              >::value,
214          bool HasEvaluateWithGradient =
215              traits::HasEvaluateWithGradient<FunctionType,
216                  traits::TypedForms<MatType,
217                                     GradType>::template
218                      EvaluateWithGradientStaticForm
219              >::value>
220 class AddEvaluateWithGradientStatic
221 {
222  public:
223   // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this
224   // object.
225   static typename MatType::elem_type EvaluateWithGradient(
226       traits::UnconstructableType&);
227 };
228 
229 /**
230  * Reflect the existing EvaluateWithGradient().
231  */
232 template<typename FunctionType,
233          typename MatType,
234          typename GradType,
235          bool HasEvaluateGradient>
236 class AddEvaluateWithGradientStatic<FunctionType,
237                                     MatType,
238                                     GradType,
239                                     HasEvaluateGradient,
240                                     true>
241 {
242  public:
243   // Reflect the existing EvaluateWithGradient().
EvaluateWithGradient(const MatType & coordinates,GradType & gradient)244   static typename MatType::elem_type EvaluateWithGradient(
245       const MatType& coordinates, GradType& gradient)
246   {
247     return FunctionType::EvaluateWithGradient(coordinates, gradient);
248   }
249 };
250 
251 /**
252  * If the FunctionType has static Evaluate() and static Gradient(), provide
253  * static EvaluateWithGradient().
254  */
255 template<typename FunctionType, typename MatType, typename GradType>
256 class AddEvaluateWithGradientStatic<FunctionType,
257                                     MatType,
258                                     GradType,
259                                     true,
260                                     false>
261 {
262  public:
263   /**
264    * Return both the evaluated objective function and its gradient, storing the
265    * gradient in the given matrix.
266    *
267    * @param coordinates Coordinates to evaluate the function at.
268    * @param gradient Matrix to store the gradient into.
269    */
EvaluateWithGradient(const MatType & coordinates,GradType & gradient)270   static typename MatType::elem_type EvaluateWithGradient(
271       const MatType& coordinates, GradType& gradient)
272   {
273     const typename MatType::elem_type objective =
274         FunctionType::Evaluate(coordinates);
275     FunctionType::Gradient(coordinates, gradient);
276     return objective;
277   }
278 };
279 
280 } // namespace ens
281 
282 #endif
283