1 /** 2 * @file traits.hpp 3 * @author Ryan Curtin 4 * 5 * This file provides metaprogramming utilities for detecting certain members of 6 * FunctionType classes. 7 * 8 * ensmallen is free software; you may redistribute it and/or modify it under 9 * the terms of the 3-clause BSD license. You should have received a copy of 10 * the 3-clause BSD license along with ensmallen. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef ENSMALLEN_FUNCTION_TRAITS_HPP 14 #define ENSMALLEN_FUNCTION_TRAITS_HPP 15 16 #include "sfinae_utility.hpp" 17 #include "arma_traits.hpp" 18 19 namespace ens { 20 namespace traits { 21 22 //! Detect an Evaluate() method. 23 ENS_HAS_EXACT_METHOD_FORM(Evaluate, HasEvaluate) 24 //! Detect a Gradient() method. 25 ENS_HAS_EXACT_METHOD_FORM(Gradient, HasGradient) 26 //! Detect an EvaluateWithGradient() method. 27 ENS_HAS_EXACT_METHOD_FORM(EvaluateWithGradient, HasEvaluateWithGradient) 28 //! Detect a NumFunctions() method. 29 ENS_HAS_EXACT_METHOD_FORM(NumFunctions, HasNumFunctions) 30 //! Detect a Shuffle() method. 31 ENS_HAS_EXACT_METHOD_FORM(Shuffle, HasShuffle) 32 //! Detect a NumConstraints() method. 33 ENS_HAS_EXACT_METHOD_FORM(NumConstraints, HasNumConstraints) 34 //! Detect an EvaluateConstraint() method. 35 ENS_HAS_EXACT_METHOD_FORM(EvaluateConstraint, HasEvaluateConstraint) 36 //! Detect a GradientConstraint() method. 37 ENS_HAS_EXACT_METHOD_FORM(GradientConstraint, HasGradientConstraint) 38 //! Detect a NumFeatures() method. 39 ENS_HAS_EXACT_METHOD_FORM(NumFeatures, HasNumFeatures) 40 //! Detect a PartialGradient() method. 41 ENS_HAS_EXACT_METHOD_FORM(PartialGradient, HasPartialGradient) 42 //! Detect an MaxIterations() method. 43 ENS_HAS_EXACT_METHOD_FORM(MaxIterations, HasMaxIterations) 44 //! Detect an ResetPolicy() method. 45 ENS_HAS_EXACT_METHOD_FORM(ResetPolicy, HasResetPolicy) 46 //! Detect an BatchSize() method. 47 ENS_HAS_EXACT_METHOD_FORM(BatchSize, HasBatchSize) 48 //! Detect an StepSize() method. 49 ENS_HAS_EXACT_METHOD_FORM(StepSize, HasStepSize) 50 51 template<typename MatType, typename GradType> 52 struct TypedForms 53 { 54 typedef typename MatTypeTraits<MatType>::BaseMatType BaseMatType; 55 typedef typename MatTypeTraits<GradType>::BaseMatType BaseGradType; 56 57 //! This is the form of a non-const Evaluate() method. 58 template<typename FunctionType> 59 using EvaluateForm = 60 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&); 61 62 //! This is the form of a const Evaluate() method. 63 template<typename FunctionType> 64 using EvaluateConstForm = typename BaseMatType::elem_type(FunctionType::*)( 65 const BaseMatType&) const; 66 67 //! This is the form of a static Evaluate() method. 68 template<typename FunctionType> 69 using EvaluateStaticForm = typename BaseMatType::elem_type(*)( 70 const BaseMatType&); 71 72 //! This is the form of a non-const Gradient() method. 73 template<typename FunctionType> 74 using GradientForm = void(FunctionType::*)(const BaseMatType&, BaseGradType&); 75 76 //! This is the form of a const Gradient() method. 77 template<typename FunctionType> 78 using GradientConstForm = 79 void(FunctionType::*)(const BaseMatType&, BaseGradType&) const; 80 81 //! This is the form of a static Gradient() method. 82 template<typename FunctionType> 83 using GradientStaticForm = void(*)(const BaseMatType&, BaseGradType&); 84 85 //! This is the form of a non-const EvaluateWithGradient() method. 86 template<typename FunctionType> 87 using EvaluateWithGradientForm = 88 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&, 89 BaseGradType&); 90 91 //! This is the form of a const EvaluateWithGradient() method. 92 template<typename FunctionType> 93 using EvaluateWithGradientConstForm = 94 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&, 95 BaseGradType&) const; 96 97 //! This is the form of a static EvaluateWithGradient() method. 98 template<typename FunctionType> 99 using EvaluateWithGradientStaticForm = typename BaseMatType::elem_type(*)( 100 const BaseMatType&, BaseGradType&); 101 102 //! This is the form of a non-const NumFunctions() method. 103 template <typename FunctionType> 104 using NumFunctionsForm = size_t(FunctionType::*)(); 105 106 //! This is the form of a const NumFunctions() method. 107 template <typename FunctionType> 108 using NumFunctionsConstForm = size_t(FunctionType::*)() const; 109 110 //! This is the form of a static NumFunctions() method. 111 template<typename FunctionType> 112 using NumFunctionsStaticForm = size_t(*)(); 113 114 //! This is the form of a non-const Shuffle() method. 115 template<typename FunctionType> 116 using ShuffleForm = void(FunctionType::*)(); 117 118 //! This is the form of a const Shuffle() method. 119 template<typename FunctionType> 120 using ShuffleConstForm = void(FunctionType::*)() const; 121 122 //! This is the form of a static Shuffle() method. 123 template<typename FunctionType> 124 using ShuffleStaticForm = void(*)(); 125 126 //! This is the form of a separable Evaluate() method. 127 template<typename FunctionType> 128 using SeparableEvaluateForm = 129 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&, 130 const size_t, 131 const size_t); 132 133 //! This is the form of a separable const Evaluate() method. 134 template<typename FunctionType> 135 using SeparableEvaluateConstForm = 136 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&, 137 const size_t, 138 const size_t) const; 139 140 //! This is the form of a separable static Evaluate() method. 141 template<typename FunctionType> 142 using SeparableEvaluateStaticForm = typename BaseMatType::elem_type(*)( 143 const BaseMatType&, const size_t, const size_t); 144 145 //! This is the form of a separable non-const Gradient() method. 146 template<typename FunctionType> 147 using SeparableGradientForm = void(FunctionType::*)( 148 const BaseMatType&, const size_t, BaseGradType&, const size_t); 149 150 //! This the form of a separable const Gradient() method. 151 template<typename FunctionType> 152 using SeparableGradientConstForm = void(FunctionType::*)( 153 const BaseMatType&, const size_t, BaseGradType&, const size_t) const; 154 155 //! This is the form of a separable static Gradient() method. 156 template<typename FunctionType> 157 using SeparableGradientStaticForm = void(*)( 158 const BaseMatType&, const size_t, BaseGradType&, const size_t); 159 160 //! This is the form of a separable non-const EvaluateWithGradient() 161 //! method. 162 template<typename FunctionType> 163 using SeparableEvaluateWithGradientForm = 164 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&, 165 const size_t, 166 BaseGradType&, 167 const size_t); 168 169 //! This is the form of a separable const EvaluateWithGradient() method. 170 template<typename FunctionType> 171 using SeparableEvaluateWithGradientConstForm = 172 typename BaseMatType::elem_type(FunctionType::*)(const BaseMatType&, 173 const size_t, 174 BaseGradType&, 175 const size_t) const; 176 177 //! This is the form of a separable static EvaluateWithGradient() method. 178 template<typename FunctionType> 179 using SeparableEvaluateWithGradientStaticForm = 180 typename BaseMatType::elem_type(*)(const BaseMatType&, 181 const size_t, 182 BaseGradType&, 183 const size_t); 184 185 //! This is the form of a non-const NumConstraints() method. 186 template<typename FunctionType> 187 using NumConstraintsForm = size_t(FunctionType::*)(); 188 189 //! This is the form of a const NumConstraints() method. 190 template<typename FunctionType> 191 using NumConstraintsConstForm = size_t(FunctionType::*)() const; 192 193 //! This is the form of a static NumConstraints() method. 194 template<typename FunctionType> 195 using NumConstraintsStaticForm = size_t(*)(); 196 197 //! This is the form of a non-const EvaluateConstraint() method. 198 template <typename FunctionType> 199 using EvaluateConstraintForm = 200 typename BaseMatType::elem_type(FunctionType::*)(const size_t, 201 const BaseMatType&); 202 203 //! This is the form of a const EvaluateConstraint() method. 204 template<typename FunctionType> 205 using EvaluateConstraintConstForm = 206 typename BaseMatType::elem_type(FunctionType::*)(const size_t, 207 const BaseMatType&) 208 const; 209 210 //! This is the form of a static EvaluateConstraint() method. 211 template<typename FunctionType> 212 using EvaluateConstraintStaticForm = typename BaseMatType::elem_type(*)( 213 const size_t, const BaseMatType&); 214 215 //! This is the form of a non-const GradientConstraint() method. 216 template <typename FunctionType> 217 using GradientConstraintForm = void(FunctionType::*)( 218 const size_t, const BaseMatType&, BaseGradType&); 219 220 //! This is the form of a const GradientConstraint() method. 221 template<typename FunctionType> 222 using GradientConstraintConstForm = void(FunctionType::*)( 223 const size_t, const BaseMatType&, BaseGradType&) const; 224 225 //! This is the form of a static GradientConstraint() method. 226 template<typename Class, typename... Ts> 227 using GradientConstraintStaticForm = void(*)( 228 const size_t, const BaseMatType&, BaseGradType&); 229 230 //! This is the form of a non-const sparse Gradient() method. 231 //! This check isn't particularly useful---the user needs to specify a sparse 232 //! gradient type... 233 template<typename FunctionType> 234 using SparseGradientForm = void(FunctionType::*)( 235 const BaseMatType&, const size_t, BaseGradType&, const size_t); 236 237 //! This is the form of a const sparse Gradient() method. 238 //! This check isn't particularly useful---the user needs to specify a sparse 239 //! gradient type... 240 template<typename FunctionType> 241 using SparseGradientConstForm = void(FunctionType::*)( 242 const BaseMatType&, const size_t, BaseGradType&, const size_t) const; 243 244 //! This is the form of a static sparse Gradient() method. 245 //! This check isn't particularly useful---the user needs to specify a sparse 246 //! gradient type... 247 template<typename FunctionType> 248 using SparseGradientStaticForm = void(*)( 249 const BaseMatType&, const size_t, BaseGradType&, const size_t); 250 251 //! This is the form of a non-const NumFeatures() method. 252 template<typename FunctionType> 253 using NumFeaturesForm = size_t(FunctionType::*)(); 254 255 //! This is the form of a const NumFeatures() method. 256 template<typename FunctionType> 257 using NumFeaturesConstForm = size_t(FunctionType::*)() const; 258 259 //! This is the form of a static NumFeatures() method. 260 template<typename FunctionType> 261 using NumFeaturesStaticForm = size_t(*)(); 262 263 //! This is the form of a non-const PartialGradient() method. 264 template<typename FunctionType> 265 using PartialGradientForm = void(FunctionType::*)( 266 const BaseMatType&, const size_t, BaseGradType&); 267 268 //! This is the form of a const PartialGradient() method. 269 template<typename FunctionType> 270 using PartialGradientConstForm = void(FunctionType::*)( 271 const BaseMatType&, const size_t, BaseGradType&) const; 272 273 //! This is the form of a static PartialGradient() method. 274 template<typename FunctionType> 275 using PartialGradientStaticForm = void(*)( 276 const BaseMatType&, const size_t, BaseGradType&); 277 278 //! This is a utility struct that will match any non-const form. 279 template<typename FunctionType, typename... Ts> 280 using OtherForm = typename BaseMatType::elem_type(FunctionType::*)(Ts...); 281 282 //! This is a utility struct that will match any const form. 283 template<typename FunctionType, typename... Ts> 284 using OtherConstForm = typename BaseMatType::elem_type(FunctionType::*)(Ts...) 285 const; 286 287 //! This is a utility struct that will match any static form. 288 template<typename FunctionType, typename... Ts> 289 using OtherStaticForm = typename BaseMatType::elem_type(*)(Ts...); 290 }; 291 292 /** 293 * This is a utility type used to provide unusable overloads from each of the 294 * mixin classes. If you are seeing an error mentioning this class, the most 295 * likely issue is that you have not implemented the right methods for your 296 * FunctionType class. 297 */ 298 struct UnconstructableType 299 { 300 private: UnconstructableTypeens::traits::UnconstructableType301 UnconstructableType() { } 302 }; 303 304 /** 305 * Utility struct: sometimes we want to know if we have two functions available, 306 * and that at least one of them is non-const and non-static. If the 307 * corresponding checkers (from ENS_HAS_METHOD_FORM()) are given as CheckerA and 308 * CheckerB, and the corresponding non-const, const, and static function 309 * signatures are given as SignatureA, ConstSignatureA, StaticSignatureA, 310 * SignatureB, ConstSignatureB, and StaticSignatureB, then 'value' will be true 311 * if methods with the correct names exist in the given ClassType and at least 312 * one of those two methods is non-const and non-static. 313 */ 314 template<typename ClassType, 315 template<typename, template<typename...> class, size_t> class CheckerA, 316 template<typename...> class SignatureA, 317 template<typename...> class ConstSignatureA, 318 template<typename...> class StaticSignatureA, 319 template<typename, template<typename...> class, size_t> class CheckerB, 320 template<typename...> class SignatureB, 321 template<typename...> class ConstSignatureB, 322 template<typename...> class StaticSignatureB> 323 struct HasNonConstSignatures 324 { 325 // Check if any const or static version of method A exists. 326 const static bool HasAnyFormA = 327 CheckerA<ClassType, SignatureA, 0>::value || 328 CheckerA<ClassType, ConstSignatureA, 0>::value || 329 CheckerA<ClassType, StaticSignatureA, 0>::value; 330 // Check if any const or static versino of method B exists. 331 const static bool HasAnyFormB = 332 CheckerB<ClassType, SignatureB, 0>::value || 333 CheckerB<ClassType, ConstSignatureB, 0>::value || 334 CheckerB<ClassType, StaticSignatureB, 0>::value; 335 336 // Make sure at least one const version exists. 337 const static bool HasEitherNonConstForm = 338 CheckerA<ClassType, SignatureA, 0>::value || 339 CheckerB<ClassType, SignatureB, 0>::value; 340 341 const static bool value = HasEitherNonConstForm && HasAnyFormA && HasAnyFormB; 342 }; 343 344 /** 345 * Utility struct: sometimes we want to know if we have two functions available, 346 * and that at least one of them is const and both of them are not non-const and 347 * non-static. If the corresponding checkers (from ENS_HAS_METHOD_FORM()) are 348 * given as CheckerA and CheckerB, and the corresponding const and static 349 * function signatures are given as ConstSignatureA, StaticSignatureA, 350 * ConstSignatureB, and StaticSignatureB, then 'value' will be true if methods 351 * with the correct names exist in the given ClassType and at least one of those 352 * two methods is const, and neither method is non-const and non-static. 353 */ 354 template<typename ClassType, 355 template<typename, template<typename...> class, size_t> class CheckerA, 356 template<typename...> class ConstSignatureA, 357 template<typename...> class StaticSignatureA, 358 template<typename, template<typename...> class, size_t> class CheckerB, 359 template<typename...> class ConstSignatureB, 360 template<typename...> class StaticSignatureB> 361 struct HasConstSignatures 362 { 363 // Check if any const or static version of method A exists. 364 const static bool HasAnyFormA = 365 CheckerA<ClassType, ConstSignatureA, 0>::value || 366 CheckerA<ClassType, StaticSignatureA, 0>::value; 367 // Check if any const or static version of method B exists. 368 const static bool HasAnyFormB = 369 CheckerB<ClassType, ConstSignatureB, 0>::value || 370 CheckerB<ClassType, StaticSignatureB, 0>::value; 371 372 // Make sure at least one const version exists. 373 const static bool HasEitherConstForm = 374 CheckerA<ClassType, ConstSignatureA, 0>::value || 375 CheckerB<ClassType, ConstSignatureB, 0>::value; 376 377 const static bool value = HasEitherConstForm && HasAnyFormA && HasAnyFormB; 378 }; 379 380 //! Utility struct, check if size_t BatchSize() const or size_t BatchSize() 381 //! exists. 382 template<typename OptimizerType> 383 struct HasBatchSizeSignature 384 { 385 template<typename C> 386 using BatchSizeConstForm = size_t(C::*)(void) const; 387 388 template<typename C> 389 using BatchSizeForm = size_t(C::*)(void); 390 391 const static bool value = 392 HasBatchSize<OptimizerType, BatchSizeForm>::value || 393 HasBatchSize<OptimizerType, BatchSizeConstForm>::value; 394 }; 395 396 //! Utility struct, check if size_t StepSize() const or size_t StepSize() 397 //! exists. 398 template<typename OptimizerType> 399 struct HasStepSizeSignature 400 { 401 template<typename C> 402 using StepSizeConstForm = double(C::*)(void) const; 403 404 template<typename C> 405 using StepSizeForm = double(C::*)(void); 406 407 const static bool value = 408 HasStepSize<OptimizerType, StepSizeForm>::value || 409 HasStepSize<OptimizerType, StepSizeConstForm>::value; 410 }; 411 412 //! Utility struct, check if size_t MaxIterations() const exists. 413 template<typename OptimizerType> 414 struct HasMaxIterationsSignature 415 { 416 template<typename C> 417 using HasMaxIterationsForm = size_t(C::*)(void) const; 418 419 const static bool value = 420 HasMaxIterations<OptimizerType, HasMaxIterationsForm>::value; 421 }; 422 423 //! Utility struct, check if size_t NumFunctions() const or 424 //! size_t NumFunctions() exists. 425 template<typename OptimizerType> 426 struct HasNumFunctionsSignature 427 { 428 template<typename C> 429 using NumFunctionsConstForm = size_t(C::*)(void) const; 430 431 template<typename C> 432 using NumFunctionsForm = size_t(C::*)(void); 433 434 const static bool value = 435 HasNumFunctions<OptimizerType, NumFunctionsForm>::value || 436 HasNumFunctions<OptimizerType, NumFunctionsConstForm>::value; 437 }; 438 439 //! Utility struct, check if bool ResetPolicy() exists. 440 template<typename OptimizerType> 441 struct HasResetPolicySignature 442 { 443 template<typename C> 444 using HasResetPolicyForm = bool&(C::*)(void); 445 446 const static bool value = 447 HasResetPolicy<OptimizerType, HasResetPolicyForm>::value; 448 }; 449 450 } // namespace traits 451 } // namespace ens 452 453 #endif 454