1 /** 2 * @file methods/ann/layer/batch_norm.hpp 3 * @author Praveen Ch 4 * @author Manthan-R-Sheth 5 * 6 * Definition of the Batch Normalization layer class. 7 * 8 * mlpack is free software; you may redistribute it and/or modify it under the 9 * terms of the 3-clause BSD license. You should have received a copy of the 10 * 3-clause BSD license along with mlpack. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP 14 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP 15 16 #include <mlpack/prereqs.hpp> 17 18 namespace mlpack { 19 namespace ann /** Artificial Neural Network. */ { 20 21 /** 22 * Declaration of the Batch Normalization layer class. The layer transforms 23 * the input data into zero mean and unit variance and then scales and shifts 24 * the data by parameters, gamma and beta respectively. These parameters are 25 * learnt by the network. 26 * 27 * If deterministic is false (training), the mean and variance over the batch is 28 * calculated and the data is normalized. If it is set to true (testing) then 29 * the mean and variance accrued over the training set is used. 30 * 31 * For more information, refer to the following paper, 32 * 33 * @code 34 * @article{Ioffe15, 35 * author = {Sergey Ioffe and 36 * Christian Szegedy}, 37 * title = {Batch Normalization: Accelerating Deep Network Training by 38 * Reducing Internal Covariate Shift}, 39 * journal = {CoRR}, 40 * volume = {abs/1502.03167}, 41 * year = {2015}, 42 * url = {http://arxiv.org/abs/1502.03167}, 43 * eprint = {1502.03167}, 44 * } 45 * @endcode 46 * 47 * @tparam InputDataType Type of the input data (arma::colvec, arma::mat, 48 * arma::sp_mat or arma::cube). 49 * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, 50 * arma::sp_mat or arma::cube). 51 */ 52 template < 53 typename InputDataType = arma::mat, 54 typename OutputDataType = arma::mat 55 > 56 class BatchNorm 57 { 58 public: 59 //! Create the BatchNorm object. 60 BatchNorm(); 61 62 /** 63 * Create the BatchNorm layer object for a specified number of input units. 64 * 65 * @param size The number of input units / channels. 66 * @param eps The epsilon added to variance to ensure numerical stability. 67 * @param average Boolean to determine whether cumulative average is used for 68 * updating the parameters or momentum is used. 69 * @param momentum Parameter used to to update the running mean and variance. 70 */ 71 BatchNorm(const size_t size, 72 const double eps = 1e-8, 73 const bool average = true, 74 const double momentum = 0.1); 75 76 /** 77 * Reset the layer parameters 78 */ 79 void Reset(); 80 81 /** 82 * Forward pass of the Batch Normalization layer. Transforms the input data 83 * into zero mean and unit variance, scales the data by a factor gamma and 84 * shifts it by beta. 85 * 86 * @param input Input data for the layer 87 * @param output Resulting output activations. 88 */ 89 template<typename eT> 90 void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output); 91 92 /** 93 * Backward pass through the layer. 94 * 95 * @param input The input activations 96 * @param gy The backpropagated error. 97 * @param g The calculated gradient. 98 */ 99 template<typename eT> 100 void Backward(const arma::Mat<eT>& input, 101 const arma::Mat<eT>& gy, 102 arma::Mat<eT>& g); 103 104 /** 105 * Calculate the gradient using the output delta and the input activations. 106 * 107 * @param input The input activations 108 * @param error The calculated error 109 * @param gradient The calculated gradient. 110 */ 111 template<typename eT> 112 void Gradient(const arma::Mat<eT>& input, 113 const arma::Mat<eT>& error, 114 arma::Mat<eT>& gradient); 115 116 //! Get the parameters. Parameters() const117 OutputDataType const& Parameters() const { return weights; } 118 //! Modify the parameters. Parameters()119 OutputDataType& Parameters() { return weights; } 120 121 //! Get the output parameter. OutputParameter() const122 OutputDataType const& OutputParameter() const { return outputParameter; } 123 //! Modify the output parameter. OutputParameter()124 OutputDataType& OutputParameter() { return outputParameter; } 125 126 //! Get the delta. Delta() const127 OutputDataType const& Delta() const { return delta; } 128 //! Modify the delta. Delta()129 OutputDataType& Delta() { return delta; } 130 131 //! Get the gradient. Gradient() const132 OutputDataType const& Gradient() const { return gradient; } 133 //! Modify the gradient. Gradient()134 OutputDataType& Gradient() { return gradient; } 135 136 //! Get the value of deterministic parameter. Deterministic() const137 bool Deterministic() const { return deterministic; } 138 //! Modify the value of deterministic parameter. Deterministic()139 bool& Deterministic() { return deterministic; } 140 141 //! Get the mean over the training data. TrainingMean() const142 OutputDataType const& TrainingMean() const { return runningMean; } 143 //! Modify the mean over the training data. TrainingMean()144 OutputDataType& TrainingMean() { return runningMean; } 145 146 //! Get the variance over the training data. TrainingVariance() const147 OutputDataType const& TrainingVariance() const { return runningVariance; } 148 //! Modify the variance over the training data. TrainingVariance()149 OutputDataType& TrainingVariance() { return runningVariance; } 150 151 //! Get the number of input units / channels. InputSize() const152 size_t InputSize() const { return size; } 153 154 //! Get the epsilon value. Epsilon() const155 double Epsilon() const { return eps; } 156 157 //! Get the momentum value. Momentum() const158 double Momentum() const { return momentum; } 159 160 //! Get the average parameter. Average() const161 bool Average() const { return average; } 162 163 //! Get size of weights. WeightSize() const164 size_t WeightSize() const { return 2 * size; } 165 166 /** 167 * Serialize the layer 168 */ 169 template<typename Archive> 170 void serialize(Archive& ar, const unsigned int /* version */); 171 172 private: 173 //! Locally-stored number of input units. 174 size_t size; 175 176 //! Locally-stored epsilon value. 177 double eps; 178 179 //! If true use average else use momentum for computing running mean 180 //! and variance 181 bool average; 182 183 //! Locally-stored value for momentum. 184 double momentum; 185 186 //! Variable to keep track of whether we are in loading or saving mode. 187 bool loading; 188 189 //! Locally-stored scale parameter. 190 OutputDataType gamma; 191 192 //! Locally-stored shift parameter. 193 OutputDataType beta; 194 195 //! Locally-stored mean object. 196 OutputDataType mean; 197 198 //! Locally-stored variance object. 199 OutputDataType variance; 200 201 //! Locally-stored parameters. 202 OutputDataType weights; 203 204 /** 205 * If true then mean and variance over the training set will be considered 206 * instead of being calculated over the batch. 207 */ 208 bool deterministic; 209 210 //! Locally-stored running mean/variance counter. 211 size_t count; 212 213 //! Locally-stored value for average factor which used to update running 214 //! mean and variance. 215 double averageFactor; 216 217 //! Locally-stored mean object. 218 OutputDataType runningMean; 219 220 //! Locally-stored variance object. 221 OutputDataType runningVariance; 222 223 //! Locally-stored gradient object. 224 OutputDataType gradient; 225 226 //! Locally-stored delta object. 227 OutputDataType delta; 228 229 //! Locally-stored output parameter object. 230 OutputDataType outputParameter; 231 232 //! Locally-stored normalized input. 233 arma::cube normalized; 234 235 //! Locally-stored zero mean input. 236 arma::cube inputMean; 237 }; // class BatchNorm 238 239 } // namespace ann 240 } // namespace mlpack 241 242 // Include the implementation. 243 #include "batch_norm_impl.hpp" 244 245 #endif 246