1Batch Normalization {#dev_guide_batch_normalization}
2====================================================
3
4>
5> [API Reference](@ref dnnl_api_batch_normalization)
6>
7
8## General
9
10The batch normalization primitive performs a forward or backward batch
11normalization operation on tensors with number of dimensions equal to 2 or more.
12
13### Forward
14
15The batch normalization operation is defined by the following formulas. We show
16formulas only for 2D spatial data which are straightforward to generalize to
17cases of higher and lower dimensions. Variable names follow the standard
18@ref dev_guide_conventions.
19
20\f[
21    \dst(n, c, h, w) =
22       \gamma(c) \cdot
23       \frac{\src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}}
24       + \beta(c),
25\f]
26
27where
28
29- \f$\gamma(c), \beta(c)\f$ are optional scale and shift for a channel
30(see #dnnl_use_scaleshift, #dnnl_use_scale and #dnnl_use_shift flags),
31
32- \f$\mu(c), \sigma^2(c)\f$ are mean and variance for a channel (see
33  #dnnl_use_global_stats flag), and
34
35- \f$\varepsilon\f$ is a constant to improve numerical stability.
36
37Mean and variance are computed at runtime or provided by a user. When mean and
38variance are computed at runtime, the following formulas are used:
39
40- \f$\mu(c) = \frac{1}{NHW} \sum\limits_{nhw} \src(n, c, h, w)_{}\f$,
41
42- \f$\sigma^2(c) = \frac{1}{NHW} \sum\limits_{nhw} {}_{} (\src(n, c, h, w) - \mu(c))^2\f$.
43
44The \f$\gamma(c)\f$ and \f$\beta(c)\f$ tensors are considered learnable.
45
46In training mode, the primitive also optionally supports fusion with ReLU
47activation with zero negative slope applied to the result
48(see #dnnl_fuse_norm_relu flag).
49
50@note
51* The batch normalization primitive computes population mean and variance and
52  not the sample or unbiased versions that are typically used to compute
53  running mean and variance.
54* Using the mean and variance computed by the batch normalization primitive,
55  running mean and variance \f$\hat\mu\f$ and \f$\hat\sigma^2\f$ can be
56  computed as \f[
57    \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\
58    \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2.
59  \f]
60
61#### Difference Between Forward Training and Forward Inference
62
63 * If mean and variance are computed at runtime (i.e., #dnnl_use_global_stats
64   is not set), they become outputs for the propagation kind
65   #dnnl_forward_training (because they would be required during the backward
66   propagation) and are not exposed for the propagation kind
67   #dnnl_forward_inference.
68
69 * If batch normalization is created with ReLU fusion (i.e.,
70   #dnnl_fuse_norm_relu is set), for the propagation kind
71   #dnnl_forward_training the primitive would produce a `workspace`
72   memory as one extra output. This memory is required to compute the backward
73   propagation. When the primitive is executed with propagation kind
74   #dnnl_forward_inference, the workspace is not produced. Behavior would
75   be the same as creating a batch normalization primitive with ReLU as a
76   post-op (see section below).
77
78### Backward
79
80The backward propagation computes
81\f$\diffsrc(n, c, h, w)\f$,
82\f$\diffgamma(c)^*\f$, and \f$\diffbeta(c)^*\f$
83based on
84\f$\diffdst(n, c, h, w)\f$, \f$\src(n, c, h, w)\f$, \f$\mu(c)\f$,
85\f$\sigma^2(c)\f$, \f$\gamma(c) ^*\f$, and \f$\beta(c) ^*\f$.
86
87The tensors marked with an asterisk are used only when the primitive is
88configured to use \f$\gamma(c)\f$ and \f$\beta(c)\f$ (i.e.,
89#dnnl_use_scaleshift, #dnnl_use_scale or #dnnl_use_shift are set).
90
91## Execution Arguments
92
93Depending on the [flags](@ref dnnl_normalization_flags_t) and
94[propagation kind](@ref dnnl_prop_kind_t), the batch normalization primitive
95requires different inputs and outputs.  For clarity, a summary is shown below.
96
97|                                                | #dnnl_forward_inference                                                                       | #dnnl_forward_training                                                                                                                        | #dnnl_backward                                                                                                                                 | #dnnl_backward_data                                                                                |
98| :--                                            | :--                                                                                           | :--                                                                                                                                           | :--                                                                                                                                            | :--                                                                                                |
99| #dnnl_normalization_flags_none                 | *Inputs*: \src <br><br> *Outputs*: \dst                                                       | *Inputs*: \src <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$                                                                            | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \diffsrc                                                               | Same as for #dnnl_backward                                                                         |
100| #dnnl_use_global_stats                         | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \dst                            | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \dst                                                                            | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \diffsrc                                                               | Same as for #dnnl_backward                                                                         |
101| #dnnl_use_scaleshift                           | *Inputs*: \src, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst                            | *Inputs*: \src, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$                                                 | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \diffsrc, \f$\diffgamma\f$, \f$\diffbeta\f$ | Not supported                                                                                      |
102| #dnnl_use_scale                                | *Inputs*: \src, \f$\gamma\f$  <br><br> *Outputs*: \dst                                        | *Inputs*: \src, \f$\gamma\f$ <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$                                                              | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$ <br><br> *Outputs*: \diffsrc, \f$\diffgamma\f$                               | Not supported                                                                                      |
103| #dnnl_use_shift                                | *Inputs*: \src, \f$\beta\f$ <br><br> *Outputs*: \dst                                          | *Inputs*: \src, \f$\beta\f$ <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$                                                               | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\beta\f$ <br><br> *Outputs*: \diffsrc, \f$\diffbeta\f$                                 | Not supported                                                                                      |
104| #dnnl_use_global_stats \| #dnnl_use_scaleshift | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst                                                 | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \diffsrc, \f$\diffgamma\f$, \f$\diffbeta\f$ | Not supported                                                                                      |
105| `flags` \| #dnnl_fuse_norm_relu                | *Inputs*: same as with `flags` <br><br> *Outputs*: same as with `flags`                       | *Inputs*: same as with `flags` <br><br> *Outputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace) | *Inputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace) <br><br> *Outputs*: same as with `flags`  | Same as for #dnnl_backward if `flags` do not contain #dnnl_use_scaleshift; not supported otherwise |
106
107When executed, the inputs and outputs should be mapped to an execution
108argument index as specified by the following table.
109
110| Primitive input/output      | Execution argument index  |
111| ---                         | ---                       |
112| \src                        | DNNL_ARG_SRC              |
113| \f$\gamma, \beta\f$         | DNNL_ARG_SCALE_SHIFT      |
114| \f$\gamma\f$                | DNNL_ARG_SCALE            |
115| \f$\beta\f$                 | DNNL_ARG_SHIFT            |
116| mean (\f$\mu\f$)            | DNNL_ARG_MEAN             |
117| variance (\f$\sigma^2\f$)   | DNNL_ARG_VARIANCE         |
118| \dst                        | DNNL_ARG_DST              |
119| workspace                   | DNNL_ARG_WORKSPACE        |
120| \diffdst                    | DNNL_ARG_DIFF_DST         |
121| \diffsrc                    | DNNL_ARG_DIFF_SRC         |
122| \f$\diffgamma, \diffbeta\f$ | DNNL_ARG_DIFF_SCALE_SHIFT |
123| \f$\diffgamma\f$            | DNNL_ARG_DIFF_SCALE       |
124| \f$\diffbeta\f$             | DNNL_ARG_DIFF_SHIFT       |
125
126## Implementation Details
127
128### General Notes
129
1301. The different flavors of the primitive are partially controlled by the @p
131   flags parameter that is passed to the operation descriptor initialization
132   function (e.g., dnnl::batch_normalization_forward::desc::desc()). Multiple
133   flags can be set using the bitwise OR operator (`|`). Flag
134   #dnnl_use_scaleshift can not be mixed with #dnnl_use_scale or #dnnl_use_shift.
135
1362. For forward propagation, the mean and variance might be either computed at
137   runtime (in which case they are outputs of the primitive) or provided by
138   a user (in which case they are inputs). In the latter case, a user must set
139   the #dnnl_use_global_stats flag. For the backward propagation, the mean and
140   variance are always input parameters.
141
1423. The memory format and data type for `src` and `dst` are assumed to be the
143   same, and in the API they are typically referred to as `data` (e.g., see
144   `data_desc` in dnnl::batch_normalization_forward::desc::desc()). The same is
145   true for `diff_src` and `diff_dst`. The corresponding memory descriptors are
146   referred to as `diff_data_desc`.
147
1484. Both forward and backward propagation support in-place operations, meaning
149   that \src can be used as input and output for forward propagation, and
150   \diffdst can be used as input and output for backward propagation. In case of
151   an in-place operation, the original data will be overwritten. Note, however,
152   that backward propagation requires original \src, hence the corresponding
153   forward propagation should not be performed in-place.
154
1555. As mentioned above, the batch normalization primitive can be fused with
156   ReLU activation even in the training mode. In this case, on the forward
157   propagation the primitive has one additional output, `workspace`, that
158   should be passed during the backward propagation.
159
160### Data Type Support
161
162The operation supports the following combinations of data types:
163
164| Propagation        | Source / Destination | Mean / Variance / ScaleShift
165| :--                | :--                  | :--
166| forward / backward | f32, bf16            | f32
167| forward            | f16                  | f32
168| forward            | s8                   | f32
169
170@warning
171    There might be hardware- or implementation-specific restrictions. Check the
172    [Implementation Limitations](@ref dg_bnorm_impl_limits) section below.
173
174### Data Representation
175
176#### Mean and Variance
177
178The mean (\f$\mu\f$) and variance (\f$\sigma^2\f$) are separate 1D tensors of
179size \f$C\f$.
180
181The format of the corresponding memory object must be #dnnl_x (#dnnl_a).
182
183#### Scale and Shift
184
185If #dnnl_use_scaleshift is used, the scale (\f$\gamma\f$) and shift
186(\f$\beta\f$) are combined in a single 2D tensor of shape \f$2 \times C\f$.
187
188If #dnnl_use_scale or #dnnl_use_shift are used, the scale (\f$\gamma\f$) and
189shift (\f$\beta\f$) are separate 1D tensors of shape \f$C\f$.
190
191
192The format of the corresponding memory object must be #dnnl_nc (#dnnl_ab).
193
194#### Source, Destination, and Their Gradients
195
196Like other CNN primitives, the batch normalization primitive expects data
197to be \f$N \times C \times SP_n \times \cdots \times SP_0\f$ tensor.
198
199The batch normalization primitive is optimized for the following memory formats:
200
201| Spatial | Logical tensor | Implementations optimized for memory formats
202| :--     | :--            | :--
203| 0D      | NC             | #dnnl_nc (#dnnl_ab)
204| 1D      | NCW            | #dnnl_ncw (#dnnl_abc), #dnnl_nwc (#dnnl_acb), *optimized^*
205| 2D      | NCHW           | #dnnl_nchw (#dnnl_abcd), #dnnl_nhwc (#dnnl_acdb), *optimized^*
206| 3D      | NCDHW          | #dnnl_ncdhw (#dnnl_abcde), #dnnl_ndhwc (#dnnl_acdeb), *optimized^*
207
208Here *optimized^* means the format that
209[comes out](@ref memory_format_propagation_cpp)
210of any preceding compute-intensive primitive.
211
212### Post-ops and Attributes
213
214Post-ops and attributes enable you to modify the behavior of the batch
215normalization primitive by chaining certain operations after the batch
216normalization operation. The following post-ops are supported by batch
217normalization primitives:
218
219| Propagation | Type    | Operation | Description
220| :--         | :--     | :--       | :--
221| forward     | post-op | eltwise   | Applies an @ref dnnl_api_eltwise operation to the result (currently only #dnnl_eltwise_relu algorithm is supported)
222
223@note As mentioned in @ref dev_guide_attributes, the post-ops should be used
224for inference only. For instance, using ReLU as a post-op would not produce the
225additional output `workspace` that is required to compute backward propagation
226correctly. Hence, in case of training one should use the #dnnl_fuse_norm_relu
227directly.
228
229@anchor dg_bnorm_impl_limits
230## Implementation Limitations
231
2321. Refer to @ref dev_guide_data_types for limitations related to data types
233   support.
234
2352. For the data types that have forward propagation support only, mean and
236   variance must be provided by a user (i.e., #dnnl_use_global_stats is set).
237
238
239## Performance Tips
240
2411. For backward propagation, use the same memory format for `src`, `diff_dst`,
242   and `diff_src` (the format of the `diff_dst` and `diff_src` are always the
243   same because of the API). Different formats are functionally supported but
244   lead to highly suboptimal performance.
245
2462. Use in-place operations whenever possible (see caveats in General Notes).
247
248## Examples
249
250| Engine  | Name                                 | Comments
251| :--     | :--                                  | :--
252| CPU/GPU | @ref batch_normalization_example_cpp | @copydetails batch_normalization_example_cpp_short
253