1 /**
2 * @file radial_basis_function_impl.hpp
3 * @author Himanshu Pathak
4 *
5 *
6 * mlpack is free software; you may redistribute it and/or modify it under the
7 * terms of the 3-clause BSD license. You should have received a copy of the
8 * 3-clause BSD license along with mlpack. If not, see
9 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10 */
11 #ifndef MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP
12 #define MLPACK_METHODS_ANN_LAYER_RBF_IMPL_HPP
13
14 // In case it hasn't yet been included.
15 #include "radial_basis_function.hpp"
16
17 namespace mlpack {
18 namespace ann /** Artificial Neural Network. */ {
19
20 template<typename InputDataType, typename OutputDataType,
21 typename Activation>
RBF()22 RBF<InputDataType, OutputDataType, Activation>::RBF() :
23 inSize(0),
24 outSize(0),
25 sigmas(0),
26 betas(0)
27 {
28 // Nothing to do here.
29 }
30
31 template<typename InputDataType, typename OutputDataType,
32 typename Activation>
RBF(const size_t inSize,const size_t outSize,arma::mat & centres,double betas)33 RBF<InputDataType, OutputDataType, Activation>::RBF(
34 const size_t inSize,
35 const size_t outSize,
36 arma::mat& centres,
37 double betas) :
38 inSize(inSize),
39 outSize(outSize),
40 betas(betas),
41 centres(centres)
42 {
43 sigmas = 0;
44 if (betas == 0)
45 {
46 for (size_t i = 0; i < centres.n_cols; i++)
47 {
48 double max_dis = 0;
49 arma::mat temp = centres.each_col() - centres.col(i);
50 max_dis = arma::accu(arma::max(arma::pow(arma::sum(
51 arma::pow((temp), 2), 0), 0.5).t()));
52 if (max_dis > sigmas)
53 sigmas = max_dis;
54 }
55 this->betas = std::pow(2 * outSize, 0.5) / sigmas;
56 }
57 }
58
59 template<typename InputDataType, typename OutputDataType,
60 typename Activation>
61 template<typename eT>
Forward(const arma::Mat<eT> & input,arma::Mat<eT> & output)62 void RBF<InputDataType, OutputDataType, Activation>::Forward(
63 const arma::Mat<eT>& input,
64 arma::Mat<eT>& output)
65 {
66 distances = arma::mat(outSize, input.n_cols);
67
68 for (size_t i = 0; i < input.n_cols; i++)
69 {
70 arma::mat temp = centres.each_col() - input.col(i);
71 distances.col(i) = arma::pow(arma::sum(
72 arma::pow((temp), 2), 0), 0.5).t();
73 }
74 Activation::Fn(distances * std::pow(betas, 0.5),
75 output);
76 }
77
78
79 template<typename InputDataType, typename OutputDataType,
80 typename Activation>
81 template<typename eT>
Backward(const arma::Mat<eT> &,const arma::Mat<eT> &,arma::Mat<eT> &)82 void RBF<InputDataType, OutputDataType, Activation>::Backward(
83 const arma::Mat<eT>& /* input */,
84 const arma::Mat<eT>& /* gy */,
85 arma::Mat<eT>& /* g */)
86 {
87 // Nothing to do here.
88 }
89
90 template<typename InputDataType, typename OutputDataType,
91 typename Activation>
92 template<typename Archive>
serialize(Archive & ar,const unsigned int)93 void RBF<InputDataType, OutputDataType, Activation>::serialize(
94 Archive& ar,
95 const unsigned int /* version */)
96 {
97 ar & BOOST_SERIALIZATION_NVP(distances);
98 ar & BOOST_SERIALIZATION_NVP(centres);
99 }
100
101 } // namespace ann
102 } // namespace mlpack
103
104 #endif
105