1 /**
2  * @file add_gradient.hpp
3  * @author Ryan Curtin
4  *
5  * This file defines a mixin for the Function class that will ensure that the
6  * function Gradient() is avaiable if EvaluateWithGradient() is available.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_FUNCTION_ADD_GRADIENT_HPP
14 #define ENSMALLEN_FUNCTION_ADD_GRADIENT_HPP
15 
16 #include "traits.hpp"
17 
18 namespace ens {
19 
20 /**
21  * The AddGradient mixin class will provide a Gradient() method if the given
22  * FunctionType has EvaluateWithGradient(), or nothing otherwise.
23  */
24 template<typename FunctionType,
25          typename MatType,
26          typename GradType,
27          bool HasEvaluateWithGradient =
28              traits::HasEvaluateWithGradient<FunctionType,
29                  traits::TypedForms<MatType, GradType>::template
30                      EvaluateWithGradientForm
31              >::value,
32          bool HasGradient = traits::HasGradient<FunctionType,
33              traits::TypedForms<MatType, GradType>::template
34                      GradientForm>::value>
35 class AddGradient
36 {
37  public:
38   // Provide a dummy overload so the name 'Gradient' exists for this object.
Gradient(traits::UnconstructableType &)39   void Gradient(traits::UnconstructableType&) { }
40 };
41 
42 /**
43  * Reflect the existing Gradient().
44  */
45 template<typename FunctionType,
46          typename MatType,
47          typename GradType,
48          bool HasEvaluateWithGradient>
49 class AddGradient<FunctionType,
50                   MatType,
51                   GradType,
52                   HasEvaluateWithGradient,
53                   true>
54 {
55  public:
56   // Reflect the existing Gradient().
Gradient(const MatType & coordinates,GradType & gradient)57   void Gradient(const MatType& coordinates, GradType& gradient)
58   {
59     static_cast<FunctionType*>(
60         static_cast<Function<FunctionType,
61                              MatType,
62                              GradType>*>(this))->Gradient(coordinates,
63                                                           gradient);
64   }
65 };
66 
67 /**
68  * If we have EvaluateWithGradient() but no existing Gradient(), add an
69  * Gradient() without a using directive to make the base Gradient() accessible.
70  */
71 template<typename FunctionType, typename MatType, typename GradType>
72 class AddGradient<FunctionType, MatType, GradType, true, false>
73 {
74  public:
75   /**
76    * Calculate the gradient and store it in the given matrix.
77    *
78    * @param coordinates Coordinates to evaluate the function at.
79    * @param gradient Matrix to store the gradient into.
80    */
Gradient(const MatType & coordinates,GradType & gradient)81   void Gradient(const MatType& coordinates, GradType& gradient)
82   {
83     // The returned objective value will be ignored.
84     (void) static_cast<Function<FunctionType,
85                                 MatType,
86                                 GradType>*>(this)->EvaluateWithGradient(
87         coordinates, gradient);
88   }
89 };
90 
91 /**
92  * The AddGradient mixin class will provide a const Gradient() method if the
93  * given FunctionType has EvaluateWithGradient() const, or nothing otherwise.
94  */
95 template<typename FunctionType,
96          typename MatType,
97          typename GradType,
98          bool HasEvaluateWithGradient =
99              traits::HasEvaluateWithGradient<FunctionType,
100                  traits::TypedForms<MatType,
101                                     GradType>::template
102                      EvaluateWithGradientConstForm
103              >::value,
104          bool HasGradient = traits::HasGradient<FunctionType,
105              traits::TypedForms<MatType, GradType>::template GradientConstForm
106          >::value>
107 class AddGradientConst
108 {
109  public:
110   // Provide a dummy overload so the name 'Gradient' exists for this object.
Gradient(traits::UnconstructableType &) const111   void Gradient(traits::UnconstructableType&) const { }
112 };
113 
114 /**
115  * Reflect the existing Gradient().
116  */
117 template<typename FunctionType,
118          typename MatType,
119          typename GradType,
120          bool HasEvaluateWithGradient>
121 class AddGradientConst<FunctionType,
122                        MatType,
123                        GradType,
124                        HasEvaluateWithGradient,
125                        true>
126 {
127  public:
128   // Reflect the existing Gradient().
Gradient(const MatType & coordinates,GradType & gradient) const129   void Gradient(const MatType& coordinates, GradType& gradient) const
130   {
131     static_cast<const FunctionType*>(
132         static_cast<const Function<FunctionType,
133                                    MatType,
134                                    GradType>*>(this))->Gradient(coordinates,
135                                                                 gradient);
136   }
137 };
138 
139 /**
140  * If we have EvaluateWithGradient() but no existing Gradient(), add a
141  * Gradient() without a using directive to make the base Gradient() accessible.
142  */
143 template<typename FunctionType, typename MatType, typename GradType>
144 class AddGradientConst<FunctionType, MatType, GradType, true, false>
145 {
146  public:
147   /**
148    * Calculate the gradient and store it in the given matrix.
149    *
150    * @param coordinates Coordinates to evaluate the function at.
151    * @param gradient Matrix to store the gradient into.
152    */
Gradient(const MatType & coordinates,GradType & gradient) const153   void Gradient(const MatType& coordinates, GradType& gradient) const
154   {
155     // The returned objective value will be ignored.
156     (void) static_cast<
157         const Function<FunctionType,
158                        MatType,
159                        GradType>*>(this)->EvaluateWithGradient(coordinates,
160                                                                gradient);
161   }
162 };
163 
164 /**
165  * The AddGradient mixin class will provide a static Gradient() method if the
166  * given FunctionType has static EvaluateWithGradient(), or nothing otherwise.
167  */
168 template<typename FunctionType,
169          typename MatType,
170          typename GradType,
171          bool HasEvaluateWithGradient =
172              traits::HasEvaluateWithGradient<FunctionType,
173                  traits::TypedForms<MatType,
174                                     GradType>::template
175                      EvaluateWithGradientStaticForm
176              >::value,
177          bool HasGradient = traits::HasGradient<FunctionType,
178              traits::TypedForms<MatType, GradType>::template GradientStaticForm
179          >::value>
180 class AddGradientStatic
181 {
182  public:
183   // Provide a dummy overload so the name 'Gradient' exists for this object.
Gradient(traits::UnconstructableType &)184   static void Gradient(traits::UnconstructableType&) { }
185 };
186 
187 /**
188  * Reflect the existing Gradient().
189  */
190 template<typename FunctionType,
191          typename MatType,
192          typename GradType,
193          bool HasEvaluateWithGradient>
194 class AddGradientStatic<FunctionType,
195                         MatType,
196                         GradType,
197                         HasEvaluateWithGradient,
198                         true>
199 {
200  public:
201   // Reflect the existing Gradient().
Gradient(const MatType & coordinates,GradType & gradient)202   static void Gradient(const MatType& coordinates, GradType& gradient)
203   {
204     FunctionType::Gradient(coordinates, gradient);
205   }
206 };
207 
208 /**
209  * If we have EvaluateWithGradient() but no existing Gradient(), add a
210  * Gradient() without a using directive to make the base Gradient() accessible.
211  */
212 template<typename FunctionType, typename MatType, typename GradType>
213 class AddGradientStatic<FunctionType, MatType, GradType, true, false>
214 {
215  public:
216   /**
217    * Calculate the gradient and store it in the given matrix.
218    *
219    * @param coordinates Coordinates to evaluate the function at.
220    * @param gradient Matrix to store the gradient into.
221    */
Gradient(const MatType & coordinates,GradType & gradient)222   static void Gradient(const MatType& coordinates, GradType& gradient)
223   {
224     // The returned objective value will be ignored.
225     (void) FunctionType::EvaluateWithGradient(coordinates, gradient);
226   }
227 };
228 
229 } // namespace ens
230 
231 #endif
232