1 /**
2  * @file methods/ann/layer/celu.hpp
3  * @author Gaurav Singh
4  *
5  * Definition of the CELU activation function as described by Jonathan T. Barron.
6  *
7  * For more information, read the following paper.
8  *
9  * @code
10  * @article{
11  *   author  = {Jonathan T. Barron},
12  *   title   = {Continuously Differentiable Exponential Linear Units},
13  *   year    = {2017},
14  *   url     = {https://arxiv.org/pdf/1704.07483}
15  * }
16  * @endcode
17  *
18  * mlpack is free software; you may redistribute it and/or modify it under the
19  * terms of the 3-clause BSD license.  You should have received a copy of the
20  * 3-clause BSD license along with mlpack.  If not, see
21  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
22  */
23 #ifndef MLPACK_METHODS_ANN_LAYER_CELU_HPP
24 #define MLPACK_METHODS_ANN_LAYER_CELU_HPP
25 
26 #include <mlpack/prereqs.hpp>
27 
28 namespace mlpack {
29 namespace ann /** Artificial Neural Network. */ {
30 
31 /**
32  * The CELU activation function, defined by
33  *
34  * @f{eqnarray*}{
35  * f(x) &=& \left\{
36  *   \begin{array}{lr}
37  *    x & : x \ge 0 \\
38  *    \alpha(e^(\frac{x}{\alpha}) - 1) & : x < 0
39  *   \end{array}
40  * \right. \\
41  * f'(x) &=& \left\{
42  *   \begin{array}{lr}
43  *     1 & : x \ge 0 \\
44  *     (\frac{f(x)}{\alpha}) + 1 & : x < 0
45  *   \end{array}
46  * \right.
47  * @f}
48  *
49  * In the deterministic mode, there is no computation of the derivative.
50  *
51  * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
52  *         arma::sp_mat or arma::cube).
53  * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
54  *         arma::sp_mat or arma::cube).
55  */
56 template <
57     typename InputDataType = arma::mat,
58     typename OutputDataType = arma::mat
59 >
60 class CELU
61 {
62  public:
63   /**
64    * Create the CELU object using the specified parameter. The non zero
65    * gradient for negative inputs can be adjusted by specifying the CELU
66    * hyperparameter alpha (alpha > 0).
67    *
68    * @param alpha Scale parameter for the negative factor (default = 1.0).
69    */
70   CELU(const double alpha = 1.0);
71 
72   /**
73    * Ordinary feed forward pass of a neural network, evaluating the function
74    * f(x) by propagating the activity forward through f.
75    *
76    * @param input Input data used for evaluating the specified function.
77    * @param output Resulting output activation.
78    */
79   template<typename InputType, typename OutputType>
80   void Forward(const InputType& input, OutputType& output);
81 
82   /**
83    * Ordinary feed backward pass of a neural network, calculating the function
84    * f(x) by propagating x backwards through f. Using the results from the feed
85    * forward pass.
86    *
87    * @param input The propagated input activation f(x).
88    * @param gy The backpropagated error.
89    * @param g The calculated gradient.
90    */
91   template<typename DataType>
92   void Backward(const DataType& input, const DataType& gy, DataType& g);
93 
94   //! Get the output parameter.
OutputParameter() const95   OutputDataType const& OutputParameter() const { return outputParameter; }
96   //! Modify the output parameter.
OutputParameter()97   OutputDataType& OutputParameter() { return outputParameter; }
98 
99   //! Get the delta.
Delta() const100   OutputDataType const& Delta() const { return delta; }
101   //! Modify the delta.
Delta()102   OutputDataType& Delta() { return delta; }
103 
104   //! Get the non zero gradient.
Alpha() const105   double const& Alpha() const { return alpha; }
106   //! Modify the non zero gradient.
Alpha()107   double& Alpha() { return alpha; }
108 
109   //! Get the value of deterministic parameter.
Deterministic() const110   bool Deterministic() const { return deterministic; }
111   //! Modify the value of deterministic parameter.
Deterministic()112   bool& Deterministic() { return deterministic; }
113 
114   /**
115    * Serialize the layer.
116    */
117   template<typename Archive>
118   void serialize(Archive& ar, const unsigned int /* version */);
119 
120  private:
121   //! Locally-stored delta object.
122   OutputDataType delta;
123 
124   //! Locally-stored output parameter object.
125   OutputDataType outputParameter;
126 
127   //! Locally stored first derivative of the activation function.
128   arma::mat derivative;
129 
130   //! CELU Hyperparameter (alpha > 0).
131   double alpha;
132 
133   //! If true the derivative computation is disabled, see notes above.
134   bool deterministic;
135 }; // class CELU
136 
137 } // namespace ann
138 } // namespace mlpack
139 
140 // Include implementation.
141 #include "celu_impl.hpp"
142 
143 #endif
144