1 /**
2  * @file radial_basis_function.hpp
3  * @author Himanshu Pathak
4  *
5  * Definition of the Radial Basis Function module class.
6  *
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_METHODS_ANN_LAYER_RBF_HPP
14 #define MLPACK_METHODS_ANN_LAYER_RBF_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <mlpack/methods/ann/activation_functions/gaussian_function.hpp>
18 
19 #include "layer_types.hpp"
20 
21 namespace mlpack {
22 namespace ann /** Artificial Neural Network. */ {
23 
24 
25 /**
26  * Implementation of the Radial Basis Function layer. The RBF class when use with a
27  * non-linear activation function acts as a Radial Basis Function which can be used
28  * with Feed-Forward neural network.
29  *
30  * For more information, refer to the following paper,
31  *
32  * @code
33  * @article{Volume 51: Artificial Intelligence and Statistics,
34  *   author  = {Qichao Que, Mikhail Belkin},
35  *   title   = {Back to the Future: Radial Basis Function Networks Revisited},
36  *   year    = {2016},
37  *   url     = {http://proceedings.mlr.press/v51/que16.pdf},
38  * }
39  * @endcode
40  *
41  * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
42  *         arma::sp_mat or arma::cube).
43  * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
44  *         arma::sp_mat or arma::cube).
45  * @tparam Activation Type of the activation function (mlpack::ann::Gaussian).
46  */
47 
48 template <
49     typename InputDataType = arma::mat,
50     typename OutputDataType = arma::mat,
51     typename Activation = GaussianFunction
52 >
53 class RBF
54 {
55  public:
56   //! Create the RBF object.
57   RBF();
58 
59   /**
60    * Create the Radial Basis Function layer object using the specified
61    * parameters.
62    *
63    * @param inSize The number of input units.
64    * @param outSize The number of output units.
65    * @param centres The centres calculated using k-means of data.
66    * @param betas The beta value to be used with centres.
67    */
68   RBF(const size_t inSize,
69       const size_t outSize,
70       arma::mat& centres,
71       double betas = 0);
72 
73   /**
74    * Ordinary feed forward pass of the radial basis function.
75    *
76    * @param input Input data used for evaluating the specified function.
77    * @param output Resulting output activation.
78    */
79   template<typename eT>
80   void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
81 
82   /**
83    * Ordinary feed backward pass of the radial basis function.
84    *
85    */
86   template<typename eT>
87   void Backward(const arma::Mat<eT>& /* input */,
88                 const arma::Mat<eT>& /* gy */,
89                 arma::Mat<eT>& /* g */);
90 
91   //! Get the output parameter.
OutputParameter() const92   OutputDataType const& OutputParameter() const { return outputParameter; }
93   //! Modify the output parameter.
OutputParameter()94   OutputDataType& OutputParameter() { return outputParameter; }
95   //! Get the parameters.
96 
97   //! Get the input parameter.
InputParameter() const98   InputDataType const& InputParameter() const { return inputParameter; }
99   //! Modify the input parameter.
InputParameter()100   InputDataType& InputParameter() { return inputParameter; }
101 
102   //! Get the input size.
InputSize() const103   size_t InputSize() const { return inSize; }
104 
105   //! Get the output size.
OutputSize() const106   size_t OutputSize() const { return outSize; }
107 
108   //! Get the detla.
Delta() const109   OutputDataType const& Delta() const { return delta; }
110   //! Modify the delta.
Delta()111   OutputDataType& Delta() { return delta; }
112 
113   /**
114    * Serialize the layer.
115    */
116   template<typename Archive>
117   void serialize(Archive& ar, const unsigned int /* version */);
118 
119  private:
120   //! Locally-stored number of input units.
121   size_t inSize;
122 
123   //! Locally-stored number of output units.
124   size_t outSize;
125 
126   //! Locally-stored delta object.
127   OutputDataType delta;
128 
129   //! Locally-stored output parameter object.
130   OutputDataType outputParameter;
131 
132   //! Locally-stored the sigmas values.
133   double sigmas;
134 
135   //! Locally-stored the betas values.
136   double betas;
137 
138   //! Locally-stored the learnable centre of the shape.
139   InputDataType centres;
140 
141   //! Locally-stored input parameter object.
142   InputDataType inputParameter;
143 
144   //! Locally-stored the output distances of the shape.
145   OutputDataType distances;
146 }; // class RBF
147 
148 } // namespace ann
149 } // namespace mlpack
150 
151 // Include implementation.
152 #include "radial_basis_function_impl.hpp"
153 
154 #endif
155