1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include <math.h>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 
25 #include "cpu/simple_q10n.hpp"
26 
27 #include "cpu/nchw_pooling.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 
33 using namespace nstl;
34 
35 template <data_type_t d_type>
nchw_pooling_fwd_t(const pd_t * apd)36 nchw_pooling_fwd_t<d_type>::nchw_pooling_fwd_t(const pd_t *apd)
37     : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
38 
39 template <data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const40 status_t nchw_pooling_fwd_t<d_type>::execute_forward(
41         const exec_ctx_t &ctx) const {
42     const auto alg = pd()->desc()->alg_kind;
43     const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
45     auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
46 
47     const memory_desc_wrapper ws_d(pd()->workspace_md());
48     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
49 
50     const dim_t MB = pd()->MB();
51     const dim_t C = pd()->OC();
52     const dim_t OD = pd()->OD();
53     const dim_t OH = pd()->OH();
54     const dim_t OW = pd()->OW();
55     const dim_t ID = pd()->ID();
56     const dim_t IH = pd()->IH();
57     const dim_t IW = pd()->IW();
58     const dim_t KD = pd()->KD();
59     const dim_t KH = pd()->KH();
60     const dim_t KW = pd()->KW();
61     const dim_t SD = pd()->KSD();
62     const dim_t SH = pd()->KSH();
63     const dim_t SW = pd()->KSW();
64     const dim_t padF = pd()->padFront();
65     const dim_t padT = pd()->padT();
66     const dim_t padL = pd()->padL();
67 
68     const auto apply_offset = [](int index, int offset) {
69         return (index > offset) ? index - offset : 0;
70     };
71 
72     const auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow,
73                                 dim_t value) {
74         if (ws) {
75             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
76             const size_t ws_offset = (size_t)OW * OH * OD * C * mb
77                     + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
78                     + (size_t)OW * oh + (size_t)ow;
79             if (ws_dt == data_type::u8) {
80                 assert(0 <= value
81                         && value <= numeric_limits<typename prec_traits<
82                                         data_type::u8>::type>::max());
83                 ws[ws_offset] = value;
84             } else
85                 reinterpret_cast<int *>(ws)[ws_offset] = value;
86         }
87     };
88 
89     const auto ker_max = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
90                                  dim_t ow) {
91         for_(dim_t kd = 0; kd < KD; ++kd)
92         for_(dim_t kh = 0; kh < KH; ++kh)
93         for (dim_t kw = 0; kw < KW; ++kw) {
94             const dim_t id = od * SD - padF + kd;
95             const dim_t ih = oh * SH - padT + kh;
96             const dim_t iw = ow * SW - padL + kw;
97 
98             if (id < 0 || id >= ID) continue;
99             if (ih < 0 || ih >= IH) continue;
100             if (iw < 0 || iw >= IW) continue;
101 
102             const auto src_offset = (size_t)IW * IH * ID * C * mb
103                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
104                     + (size_t)IW * ih + (size_t)iw;
105             const auto &s = src[src_offset];
106             if (s > d[0]) {
107                 d[0] = s;
108                 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
109             }
110         }
111     };
112 
113     const auto ker_avg = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
114                                  dim_t ow) {
115         const auto id_start = apply_offset(od * SD, padF);
116         const auto ih_start = apply_offset(oh * SH, padT);
117         const auto iw_start = apply_offset(ow * SW, padL);
118         const auto id_end = min(od * SD - padF + KD, ID);
119         const auto ih_end = min(oh * SH - padT + KH, IH);
120         const auto iw_end = min(ow * SW - padL + KW, IW);
121 
122         const auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
123                 ? KD * KW * KH
124                 : (id_end - id_start) * (ih_end - ih_start)
125                         * (iw_end - iw_start);
126 
127         float d_val = 0;
128         for_(dim_t id = id_start; id < id_end; ++id)
129         for_(dim_t ih = ih_start; ih < ih_end; ++ih)
130         for (dim_t iw = iw_start; iw < iw_end; ++iw) {
131             const auto src_offset = (size_t)IW * IH * ID * C * mb
132                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
133                     + (size_t)IW * ih + (size_t)iw;
134             d_val += src[src_offset];
135         }
136 
137         return d_val / num_summands;
138     };
139 
140     if (alg == alg_kind::pooling_max) {
141         parallel_nd(MB, C, OD, OH, OW,
142                 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
143                     const size_t dst_offset = (size_t)OW * OH * OD * C * mb
144                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
145                             + (size_t)OW * oh + (size_t)ow;
146                     data_t *d = &dst[dst_offset];
147                     d[0] = numeric_limits<data_t>::lowest();
148                     set_ws(mb, c, od, oh, ow, 0);
149                     ker_max(d, mb, c, od, oh, ow);
150 
151                     ref_post_ops_t::args_t args;
152                     args.ctx = &ctx;
153                     args.l_offset = dst_offset;
154                     args.dst_md = pd()->dst_md();
155                     ref_post_ops_.execute(dst[dst_offset], args);
156                     dst[dst_offset]
157                             = saturate_and_round<data_t>(dst[dst_offset]);
158                 });
159     } else {
160         parallel_nd(MB, C, OD, OH, OW,
161                 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
162                     const size_t dst_offset = (size_t)OW * OH * OD * C * mb
163                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
164                             + (size_t)OW * oh + (size_t)ow;
165                     data_t *d = &dst[dst_offset];
166                     d[0] = 0;
167                     auto res = ker_avg(d, mb, c, od, oh, ow);
168 
169                     ref_post_ops_t::args_t args;
170                     args.ctx = &ctx;
171                     args.l_offset = dst_offset;
172                     args.dst_md = pd()->dst_md();
173                     ref_post_ops_.execute(res, args);
174                     d[0] = saturate_and_round<data_t>(res);
175                 });
176     }
177 
178     return status::success;
179 }
180 
181 template <>
execute_forward(const exec_ctx_t & ctx) const182 status_t nchw_pooling_fwd_t<data_type::bf16>::execute_forward(
183         const exec_ctx_t &ctx) const {
184 
185     auto alg = pd()->desc()->alg_kind;
186 
187     auto src = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_SRC);
188     auto dst = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DST);
189     auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
190     memory_desc_wrapper dst_d(pd()->dst_md());
191 
192     auto scratchpad = ctx.get_scratchpad_grantor();
193     float *bf16cvt_wsp = scratchpad.template get<float>(
194             memory_tracking::names::key_pool_src_bf16cvt);
195 
196     const memory_desc_wrapper ws_d(pd()->workspace_md());
197     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
198 
199     const dim_t MB = pd()->MB();
200     const dim_t C = pd()->OC();
201     const dim_t OD = pd()->OD();
202     const dim_t OH = pd()->OH();
203     const dim_t OW = pd()->OW();
204     const dim_t ID = pd()->ID();
205     const dim_t IH = pd()->IH();
206     const dim_t IW = pd()->IW();
207     const dim_t KD = pd()->KD();
208     const dim_t KH = pd()->KH();
209     const dim_t KW = pd()->KW();
210     const dim_t SD = pd()->KSD();
211     const dim_t SH = pd()->KSH();
212     const dim_t SW = pd()->KSW();
213     const dim_t padF = pd()->padFront();
214     const dim_t padT = pd()->padT();
215     const dim_t padL = pd()->padL();
216 
217     const size_t simd_w = 16;
218     const size_t src_size = MB * C * ID * IH * IW;
219     const size_t blocked_size = src_size / simd_w;
220     const size_t tail_size = src_size % simd_w;
221 
222     auto apply_offset = [=](int index, int offset) {
223         return (index > offset) ? index - offset : 0;
224     };
225 
226     auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow,
227                           dim_t value) {
228         if (ws) {
229             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
230             size_t ws_offset = (size_t)OW * OH * OD * C * mb
231                     + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
232                     + (size_t)OW * oh + (size_t)ow;
233             if (ws_dt == data_type::u8) {
234                 assert(0 <= value
235                         && value <= numeric_limits<typename prec_traits<
236                                         data_type::u8>::type>::max());
237                 ws[ws_offset] = value;
238             } else
239                 reinterpret_cast<int *>(ws)[ws_offset] = value;
240         }
241     };
242 
243     auto ker_max = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
244                            dim_t ow) {
245         for_(dim_t kd = 0; kd < KD; ++kd)
246         for_(dim_t kh = 0; kh < KH; ++kh)
247         for (dim_t kw = 0; kw < KW; ++kw) {
248             const dim_t id = od * SD - padF + kd;
249             const dim_t ih = oh * SH - padT + kh;
250             const dim_t iw = ow * SW - padL + kw;
251 
252             if (id < 0 || id >= ID) continue;
253             if (ih < 0 || ih >= IH) continue;
254             if (iw < 0 || iw >= IW) continue;
255 
256             auto src_offset = (size_t)IW * IH * ID * C * mb
257                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
258                     + (size_t)IW * ih + (size_t)iw;
259             auto &s = bf16cvt_wsp[src_offset];
260 
261             if (s > d[0]) {
262                 d[0] = s;
263                 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
264             }
265         }
266     };
267 
268     auto ker_avg = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
269                            dim_t ow) {
270         auto id_start = apply_offset(od * SD, padF);
271         auto ih_start = apply_offset(oh * SH, padT);
272         auto iw_start = apply_offset(ow * SW, padL);
273         auto id_end = min(od * SD - padF + KD, ID);
274         auto ih_end = min(oh * SH - padT + KH, IH);
275         auto iw_end = min(ow * SW - padL + KW, IW);
276 
277         auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
278                 ? KD * KW * KH
279                 : (id_end - id_start) * (ih_end - ih_start)
280                         * (iw_end - iw_start);
281 
282         for_(dim_t id = id_start; id < id_end; ++id)
283         for_(dim_t ih = ih_start; ih < ih_end; ++ih)
284         for (dim_t iw = iw_start; iw < iw_end; ++iw) {
285             auto src_offset = (size_t)IW * IH * ID * C * mb
286                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
287                     + (size_t)IW * ih + (size_t)iw;
288             d[0] += bf16cvt_wsp[src_offset];
289         }
290 
291         d[0] = out_round<float>((float)d[0] / num_summands);
292     };
293     parallel_nd(blocked_size, [&](size_t i) {
294         cvt_bfloat16_to_float(
295                 &bf16cvt_wsp[i * simd_w], &src[i * simd_w], simd_w);
296     });
297     if (tail_size)
298         cvt_bfloat16_to_float(&bf16cvt_wsp[blocked_size * simd_w],
299                 &src[blocked_size * simd_w], tail_size);
300     if (alg == alg_kind::pooling_max) {
301         parallel_nd(MB, C, OD, OH, OW,
302                 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
303                     size_t dst_offset = (size_t)OW * OH * OD * C * mb
304                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
305                             + (size_t)OW * oh + (size_t)ow;
306                     float d_fp32 = numeric_limits<bfloat16_t>::lowest();
307 
308                     set_ws(mb, c, od, oh, ow, 0);
309 
310                     ker_max(&d_fp32, mb, c, od, oh, ow);
311 
312                     ref_post_ops_t::args_t args;
313                     args.ctx = &ctx;
314                     args.l_offset = dst_offset;
315                     args.dst_md = pd()->dst_md();
316                     ref_post_ops_.execute(d_fp32, args);
317 
318                     dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
319                 });
320     } else {
321         parallel_nd(MB, C, OD, OH, OW,
322                 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
323                     size_t dst_offset = (size_t)OW * OH * OD * C * mb
324                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
325                             + (size_t)OW * oh + (size_t)ow;
326                     float d_fp32 = 0.0f;
327                     ker_avg(&d_fp32, mb, c, od, oh, ow);
328                     ref_post_ops_t::args_t args;
329                     args.ctx = &ctx;
330                     args.l_offset = dst_offset;
331                     args.dst_md = pd()->dst_md();
332                     ref_post_ops_.execute(d_fp32, args);
333                     dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
334                 });
335     }
336 
337     return status::success;
338 }
339 
340 template <data_type_t d_type>
execute_backward(const exec_ctx_t & ctx) const341 status_t nchw_pooling_bwd_t<d_type>::execute_backward(
342         const exec_ctx_t &ctx) const {
343     auto alg = pd()->desc()->alg_kind;
344     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
345     const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
346 
347     auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
348     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
349     auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
350 
351     const memory_desc_wrapper ws_d(pd()->workspace_md());
352 
353     const dim_t MB = pd()->MB();
354     const dim_t C = pd()->OC();
355     const dim_t OD = pd()->OD();
356     const dim_t OH = pd()->OH();
357     const dim_t OW = pd()->OW();
358     const dim_t ID = pd()->ID();
359     const dim_t IH = pd()->IH();
360     const dim_t IW = pd()->IW();
361     const dim_t KD = pd()->KD();
362     const dim_t KH = pd()->KH();
363     const dim_t KW = pd()->KW();
364     const dim_t SD = pd()->KSD();
365     const dim_t SH = pd()->KSH();
366     const dim_t SW = pd()->KSW();
367     const dim_t padF = pd()->padFront();
368     const dim_t padT = pd()->padT();
369     const dim_t padL = pd()->padL();
370 
371     auto apply_offset = [=](int index, int offset) {
372         return (index > offset) ? index - offset : 0;
373     };
374 
375     auto ker_zero = [=](dim_t mb, dim_t c) {
376         size_t diff_src_offset
377                 = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW;
378         for_(dim_t id = 0; id < ID; ++id)
379         for_(dim_t ih = 0; ih < IH; ++ih)
380         for (dim_t iw = 0; iw < IW; ++iw) {
381             diff_src[diff_src_offset++] = 0;
382         }
383     };
384 
385     auto ker_max = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
386                            dim_t ow) {
387         auto b_c = ws_d.blocking_desc().inner_nblks == 0
388                 ? 1
389                 : ws_d.blocking_desc().inner_blks[0];
390         auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
391                                 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
392                                         : ws_d.blk_off(mb, c / b_c, ow))
393                 + c % b_c;
394 
395         const int index = ws_d.data_type() == data_type::u8
396                 ? (int)ws[ws_offset]
397                 : ((const int *)ws)[ws_offset];
398         const dim_t kw = index % KW;
399         const dim_t kh = (index / KW) % KH;
400         const dim_t kd = (index / KW) / KH;
401 
402         const dim_t id = od * SD - padF + kd;
403         const dim_t ih = oh * SH - padT + kh;
404         const dim_t iw = ow * SW - padL + kw;
405 
406         // If padding area could fit the kernel,
407         // then input displacement would be out of bounds.
408         // No need to back propagate there as padding is
409         // virtual in pooling_max case.
410         if (id < 0 || id >= ID) return;
411         if (ih < 0 || ih >= IH) return;
412         if (iw < 0 || iw >= IW) return;
413 
414         size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
415                 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
416                 + (size_t)ih * IW + (size_t)iw;
417         diff_src[diff_src_offset] += d[0];
418     };
419 
420     auto ker_avg = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
421                            dim_t ow) {
422         dim_t id_start = apply_offset(od * SD, padF);
423         dim_t ih_start = apply_offset(oh * SH, padT);
424         dim_t iw_start = apply_offset(ow * SW, padL);
425         dim_t id_end = min(od * SD - padF + KD, ID);
426         dim_t ih_end = min(oh * SH - padT + KH, IH);
427         dim_t iw_end = min(ow * SW - padL + KW, IW);
428 
429         size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
430                 ? (size_t)KW * KH * KD
431                 : (size_t)(id_end - id_start) * (ih_end - ih_start)
432                         * (iw_end - iw_start);
433 
434         for_(dim_t id = id_start; id < id_end; ++id)
435         for_(dim_t ih = ih_start; ih < ih_end; ++ih)
436         for (dim_t iw = iw_start; iw < iw_end; ++iw) {
437             size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
438                     + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
439                     + (size_t)ih * IW + (size_t)iw;
440             diff_src[diff_src_offset] += d[0] / num_summands;
441         }
442     };
443 
444     dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW));
445     dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
446 
447     dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH));
448     dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
449 
450     dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD));
451     dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
452 
453     if (alg == alg_kind::pooling_max) {
454         parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
455             size_t diff_dst_offset_b
456                     = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
457             ker_zero(mb, c);
458             for_(dim_t od = od_start; od < od_end; ++od)
459             for (dim_t oh = oh_start; oh < oh_end; ++oh) {
460                 size_t diff_dst_offset = diff_dst_offset_b
461                         + (size_t)od * OH * OW + (size_t)oh * OW;
462                 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
463                     const data_t *d = &diff_dst[diff_dst_offset + ow];
464                     ker_max(d, mb, c, od, oh, ow);
465                 }
466             }
467         });
468     } else {
469         parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
470             size_t diff_dst_offset_b
471                     = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
472             ker_zero(mb, c);
473             for_(dim_t od = od_start; od < od_end; ++od)
474             for (dim_t oh = oh_start; oh < oh_end; ++oh) {
475                 size_t diff_dst_offset = diff_dst_offset_b
476                         + (size_t)od * OH * OW + (size_t)oh * OW;
477                 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
478                     const data_t *d = &diff_dst[diff_dst_offset + ow];
479                     ker_avg(d, mb, c, od, oh, ow);
480                 }
481             }
482         });
483     }
484 
485     return status::success;
486 }
487 
488 template <>
execute_backward(const exec_ctx_t & ctx) const489 status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
490         const exec_ctx_t &ctx) const {
491 
492     auto alg = pd()->desc()->alg_kind;
493     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
494     const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
495 
496     auto diff_src = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_SRC);
497     auto diff_dst = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_DIFF_DST);
498     auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
499 
500     auto scratchpad = ctx.get_scratchpad_grantor();
501     float *bf16cvt_src = scratchpad.template get<float>(
502             memory_tracking::names::key_pool_src_bf16cvt);
503     float *bf16cvt_dst = scratchpad.template get<float>(
504             memory_tracking::names::key_pool_dst_bf16cvt);
505 
506     const memory_desc_wrapper ws_d(pd()->workspace_md());
507 
508     const dim_t MB = pd()->MB();
509     const dim_t C = pd()->OC();
510     const dim_t OD = pd()->OD();
511     const dim_t OH = pd()->OH();
512     const dim_t OW = pd()->OW();
513     const dim_t ID = pd()->ID();
514     const dim_t IH = pd()->IH();
515     const dim_t IW = pd()->IW();
516     const dim_t KD = pd()->KD();
517     const dim_t KH = pd()->KH();
518     const dim_t KW = pd()->KW();
519     const dim_t SD = pd()->KSD();
520     const dim_t SH = pd()->KSH();
521     const dim_t SW = pd()->KSW();
522     const dim_t padF = pd()->padFront();
523     const dim_t padT = pd()->padT();
524     const dim_t padL = pd()->padL();
525 
526     const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW();
527     const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW();
528 
529     auto apply_offset = [=](int index, int offset) {
530         return (index > offset) ? index - offset : 0;
531     };
532 
533     auto ker_zero = [=](float *diff_src, dim_t c_block_size) {
534         size_t diff_src_offset = 0;
535         for_(dim_t c = 0; c < c_block_size; ++c)
536         for_(dim_t id = 0; id < ID; ++id)
537         for_(dim_t ih = 0; ih < IH; ++ih)
538         for (dim_t iw = 0; iw < IW; ++iw) {
539             diff_src[diff_src_offset++] = 0.0f;
540         }
541     };
542 
543     auto ker_max = [=](const float *d, float *diff_src, dim_t mb, dim_t c,
544                            dim_t od, dim_t oh, dim_t ow) {
545         auto b_c = ws_d.blocking_desc().inner_nblks == 0
546                 ? 1
547                 : ws_d.blocking_desc().inner_blks[0];
548         auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
549                                 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
550                                         : ws_d.blk_off(mb, c / b_c, ow))
551                 + c % b_c;
552 
553         const int index = ws_d.data_type() == data_type::u8
554                 ? (int)ws[ws_offset]
555                 : ((const int *)ws)[ws_offset];
556         const dim_t kw = index % KW;
557         const dim_t kh = (index / KW) % KH;
558         const dim_t kd = (index / KW) / KH;
559 
560         const dim_t id = od * SD - padF + kd;
561         const dim_t ih = oh * SH - padT + kh;
562         const dim_t iw = ow * SW - padL + kw;
563 
564         // If padding area could fit the kernel,
565         // then input displacement would be out of bounds.
566         // No need to back propagate there as padding is
567         // virtual in pooling_max case.
568         if (id < 0 || id >= ID) return;
569         if (ih < 0 || ih >= IH) return;
570         if (iw < 0 || iw >= IW) return;
571 
572         size_t diff_src_offset
573                 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
574         diff_src[diff_src_offset] += d[0];
575     };
576 
577     auto ker_avg = [=](const float *d, float *diff_src, dim_t mb, dim_t c,
578                            dim_t od, dim_t oh, dim_t ow) {
579         auto id_start = apply_offset(od * SD, padF);
580         auto ih_start = apply_offset(oh * SH, padT);
581         auto iw_start = apply_offset(ow * SW, padL);
582         auto id_end = min(od * SD - padF + KD, ID);
583         auto ih_end = min(oh * SH - padT + KH, IH);
584         auto iw_end = min(ow * SW - padL + KW, IW);
585 
586         size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
587                 ? (size_t)KW * KH * KD
588                 : (size_t)(id_end - id_start) * (ih_end - ih_start)
589                         * (iw_end - iw_start);
590 
591         for_(dim_t id = id_start; id < id_end; ++id)
592         for_(dim_t ih = ih_start; ih < ih_end; ++ih)
593         for (dim_t iw = iw_start; iw < iw_end; ++iw) {
594             size_t diff_src_offset
595                     = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
596             diff_src[diff_src_offset] += d[0] / num_summands;
597         }
598     };
599 
600     dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW));
601     dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
602 
603     dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH));
604     dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
605 
606     dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD));
607     dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
608 
609     dim_t c_blk = pd()->channel_block_size_;
610     dim_t c_blk_tail = C % c_blk;
611     const int nthr = pd()->nthr_;
612 
613     if (alg == alg_kind::pooling_max) {
614         parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
615                 [&](int ithr, int, dim_t mb, dim_t cb) {
616                     bool is_last_c_block
617                             = c_blk_tail > 0 && (cb + 1) * c_blk > C;
618                     dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
619                     size_t diff_dst_offset_b
620                             = ((size_t)mb * C + (size_t)cb * c_blk) * OD * OH
621                             * OW;
622                     size_t diff_src_offset
623                             = ((size_t)mb * C + (size_t)cb * c_blk) * ID * IH
624                             * IW;
625                     float *diff_dst_fp32
626                             = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
627                     float *diff_src_fp32
628                             = &bf16cvt_src[ithr * src_sp_size * c_blk];
629 
630                     ker_zero(diff_src_fp32, curr_c_block);
631 
632                     cvt_bfloat16_to_float(diff_dst_fp32,
633                             &diff_dst[diff_dst_offset_b],
634                             dst_sp_size * curr_c_block);
635 
636                     for_(dim_t c = 0; c < curr_c_block; ++c)
637                     for_(dim_t od = od_start; od < od_end; ++od)
638                     for (dim_t oh = oh_start; oh < oh_end; ++oh) {
639                         size_t diff_dst_offset = (size_t)c * OD * OH * OW
640                                 + (size_t)od * OH * OW + (size_t)oh * OW;
641                         for (dim_t ow = ow_start; ow < ow_end; ++ow) {
642                             const float *d
643                                     = &diff_dst_fp32[diff_dst_offset + ow];
644                             ker_max(d, &diff_src_fp32[c * ID * IH * IW], mb,
645                                     cb * c_blk + c, od, oh, ow);
646                         }
647                     }
648                     cvt_float_to_bfloat16(&diff_src[diff_src_offset],
649                             diff_src_fp32, src_sp_size * curr_c_block);
650                 });
651     } else {
652         parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
653                 [&](int ithr, int, dim_t mb, dim_t cb) {
654                     bool is_last_c_block
655                             = c_blk_tail > 0 && (cb + 1) * c_blk > C;
656                     dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
657                     size_t diff_dst_offset_b = (size_t)mb * C * OD * OH * OW
658                             + (size_t)cb * c_blk * OD * OH * OW;
659                     float *diff_dst_fp32
660                             = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
661                     size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
662                             + (size_t)cb * c_blk * ID * IH * IW;
663                     float *diff_src_fp32
664                             = &bf16cvt_src[ithr * src_sp_size * c_blk];
665 
666                     ker_zero(diff_src_fp32, curr_c_block);
667 
668                     cvt_bfloat16_to_float(diff_dst_fp32,
669                             &diff_dst[diff_dst_offset_b],
670                             dst_sp_size * curr_c_block);
671                     for_(dim_t c = 0; c < curr_c_block; ++c)
672                     for_(dim_t od = od_start; od < od_end; ++od)
673                     for (dim_t oh = oh_start; oh < oh_end; ++oh) {
674                         size_t diff_dst_offset = (size_t)c * OD * OH * OW
675                                 + (size_t)od * OH * OW + (size_t)oh * OW;
676                         for (dim_t ow = ow_start; ow < ow_end; ++ow) {
677                             const float *d
678                                     = &diff_dst_fp32[diff_dst_offset + ow];
679                             ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb,
680                                     cb * c_blk + c, od, oh, ow);
681                         }
682                     }
683                     cvt_float_to_bfloat16(&diff_src[diff_src_offset],
684                             diff_src_fp32, src_sp_size * curr_c_block);
685                 });
686     }
687 
688     return status::success;
689 }
690 template struct nchw_pooling_fwd_t<data_type::f32>;
691 template struct nchw_pooling_bwd_t<data_type::f32>;
692 template struct nchw_pooling_fwd_t<data_type::bf16>;
693 template struct nchw_pooling_bwd_t<data_type::bf16>;
694 } // namespace cpu
695 } // namespace impl
696 } // namespace dnnl
697 
698 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
699