1 /** 2 * @file methods/ann/visitor/weight_size_visitor.hpp 3 * @author Marcus Edel 4 * 5 * This file provides an abstraction for the WeightSize() function for 6 * different layers and automatically directs any parameter to the right layer 7 * type. 8 * 9 * mlpack is free software; you may redistribute it and/or modify it under the 10 * terms of the 3-clause BSD license. You should have received a copy of the 11 * 3-clause BSD license along with mlpack. If not, see 12 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 13 */ 14 #ifndef MLPACK_METHODS_ANN_VISITOR_WEIGHT_SIZE_VISITOR_HPP 15 #define MLPACK_METHODS_ANN_VISITOR_WEIGHT_SIZE_VISITOR_HPP 16 17 #include <mlpack/methods/ann/layer/layer_traits.hpp> 18 19 #include <boost/variant.hpp> 20 21 namespace mlpack { 22 namespace ann { 23 24 /** 25 * WeightSizeVisitor returns the number of weights of the given module. 26 */ 27 class WeightSizeVisitor : public boost::static_visitor<size_t> 28 { 29 public: 30 //! Return the number of weights. 31 template<typename LayerType> 32 size_t operator()(LayerType* layer) const; 33 34 size_t operator()(MoreTypes layer) const; 35 36 private: 37 //! If the module doesn't implement the Parameters() or Model() function 38 //! return 0. 39 template<typename T, typename P> 40 typename std::enable_if< 41 !HasParametersCheck<T, P&(T::*)()>::value && 42 !HasModelCheck<T>::value, size_t>::type 43 LayerSize(T* layer, P& output) const; 44 45 //! Return the number of parameters if the module implements the Model() 46 //! function. 47 template<typename T, typename P> 48 typename std::enable_if< 49 !HasParametersCheck<T, P&(T::*)()>::value && 50 HasModelCheck<T>::value, size_t>::type 51 LayerSize(T* layer, P& output) const; 52 53 //! Return the number of parameters if the module implements the Parameters() 54 //! function. 55 template<typename T, typename P> 56 typename std::enable_if< 57 HasParametersCheck<T, P&(T::*)()>::value && 58 !HasModelCheck<T>::value, size_t>::type 59 LayerSize(T* layer, P& output) const; 60 61 //! Return the accumulated number of parameters if the module implements the 62 //! Parameters() and Model() function. 63 template<typename T, typename P> 64 typename std::enable_if< 65 HasParametersCheck<T, P&(T::*)()>::value && 66 HasModelCheck<T>::value, size_t>::type 67 LayerSize(T* layer, P& output) const; 68 }; 69 70 } // namespace ann 71 } // namespace mlpack 72 73 // Include implementation. 74 #include "weight_size_visitor_impl.hpp" 75 76 #endif 77