1 /** 2 * @file radial_basis_function.hpp 3 * @author Himanshu Pathak 4 * 5 * Definition of the Radial Basis Function module class. 6 * 7 * 8 * mlpack is free software; you may redistribute it and/or modify it under the 9 * terms of the 3-clause BSD license. You should have received a copy of the 10 * 3-clause BSD license along with mlpack. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef MLPACK_METHODS_ANN_LAYER_RBF_HPP 14 #define MLPACK_METHODS_ANN_LAYER_RBF_HPP 15 16 #include <mlpack/prereqs.hpp> 17 #include <mlpack/methods/ann/activation_functions/gaussian_function.hpp> 18 19 #include "layer_types.hpp" 20 21 namespace mlpack { 22 namespace ann /** Artificial Neural Network. */ { 23 24 25 /** 26 * Implementation of the Radial Basis Function layer. The RBF class when use with a 27 * non-linear activation function acts as a Radial Basis Function which can be used 28 * with Feed-Forward neural network. 29 * 30 * For more information, refer to the following paper, 31 * 32 * @code 33 * @article{Volume 51: Artificial Intelligence and Statistics, 34 * author = {Qichao Que, Mikhail Belkin}, 35 * title = {Back to the Future: Radial Basis Function Networks Revisited}, 36 * year = {2016}, 37 * url = {http://proceedings.mlr.press/v51/que16.pdf}, 38 * } 39 * @endcode 40 * 41 * @tparam InputDataType Type of the input data (arma::colvec, arma::mat, 42 * arma::sp_mat or arma::cube). 43 * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, 44 * arma::sp_mat or arma::cube). 45 * @tparam Activation Type of the activation function (mlpack::ann::Gaussian). 46 */ 47 48 template < 49 typename InputDataType = arma::mat, 50 typename OutputDataType = arma::mat, 51 typename Activation = GaussianFunction 52 > 53 class RBF 54 { 55 public: 56 //! Create the RBF object. 57 RBF(); 58 59 /** 60 * Create the Radial Basis Function layer object using the specified 61 * parameters. 62 * 63 * @param inSize The number of input units. 64 * @param outSize The number of output units. 65 * @param centres The centres calculated using k-means of data. 66 * @param betas The beta value to be used with centres. 67 */ 68 RBF(const size_t inSize, 69 const size_t outSize, 70 arma::mat& centres, 71 double betas = 0); 72 73 /** 74 * Ordinary feed forward pass of the radial basis function. 75 * 76 * @param input Input data used for evaluating the specified function. 77 * @param output Resulting output activation. 78 */ 79 template<typename eT> 80 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output); 81 82 /** 83 * Ordinary feed backward pass of the radial basis function. 84 * 85 */ 86 template<typename eT> 87 void Backward(const arma::Mat<eT>& /* input */, 88 const arma::Mat<eT>& /* gy */, 89 arma::Mat<eT>& /* g */); 90 91 //! Get the output parameter. OutputParameter() const92 OutputDataType const& OutputParameter() const { return outputParameter; } 93 //! Modify the output parameter. OutputParameter()94 OutputDataType& OutputParameter() { return outputParameter; } 95 //! Get the parameters. 96 97 //! Get the input parameter. InputParameter() const98 InputDataType const& InputParameter() const { return inputParameter; } 99 //! Modify the input parameter. InputParameter()100 InputDataType& InputParameter() { return inputParameter; } 101 102 //! Get the input size. InputSize() const103 size_t InputSize() const { return inSize; } 104 105 //! Get the output size. OutputSize() const106 size_t OutputSize() const { return outSize; } 107 108 //! Get the detla. Delta() const109 OutputDataType const& Delta() const { return delta; } 110 //! Modify the delta. Delta()111 OutputDataType& Delta() { return delta; } 112 113 /** 114 * Serialize the layer. 115 */ 116 template<typename Archive> 117 void serialize(Archive& ar, const unsigned int /* version */); 118 119 private: 120 //! Locally-stored number of input units. 121 size_t inSize; 122 123 //! Locally-stored number of output units. 124 size_t outSize; 125 126 //! Locally-stored delta object. 127 OutputDataType delta; 128 129 //! Locally-stored output parameter object. 130 OutputDataType outputParameter; 131 132 //! Locally-stored the sigmas values. 133 double sigmas; 134 135 //! Locally-stored the betas values. 136 double betas; 137 138 //! Locally-stored the learnable centre of the shape. 139 InputDataType centres; 140 141 //! Locally-stored input parameter object. 142 InputDataType inputParameter; 143 144 //! Locally-stored the output distances of the shape. 145 OutputDataType distances; 146 }; // class RBF 147 148 } // namespace ann 149 } // namespace mlpack 150 151 // Include implementation. 152 #include "radial_basis_function_impl.hpp" 153 154 #endif 155