1 /** 2 * @file add_gradient.hpp 3 * @author Ryan Curtin 4 * 5 * This file defines a mixin for the Function class that will ensure that the 6 * function Gradient() is avaiable if EvaluateWithGradient() is available. 7 * 8 * ensmallen is free software; you may redistribute it and/or modify it under 9 * the terms of the 3-clause BSD license. You should have received a copy of 10 * the 3-clause BSD license along with ensmallen. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef ENSMALLEN_FUNCTION_ADD_GRADIENT_HPP 14 #define ENSMALLEN_FUNCTION_ADD_GRADIENT_HPP 15 16 #include "traits.hpp" 17 18 namespace ens { 19 20 /** 21 * The AddGradient mixin class will provide a Gradient() method if the given 22 * FunctionType has EvaluateWithGradient(), or nothing otherwise. 23 */ 24 template<typename FunctionType, 25 typename MatType, 26 typename GradType, 27 bool HasEvaluateWithGradient = 28 traits::HasEvaluateWithGradient<FunctionType, 29 traits::TypedForms<MatType, GradType>::template 30 EvaluateWithGradientForm 31 >::value, 32 bool HasGradient = traits::HasGradient<FunctionType, 33 traits::TypedForms<MatType, GradType>::template 34 GradientForm>::value> 35 class AddGradient 36 { 37 public: 38 // Provide a dummy overload so the name 'Gradient' exists for this object. Gradient(traits::UnconstructableType &)39 void Gradient(traits::UnconstructableType&) { } 40 }; 41 42 /** 43 * Reflect the existing Gradient(). 44 */ 45 template<typename FunctionType, 46 typename MatType, 47 typename GradType, 48 bool HasEvaluateWithGradient> 49 class AddGradient<FunctionType, 50 MatType, 51 GradType, 52 HasEvaluateWithGradient, 53 true> 54 { 55 public: 56 // Reflect the existing Gradient(). Gradient(const MatType & coordinates,GradType & gradient)57 void Gradient(const MatType& coordinates, GradType& gradient) 58 { 59 static_cast<FunctionType*>( 60 static_cast<Function<FunctionType, 61 MatType, 62 GradType>*>(this))->Gradient(coordinates, 63 gradient); 64 } 65 }; 66 67 /** 68 * If we have EvaluateWithGradient() but no existing Gradient(), add an 69 * Gradient() without a using directive to make the base Gradient() accessible. 70 */ 71 template<typename FunctionType, typename MatType, typename GradType> 72 class AddGradient<FunctionType, MatType, GradType, true, false> 73 { 74 public: 75 /** 76 * Calculate the gradient and store it in the given matrix. 77 * 78 * @param coordinates Coordinates to evaluate the function at. 79 * @param gradient Matrix to store the gradient into. 80 */ Gradient(const MatType & coordinates,GradType & gradient)81 void Gradient(const MatType& coordinates, GradType& gradient) 82 { 83 // The returned objective value will be ignored. 84 (void) static_cast<Function<FunctionType, 85 MatType, 86 GradType>*>(this)->EvaluateWithGradient( 87 coordinates, gradient); 88 } 89 }; 90 91 /** 92 * The AddGradient mixin class will provide a const Gradient() method if the 93 * given FunctionType has EvaluateWithGradient() const, or nothing otherwise. 94 */ 95 template<typename FunctionType, 96 typename MatType, 97 typename GradType, 98 bool HasEvaluateWithGradient = 99 traits::HasEvaluateWithGradient<FunctionType, 100 traits::TypedForms<MatType, 101 GradType>::template 102 EvaluateWithGradientConstForm 103 >::value, 104 bool HasGradient = traits::HasGradient<FunctionType, 105 traits::TypedForms<MatType, GradType>::template GradientConstForm 106 >::value> 107 class AddGradientConst 108 { 109 public: 110 // Provide a dummy overload so the name 'Gradient' exists for this object. Gradient(traits::UnconstructableType &) const111 void Gradient(traits::UnconstructableType&) const { } 112 }; 113 114 /** 115 * Reflect the existing Gradient(). 116 */ 117 template<typename FunctionType, 118 typename MatType, 119 typename GradType, 120 bool HasEvaluateWithGradient> 121 class AddGradientConst<FunctionType, 122 MatType, 123 GradType, 124 HasEvaluateWithGradient, 125 true> 126 { 127 public: 128 // Reflect the existing Gradient(). Gradient(const MatType & coordinates,GradType & gradient) const129 void Gradient(const MatType& coordinates, GradType& gradient) const 130 { 131 static_cast<const FunctionType*>( 132 static_cast<const Function<FunctionType, 133 MatType, 134 GradType>*>(this))->Gradient(coordinates, 135 gradient); 136 } 137 }; 138 139 /** 140 * If we have EvaluateWithGradient() but no existing Gradient(), add a 141 * Gradient() without a using directive to make the base Gradient() accessible. 142 */ 143 template<typename FunctionType, typename MatType, typename GradType> 144 class AddGradientConst<FunctionType, MatType, GradType, true, false> 145 { 146 public: 147 /** 148 * Calculate the gradient and store it in the given matrix. 149 * 150 * @param coordinates Coordinates to evaluate the function at. 151 * @param gradient Matrix to store the gradient into. 152 */ Gradient(const MatType & coordinates,GradType & gradient) const153 void Gradient(const MatType& coordinates, GradType& gradient) const 154 { 155 // The returned objective value will be ignored. 156 (void) static_cast< 157 const Function<FunctionType, 158 MatType, 159 GradType>*>(this)->EvaluateWithGradient(coordinates, 160 gradient); 161 } 162 }; 163 164 /** 165 * The AddGradient mixin class will provide a static Gradient() method if the 166 * given FunctionType has static EvaluateWithGradient(), or nothing otherwise. 167 */ 168 template<typename FunctionType, 169 typename MatType, 170 typename GradType, 171 bool HasEvaluateWithGradient = 172 traits::HasEvaluateWithGradient<FunctionType, 173 traits::TypedForms<MatType, 174 GradType>::template 175 EvaluateWithGradientStaticForm 176 >::value, 177 bool HasGradient = traits::HasGradient<FunctionType, 178 traits::TypedForms<MatType, GradType>::template GradientStaticForm 179 >::value> 180 class AddGradientStatic 181 { 182 public: 183 // Provide a dummy overload so the name 'Gradient' exists for this object. Gradient(traits::UnconstructableType &)184 static void Gradient(traits::UnconstructableType&) { } 185 }; 186 187 /** 188 * Reflect the existing Gradient(). 189 */ 190 template<typename FunctionType, 191 typename MatType, 192 typename GradType, 193 bool HasEvaluateWithGradient> 194 class AddGradientStatic<FunctionType, 195 MatType, 196 GradType, 197 HasEvaluateWithGradient, 198 true> 199 { 200 public: 201 // Reflect the existing Gradient(). Gradient(const MatType & coordinates,GradType & gradient)202 static void Gradient(const MatType& coordinates, GradType& gradient) 203 { 204 FunctionType::Gradient(coordinates, gradient); 205 } 206 }; 207 208 /** 209 * If we have EvaluateWithGradient() but no existing Gradient(), add a 210 * Gradient() without a using directive to make the base Gradient() accessible. 211 */ 212 template<typename FunctionType, typename MatType, typename GradType> 213 class AddGradientStatic<FunctionType, MatType, GradType, true, false> 214 { 215 public: 216 /** 217 * Calculate the gradient and store it in the given matrix. 218 * 219 * @param coordinates Coordinates to evaluate the function at. 220 * @param gradient Matrix to store the gradient into. 221 */ Gradient(const MatType & coordinates,GradType & gradient)222 static void Gradient(const MatType& coordinates, GradType& gradient) 223 { 224 // The returned objective value will be ignored. 225 (void) FunctionType::EvaluateWithGradient(coordinates, gradient); 226 } 227 }; 228 229 } // namespace ens 230 231 #endif 232