1 /**
2  * @file methods/ann/activation_functions/spline_function.hpp
3  * @author Himanshu Pathak
4  *
5  * Definition and implementation of Spline function.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SPLINE_FUNCTION_HPP
13 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SPLINE_FUNCTION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann /** Artificial Neural Network. */ {
19 
20 /**
21  * The Spline function, defined by
22  *
23  * @f{eqnarray*}{
24  * f(x) = x^2 * log(1 + x) \\
25  * f'(x) = 2 * x * log(1 + x) + x^2 / (1 + x)\\
26  * @f}
27  */
28 class SplineFunction
29 {
30  public:
31   /**
32    * Computes the Spline function.
33    *
34    * @param x Input data.
35    * @return f(x).
36    */
Fn(const double x)37   static double Fn(const double x)
38   {
39     return std::pow(x, 2) * std::log(1 + x);
40   }
41 
42   /**
43    * Computes the Spline function.
44    *
45    * @param x Input data.
46    * @param y The resulting output activation.
47    */
48   template<typename InputVecType, typename OutputVecType>
Fn(const InputVecType & x,OutputVecType & y)49   static void Fn(const InputVecType& x, OutputVecType& y)
50   {
51     y = arma::pow(x, 2) % arma::log(1 + x);
52   }
53 
54   /**
55    * Computes the first derivative of the Spline function.
56    *
57    * @param y Input data.
58    * @return f'(x)
59    */
Deriv(const double y)60   static double Deriv(const double y)
61   {
62     return  2 * y * std::log(1 + y) + std::pow(y, 2) / (1 + y);
63   }
64 
65   /**
66    * Computes the first derivatives of the Spline function.
67    *
68    * @param y Input data.
69    * @param x The resulting derivatives.
70    */
71   template<typename InputVecType, typename OutputVecType>
Deriv(const InputVecType & x,OutputVecType & y)72   static void Deriv(const InputVecType& x, OutputVecType& y)
73   {
74     y = 2 * x % arma::log(1 + x) + arma::pow(x, 2) / (1 + x);
75   }
76 }; // class SplineFunction
77 
78 } // namespace ann
79 } // namespace mlpack
80 
81 #endif
82