1Convolution {#dev_guide_convolution}
2=====================================
3
4>
5> [API Reference](@ref dnnl_api_convolution)
6>
7
8## General
9
10The convolution primitive computes forward, backward, or weight update for a
11batched convolution operation on 1D, 2D, or 3D spatial data with bias.
12
13The convolution operation is defined by the following formulas. We show formulas
14only for 2D spatial data which are straightforward to generalize to cases of
15higher and lower dimensions. Variable names follow the standard
16@ref dev_guide_conventions.
17
18Let \src, \weights and \dst be \f$N \times IC \times IH \times
19IW\f$, \f$OC \times IC \times KH \times KW\f$, and \f$N \times OC \times OH
20\times OW\f$ tensors respectively. Let \bias be a 1D tensor with \f$OC\f$
21elements.
22
23Furthermore, let the remaining convolution parameters be:
24
25| Parameter                            | Depth      | Height     | Width      | Comment                                                                                                                |
26| --:--                                | :--        | :--        | :--        | :--                                                                                                                    |
27| Padding: <br>Front, top, and left    | \f$PD_L\f$ | \f$PH_L\f$ | \f$PW_L\f$ | In the API we use `padding_l` to indicate the corresponding vector of paddings (`_l` in the name stands for **left**)  |
28| Padding: <br>Back, bottom, and right | \f$PD_R\f$ | \f$PH_R\f$ | \f$PW_R\f$ | In the API we use `padding_r` to indicate the corresponding vector of paddings (`_r` in the name stands for **right**) |
29| Stride                               | \f$SD\f$   | \f$SH\f$   | \f$SW\f$   | Convolution without strides is defined by setting the stride parameters to 1                                           |
30| Dilation                             | \f$DD\f$   | \f$DH\f$   | \f$DW\f$   | Non-dilated convolution is defined by setting the dilation parameters to 0                                             |
31
32The following formulas show how oneDNN computes convolutions. They are
33broken down into several types to simplify the exposition, but in reality the
34convolution types can be combined.
35
36To further simplify the formulas, we assume that \f$\src(n, ic, ih, iw) = 0\f$
37if \f$ih < 0\f$, or \f$ih \geq IH\f$, or \f$iw < 0\f$, or \f$iw \geq IW\f$.
38
39### Forward
40
41#### Regular Convolution
42
43\f[\dst(n, oc, oh, ow) =  \bias(oc) \\
44    + \sum_{ic=0}^{IC-1}\sum_{kh=0}^{KH-1}\sum_{kw=0}^{KW-1}
45        \src(n, ic, oh \cdot SH + kh - PH_L, ow \cdot SW + kw - PW_L)
46        \cdot
47        \weights(oc, ic, kh, kw).\f]
48
49Here:
50
51- \f$OH = \left\lfloor{\frac{IH - KH + PH_L + PH_R}{SH}} \right\rfloor + 1,\f$
52
53- \f$OW = \left\lfloor{\frac{IW - KW + PW_L + PW_R}{SW}} \right\rfloor + 1.\f$
54
55#### Convolution with Groups
56
57In the API, oneDNN adds a separate groups dimension to memory objects
58representing \weights tensors and represents weights as \f$G \times OC_G \times
59IC_G \times KH \times KW \f$ 5D tensors for 2D convolutions with groups.
60
61\f[
62    \dst(n, g \cdot OC_G + oc_g, oh, ow) =
63        \bias(g \cdot OC_G + oc_g) \\
64        +
65        \sum_{ic_g=0}^{IC_G-1}\sum_{kh=0}^{KH-1}\sum_{kw=0}^{KW-1}
66            \src(n, g \cdot IC_G + ic_g, oh \cdot SH + kh - PH_L,
67                    ow \cdot SW + kw - PW_L)
68            \cdot
69            \weights(g, oc_g, ic_g, kh, kw),
70\f]
71
72where
73- \f$IC_G = \frac{IC}{G}\f$,
74- \f$OC_G = \frac{OC}{G}\f$, and
75- \f$oc_g \in [0, OC_G).\f$
76
77The case when \f$OC_G = IC_G = 1\f$ is also known as *a depthwise convolution*.
78
79#### Convolution with Dilation
80
81\f[
82    \dst(n, oc, oh, ow) =
83        \bias(oc) \\
84        +
85        \sum_{ic=0}^{IC-1}\sum_{kh=0}^{KH-1}\sum_{kw=0}^{KW-1}
86            \src(n, ic, oh \cdot SH + kh \cdot (DH + 1) - PH_L,
87                    ow \cdot SW + kw \cdot (DW + 1) - PW_L)
88            \cdot
89            \weights(oc, ic, kh, kw).
90\f]
91
92Here:
93
94- \f$OH = \left\lfloor{\frac{IH - DKH + PH_L + PH_R}{SH}}
95        \right\rfloor + 1,\f$ where \f$DKH = 1 + (KH - 1) \cdot (DH + 1)\f$, and
96
97- \f$OW = \left\lfloor{\frac{IW - DKW + PW_L + PW_R}{SW}}
98        \right\rfloor + 1,\f$ where \f$DKW = 1 + (KW - 1) \cdot (DW + 1)\f$.
99
100#### Deconvolution (Transposed Convolution)
101
102Deconvolutions (also called fractionally strided convolutions or transposed
103convolutions) work by swapping the forward and backward passes of a
104convolution. One way to put it is to note that the weights define a
105convolution, but whether it is a direct convolution or a transposed
106convolution is determined by how the forward and backward passes are computed.
107
108#### Difference Between Forward Training and Forward Inference
109
110There is no difference between the #dnnl_forward_training
111and #dnnl_forward_inference propagation kinds.
112
113### Backward
114
115The backward propagation computes \diffsrc based on \diffdst and
116\weights.
117
118The weights update computes \diffweights and \diffbias based on
119\diffdst and \src.
120
121@note The *optimized* memory formats \src and \weights might be
122different on forward propagation, backward propagation, and weights
123update.
124
125## Execution Arguments
126
127When executed, the inputs and outputs should be mapped to an execution
128argument index as specified by the following table.
129
130| Primitive input/output      | Execution argument index                                                  |
131| ---                         | ---                                                                       |
132| \src                        | DNNL_ARG_SRC                                                              |
133| \weights                    | DNNL_ARG_WEIGHTS                                                          |
134| \bias                       | DNNL_ARG_BIAS                                                             |
135| \dst                        | DNNL_ARG_DST                                                              |
136| \diffsrc                    | DNNL_ARG_DIFF_SRC                                                         |
137| \diffweights                | DNNL_ARG_DIFF_WEIGHTS                                                     |
138| \diffbias                   | DNNL_ARG_DIFF_BIAS                                                        |
139| \diffdst                    | DNNL_ARG_DIFF_DST                                                         |
140| \f$depthwise\f$             | DNNL_ARG_ATTR_POST_OP_DW                                                  |
141| \f$\text{binary post-op}\f$ | DNNL_ARG_ATTR_MULTIPLE_POST_OP(binary_post_op_position) \| DNNL_ARG_SRC_1 |
142
143## Implementation Details
144
145### General Notes
146
147N/A.
148
149### Data Types
150
151Convolution primitive supports the following combination of data types for
152source, destination, and weights memory objects:
153
154| Propagation        | Source    | Weights   | Destination                  | Bias                        |
155| :--                | :--       | :--       | :--                          | :--                         |
156| forward            | f32       | f32       | f32, s8                      | f32                         |
157| forward            | f16       | f16       | f16, f32, u8, s8             | f16, f32                    |
158| forward            | u8, s8    | s8        | u8, s8, s32, f32, f16, bf16  | u8, s8, s32, f32, f16, bf16 |
159| forward            | bf16      | bf16      | f32, bf16                    | f32, bf16                   |
160| backward           | f32, bf16 | bf16      | bf16                         |                             |
161| backward           | f32       | f32       | f32                          | f32                         |
162| weights update     | bf16      | f32, bf16 | bf16, s8, u8                 | f32, bf16                   |
163
164@warning
165    There might be hardware and/or implementation specific restrictions.
166    Check [Implementation Limitations](@ref dg_conv_impl_limits) section below.
167
168### Data Representation
169
170Like other CNN primitives, the convolution primitive expects the following
171tensors:
172
173| Spatial | Source / Destination                        | Weights
174| :--     | :--                                         | :--
175| 1D      | \f$N \times C \times W\f$                   | \f$[G \times ] OC \times IC \times KW\f$
176| 2D      | \f$N \times C \times H \times W\f$          | \f$[G \times ] OC \times IC \times KH \times KW\f$
177| 3D      | \f$N \times C \times D \times H \times W\f$ | \f$[G \times ] OC \times IC \times KD \times KH \times KW\f$
178
179Physical format of data and weights memory objects is critical for convolution
180primitive performance. In the oneDNN programming model, convolution is
181one of the few primitives that support the placeholder memory format tag
182 #dnnl::memory::format_tag::any (shortened to `any` from now on) and can
183define data and weight memory objects format based on the primitive parameters.
184When using `any` it is necessary to first create a convolution primitive
185descriptor and then query it for the actual data and weight memory objects
186formats.
187
188While convolution primitives can be created with memory formats specified
189explicitly, the performance is likely to be suboptimal.
190
191The table below shows the combinations for which **plain** memory formats
192the convolution primitive is optimized for.
193
194| Spatial    | Convolution Type | Data / Weights logical tensor | Implementation optimized for memory formats
195| :--        | :--              | :--                           | :--
196| 1D, 2D, 3D |                  | `any`                         | *optimized*
197| 1D         | f32, bf16        | NCW / OIW, GOIW               | #dnnl_ncw (#dnnl_abc) / #dnnl_oiw (#dnnl_abc), #dnnl_goiw (#dnnl_abcd)
198| 1D         | \"               | \"                            | #dnnl_nwc (#dnnl_acb) / #dnnl_wio (#dnnl_cba), #dnnl_wigo (#dnnl_dcab)
199| 1D         | int8             | NCW / OIW                     | #dnnl_nwc (#dnnl_acb) / #dnnl_wio (#dnnl_cba)
200| 2D         | f32, bf16        | NCHW / OIHW, GOIHW            | #dnnl_nchw (#dnnl_abcd) / #dnnl_oihw (#dnnl_abcd), #dnnl_goihw (#dnnl_abcde)
201| 2D         | \"               | \"                            | #dnnl_nhwc (#dnnl_acdb) / #dnnl_hwio (#dnnl_cdba), #dnnl_hwigo (#dnnl_decab)
202| 2D         | int8             | NCHW / OIHW, GOIHW            | #dnnl_nhwc (#dnnl_acdb) / #dnnl_hwio (#dnnl_cdba), #dnnl_hwigo (#dnnl_decab)
203| 3D         | f32, bf16        | NCDHW / OIDHW, GOIDHW         | #dnnl_ncdhw (#dnnl_abcde) / #dnnl_oidhw (#dnnl_abcde), #dnnl_goidhw (#dnnl_abcdef)
204| 3D         | \"               | \"                            | #dnnl_ndhwc (#dnnl_acdeb) / #dnnl_dhwio (#dnnl_cdeba), #dnnl_dhwigo (#dnnl_defcab)
205| 3D         | int8             | NCDHW / OIDHW                 | #dnnl_ndhwc (#dnnl_acdeb) / #dnnl_dhwio (#dnnl_cdeba)
206
207### Post-ops and Attributes
208
209Post-ops and attributes enable you to modify the behavior of the convolution
210primitive by applying the output scale to the result of the primitive and by
211chaining certain operations after the primitive. The following attributes and
212post-ops are supported:
213
214| Propagation | Type      | Operation                                                    | Description                                                                   | Restrictions                        |
215| :--         | :--       | :--                                                          | :--                                                                           | :--                                 |
216| forward     | attribute | [Output scale](@ref dnnl::primitive_attr::set_output_scales) | Scales the result of convolution by given scale factor(s)                     | int8 convolutions only              |
217| forward     | attribute | [Zero points](@ref dnnl::primitive_attr::set_zero_points)    | Sets zero point(s) for the corresponding tensors                              | int8 convolutions only              |
218| forward     | post-op   | [Eltwise](@ref dnnl::post_ops::append_eltwise)               | Applies an @ref dnnl_api_eltwise operation to the result                      |                                     |
219| forward     | post-op   | [Sum](@ref dnnl::post_ops::append_sum)                       | Adds the operation result to the destination tensor instead of overwriting it |                                     |
220| forward     | post-op   | [Binary](@ref dnnl::post_ops::append_binary)                 | Applies a @ref dnnl_api_binary operation to the result                        | General binary post-op restrictions |
221
222To facilitate dynamic quantization, the primitive supports run-time output
223scales. That means a user could configure attributes with output scales set to
224the #DNNL_RUNTIME_F32_VAL wildcard value instead of the actual scales,
225if the scales are not known at the primitive descriptor creation stage.
226In this case, the user must provide the scales as an additional input memory
227object with argument `DNNL_ARG_ATTR_OUTPUT_SCALES` during the execution stage.
228
229Similarly to run-time output scales, the primitive supports run-time zero
230points. The wildcard value for zero points is #DNNL_RUNTIME_S32_VAL. The
231following masks are supported by the primitive:
232- 0, which applies one zero point value to an entire tensor, and
233- 2, which applies a zero point value per each element in a `IC` or `OC`
234  dimension for `DNNL_ARG_SRC` or `DNNL_ARG_DST` arguments respectively.
235
236During the execution stage, the corresponding memory object must be passed as an
237argument with its index set to
238(`DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_${MEMORY_INDEX}`). Possible
239`${MEMORY_INDEX}` values are `DNNL_ARG_SRC` and `DNNL_ARG_DST`.
240- For instance, a source tensor zero points memory argument would be passed with
241  index (`DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC`).
242
243@note The library does not prevent using post-ops in training, but note that
244not all post-ops are feasible for training usage. For instance, using ReLU
245with non-zero negative slope parameter as a post-op would not produce an
246additional output `workspace` that is required to compute backward propagation
247correctly. Hence, in this particular case one should use separate convolution
248and eltwise primitives for training.
249
250The library supports any number and order of post operations, but only the
251following sequences deploy optimized code:
252
253| Type of convolutions      | Post-ops sequence supported
254| :--                       | :--
255| f32 and bf16 convolution  | eltwise, sum, sum -> eltwise
256| int8 convolution          | eltwise, sum, sum -> eltwise, eltwise -> sum
257
258The attributes and post-ops take effect in the following sequence:
259- Source zero point attribute,
260- Output scale attribute,
261- Post-ops, in the order they were attached,
262- Destination zero point attribute.
263
264The operations during attributes and post-ops applying are done in single
265precision floating point data type. The conversion to the actual destination
266data type happens just before the actual storing.
267
268#### Example 1
269
270Consider the following pseudo-code:
271
272~~~
273    primitive_attr attr;
274    attr.set_output_scale(mask=0, alpha);
275    attr.set_post_ops({
276            { sum={scale=beta} },
277            { eltwise={scale=gamma, type=tanh, alpha=ignore, beta=ignored } }
278        });
279
280    convolution_forward(src, weights, dst, attr);
281~~~
282
283The would lead to the following:
284
285\f[
286    \dst(\overline{x}) =
287        \gamma \cdot \tanh \left(
288            \alpha \cdot conv(\src, \weights) +
289            \beta  \cdot \dst(\overline{x})
290        \right)
291\f]
292
293#### Example 2
294
295The following pseudo-code:
296
297~~~
298    primitive_attr attr;
299    attr.set_output_scale(mask=0, alpha);
300    attr.set_post_ops({
301            { eltwise={scale=gamma, type=relu, alpha=eta, beta=ignored } },
302            { sum={scale=beta} }
303        });
304
305    convolution_forward(src, weights, dst, attr);
306~~~
307
308That would lead to the following:
309
310\f[
311    \dst(\overline{x}) =
312        \beta \cdot \dst(\overline{x}) +
313        \gamma \cdot ReLU \left(
314            \alpha \cdot conv(\src, \weights),
315            \eta
316        \right)
317\f]
318
319#### Example 3
320
321The following pseudo-code:
322
323~~~
324    primitive_attr attr;
325    attr.set_output_scale(mask=0, alpha);
326    attr.set_zero_point(src, mask=0, shift_src);
327    attr.set_zero_point(dst, mask=0, shift_dst);
328    attr.set_post_ops({
329            { eltwise={scale=gamma, type=relu, alpha=eta, beta=ignored } }
330        });
331
332    convolution_forward(src, weights, dst, attr);
333~~~
334
335That would lead to the following:
336
337\f[
338    \dst(\overline{x}) =
339        \gamma \cdot ReLU \left(
340            \alpha \cdot conv(\src - shift_{src}, \weights),
341            \eta
342        \right) + shift_{dst}
343\f]
344
345## Algorithms
346
347oneDNN implements convolution primitives using several different
348algorithms:
349
350- _Direct_. The convolution operation is computed directly using SIMD
351  instructions. This is the algorithm used for the most shapes and supports
352  int8, f32 and bf16 data types.
353
354- _Winograd_. This algorithm reduces computational complexity of convolution
355  at the expense of accuracy loss and additional memory operations. The
356  implementation is based on the [Fast Algorithms for Convolutional Neural
357  Networks by A. Lavin and S. Gray](https://arxiv.org/abs/1509.09308). The
358  Winograd algorithm often results in the best performance, but it is
359  applicable only to particular shapes. Moreover, Winograd only supports
360  int8 and f32 data types.
361
362- _Implicit GEMM_. The convolution operation is reinterpreted in terms of
363  matrix-matrix multiplication by rearranging the source data into a
364  [scratchpad memory](@ref dev_guide_attributes_scratchpad). This is a fallback
365  algorithm that is dispatched automatically when other implementations are
366  not available. GEMM convolution supports the int8, f32, and bf16 data types.
367
368### Direct Algorithm
369
370oneDNN supports the direct convolution algorithm on all supported
371platforms for the following conditions:
372
373- Data and weights memory formats are defined by the convolution primitive
374  (user passes `any`).
375
376- The number of channels per group is a multiple of SIMD width for grouped
377  convolutions.
378
379- For each spatial direction padding does not exceed one half of the
380  corresponding dimension of the weights tensor.
381
382- Weights tensor width does not exceed 14.
383
384In case any of these constraints are not met, the implementation will silently
385fall back to an explicit GEMM algorithm.
386
387### Winograd Convolution
388
389oneDNN supports the Winograd convolution algorithm on systems with
390Intel(R) Advanced Vector Extensions 512 (Intel(R) AVX-512) support and
391Intel Deep Learning Boost (Intel DL Boost)
392under the following conditions:
393
394- Data and weights memory formats are defined by the convolution primitive
395  (user passes `any` as the data format).
396
397- The spatial domain is two-dimensional.
398
399- The weights shape is 3x3, there are no groups, dilation or strides
400  (\f$KH = KW = 3\f$, \f$SH = SW = 1\f$, and \f$DH = DW = 0\f$).
401
402- The data type is either int8 or f32.
403
404In case any of these constraints is not met, the implementation will silently
405fall back to the direct algorithm.
406
407The Winograd convolution algorithm implementation additionally chooses tile
408size based on the problem shape and
409[propagation kind](@ref dnnl_prop_kind_t):
410
411- For `forward_inference` oneDNN supports
412  \f$F(2 \times 2, 3 \times 3)\f$ or
413  \f$F(4 \times 4, 3 \times 3)\f$
414
415- oneDNN supports only \f$F(4 \times 4, 3 \times 3)\f$ Winograd for all
416  the training propagation kinds.
417
418The following side effects should be weighed against the (potential)
419performance boost achieved from using the Winograd algorithm:
420
421- _Memory consumption_. Winograd implementation in oneDNN requires additional
422  scratchpad memory to store intermediate results. As more convolutions using
423  Winograd are added to the topology, the amount of memory required can grow
424  significantly. This growth can be controlled if the scratchpad memory can be
425  reused across multiple primitives. See @ref dev_guide_attributes_scratchpad
426  for more details.
427
428- _Accuracy_. In some cases Winograd convolution produce results that are
429  significantly less accurate than results from the direct convolution.
430
431Create a Winograd convolution by simply creating a convolution descriptor
432(step 6 in [simple network example](@ref cnn_inference_f32_cpp) specifying
433the Winograd algorithm. The rest of the steps are exactly the same.
434
435~~~cpp
436auto conv1_desc = convolution_forward::desc(
437    prop_kind::forward_inference, algorithm::convolution_winograd,
438    conv1_src_md, conv1_weights_md, conv1_bias_md, conv1_dst_md,
439    conv1_strides, conv1_padding_l, conv1_padding_r);
440~~~
441
442### Automatic Algorithm Selection
443
444oneDNN supports `dnnl::algorithm::convolution_auto` algorithm that
445instructs the library to automatically select the *best* algorithm based on
446the heuristics that take into account tensor shapes and the number of logical
447processors available.  (For automatic selection to work as intended, use the
448same thread affinity settings when creating the convolution as when executing
449the convolution.)
450
451@anchor dg_conv_impl_limits
452## Implementation Limitations
453
4541. Refer to @ref dev_guide_data_types for limitations related to data types
455   support.
456
4572. **CPU**
458   - Winograd are implemented only for processors with Intel AVX-512 and
459     Intel DL Boost instruction sets
460   - Run-time output scales are not supported
461   - Integer \dst is not supported for floating point \src and \weights
462   - f16 \dst is not supported for integer \src and \weights
463   - backward convolution with bias is not supported
464
465## Performance Tips
466
467- Use #dnnl::memory::format_tag::any for source, weights, and destinations
468  memory format tags when create a convolution primitive to allow the library
469  to choose the most appropriate memory format.
470
471## Example
472
473[Convolution Primitive Example](@ref convolution_example_cpp)
474
475@copydetails convolution_example_cpp_short
476