1 /**
2  * @file radial_basis_function_impl.hpp
3  * @author Himanshu Pathak
4  *
5  *
6  * mlpack is free software; you may redistribute it and/or modify it under the
7  * terms of the 3-clause BSD license.  You should have received a copy of the
8  * 3-clause BSD license along with mlpack.  If not, see
9  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10  */
11 #ifndef MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP
12 #define MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP
13 
14 // In case it hasn't yet been included.
15 #include "radial_basis_function.hpp"
16 
17 namespace mlpack {
18 namespace ann /** Artificial Neural Network. */ {
19 
20 template<typename InputDataType, typename OutputDataType,
21          typename Activation>
RBF()22 RBF<InputDataType, OutputDataType, Activation>::RBF() :
23     inSize(0),
24     outSize(0),
25     sigmas(0),
26     betas(0)
27 {
28   // Nothing to do here.
29 }
30 
31 template<typename InputDataType, typename OutputDataType,
32          typename Activation>
RBF(const size_t inSize,const size_t outSize,arma::mat & centres,double betas)33 RBF<InputDataType, OutputDataType, Activation>::RBF(
34     const size_t inSize,
35     const size_t outSize,
36     arma::mat& centres,
37     double betas) :
38     inSize(inSize),
39     outSize(outSize),
40     betas(betas),
41     centres(centres)
42 {
43   sigmas = 0;
44   if (betas == 0)
45   {
46     for (size_t i = 0; i < centres.n_cols; i++)
47     {
48       double max_dis = 0;
49       arma::mat temp = centres.each_col() - centres.col(i);
50       max_dis = arma::accu(arma::max(arma::pow(arma::sum(
51           arma::pow((temp), 2), 0), 0.5).t()));
52       if (max_dis > sigmas)
53         sigmas = max_dis;
54     }
55     this->betas = std::pow(2 * outSize, 0.5) / sigmas;
56   }
57 }
58 
59 template<typename InputDataType, typename OutputDataType,
60          typename Activation>
61 template<typename eT>
Forward(const arma::Mat<eT> & input,arma::Mat<eT> & output)62 void RBF<InputDataType, OutputDataType, Activation>::Forward(
63     const arma::Mat<eT>& input,
64     arma::Mat<eT>& output)
65 {
66   distances = arma::mat(outSize, input.n_cols);
67 
68   for (size_t i = 0; i < input.n_cols; i++)
69   {
70     arma::mat temp = centres.each_col() - input.col(i);
71     distances.col(i) = arma::pow(arma::sum(
72       arma::pow((temp), 2), 0), 0.5).t();
73   }
74   Activation::Fn(distances * std::pow(betas, 0.5),
75       output);
76 }
77 
78 
79 template<typename InputDataType, typename OutputDataType,
80          typename Activation>
81 template<typename eT>
Backward(const arma::Mat<eT> &,const arma::Mat<eT> &,arma::Mat<eT> &)82 void RBF<InputDataType, OutputDataType, Activation>::Backward(
83     const arma::Mat<eT>& /* input */,
84     const arma::Mat<eT>& /* gy */,
85     arma::Mat<eT>& /* g */)
86 {
87   // Nothing to do here.
88 }
89 
90 template<typename InputDataType, typename OutputDataType,
91          typename Activation>
92 template<typename Archive>
serialize(Archive & ar,const unsigned int)93 void RBF<InputDataType, OutputDataType, Activation>::serialize(
94     Archive& ar,
95     const unsigned int /* version */)
96 {
97   ar & BOOST_SERIALIZATION_NVP(distances);
98   ar & BOOST_SERIALIZATION_NVP(centres);
99 }
100 
101 } // namespace ann
102 } // namespace mlpack
103 
104 #endif
105