1 /**
2  * @file methods/ann/layer/batch_norm.hpp
3  * @author Praveen Ch
4  * @author Manthan-R-Sheth
5  *
6  * Definition of the Batch Normalization layer class.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann /** Artificial Neural Network. */ {
20 
21 /**
22  * Declaration of the Batch Normalization layer class. The layer transforms
23  * the input data into zero mean and unit variance and then scales and shifts
24  * the data by parameters, gamma and beta respectively. These parameters are
25  * learnt by the network.
26  *
27  * If deterministic is false (training), the mean and variance over the batch is
28  * calculated and the data is normalized. If it is set to true (testing) then
29  * the mean and variance accrued over the training set is used.
30  *
31  * For more information, refer to the following paper,
32  *
33  * @code
34  * @article{Ioffe15,
35  *   author    = {Sergey Ioffe and
36  *                Christian Szegedy},
37  *   title     = {Batch Normalization: Accelerating Deep Network Training by
38  *                Reducing Internal Covariate Shift},
39  *   journal   = {CoRR},
40  *   volume    = {abs/1502.03167},
41  *   year      = {2015},
42  *   url       = {http://arxiv.org/abs/1502.03167},
43  *   eprint    = {1502.03167},
44  * }
45  * @endcode
46  *
47  * @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
48  *         arma::sp_mat or arma::cube).
49  * @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
50  *         arma::sp_mat or arma::cube).
51  */
52 template <
53   typename InputDataType = arma::mat,
54   typename OutputDataType = arma::mat
55 >
56 class BatchNorm
57 {
58  public:
59   //! Create the BatchNorm object.
60   BatchNorm();
61 
62   /**
63    * Create the BatchNorm layer object for a specified number of input units.
64    *
65    * @param size The number of input units / channels.
66    * @param eps The epsilon added to variance to ensure numerical stability.
67    * @param average Boolean to determine whether cumulative average is used for
68    *                updating the parameters or momentum is used.
69    * @param momentum Parameter used to to update the running mean and variance.
70    */
71   BatchNorm(const size_t size,
72             const double eps = 1e-8,
73             const bool average = true,
74             const double momentum = 0.1);
75 
76   /**
77    * Reset the layer parameters
78    */
79   void Reset();
80 
81   /**
82    * Forward pass of the Batch Normalization layer. Transforms the input data
83    * into zero mean and unit variance, scales the data by a factor gamma and
84    * shifts it by beta.
85    *
86    * @param input Input data for the layer
87    * @param output Resulting output activations.
88    */
89   template<typename eT>
90   void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
91 
92   /**
93    * Backward pass through the layer.
94    *
95    * @param input The input activations
96    * @param gy The backpropagated error.
97    * @param g The calculated gradient.
98    */
99   template<typename eT>
100   void Backward(const arma::Mat<eT>& input,
101                 const arma::Mat<eT>& gy,
102                 arma::Mat<eT>& g);
103 
104   /**
105    * Calculate the gradient using the output delta and the input activations.
106    *
107    * @param input The input activations
108    * @param error The calculated error
109    * @param gradient The calculated gradient.
110    */
111   template<typename eT>
112   void Gradient(const arma::Mat<eT>& input,
113                 const arma::Mat<eT>& error,
114                 arma::Mat<eT>& gradient);
115 
116   //! Get the parameters.
Parameters() const117   OutputDataType const& Parameters() const { return weights; }
118   //! Modify the parameters.
Parameters()119   OutputDataType& Parameters() { return weights; }
120 
121   //! Get the output parameter.
OutputParameter() const122   OutputDataType const& OutputParameter() const { return outputParameter; }
123   //! Modify the output parameter.
OutputParameter()124   OutputDataType& OutputParameter() { return outputParameter; }
125 
126   //! Get the delta.
Delta() const127   OutputDataType const& Delta() const { return delta; }
128   //! Modify the delta.
Delta()129   OutputDataType& Delta() { return delta; }
130 
131   //! Get the gradient.
Gradient() const132   OutputDataType const& Gradient() const { return gradient; }
133   //! Modify the gradient.
Gradient()134   OutputDataType& Gradient() { return gradient; }
135 
136   //! Get the value of deterministic parameter.
Deterministic() const137   bool Deterministic() const { return deterministic; }
138   //! Modify the value of deterministic parameter.
Deterministic()139   bool& Deterministic() { return deterministic; }
140 
141   //! Get the mean over the training data.
TrainingMean() const142   OutputDataType const& TrainingMean() const { return runningMean; }
143   //! Modify the mean over the training data.
TrainingMean()144   OutputDataType& TrainingMean() { return runningMean; }
145 
146   //! Get the variance over the training data.
TrainingVariance() const147   OutputDataType const& TrainingVariance() const { return runningVariance; }
148   //! Modify the variance over the training data.
TrainingVariance()149   OutputDataType& TrainingVariance() { return runningVariance; }
150 
151   //! Get the number of input units / channels.
InputSize() const152   size_t InputSize() const { return size; }
153 
154   //! Get the epsilon value.
Epsilon() const155   double Epsilon() const { return eps; }
156 
157   //! Get the momentum value.
Momentum() const158   double Momentum() const { return momentum; }
159 
160   //! Get the average parameter.
Average() const161   bool Average() const { return average; }
162 
163   //! Get size of weights.
WeightSize() const164   size_t WeightSize() const { return 2 * size; }
165 
166   /**
167    * Serialize the layer
168    */
169   template<typename Archive>
170   void serialize(Archive& ar, const unsigned int /* version */);
171 
172  private:
173   //! Locally-stored number of input units.
174   size_t size;
175 
176   //! Locally-stored epsilon value.
177   double eps;
178 
179   //! If true use average else use momentum for computing running mean
180   //! and variance
181   bool average;
182 
183   //! Locally-stored value for momentum.
184   double momentum;
185 
186   //! Variable to keep track of whether we are in loading or saving mode.
187   bool loading;
188 
189   //! Locally-stored scale parameter.
190   OutputDataType gamma;
191 
192   //! Locally-stored shift parameter.
193   OutputDataType beta;
194 
195   //! Locally-stored mean object.
196   OutputDataType mean;
197 
198   //! Locally-stored variance object.
199   OutputDataType variance;
200 
201   //! Locally-stored parameters.
202   OutputDataType weights;
203 
204   /**
205    * If true then mean and variance over the training set will be considered
206    * instead of being calculated over the batch.
207    */
208   bool deterministic;
209 
210   //! Locally-stored running mean/variance counter.
211   size_t count;
212 
213   //! Locally-stored value for average factor which used to update running
214   //! mean and variance.
215   double averageFactor;
216 
217   //! Locally-stored mean object.
218   OutputDataType runningMean;
219 
220   //! Locally-stored variance object.
221   OutputDataType runningVariance;
222 
223   //! Locally-stored gradient object.
224   OutputDataType gradient;
225 
226   //! Locally-stored delta object.
227   OutputDataType delta;
228 
229   //! Locally-stored output parameter object.
230   OutputDataType outputParameter;
231 
232   //! Locally-stored normalized input.
233   arma::cube normalized;
234 
235   //! Locally-stored zero mean input.
236   arma::cube inputMean;
237 }; // class BatchNorm
238 
239 } // namespace ann
240 } // namespace mlpack
241 
242 // Include the implementation.
243 #include "batch_norm_impl.hpp"
244 
245 #endif
246