1 /**
2  * @file methods/ann/visitor/weight_size_visitor.hpp
3  * @author Marcus Edel
4  *
5  * This file provides an abstraction for the WeightSize() function for
6  * different layers and automatically directs any parameter to the right layer
7  * type.
8  *
9  * mlpack is free software; you may redistribute it and/or modify it under the
10  * terms of the 3-clause BSD license.  You should have received a copy of the
11  * 3-clause BSD license along with mlpack.  If not, see
12  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13  */
14 #ifndef MLPACK_METHODS_ANN_VISITOR_WEIGHT_SIZE_VISITOR_HPP
15 #define MLPACK_METHODS_ANN_VISITOR_WEIGHT_SIZE_VISITOR_HPP
16 
17 #include <mlpack/methods/ann/layer/layer_traits.hpp>
18 
19 #include <boost/variant.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
24 /**
25  * WeightSizeVisitor returns the number of weights of the given module.
26  */
27 class WeightSizeVisitor : public boost::static_visitor<size_t>
28 {
29  public:
30   //! Return the number of weights.
31   template<typename LayerType>
32   size_t operator()(LayerType* layer) const;
33 
34   size_t operator()(MoreTypes layer) const;
35 
36  private:
37   //! If the module doesn't implement the Parameters() or Model() function
38   //! return 0.
39   template<typename T, typename P>
40   typename std::enable_if<
41       !HasParametersCheck<T, P&(T::*)()>::value &&
42       !HasModelCheck<T>::value, size_t>::type
43   LayerSize(T* layer, P& output) const;
44 
45   //! Return the number of parameters if the module implements the Model()
46   //! function.
47   template<typename T, typename P>
48   typename std::enable_if<
49       !HasParametersCheck<T, P&(T::*)()>::value &&
50       HasModelCheck<T>::value, size_t>::type
51   LayerSize(T* layer, P& output) const;
52 
53   //! Return the number of parameters if the module implements the Parameters()
54   //! function.
55   template<typename T, typename P>
56   typename std::enable_if<
57       HasParametersCheck<T, P&(T::*)()>::value &&
58       !HasModelCheck<T>::value, size_t>::type
59   LayerSize(T* layer, P& output) const;
60 
61   //! Return the accumulated number of parameters if the module implements the
62   //! Parameters() and Model() function.
63   template<typename T, typename P>
64   typename std::enable_if<
65       HasParametersCheck<T, P&(T::*)()>::value &&
66       HasModelCheck<T>::value, size_t>::type
67   LayerSize(T* layer, P& output) const;
68 };
69 
70 } // namespace ann
71 } // namespace mlpack
72 
73 // Include implementation.
74 #include "weight_size_visitor_impl.hpp"
75 
76 #endif
77