1 /**
2  * @file methods/ann/activation_functions/multi_quadratic_function.hpp
3  * @author Himanshu Pathak
4  *
5  * Definition and implementation of multi quadratic function.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_MULTIQUAD_FUNCTION_HPP
13 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_MULTIQUAD_FUNCTION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann /** Artificial Neural Network. */ {
19 
20 /**
21  * The Multi Quadratic function, defined by
22  *
23  * @f{eqnarray*}{
24  * f(x) = \sqrt(1 + x^2) \\
25  * f'(x) = x / f(x) \\
26  * @f}
27  */
28 class MultiQuadFunction
29 {
30  public:
31   /**
32    * Computes the Multi Quadratic function.
33    *
34    * @param x Input data.
35    * @return f(x).
36    */
Fn(const double x)37   static double Fn(const double x)
38   {
39     return std::pow(1 + x * x, 0.5);
40   }
41 
42   /**
43    * Computes the Multi Quadratic function.
44    *
45    * @param x Input data.
46    * @param y The resulting output activation.
47    */
48   template<typename InputVecType, typename OutputVecType>
Fn(const InputVecType & x,OutputVecType & y)49   static void Fn(const InputVecType& x, OutputVecType& y)
50   {
51     y = arma::pow((1 + arma::pow(x, 2)), 0.5);
52   }
53 
54   /**
55    * Computes the first derivative of the Multi Quadratic function.
56    *
57    * @param y Input data.
58    * @return f'(x)
59    */
Deriv(const double y)60   static double Deriv(const double y)
61   {
62     return  y / std::pow(1 + y * y, 0.5);
63   }
64 
65   /**
66    * Computes the first derivatives of the Multi Quadratic function.
67    *
68    * @param y Input data.
69    * @param x The resulting derivatives.
70    */
71   template<typename InputVecType, typename OutputVecType>
Deriv(const InputVecType & x,OutputVecType & y)72   static void Deriv(const InputVecType& x, OutputVecType& y)
73   {
74     y = x / arma::pow((1 + arma::pow(x, 2)), 0.5);
75   }
76 }; // class MultiquadFunction
77 
78 } // namespace ann
79 } // namespace mlpack
80 
81 #endif
82