1 /**
2  * @file add_separable_evaluate_with_gradient.hpp
3  * @author Ryan Curtin
4  *
5  * Adds a separable EvaluateWithGradient() function if both a separable
6  * Evaluate() and a separable Gradient() function exist.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_FUNCTION_ADD_DECOMPOSABLE_EVALUATE_W_GRADIENT_HPP
14 #define ENSMALLEN_FUNCTION_ADD_DECOMPOSABLE_EVALUATE_W_GRADIENT_HPP
15 
16 #include "traits.hpp"
17 
18 namespace ens {
19 
20 /**
21  * The AddSeparableEvaluateWithGradient mixin class will add a separable
22  * EvaluateWithGradient() method if a separable Evaluate() method and a
23  * separable Gradient() method exists, or nothing otherwise.
24  */
25 template<typename FunctionType,
26          typename MatType,
27          typename GradType,
28          // Check if there is at least one non-const Evaluate() or Gradient().
29          bool HasSeparableEvaluateGradient = traits::HasNonConstSignatures<
30              FunctionType,
31              traits::HasEvaluate,
32              traits::TypedForms<MatType, GradType>::template
33                  SeparableEvaluateForm,
34              traits::TypedForms<MatType, GradType>::template
35                  SeparableEvaluateConstForm,
36              traits::TypedForms<MatType, GradType>::template
37                  SeparableEvaluateStaticForm,
38              traits::HasGradient,
39              traits::TypedForms<MatType, GradType>::template
40                  SeparableGradientForm,
41              traits::TypedForms<MatType, GradType>::template
42                  SeparableGradientConstForm,
43              traits::TypedForms<MatType, GradType>::template
44                  SeparableGradientStaticForm>::value,
45          bool HasSeparableEvaluateWithGradient =
46              traits::HasEvaluateWithGradient<FunctionType,
47                  traits::TypedForms<MatType, GradType>::template
48                      SeparableEvaluateWithGradientForm>::value>
49 class AddSeparableEvaluateWithGradient
50 {
51  public:
52   // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this
53   // object.
54   typename MatType::elem_type EvaluateWithGradient(
55       traits::UnconstructableType&,
56       const size_t,
57       const size_t);
58 };
59 
60 /**
61  * Reflect the existing EvaluateWithGradient().
62  */
63 template<typename FunctionType,
64          typename MatType,
65          typename GradType,
66          bool HasSeparableEvaluateGradient>
67 class AddSeparableEvaluateWithGradient<FunctionType, MatType, GradType,
68     HasSeparableEvaluateGradient, true>
69 {
70  public:
71   // Reflect the existing EvaluateWithGradient().
EvaluateWithGradient(const MatType & coordinates,const size_t begin,GradType & gradient,const size_t batchSize)72   typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates,
73                                                    const size_t begin,
74                                                    GradType& gradient,
75                                                    const size_t batchSize)
76   {
77     return static_cast<FunctionType*>(
78         static_cast<Function<FunctionType,
79                              MatType,
80                              GradType>*>(this))->EvaluateWithGradient(
81         coordinates, begin, gradient, batchSize);
82   }
83 };
84 
85 /**
86  * If we have a both separable Evaluate() and a separable Gradient() but
87  * not a separable EvaluateWithGradient(), add a separable
88  * EvaluateWithGradient() method.
89  */
90 template<typename FunctionType, typename MatType, typename GradType>
91 class AddSeparableEvaluateWithGradient<FunctionType, MatType, GradType, true,
92     false>
93 {
94  public:
95   /**
96    * Return both the evaluated objective function and its gradient, storing the
97    * gradient in the given matrix, starting at the given separable function
98    * and using the given batch size.
99    *
100    * @param coordinates Coordinates to evaluate the function at.
101    * @param begin Index of separable function to begin with.
102    * @param gradient Matrix to store the gradient into.
103    * @param batchSize Number of separable functions to evaluate.
104    */
EvaluateWithGradient(const MatType & coordinates,const size_t begin,GradType & gradient,const size_t batchSize)105   typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates,
106                                                    const size_t begin,
107                                                    GradType& gradient,
108                                                    const size_t batchSize)
109   {
110     const typename MatType::elem_type objective =
111         static_cast<Function<FunctionType, MatType, GradType>*>(this)->Evaluate(
112         coordinates, begin, batchSize);
113     static_cast<Function<FunctionType, MatType, GradType>*>(this)->Gradient(
114         coordinates, begin, gradient, batchSize);
115     return objective;
116   }
117 };
118 
119 /**
120  * The AddSeparableEvaluateWithGradientConst mixin class will add a
121  * separable const EvaluateWithGradient() method if both a separable const
122  * Evaluate() and a separable const Gradient() function exist, or nothing
123  * otherwise.
124  */
125 template<typename FunctionType,
126          typename MatType,
127          typename GradType,
128          // Check if there is at least one const Evaluate() or Gradient().
129          bool HasSeparableEvaluateGradient = traits::HasConstSignatures<
130              FunctionType,
131              traits::HasEvaluate,
132              traits::TypedForms<MatType, GradType>::template
133                  SeparableEvaluateConstForm,
134              traits::TypedForms<MatType, GradType>::template
135                  SeparableEvaluateStaticForm,
136              traits::HasGradient,
137              traits::TypedForms<MatType, GradType>::template
138                  SeparableGradientConstForm,
139              traits::TypedForms<MatType, GradType>::template
140                  SeparableGradientStaticForm>::value,
141          bool HasSeparableEvaluateWithGradient =
142              traits::HasEvaluateWithGradient<FunctionType,
143                  traits::TypedForms<MatType, GradType>::template
144                      SeparableEvaluateWithGradientConstForm>::value>
145 class AddSeparableEvaluateWithGradientConst
146 {
147  public:
148   // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this
149   // object.
150   typename MatType::elem_type EvaluateWithGradient(
151       traits::UnconstructableType&,
152       const size_t,
153       const size_t) const;
154 };
155 
156 /**
157  * Reflect the existing EvaluateWithGradient().
158  */
159 template<typename FunctionType,
160          typename MatType,
161          typename GradType,
162          bool HasSeparableEvaluateGradient>
163 class AddSeparableEvaluateWithGradientConst<FunctionType, MatType, GradType,
164     HasSeparableEvaluateGradient, true>
165 {
166  public:
167   // Reflect the existing Evaluate().
EvaluateWithGradient(const MatType & coordinates,const size_t begin,GradType & gradient,const size_t batchSize) const168   typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates,
169                                                    const size_t begin,
170                                                    GradType& gradient,
171                                                    const size_t batchSize) const
172   {
173     return static_cast<const FunctionType*>(
174         static_cast<const Function<FunctionType,
175                                    MatType,
176                                    GradType>*>(this))->EvaluateWithGradient(
177         coordinates, begin, gradient, batchSize);
178   }
179 };
180 
181 /**
182  * If we have both a separable const Evaluate() and a separable const
183  * Gradient() but not a separable const EvaluateWithGradient(), add a
184  * separable const EvaluateWithGradient() method.
185  */
186 template<typename FunctionType, typename MatType, typename GradType>
187 class AddSeparableEvaluateWithGradientConst<FunctionType, MatType, GradType,
188     true, false>
189 {
190  public:
191   /**
192    * Return both the evaluated objective function and its gradient, storing the
193    * gradient in the given matrix, starting at the given separable function
194    * and using the given batch size.
195    *
196    * @param coordinates Coordinates to evaluate the function at.
197    * @param begin Index of separable function to begin with.
198    * @param gradient Matrix to store the gradient into.
199    * @param batchSize Number of separable functions to evaluate.
200    */
EvaluateWithGradient(const MatType & coordinates,const size_t begin,GradType & gradient,const size_t batchSize) const201   typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates,
202                                                    const size_t begin,
203                                                    GradType& gradient,
204                                                    const size_t batchSize) const
205   {
206     const typename MatType::elem_type objective =
207         static_cast<const Function<FunctionType,
208                                    MatType,
209                                    GradType>*>(this)->Evaluate(coordinates,
210         begin, batchSize);
211     static_cast<const Function<FunctionType,
212                                MatType,
213                                GradType>*>(this)->Gradient(coordinates,
214         begin, gradient, batchSize);
215     return objective;
216   }
217 };
218 
219 /**
220  * The AddSeparableEvaluateWithGradientStatic mixin class will add a
221  * separable static EvaluateWithGradient() method if both a separable
222  * static Evaluate() and a separable static gradient() function exist, or
223  * nothing otherwise.
224  */
225 template<typename FunctionType,
226          typename MatType,
227          typename GradType,
228          bool HasSeparableEvaluateGradient =
229              traits::HasEvaluate<FunctionType,
230                  traits::TypedForms<MatType, GradType>::template
231                      SeparableEvaluateStaticForm>::value &&
232              traits::HasGradient<FunctionType,
233                  traits::TypedForms<MatType, GradType>::template
234                      SeparableGradientStaticForm>::value,
235          bool HasSeparableEvaluateWithGradient =
236              traits::HasEvaluateWithGradient<FunctionType,
237                  traits::TypedForms<MatType, GradType>::template
238                      SeparableEvaluateWithGradientStaticForm>::value>
239 class AddSeparableEvaluateWithGradientStatic
240 {
241  public:
242   // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this
243   // object.
244   static typename MatType::elem_type EvaluateWithGradient(
245       traits::UnconstructableType&,
246       const size_t,
247       const size_t);
248 };
249 
250 /**
251  * Reflect the existing EvaluateWithGradient().
252  */
253 template<typename FunctionType,
254          typename MatType,
255          typename GradType,
256          bool HasSeparableEvaluateGradient>
257 class AddSeparableEvaluateWithGradientStatic<FunctionType, MatType, GradType,
258     HasSeparableEvaluateGradient, true>
259 {
260  public:
261   // Reflect the existing Evaluate().
EvaluateWithGradient(const MatType & coordinates,const size_t begin,GradType & gradient,const size_t batchSize)262   static typename MatType::elem_type EvaluateWithGradient(
263       const MatType& coordinates,
264       const size_t begin,
265       GradType& gradient,
266       const size_t batchSize)
267   {
268     return FunctionType::EvaluateWithGradient(coordinates, begin, gradient,
269         batchSize);
270   }
271 };
272 
273 /**
274  * If we have a separable static Evaluate() and a separable static
275  * Gradient() but not a separable static EvaluateWithGradient(), add a
276  * separable static Gradient() method.
277  */
278 template<typename FunctionType, typename MatType, typename GradType>
279 class AddSeparableEvaluateWithGradientStatic<FunctionType, MatType, GradType,
280     true, false>
281 {
282  public:
283   /**
284    * Return both the evaluated objective function and its gradient, storing the
285    * gradient in the given matrix, starting at the given separable function
286    * and using the given batch size.
287    *
288    * @param coordinates Coordinates to evaluate the function at.
289    * @param begin Index of separable function to begin with.
290    * @param gradient Matrix to store the gradient into.
291    * @param batchSize Number of separable functions to evaluate.
292    */
EvaluateWithGradient(const MatType & coordinates,const size_t begin,GradType & gradient,const size_t batchSize) const293   typename MatType::elem_type EvaluateWithGradient(
294       const MatType& coordinates,
295       const size_t begin,
296       GradType& gradient,
297       const size_t batchSize) const
298   {
299     const typename MatType::elem_type objective = FunctionType::Evaluate(
300         coordinates, begin, batchSize);
301     FunctionType::Gradient(coordinates, begin, gradient, batchSize);
302     return objective;
303   }
304 };
305 
306 } // namespace ens
307 
308 #endif
309