1 /** 2 * @file add_evaluate_with_gradient.hpp 3 * @author Ryan Curtin 4 * 5 * This file defines a mixin for the Function class that will ensure that the 6 * EvaluateWithGradient() function is available if possible. 7 * 8 * ensmallen is free software; you may redistribute it and/or modify it under 9 * the terms of the 3-clause BSD license. You should have received a copy of 10 * the 3-clause BSD license along with ensmallen. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef ENSMALLEN_FUNCTION_ADD_EVALUATE_WITH_GRADIENT_HPP 14 #define ENSMALLEN_FUNCTION_ADD_EVALUATE_WITH_GRADIENT_HPP 15 16 #include "sfinae_utility.hpp" 17 #include "traits.hpp" 18 19 namespace ens { 20 21 /** 22 * The AddEvaluateWithGradient mixin class will provide an 23 * EvaluateWithGradient() method if the given FunctionType has both Evaluate() 24 * and Gradient(), or it will provide nothing otherwise. 25 */ 26 template<typename FunctionType, 27 typename MatType, 28 typename GradType, 29 // Check if there is at least one non-const Evaluate() or Gradient(). 30 bool HasEvaluateGradient = traits::HasNonConstSignatures< 31 FunctionType, 32 traits::HasEvaluate, 33 traits::TypedForms<MatType, GradType>::template EvaluateForm, 34 traits::TypedForms<MatType, GradType>::template EvaluateConstForm, 35 traits::TypedForms<MatType, GradType>::template EvaluateStaticForm, 36 traits::HasGradient, 37 traits::TypedForms<MatType, GradType>::template GradientForm, 38 traits::TypedForms<MatType, GradType>::template GradientConstForm, 39 traits::TypedForms<MatType, GradType>::template GradientStaticForm 40 >::value, 41 bool HasEvaluateWithGradient = traits::HasEvaluateWithGradient< 42 FunctionType, 43 traits::TypedForms<MatType, GradType>::template 44 EvaluateWithGradientForm>::value> 45 class AddEvaluateWithGradient 46 { 47 public: 48 // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this 49 // object. 50 typename MatType::elem_type EvaluateWithGradient( 51 traits::UnconstructableType&); 52 }; 53 54 /** 55 * Reflect the existing EvaluateWithGradient(). 56 */ 57 template<typename FunctionType, 58 typename MatType, 59 typename GradType, 60 bool HasEvaluateGradient> 61 class AddEvaluateWithGradient<FunctionType, 62 MatType, 63 GradType, 64 HasEvaluateGradient, 65 true> 66 { 67 public: 68 // Reflect the existing EvaluateWithGradient(). EvaluateWithGradient(const MatType & coordinates,GradType & gradient)69 typename MatType::elem_type EvaluateWithGradient( 70 const MatType& coordinates, GradType& gradient) 71 { 72 return static_cast<FunctionType*>( 73 static_cast<Function<FunctionType, 74 MatType, 75 GradType>*>(this))->EvaluateWithGradient( 76 coordinates, gradient); 77 } 78 }; 79 80 /** 81 * If the FunctionType has Evaluate() and Gradient(), provide 82 * EvaluateWithGradient(). 83 */ 84 template<typename FunctionType, typename MatType, typename GradType> 85 class AddEvaluateWithGradient<FunctionType, MatType, GradType, true, false> 86 { 87 public: 88 /** 89 * Return both the evaluated objective function and its gradient, storing the 90 * gradient in the given matrix. 91 * 92 * @param coordinates Coordinates to evaluate the function at. 93 * @param gradient Matrix to store the gradient into. 94 */ EvaluateWithGradient(const MatType & coordinates,GradType & gradient)95 typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates, 96 GradType& gradient) 97 { 98 const typename MatType::elem_type objective = 99 static_cast<Function<FunctionType, 100 MatType, GradType>*>(this)->Evaluate(coordinates); 101 static_cast<Function<FunctionType, 102 MatType, 103 GradType>*>(this)->Gradient(coordinates, gradient); 104 return objective; 105 } 106 }; 107 108 /** 109 * The AddEvaluateWithGradient mixin class will provide an 110 * EvaluateWithGradient() const method if the given FunctionType has both 111 * Evaluate() const and Gradient() const, or it will provide nothing otherwise. 112 */ 113 template<typename FunctionType, 114 typename MatType, 115 typename GradType, 116 // Check if there is at least one const Evaluate() or Gradient(). 117 bool HasEvaluateGradient = traits::HasConstSignatures< 118 FunctionType, 119 traits::HasEvaluate, 120 traits::TypedForms<MatType, GradType>::template EvaluateConstForm, 121 traits::TypedForms<MatType, GradType>::template EvaluateStaticForm, 122 traits::HasGradient, 123 traits::TypedForms<MatType, GradType>::template GradientConstForm, 124 traits::TypedForms<MatType, GradType>::template GradientStaticForm 125 >::value, 126 bool HasEvaluateWithGradient = traits::HasEvaluateWithGradient< 127 FunctionType, 128 traits::TypedForms< 129 MatType, GradType 130 >::template EvaluateWithGradientConstForm>::value> 131 class AddEvaluateWithGradientConst 132 { 133 public: 134 // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this 135 // object. 136 typename MatType::elem_type EvaluateWithGradient( 137 traits::UnconstructableType&) const; 138 }; 139 140 /** 141 * Reflect the existing EvaluateWithGradient(). 142 */ 143 template<typename FunctionType, 144 typename MatType, 145 typename GradType, 146 bool HasEvaluateGradient> 147 class AddEvaluateWithGradientConst<FunctionType, 148 MatType, 149 GradType, 150 HasEvaluateGradient, 151 true> 152 { 153 public: 154 // Reflect the existing EvaluateWithGradient(). EvaluateWithGradient(const MatType & coordinates,GradType & gradient) const155 typename MatType::elem_type EvaluateWithGradient( 156 const MatType& coordinates, GradType& gradient) const 157 { 158 return static_cast<const FunctionType*>( 159 static_cast<const Function<FunctionType, 160 MatType, 161 GradType>*>(this))->EvaluateWithGradient( 162 coordinates, gradient); 163 } 164 }; 165 166 /** 167 * If the FunctionType has Evaluate() const and Gradient() const, provide 168 * EvaluateWithGradient() const. 169 */ 170 template<typename FunctionType, typename MatType, typename GradType> 171 class AddEvaluateWithGradientConst<FunctionType, MatType, GradType, true, false> 172 { 173 public: 174 /** 175 * Return both the evaluated objective function and its gradient, storing the 176 * gradient in the given matrix. 177 * 178 * @param coordinates Coordinates to evaluate the function at. 179 * @param gradient Matrix to store the gradient into. 180 */ EvaluateWithGradient(const MatType & coordinates,GradType & gradient) const181 typename MatType::elem_type EvaluateWithGradient(const MatType& coordinates, 182 GradType& gradient) const 183 { 184 const typename MatType::elem_type objective = 185 static_cast<const Function<FunctionType, 186 MatType, 187 GradType>*>(this)->Evaluate(coordinates); 188 static_cast<const Function<FunctionType, 189 MatType, 190 GradType>*>(this)->Gradient(coordinates, 191 gradient); 192 return objective; 193 } 194 }; 195 196 /** 197 * The AddEvaluateWithGradientStatic mixin class will provide a 198 * static EvaluateWithGradient() method if the given FunctionType has both 199 * static Evaluate() and static Gradient(), or it will provide nothing 200 * otherwise. 201 */ 202 template<typename FunctionType, 203 typename MatType, 204 typename GradType, 205 bool HasEvaluateGradient = 206 traits::HasEvaluate<FunctionType, 207 traits::TypedForms<MatType, GradType>::template 208 EvaluateStaticForm 209 >::value && 210 traits::HasGradient<FunctionType, 211 traits::TypedForms<MatType, GradType>::template 212 GradientStaticForm 213 >::value, 214 bool HasEvaluateWithGradient = 215 traits::HasEvaluateWithGradient<FunctionType, 216 traits::TypedForms<MatType, 217 GradType>::template 218 EvaluateWithGradientStaticForm 219 >::value> 220 class AddEvaluateWithGradientStatic 221 { 222 public: 223 // Provide a dummy overload so the name 'EvaluateWithGradient' exists for this 224 // object. 225 static typename MatType::elem_type EvaluateWithGradient( 226 traits::UnconstructableType&); 227 }; 228 229 /** 230 * Reflect the existing EvaluateWithGradient(). 231 */ 232 template<typename FunctionType, 233 typename MatType, 234 typename GradType, 235 bool HasEvaluateGradient> 236 class AddEvaluateWithGradientStatic<FunctionType, 237 MatType, 238 GradType, 239 HasEvaluateGradient, 240 true> 241 { 242 public: 243 // Reflect the existing EvaluateWithGradient(). EvaluateWithGradient(const MatType & coordinates,GradType & gradient)244 static typename MatType::elem_type EvaluateWithGradient( 245 const MatType& coordinates, GradType& gradient) 246 { 247 return FunctionType::EvaluateWithGradient(coordinates, gradient); 248 } 249 }; 250 251 /** 252 * If the FunctionType has static Evaluate() and static Gradient(), provide 253 * static EvaluateWithGradient(). 254 */ 255 template<typename FunctionType, typename MatType, typename GradType> 256 class AddEvaluateWithGradientStatic<FunctionType, 257 MatType, 258 GradType, 259 true, 260 false> 261 { 262 public: 263 /** 264 * Return both the evaluated objective function and its gradient, storing the 265 * gradient in the given matrix. 266 * 267 * @param coordinates Coordinates to evaluate the function at. 268 * @param gradient Matrix to store the gradient into. 269 */ EvaluateWithGradient(const MatType & coordinates,GradType & gradient)270 static typename MatType::elem_type EvaluateWithGradient( 271 const MatType& coordinates, GradType& gradient) 272 { 273 const typename MatType::elem_type objective = 274 FunctionType::Evaluate(coordinates); 275 FunctionType::Gradient(coordinates, gradient); 276 return objective; 277 } 278 }; 279 280 } // namespace ens 281 282 #endif 283