1 /**
2  * @file add_separable_evaluate.hpp
3  * @author Ryan Curtin
4  *
5  * Adds a separable Evaluate() function if a separable
6  * EvaluateWithGradient() function exists.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_FUNCTION_ADD_DECOMPOSABLE_EVALUATE_HPP
14 #define ENSMALLEN_FUNCTION_ADD_DECOMPOSABLE_EVALUATE_HPP
15 
16 #include "traits.hpp"
17 
18 namespace ens {
19 
20 /**
21  * The AddSeparableEvaluate mixin class will add a separable Evaluate()
22  * method if a separable EvaluateWithGradient() function exists, or nothing
23  * otherwise.
24  */
25 template<typename FunctionType,
26          typename MatType,
27          typename GradType,
28          bool HasSeparableEvaluateWithGradient =
29              traits::HasEvaluateWithGradient<FunctionType,
30                  traits::TypedForms<MatType, GradType>::template
31                      SeparableEvaluateWithGradientForm
32              >::value,
33          bool HasSeparableEvaluate =
34              traits::HasEvaluate<FunctionType,
35                  traits::TypedForms<MatType, GradType>::template
36                       SeparableEvaluateForm>::value>
37 class AddSeparableEvaluate
38 {
39  public:
40   // Provide a dummy overload so the name 'Evaluate' exists for this object.
41   typename MatType::elem_type Evaluate(traits::UnconstructableType&,
42                                        const size_t,
43                                        const size_t);
44 };
45 
46 /**
47  * Reflect the existing Evaluate().
48  */
49 template<typename FunctionType,
50          typename MatType,
51          typename GradType,
52          bool HasSeparableEvaluateWithGradient>
53 class AddSeparableEvaluate<FunctionType, MatType, GradType,
54     HasSeparableEvaluateWithGradient, true>
55 {
56  public:
57   // Reflect the existing Evaluate().
Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)58   typename MatType::elem_type Evaluate(const MatType& coordinates,
59                                        const size_t begin,
60                                        const size_t batchSize)
61   {
62     return static_cast<FunctionType*>(
63         static_cast<Function<FunctionType,
64                              MatType,
65                              GradType>*>(this))->Evaluate(coordinates,
66                                                           begin,
67                                                           batchSize);
68   }
69 };
70 
71 /**
72  * If we have a separable EvaluateWithGradient() but not a separable
73  * Evaluate(), add a separable Evaluate() method.
74  */
75 template<typename FunctionType, typename MatType, typename GradType>
76 class AddSeparableEvaluate<FunctionType, MatType, GradType, true, false>
77 {
78  public:
79   /**
80    * Return the objective function for the given coordinates, starting at the
81    * given separable function using the given batch size.
82    *
83    * @param coordinates Coordinates to evaluate the function at.
84    * @param begin Index of first function to evaluate.
85    * @param batchSize Number of functions to evaluate.
86    */
Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)87   double Evaluate(const MatType& coordinates,
88                   const size_t begin,
89                   const size_t batchSize)
90   {
91     GradType gradient; // This will be ignored.
92     return static_cast<Function<FunctionType,
93                                 MatType,
94                                 GradType>*>(this)->EvaluateWithGradient(
95         coordinates, begin, gradient, batchSize);
96   }
97 };
98 
99 /**
100  * The AddSeparableEvaluateConst mixin class will add a separable const
101  * Evaluate() method if a separable const EvaluateWithGradient() function
102  * exists, or nothing otherwise.
103  */
104 template<typename FunctionType,
105          typename MatType,
106          typename GradType,
107          bool HasSeparableEvaluateWithGradient =
108              traits::HasEvaluateWithGradient<FunctionType,
109                  traits::TypedForms<MatType, GradType>::template
110                      SeparableEvaluateWithGradientConstForm>::value,
111          bool HasSeparableEvaluate =
112              traits::HasEvaluate<FunctionType,
113                  traits::TypedForms<MatType, GradType>::template
114                      SeparableEvaluateConstForm>::value>
115 class AddSeparableEvaluateConst
116 {
117  public:
118   // Provide a dummy overload so the name 'Evaluate' exists for this object.
119   typename MatType::elem_type Evaluate(traits::UnconstructableType&,
120                                        const size_t,
121                                        const size_t) const;
122 };
123 
124 /**
125  * Reflect the existing Evaluate().
126  */
127 template<typename FunctionType,
128          typename MatType,
129          typename GradType,
130          bool HasSeparableEvaluateWithGradient>
131 class AddSeparableEvaluateConst<FunctionType, MatType, GradType,
132     HasSeparableEvaluateWithGradient, true>
133 {
134  public:
135   // Reflect the existing Evaluate().
Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize) const136   typename MatType::elem_type Evaluate(const MatType& coordinates,
137                                        const size_t begin,
138                                        const size_t batchSize) const
139   {
140     return static_cast<const FunctionType*>(
141         static_cast<const Function<FunctionType,
142                                    MatType,
143                                    GradType>*>(this))->Evaluate(coordinates,
144                                                                 begin,
145                                                                 batchSize);
146   }
147 };
148 
149 /**
150  * If we have a separable const EvaluateWithGradient() but not a separable
151  * const Evaluate(), add a separable const Evaluate() method.
152  */
153 template<typename FunctionType, typename MatType, typename GradType>
154 class AddSeparableEvaluateConst<FunctionType, MatType, GradType, true, false>
155 {
156  public:
157   /**
158    * Return the objective function for the given coordinates, starting at the
159    * given separable function using the given batch size.
160    *
161    * @param coordinates Coordinates to evaluate the function at.
162    * @param begin Index of first function to evaluate.
163    * @param batchSize Number of functions to evaluate.
164    */
Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize) const165   typename MatType::elem_type Evaluate(const MatType& coordinates,
166                                        const size_t begin,
167                                        const size_t batchSize) const
168   {
169     GradType gradient; // This will be ignored.
170     return static_cast<const Function<FunctionType,
171                                       MatType,
172                                       GradType>*>(this)->EvaluateWithGradient(
173         coordinates, begin, gradient, batchSize);
174   }
175 };
176 
177 /**
178  * The AddSeparableEvaluateStatic mixin class will add a separable static
179  * Evaluate() method if a separable static EvaluateWithGradient() function
180  * exists, or nothing otherwise.
181  */
182 template<typename FunctionType,
183          typename MatType,
184          typename GradType,
185          bool HasSeparableEvaluateWithGradient =
186              traits::HasEvaluateWithGradient<FunctionType,
187                  traits::TypedForms<MatType, GradType>::template
188                      SeparableEvaluateWithGradientStaticForm>::value,
189          bool HasSeparableEvaluate =
190              traits::HasEvaluate<FunctionType,
191                  traits::TypedForms<MatType, GradType>::template
192                      SeparableEvaluateStaticForm>::value>
193 class AddSeparableEvaluateStatic
194 {
195  public:
196   // Provide a dummy overload so the name 'Evaluate' exists for this object.
197   static typename MatType::elem_type Evaluate(traits::UnconstructableType&,
198                                               const size_t,
199                                               const size_t);
200 };
201 
202 /**
203  * Reflect the existing Evaluate().
204  */
205 template<typename FunctionType,
206          typename MatType,
207          typename GradType,
208          bool HasSeparableEvaluateWithGradient>
209 class AddSeparableEvaluateStatic<FunctionType, MatType, GradType,
210     HasSeparableEvaluateWithGradient, true>
211 {
212  public:
213   // Reflect the existing Evaluate().
Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)214   static typename MatType::elem_type Evaluate(const MatType& coordinates,
215                                               const size_t begin,
216                                               const size_t batchSize)
217   {
218     return FunctionType::Evaluate(coordinates, begin, batchSize);
219   }
220 };
221 
222 /**
223  * If we have a separable EvaluateWithGradient() but not a separable
224  * Evaluate(), add a separable Evaluate() method.
225  */
226 template<typename FunctionType, typename MatType, typename GradType>
227 class AddSeparableEvaluateStatic<FunctionType, MatType, GradType, true,
228     false>
229 {
230  public:
231   /**
232    * Return the objective function for the given coordinates, starting at the
233    * given separable function using the given batch size.
234    *
235    * @param coordinates Coordinates to evaluate the function at.
236    * @param begin Index of first function to evaluate.
237    * @param batchSize Number of functions to evaluate.
238    */
Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)239   static typename MatType::elem_type Evaluate(const MatType& coordinates,
240                                               const size_t begin,
241                                               const size_t batchSize)
242   {
243     GradType gradient; // This will be ignored.
244     return FunctionType::EvaluateWithGradient(coordinates, begin, gradient,
245         batchSize);
246   }
247 };
248 
249 } // namespace ens
250 
251 #endif
252