1 /** 2 * @file methods/ann/activation_functions/multi_quadratic_function.hpp 3 * @author Himanshu Pathak 4 * 5 * Definition and implementation of multi quadratic function. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_MULTIQUAD_FUNCTION_HPP 13 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_MULTIQUAD_FUNCTION_HPP 14 15 #include <mlpack/prereqs.hpp> 16 17 namespace mlpack { 18 namespace ann /** Artificial Neural Network. */ { 19 20 /** 21 * The Multi Quadratic function, defined by 22 * 23 * @f{eqnarray*}{ 24 * f(x) = \sqrt(1 + x^2) \\ 25 * f'(x) = x / f(x) \\ 26 * @f} 27 */ 28 class MultiQuadFunction 29 { 30 public: 31 /** 32 * Computes the Multi Quadratic function. 33 * 34 * @param x Input data. 35 * @return f(x). 36 */ Fn(const double x)37 static double Fn(const double x) 38 { 39 return std::pow(1 + x * x, 0.5); 40 } 41 42 /** 43 * Computes the Multi Quadratic function. 44 * 45 * @param x Input data. 46 * @param y The resulting output activation. 47 */ 48 template<typename InputVecType, typename OutputVecType> Fn(const InputVecType & x,OutputVecType & y)49 static void Fn(const InputVecType& x, OutputVecType& y) 50 { 51 y = arma::pow((1 + arma::pow(x, 2)), 0.5); 52 } 53 54 /** 55 * Computes the first derivative of the Multi Quadratic function. 56 * 57 * @param y Input data. 58 * @return f'(x) 59 */ Deriv(const double y)60 static double Deriv(const double y) 61 { 62 return y / std::pow(1 + y * y, 0.5); 63 } 64 65 /** 66 * Computes the first derivatives of the Multi Quadratic function. 67 * 68 * @param y Input data. 69 * @param x The resulting derivatives. 70 */ 71 template<typename InputVecType, typename OutputVecType> Deriv(const InputVecType & x,OutputVecType & y)72 static void Deriv(const InputVecType& x, OutputVecType& y) 73 { 74 y = x / arma::pow((1 + arma::pow(x, 2)), 0.5); 75 } 76 }; // class MultiquadFunction 77 78 } // namespace ann 79 } // namespace mlpack 80 81 #endif 82