1 /** 2 * @file add_separable_evaluate.hpp 3 * @author Ryan Curtin 4 * 5 * Adds a separable Evaluate() function if a separable 6 * EvaluateWithGradient() function exists. 7 * 8 * ensmallen is free software; you may redistribute it and/or modify it under 9 * the terms of the 3-clause BSD license. You should have received a copy of 10 * the 3-clause BSD license along with ensmallen. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef ENSMALLEN_FUNCTION_ADD_DECOMPOSABLE_EVALUATE_HPP 14 #define ENSMALLEN_FUNCTION_ADD_DECOMPOSABLE_EVALUATE_HPP 15 16 #include "traits.hpp" 17 18 namespace ens { 19 20 /** 21 * The AddSeparableEvaluate mixin class will add a separable Evaluate() 22 * method if a separable EvaluateWithGradient() function exists, or nothing 23 * otherwise. 24 */ 25 template<typename FunctionType, 26 typename MatType, 27 typename GradType, 28 bool HasSeparableEvaluateWithGradient = 29 traits::HasEvaluateWithGradient<FunctionType, 30 traits::TypedForms<MatType, GradType>::template 31 SeparableEvaluateWithGradientForm 32 >::value, 33 bool HasSeparableEvaluate = 34 traits::HasEvaluate<FunctionType, 35 traits::TypedForms<MatType, GradType>::template 36 SeparableEvaluateForm>::value> 37 class AddSeparableEvaluate 38 { 39 public: 40 // Provide a dummy overload so the name 'Evaluate' exists for this object. 41 typename MatType::elem_type Evaluate(traits::UnconstructableType&, 42 const size_t, 43 const size_t); 44 }; 45 46 /** 47 * Reflect the existing Evaluate(). 48 */ 49 template<typename FunctionType, 50 typename MatType, 51 typename GradType, 52 bool HasSeparableEvaluateWithGradient> 53 class AddSeparableEvaluate<FunctionType, MatType, GradType, 54 HasSeparableEvaluateWithGradient, true> 55 { 56 public: 57 // Reflect the existing Evaluate(). Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)58 typename MatType::elem_type Evaluate(const MatType& coordinates, 59 const size_t begin, 60 const size_t batchSize) 61 { 62 return static_cast<FunctionType*>( 63 static_cast<Function<FunctionType, 64 MatType, 65 GradType>*>(this))->Evaluate(coordinates, 66 begin, 67 batchSize); 68 } 69 }; 70 71 /** 72 * If we have a separable EvaluateWithGradient() but not a separable 73 * Evaluate(), add a separable Evaluate() method. 74 */ 75 template<typename FunctionType, typename MatType, typename GradType> 76 class AddSeparableEvaluate<FunctionType, MatType, GradType, true, false> 77 { 78 public: 79 /** 80 * Return the objective function for the given coordinates, starting at the 81 * given separable function using the given batch size. 82 * 83 * @param coordinates Coordinates to evaluate the function at. 84 * @param begin Index of first function to evaluate. 85 * @param batchSize Number of functions to evaluate. 86 */ Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)87 double Evaluate(const MatType& coordinates, 88 const size_t begin, 89 const size_t batchSize) 90 { 91 GradType gradient; // This will be ignored. 92 return static_cast<Function<FunctionType, 93 MatType, 94 GradType>*>(this)->EvaluateWithGradient( 95 coordinates, begin, gradient, batchSize); 96 } 97 }; 98 99 /** 100 * The AddSeparableEvaluateConst mixin class will add a separable const 101 * Evaluate() method if a separable const EvaluateWithGradient() function 102 * exists, or nothing otherwise. 103 */ 104 template<typename FunctionType, 105 typename MatType, 106 typename GradType, 107 bool HasSeparableEvaluateWithGradient = 108 traits::HasEvaluateWithGradient<FunctionType, 109 traits::TypedForms<MatType, GradType>::template 110 SeparableEvaluateWithGradientConstForm>::value, 111 bool HasSeparableEvaluate = 112 traits::HasEvaluate<FunctionType, 113 traits::TypedForms<MatType, GradType>::template 114 SeparableEvaluateConstForm>::value> 115 class AddSeparableEvaluateConst 116 { 117 public: 118 // Provide a dummy overload so the name 'Evaluate' exists for this object. 119 typename MatType::elem_type Evaluate(traits::UnconstructableType&, 120 const size_t, 121 const size_t) const; 122 }; 123 124 /** 125 * Reflect the existing Evaluate(). 126 */ 127 template<typename FunctionType, 128 typename MatType, 129 typename GradType, 130 bool HasSeparableEvaluateWithGradient> 131 class AddSeparableEvaluateConst<FunctionType, MatType, GradType, 132 HasSeparableEvaluateWithGradient, true> 133 { 134 public: 135 // Reflect the existing Evaluate(). Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize) const136 typename MatType::elem_type Evaluate(const MatType& coordinates, 137 const size_t begin, 138 const size_t batchSize) const 139 { 140 return static_cast<const FunctionType*>( 141 static_cast<const Function<FunctionType, 142 MatType, 143 GradType>*>(this))->Evaluate(coordinates, 144 begin, 145 batchSize); 146 } 147 }; 148 149 /** 150 * If we have a separable const EvaluateWithGradient() but not a separable 151 * const Evaluate(), add a separable const Evaluate() method. 152 */ 153 template<typename FunctionType, typename MatType, typename GradType> 154 class AddSeparableEvaluateConst<FunctionType, MatType, GradType, true, false> 155 { 156 public: 157 /** 158 * Return the objective function for the given coordinates, starting at the 159 * given separable function using the given batch size. 160 * 161 * @param coordinates Coordinates to evaluate the function at. 162 * @param begin Index of first function to evaluate. 163 * @param batchSize Number of functions to evaluate. 164 */ Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize) const165 typename MatType::elem_type Evaluate(const MatType& coordinates, 166 const size_t begin, 167 const size_t batchSize) const 168 { 169 GradType gradient; // This will be ignored. 170 return static_cast<const Function<FunctionType, 171 MatType, 172 GradType>*>(this)->EvaluateWithGradient( 173 coordinates, begin, gradient, batchSize); 174 } 175 }; 176 177 /** 178 * The AddSeparableEvaluateStatic mixin class will add a separable static 179 * Evaluate() method if a separable static EvaluateWithGradient() function 180 * exists, or nothing otherwise. 181 */ 182 template<typename FunctionType, 183 typename MatType, 184 typename GradType, 185 bool HasSeparableEvaluateWithGradient = 186 traits::HasEvaluateWithGradient<FunctionType, 187 traits::TypedForms<MatType, GradType>::template 188 SeparableEvaluateWithGradientStaticForm>::value, 189 bool HasSeparableEvaluate = 190 traits::HasEvaluate<FunctionType, 191 traits::TypedForms<MatType, GradType>::template 192 SeparableEvaluateStaticForm>::value> 193 class AddSeparableEvaluateStatic 194 { 195 public: 196 // Provide a dummy overload so the name 'Evaluate' exists for this object. 197 static typename MatType::elem_type Evaluate(traits::UnconstructableType&, 198 const size_t, 199 const size_t); 200 }; 201 202 /** 203 * Reflect the existing Evaluate(). 204 */ 205 template<typename FunctionType, 206 typename MatType, 207 typename GradType, 208 bool HasSeparableEvaluateWithGradient> 209 class AddSeparableEvaluateStatic<FunctionType, MatType, GradType, 210 HasSeparableEvaluateWithGradient, true> 211 { 212 public: 213 // Reflect the existing Evaluate(). Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)214 static typename MatType::elem_type Evaluate(const MatType& coordinates, 215 const size_t begin, 216 const size_t batchSize) 217 { 218 return FunctionType::Evaluate(coordinates, begin, batchSize); 219 } 220 }; 221 222 /** 223 * If we have a separable EvaluateWithGradient() but not a separable 224 * Evaluate(), add a separable Evaluate() method. 225 */ 226 template<typename FunctionType, typename MatType, typename GradType> 227 class AddSeparableEvaluateStatic<FunctionType, MatType, GradType, true, 228 false> 229 { 230 public: 231 /** 232 * Return the objective function for the given coordinates, starting at the 233 * given separable function using the given batch size. 234 * 235 * @param coordinates Coordinates to evaluate the function at. 236 * @param begin Index of first function to evaluate. 237 * @param batchSize Number of functions to evaluate. 238 */ Evaluate(const MatType & coordinates,const size_t begin,const size_t batchSize)239 static typename MatType::elem_type Evaluate(const MatType& coordinates, 240 const size_t begin, 241 const size_t batchSize) 242 { 243 GradType gradient; // This will be ignored. 244 return FunctionType::EvaluateWithGradient(coordinates, begin, gradient, 245 batchSize); 246 } 247 }; 248 249 } // namespace ens 250 251 #endif 252