1Batch Normalization {#dev_guide_batch_normalization} 2==================================================== 3 4> 5> [API Reference](@ref dnnl_api_batch_normalization) 6> 7 8## General 9 10The batch normalization primitive performs a forward or backward batch 11normalization operation on tensors with number of dimensions equal to 2 or more. 12 13### Forward 14 15The batch normalization operation is defined by the following formulas. We show 16formulas only for 2D spatial data which are straightforward to generalize to 17cases of higher and lower dimensions. Variable names follow the standard 18@ref dev_guide_conventions. 19 20\f[ 21 \dst(n, c, h, w) = 22 \gamma(c) \cdot 23 \frac{\src(n, c, h, w) - \mu(c)} {\sqrt{\sigma^2(c) + \varepsilon}} 24 + \beta(c), 25\f] 26 27where 28 29- \f$\gamma(c), \beta(c)\f$ are optional scale and shift for a channel 30(see #dnnl_use_scaleshift, #dnnl_use_scale and #dnnl_use_shift flags), 31 32- \f$\mu(c), \sigma^2(c)\f$ are mean and variance for a channel (see 33 #dnnl_use_global_stats flag), and 34 35- \f$\varepsilon\f$ is a constant to improve numerical stability. 36 37Mean and variance are computed at runtime or provided by a user. When mean and 38variance are computed at runtime, the following formulas are used: 39 40- \f$\mu(c) = \frac{1}{NHW} \sum\limits_{nhw} \src(n, c, h, w)_{}\f$, 41 42- \f$\sigma^2(c) = \frac{1}{NHW} \sum\limits_{nhw} {}_{} (\src(n, c, h, w) - \mu(c))^2\f$. 43 44The \f$\gamma(c)\f$ and \f$\beta(c)\f$ tensors are considered learnable. 45 46In training mode, the primitive also optionally supports fusion with ReLU 47activation with zero negative slope applied to the result 48(see #dnnl_fuse_norm_relu flag). 49 50@note 51* The batch normalization primitive computes population mean and variance and 52 not the sample or unbiased versions that are typically used to compute 53 running mean and variance. 54* Using the mean and variance computed by the batch normalization primitive, 55 running mean and variance \f$\hat\mu\f$ and \f$\hat\sigma^2\f$ can be 56 computed as \f[ 57 \hat\mu := \alpha \cdot \hat\mu + (1 - \alpha) \cdot \mu, \\ 58 \hat\sigma^2 := \alpha \cdot \hat\sigma^2 + (1 - \alpha) \cdot \sigma^2. 59 \f] 60 61#### Difference Between Forward Training and Forward Inference 62 63 * If mean and variance are computed at runtime (i.e., #dnnl_use_global_stats 64 is not set), they become outputs for the propagation kind 65 #dnnl_forward_training (because they would be required during the backward 66 propagation) and are not exposed for the propagation kind 67 #dnnl_forward_inference. 68 69 * If batch normalization is created with ReLU fusion (i.e., 70 #dnnl_fuse_norm_relu is set), for the propagation kind 71 #dnnl_forward_training the primitive would produce a `workspace` 72 memory as one extra output. This memory is required to compute the backward 73 propagation. When the primitive is executed with propagation kind 74 #dnnl_forward_inference, the workspace is not produced. Behavior would 75 be the same as creating a batch normalization primitive with ReLU as a 76 post-op (see section below). 77 78### Backward 79 80The backward propagation computes 81\f$\diffsrc(n, c, h, w)\f$, 82\f$\diffgamma(c)^*\f$, and \f$\diffbeta(c)^*\f$ 83based on 84\f$\diffdst(n, c, h, w)\f$, \f$\src(n, c, h, w)\f$, \f$\mu(c)\f$, 85\f$\sigma^2(c)\f$, \f$\gamma(c) ^*\f$, and \f$\beta(c) ^*\f$. 86 87The tensors marked with an asterisk are used only when the primitive is 88configured to use \f$\gamma(c)\f$ and \f$\beta(c)\f$ (i.e., 89#dnnl_use_scaleshift, #dnnl_use_scale or #dnnl_use_shift are set). 90 91## Execution Arguments 92 93Depending on the [flags](@ref dnnl_normalization_flags_t) and 94[propagation kind](@ref dnnl_prop_kind_t), the batch normalization primitive 95requires different inputs and outputs. For clarity, a summary is shown below. 96 97| | #dnnl_forward_inference | #dnnl_forward_training | #dnnl_backward | #dnnl_backward_data | 98| :-- | :-- | :-- | :-- | :-- | 99| #dnnl_normalization_flags_none | *Inputs*: \src <br><br> *Outputs*: \dst | *Inputs*: \src <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \diffsrc | Same as for #dnnl_backward | 100| #dnnl_use_global_stats | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \dst | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \dst | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$ <br><br> *Outputs*: \diffsrc | Same as for #dnnl_backward | 101| #dnnl_use_scaleshift | *Inputs*: \src, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst | *Inputs*: \src, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \diffsrc, \f$\diffgamma\f$, \f$\diffbeta\f$ | Not supported | 102| #dnnl_use_scale | *Inputs*: \src, \f$\gamma\f$ <br><br> *Outputs*: \dst | *Inputs*: \src, \f$\gamma\f$ <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$ <br><br> *Outputs*: \diffsrc, \f$\diffgamma\f$ | Not supported | 103| #dnnl_use_shift | *Inputs*: \src, \f$\beta\f$ <br><br> *Outputs*: \dst | *Inputs*: \src, \f$\beta\f$ <br><br> *Outputs*: \dst, \f$\mu\f$, \f$\sigma^2\f$ | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\beta\f$ <br><br> *Outputs*: \diffsrc, \f$\diffbeta\f$ | Not supported | 104| #dnnl_use_global_stats \| #dnnl_use_scaleshift | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst | *Inputs*: \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \dst | *Inputs*: \diffdst, \src, \f$\mu\f$, \f$\sigma^2\f$, \f$\gamma\f$, \f$\beta\f$ <br><br> *Outputs*: \diffsrc, \f$\diffgamma\f$, \f$\diffbeta\f$ | Not supported | 105| `flags` \| #dnnl_fuse_norm_relu | *Inputs*: same as with `flags` <br><br> *Outputs*: same as with `flags` | *Inputs*: same as with `flags` <br><br> *Outputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace) | *Inputs*: same as with `flags`, [Workspace](@ref dev_guide_inference_and_training_aspects_workspace) <br><br> *Outputs*: same as with `flags` | Same as for #dnnl_backward if `flags` do not contain #dnnl_use_scaleshift; not supported otherwise | 106 107When executed, the inputs and outputs should be mapped to an execution 108argument index as specified by the following table. 109 110| Primitive input/output | Execution argument index | 111| --- | --- | 112| \src | DNNL_ARG_SRC | 113| \f$\gamma, \beta\f$ | DNNL_ARG_SCALE_SHIFT | 114| \f$\gamma\f$ | DNNL_ARG_SCALE | 115| \f$\beta\f$ | DNNL_ARG_SHIFT | 116| mean (\f$\mu\f$) | DNNL_ARG_MEAN | 117| variance (\f$\sigma^2\f$) | DNNL_ARG_VARIANCE | 118| \dst | DNNL_ARG_DST | 119| workspace | DNNL_ARG_WORKSPACE | 120| \diffdst | DNNL_ARG_DIFF_DST | 121| \diffsrc | DNNL_ARG_DIFF_SRC | 122| \f$\diffgamma, \diffbeta\f$ | DNNL_ARG_DIFF_SCALE_SHIFT | 123| \f$\diffgamma\f$ | DNNL_ARG_DIFF_SCALE | 124| \f$\diffbeta\f$ | DNNL_ARG_DIFF_SHIFT | 125 126## Implementation Details 127 128### General Notes 129 1301. The different flavors of the primitive are partially controlled by the @p 131 flags parameter that is passed to the operation descriptor initialization 132 function (e.g., dnnl::batch_normalization_forward::desc::desc()). Multiple 133 flags can be set using the bitwise OR operator (`|`). Flag 134 #dnnl_use_scaleshift can not be mixed with #dnnl_use_scale or #dnnl_use_shift. 135 1362. For forward propagation, the mean and variance might be either computed at 137 runtime (in which case they are outputs of the primitive) or provided by 138 a user (in which case they are inputs). In the latter case, a user must set 139 the #dnnl_use_global_stats flag. For the backward propagation, the mean and 140 variance are always input parameters. 141 1423. The memory format and data type for `src` and `dst` are assumed to be the 143 same, and in the API they are typically referred to as `data` (e.g., see 144 `data_desc` in dnnl::batch_normalization_forward::desc::desc()). The same is 145 true for `diff_src` and `diff_dst`. The corresponding memory descriptors are 146 referred to as `diff_data_desc`. 147 1484. Both forward and backward propagation support in-place operations, meaning 149 that \src can be used as input and output for forward propagation, and 150 \diffdst can be used as input and output for backward propagation. In case of 151 an in-place operation, the original data will be overwritten. Note, however, 152 that backward propagation requires original \src, hence the corresponding 153 forward propagation should not be performed in-place. 154 1555. As mentioned above, the batch normalization primitive can be fused with 156 ReLU activation even in the training mode. In this case, on the forward 157 propagation the primitive has one additional output, `workspace`, that 158 should be passed during the backward propagation. 159 160### Data Type Support 161 162The operation supports the following combinations of data types: 163 164| Propagation | Source / Destination | Mean / Variance / ScaleShift 165| :-- | :-- | :-- 166| forward / backward | f32, bf16 | f32 167| forward | f16 | f32 168| forward | s8 | f32 169 170@warning 171 There might be hardware- or implementation-specific restrictions. Check the 172 [Implementation Limitations](@ref dg_bnorm_impl_limits) section below. 173 174### Data Representation 175 176#### Mean and Variance 177 178The mean (\f$\mu\f$) and variance (\f$\sigma^2\f$) are separate 1D tensors of 179size \f$C\f$. 180 181The format of the corresponding memory object must be #dnnl_x (#dnnl_a). 182 183#### Scale and Shift 184 185If #dnnl_use_scaleshift is used, the scale (\f$\gamma\f$) and shift 186(\f$\beta\f$) are combined in a single 2D tensor of shape \f$2 \times C\f$. 187 188If #dnnl_use_scale or #dnnl_use_shift are used, the scale (\f$\gamma\f$) and 189shift (\f$\beta\f$) are separate 1D tensors of shape \f$C\f$. 190 191 192The format of the corresponding memory object must be #dnnl_nc (#dnnl_ab). 193 194#### Source, Destination, and Their Gradients 195 196Like other CNN primitives, the batch normalization primitive expects data 197to be \f$N \times C \times SP_n \times \cdots \times SP_0\f$ tensor. 198 199The batch normalization primitive is optimized for the following memory formats: 200 201| Spatial | Logical tensor | Implementations optimized for memory formats 202| :-- | :-- | :-- 203| 0D | NC | #dnnl_nc (#dnnl_ab) 204| 1D | NCW | #dnnl_ncw (#dnnl_abc), #dnnl_nwc (#dnnl_acb), *optimized^* 205| 2D | NCHW | #dnnl_nchw (#dnnl_abcd), #dnnl_nhwc (#dnnl_acdb), *optimized^* 206| 3D | NCDHW | #dnnl_ncdhw (#dnnl_abcde), #dnnl_ndhwc (#dnnl_acdeb), *optimized^* 207 208Here *optimized^* means the format that 209[comes out](@ref memory_format_propagation_cpp) 210of any preceding compute-intensive primitive. 211 212### Post-ops and Attributes 213 214Post-ops and attributes enable you to modify the behavior of the batch 215normalization primitive by chaining certain operations after the batch 216normalization operation. The following post-ops are supported by batch 217normalization primitives: 218 219| Propagation | Type | Operation | Description 220| :-- | :-- | :-- | :-- 221| forward | post-op | eltwise | Applies an @ref dnnl_api_eltwise operation to the result (currently only #dnnl_eltwise_relu algorithm is supported) 222 223@note As mentioned in @ref dev_guide_attributes, the post-ops should be used 224for inference only. For instance, using ReLU as a post-op would not produce the 225additional output `workspace` that is required to compute backward propagation 226correctly. Hence, in case of training one should use the #dnnl_fuse_norm_relu 227directly. 228 229@anchor dg_bnorm_impl_limits 230## Implementation Limitations 231 2321. Refer to @ref dev_guide_data_types for limitations related to data types 233 support. 234 2352. For the data types that have forward propagation support only, mean and 236 variance must be provided by a user (i.e., #dnnl_use_global_stats is set). 237 238 239## Performance Tips 240 2411. For backward propagation, use the same memory format for `src`, `diff_dst`, 242 and `diff_src` (the format of the `diff_dst` and `diff_src` are always the 243 same because of the API). Different formats are functionally supported but 244 lead to highly suboptimal performance. 245 2462. Use in-place operations whenever possible (see caveats in General Notes). 247 248## Examples 249 250| Engine | Name | Comments 251| :-- | :-- | :-- 252| CPU/GPU | @ref batch_normalization_example_cpp | @copydetails batch_normalization_example_cpp_short 253