1 /** 2 * @file methods/ann/layer/celu.hpp 3 * @author Gaurav Singh 4 * 5 * Definition of the CELU activation function as described by Jonathan T. Barron. 6 * 7 * For more information, read the following paper. 8 * 9 * @code 10 * @article{ 11 * author = {Jonathan T. Barron}, 12 * title = {Continuously Differentiable Exponential Linear Units}, 13 * year = {2017}, 14 * url = {https://arxiv.org/pdf/1704.07483} 15 * } 16 * @endcode 17 * 18 * mlpack is free software; you may redistribute it and/or modify it under the 19 * terms of the 3-clause BSD license. You should have received a copy of the 20 * 3-clause BSD license along with mlpack. If not, see 21 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 22 */ 23 #ifndef MLPACK_METHODS_ANN_LAYER_CELU_HPP 24 #define MLPACK_METHODS_ANN_LAYER_CELU_HPP 25 26 #include <mlpack/prereqs.hpp> 27 28 namespace mlpack { 29 namespace ann /** Artificial Neural Network. */ { 30 31 /** 32 * The CELU activation function, defined by 33 * 34 * @f{eqnarray*}{ 35 * f(x) &=& \left\{ 36 * \begin{array}{lr} 37 * x & : x \ge 0 \\ 38 * \alpha(e^(\frac{x}{\alpha}) - 1) & : x < 0 39 * \end{array} 40 * \right. \\ 41 * f'(x) &=& \left\{ 42 * \begin{array}{lr} 43 * 1 & : x \ge 0 \\ 44 * (\frac{f(x)}{\alpha}) + 1 & : x < 0 45 * \end{array} 46 * \right. 47 * @f} 48 * 49 * In the deterministic mode, there is no computation of the derivative. 50 * 51 * @tparam InputDataType Type of the input data (arma::colvec, arma::mat, 52 * arma::sp_mat or arma::cube). 53 * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, 54 * arma::sp_mat or arma::cube). 55 */ 56 template < 57 typename InputDataType = arma::mat, 58 typename OutputDataType = arma::mat 59 > 60 class CELU 61 { 62 public: 63 /** 64 * Create the CELU object using the specified parameter. The non zero 65 * gradient for negative inputs can be adjusted by specifying the CELU 66 * hyperparameter alpha (alpha > 0). 67 * 68 * @param alpha Scale parameter for the negative factor (default = 1.0). 69 */ 70 CELU(const double alpha = 1.0); 71 72 /** 73 * Ordinary feed forward pass of a neural network, evaluating the function 74 * f(x) by propagating the activity forward through f. 75 * 76 * @param input Input data used for evaluating the specified function. 77 * @param output Resulting output activation. 78 */ 79 template<typename InputType, typename OutputType> 80 void Forward(const InputType& input, OutputType& output); 81 82 /** 83 * Ordinary feed backward pass of a neural network, calculating the function 84 * f(x) by propagating x backwards through f. Using the results from the feed 85 * forward pass. 86 * 87 * @param input The propagated input activation f(x). 88 * @param gy The backpropagated error. 89 * @param g The calculated gradient. 90 */ 91 template<typename DataType> 92 void Backward(const DataType& input, const DataType& gy, DataType& g); 93 94 //! Get the output parameter. OutputParameter() const95 OutputDataType const& OutputParameter() const { return outputParameter; } 96 //! Modify the output parameter. OutputParameter()97 OutputDataType& OutputParameter() { return outputParameter; } 98 99 //! Get the delta. Delta() const100 OutputDataType const& Delta() const { return delta; } 101 //! Modify the delta. Delta()102 OutputDataType& Delta() { return delta; } 103 104 //! Get the non zero gradient. Alpha() const105 double const& Alpha() const { return alpha; } 106 //! Modify the non zero gradient. Alpha()107 double& Alpha() { return alpha; } 108 109 //! Get the value of deterministic parameter. Deterministic() const110 bool Deterministic() const { return deterministic; } 111 //! Modify the value of deterministic parameter. Deterministic()112 bool& Deterministic() { return deterministic; } 113 114 /** 115 * Serialize the layer. 116 */ 117 template<typename Archive> 118 void serialize(Archive& ar, const unsigned int /* version */); 119 120 private: 121 //! Locally-stored delta object. 122 OutputDataType delta; 123 124 //! Locally-stored output parameter object. 125 OutputDataType outputParameter; 126 127 //! Locally stored first derivative of the activation function. 128 arma::mat derivative; 129 130 //! CELU Hyperparameter (alpha > 0). 131 double alpha; 132 133 //! If true the derivative computation is disabled, see notes above. 134 bool deterministic; 135 }; // class CELU 136 137 } // namespace ann 138 } // namespace mlpack 139 140 // Include implementation. 141 #include "celu_impl.hpp" 142 143 #endif 144