1 /** 2 * @file methods/ann/activation_functions/spline_function.hpp 3 * @author Himanshu Pathak 4 * 5 * Definition and implementation of Spline function. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SPLINE_FUNCTION_HPP 13 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SPLINE_FUNCTION_HPP 14 15 #include <mlpack/prereqs.hpp> 16 17 namespace mlpack { 18 namespace ann /** Artificial Neural Network. */ { 19 20 /** 21 * The Spline function, defined by 22 * 23 * @f{eqnarray*}{ 24 * f(x) = x^2 * log(1 + x) \\ 25 * f'(x) = 2 * x * log(1 + x) + x^2 / (1 + x)\\ 26 * @f} 27 */ 28 class SplineFunction 29 { 30 public: 31 /** 32 * Computes the Spline function. 33 * 34 * @param x Input data. 35 * @return f(x). 36 */ Fn(const double x)37 static double Fn(const double x) 38 { 39 return std::pow(x, 2) * std::log(1 + x); 40 } 41 42 /** 43 * Computes the Spline function. 44 * 45 * @param x Input data. 46 * @param y The resulting output activation. 47 */ 48 template<typename InputVecType, typename OutputVecType> Fn(const InputVecType & x,OutputVecType & y)49 static void Fn(const InputVecType& x, OutputVecType& y) 50 { 51 y = arma::pow(x, 2) % arma::log(1 + x); 52 } 53 54 /** 55 * Computes the first derivative of the Spline function. 56 * 57 * @param y Input data. 58 * @return f'(x) 59 */ Deriv(const double y)60 static double Deriv(const double y) 61 { 62 return 2 * y * std::log(1 + y) + std::pow(y, 2) / (1 + y); 63 } 64 65 /** 66 * Computes the first derivatives of the Spline function. 67 * 68 * @param y Input data. 69 * @param x The resulting derivatives. 70 */ 71 template<typename InputVecType, typename OutputVecType> Deriv(const InputVecType & x,OutputVecType & y)72 static void Deriv(const InputVecType& x, OutputVecType& y) 73 { 74 y = 2 * x % arma::log(1 + x) + arma::pow(x, 2) / (1 + x); 75 } 76 }; // class SplineFunction 77 78 } // namespace ann 79 } // namespace mlpack 80 81 #endif 82