1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include <math.h>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 
25 #include "cpu/simple_q10n.hpp"
26 
27 #include "cpu/nchw_pooling.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 
33 using namespace nstl;
34 
35 template <data_type_t d_type>
nchw_pooling_fwd_t(const pd_t * apd)36 nchw_pooling_fwd_t<d_type>::nchw_pooling_fwd_t(const pd_t *apd)
37     : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
38 
39 template <data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const40 status_t nchw_pooling_fwd_t<d_type>::execute_forward(
41         const exec_ctx_t &ctx) const {
42     const auto alg = pd()->desc()->alg_kind;
43     const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
45     auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
46 
47     const memory_desc_wrapper ws_d(pd()->workspace_md());
48     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
49 
50     const int MB = pd()->MB();
51     const int C = pd()->IC();
52     const int OD = pd()->OD();
53     const int OH = pd()->OH();
54     const int OW = pd()->OW();
55     const int ID = pd()->ID();
56     const int IH = pd()->IH();
57     const int IW = pd()->IW();
58     const int KD = pd()->KD();
59     const int KH = pd()->KH();
60     const int KW = pd()->KW();
61     const int SD = pd()->KSD();
62     const int SH = pd()->KSH();
63     const int SW = pd()->KSW();
64     const int padF = pd()->padFront();
65     const int padT = pd()->padT();
66     const int padL = pd()->padL();
67 
68     const auto apply_offset = [](int index, int offset) {
69         return (index > offset) ? index - offset : 0;
70     };
71 
72     const auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
73         if (ws) {
74             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
75             const size_t ws_offset = (size_t)OW * OH * OD * C * mb
76                     + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
77                     + (size_t)OW * oh + (size_t)ow;
78             if (ws_dt == data_type::u8) {
79                 assert(0 <= value
80                         && value <= numeric_limits<typename prec_traits<
81                                         data_type::u8>::type>::max());
82                 ws[ws_offset] = value;
83             } else
84                 reinterpret_cast<int *>(ws)[ws_offset] = value;
85         }
86     };
87 
88     const auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
89         for_(int kd = 0; kd < KD; ++kd)
90         for_(int kh = 0; kh < KH; ++kh)
91         for (int kw = 0; kw < KW; ++kw) {
92             const int id = od * SD - padF + kd;
93             const int ih = oh * SH - padT + kh;
94             const int iw = ow * SW - padL + kw;
95 
96             if (id < 0 || id >= ID) continue;
97             if (ih < 0 || ih >= IH) continue;
98             if (iw < 0 || iw >= IW) continue;
99 
100             const auto src_offset = (size_t)IW * IH * ID * C * mb
101                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
102                     + (size_t)IW * ih + (size_t)iw;
103             const auto &s = src[src_offset];
104             if (s > d[0]) {
105                 d[0] = s;
106                 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
107             }
108         }
109     };
110 
111     const auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
112         const auto id_start = apply_offset(od * SD, padF);
113         const auto ih_start = apply_offset(oh * SH, padT);
114         const auto iw_start = apply_offset(ow * SW, padL);
115         const auto id_end = min(od * SD - padF + KD, ID);
116         const auto ih_end = min(oh * SH - padT + KH, IH);
117         const auto iw_end = min(ow * SW - padL + KW, IW);
118 
119         const auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
120                 ? KD * KW * KH
121                 : (id_end - id_start) * (ih_end - ih_start)
122                         * (iw_end - iw_start);
123 
124         float d_val = 0;
125         for_(int id = id_start; id < id_end; ++id)
126         for_(int ih = ih_start; ih < ih_end; ++ih)
127         for (int iw = iw_start; iw < iw_end; ++iw) {
128             const auto src_offset = (size_t)IW * IH * ID * C * mb
129                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
130                     + (size_t)IW * ih + (size_t)iw;
131             d_val += src[src_offset];
132         }
133 
134         return d_val / num_summands;
135     };
136 
137     if (alg == alg_kind::pooling_max) {
138         parallel_nd(
139                 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
140                     const size_t dst_offset = (size_t)OW * OH * OD * C * mb
141                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
142                             + (size_t)OW * oh + (size_t)ow;
143                     data_t *d = &dst[dst_offset];
144                     d[0] = numeric_limits<data_t>::lowest();
145                     set_ws(mb, c, od, oh, ow, 0);
146                     ker_max(d, mb, c, od, oh, ow);
147 
148                     ref_post_ops_t::args_t args;
149                     args.ctx = &ctx;
150                     args.l_offset = dst_offset;
151                     args.dst_md = pd()->dst_md();
152                     ref_post_ops_.execute(dst[dst_offset], args);
153                     dst[dst_offset]
154                             = saturate_and_round<data_t>(dst[dst_offset]);
155                 });
156     } else {
157         parallel_nd(
158                 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
159                     const size_t dst_offset = (size_t)OW * OH * OD * C * mb
160                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
161                             + (size_t)OW * oh + (size_t)ow;
162                     data_t *d = &dst[dst_offset];
163                     d[0] = 0;
164                     auto res = ker_avg(d, mb, c, od, oh, ow);
165 
166                     ref_post_ops_t::args_t args;
167                     args.ctx = &ctx;
168                     args.l_offset = dst_offset;
169                     args.dst_md = pd()->dst_md();
170                     ref_post_ops_.execute(res, args);
171                     d[0] = saturate_and_round<data_t>(res);
172                 });
173     }
174 
175     return status::success;
176 }
177 
178 template <>
execute_forward(const exec_ctx_t & ctx) const179 status_t nchw_pooling_fwd_t<data_type::bf16>::execute_forward(
180         const exec_ctx_t &ctx) const {
181 
182     auto alg = pd()->desc()->alg_kind;
183 
184     auto src = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_SRC);
185     auto dst = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DST);
186     auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
187     memory_desc_wrapper dst_d(pd()->dst_md());
188 
189     auto scratchpad = ctx.get_scratchpad_grantor();
190     float *bf16cvt_wsp = scratchpad.template get<float>(
191             memory_tracking::names::key_pool_src_bf16cvt);
192 
193     const memory_desc_wrapper ws_d(pd()->workspace_md());
194     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
195 
196     const int MB = pd()->MB();
197     const int C = pd()->IC();
198     const int OD = pd()->OD();
199     const int OH = pd()->OH();
200     const int OW = pd()->OW();
201     const int ID = pd()->ID();
202     const int IH = pd()->IH();
203     const int IW = pd()->IW();
204     const int KD = pd()->KD();
205     const int KH = pd()->KH();
206     const int KW = pd()->KW();
207     const int SD = pd()->KSD();
208     const int SH = pd()->KSH();
209     const int SW = pd()->KSW();
210     const int padF = pd()->padFront();
211     const int padT = pd()->padT();
212     const int padL = pd()->padL();
213 
214     const size_t simd_w = 16;
215     const size_t src_size = MB * C * ID * IH * IW;
216     const size_t blocked_size = src_size / simd_w;
217     const size_t tail_size = src_size % simd_w;
218 
219     auto apply_offset = [=](int index, int offset) {
220         return (index > offset) ? index - offset : 0;
221     };
222 
223     auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
224         if (ws) {
225             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
226             size_t ws_offset = (size_t)OW * OH * OD * C * mb
227                     + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
228                     + (size_t)OW * oh + (size_t)ow;
229             if (ws_dt == data_type::u8) {
230                 assert(0 <= value
231                         && value <= numeric_limits<typename prec_traits<
232                                         data_type::u8>::type>::max());
233                 ws[ws_offset] = value;
234             } else
235                 reinterpret_cast<int *>(ws)[ws_offset] = value;
236         }
237     };
238 
239     auto ker_max = [=](float *d, int mb, int c, int od, int oh, int ow) {
240         for_(int kd = 0; kd < KD; ++kd)
241         for_(int kh = 0; kh < KH; ++kh)
242         for (int kw = 0; kw < KW; ++kw) {
243             const int id = od * SD - padF + kd;
244             const int ih = oh * SH - padT + kh;
245             const int iw = ow * SW - padL + kw;
246 
247             if (id < 0 || id >= ID) continue;
248             if (ih < 0 || ih >= IH) continue;
249             if (iw < 0 || iw >= IW) continue;
250 
251             auto src_offset = (size_t)IW * IH * ID * C * mb
252                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
253                     + (size_t)IW * ih + (size_t)iw;
254             auto &s = bf16cvt_wsp[src_offset];
255 
256             if (s > d[0]) {
257                 d[0] = s;
258                 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
259             }
260         }
261     };
262 
263     auto ker_avg = [=](float *d, int mb, int c, int od, int oh, int ow) {
264         auto id_start = apply_offset(od * SD, padF);
265         auto ih_start = apply_offset(oh * SH, padT);
266         auto iw_start = apply_offset(ow * SW, padL);
267         auto id_end = min(od * SD - padF + KD, ID);
268         auto ih_end = min(oh * SH - padT + KH, IH);
269         auto iw_end = min(ow * SW - padL + KW, IW);
270 
271         auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
272                 ? KD * KW * KH
273                 : (id_end - id_start) * (ih_end - ih_start)
274                         * (iw_end - iw_start);
275 
276         for_(int id = id_start; id < id_end; ++id)
277         for_(int ih = ih_start; ih < ih_end; ++ih)
278         for (int iw = iw_start; iw < iw_end; ++iw) {
279             auto src_offset = (size_t)IW * IH * ID * C * mb
280                     + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
281                     + (size_t)IW * ih + (size_t)iw;
282             d[0] += bf16cvt_wsp[src_offset];
283         }
284 
285         d[0] = out_round<float>((float)d[0] / num_summands);
286     };
287     parallel_nd(blocked_size, [&](size_t i) {
288         cvt_bfloat16_to_float(
289                 &bf16cvt_wsp[i * simd_w], &src[i * simd_w], simd_w);
290     });
291     if (tail_size)
292         cvt_bfloat16_to_float(&bf16cvt_wsp[blocked_size * simd_w],
293                 &src[blocked_size * simd_w], tail_size);
294     if (alg == alg_kind::pooling_max) {
295         parallel_nd(
296                 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
297                     size_t dst_offset = (size_t)OW * OH * OD * C * mb
298                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
299                             + (size_t)OW * oh + (size_t)ow;
300                     float d_fp32 = numeric_limits<bfloat16_t>::lowest();
301 
302                     set_ws(mb, c, od, oh, ow, 0);
303 
304                     ker_max(&d_fp32, mb, c, od, oh, ow);
305 
306                     ref_post_ops_t::args_t args;
307                     args.ctx = &ctx;
308                     args.l_offset = dst_offset;
309                     args.dst_md = pd()->dst_md();
310                     ref_post_ops_.execute(d_fp32, args);
311 
312                     dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
313                 });
314     } else {
315         parallel_nd(
316                 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
317                     size_t dst_offset = (size_t)OW * OH * OD * C * mb
318                             + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
319                             + (size_t)OW * oh + (size_t)ow;
320                     float d_fp32 = 0.0f;
321                     ker_avg(&d_fp32, mb, c, od, oh, ow);
322                     ref_post_ops_t::args_t args;
323                     args.ctx = &ctx;
324                     args.l_offset = dst_offset;
325                     args.dst_md = pd()->dst_md();
326                     ref_post_ops_.execute(d_fp32, args);
327                     dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
328                 });
329     }
330 
331     return status::success;
332 }
333 
334 template <data_type_t d_type>
execute_backward(const exec_ctx_t & ctx) const335 status_t nchw_pooling_bwd_t<d_type>::execute_backward(
336         const exec_ctx_t &ctx) const {
337     auto alg = pd()->desc()->alg_kind;
338     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
339     const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
340 
341     auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
342     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
343     auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
344 
345     const memory_desc_wrapper ws_d(pd()->workspace_md());
346 
347     const int MB = pd()->MB();
348     const int C = pd()->IC();
349     const int OD = pd()->OD();
350     const int OH = pd()->OH();
351     const int OW = pd()->OW();
352     const int ID = pd()->ID();
353     const int IH = pd()->IH();
354     const int IW = pd()->IW();
355     const int KD = pd()->KD();
356     const int KH = pd()->KH();
357     const int KW = pd()->KW();
358     const int SD = pd()->KSD();
359     const int SH = pd()->KSH();
360     const int SW = pd()->KSW();
361     const int padF = pd()->padFront();
362     const int padT = pd()->padT();
363     const int padL = pd()->padL();
364 
365     auto apply_offset = [=](int index, int offset) {
366         return (index > offset) ? index - offset : 0;
367     };
368 
369     auto ker_zero = [=](int mb, int c) {
370         size_t diff_src_offset
371                 = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW;
372         for_(int id = 0; id < ID; ++id)
373         for_(int ih = 0; ih < IH; ++ih)
374         for (int iw = 0; iw < IW; ++iw) {
375             diff_src[diff_src_offset++] = 0;
376         }
377     };
378 
379     auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
380         auto b_c = ws_d.blocking_desc().inner_nblks == 0
381                 ? 1
382                 : ws_d.blocking_desc().inner_blks[0];
383         auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
384                                 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
385                                         : ws_d.blk_off(mb, c / b_c, ow))
386                 + c % b_c;
387 
388         const int index = ws_d.data_type() == data_type::u8
389                 ? (int)ws[ws_offset]
390                 : ((const int *)ws)[ws_offset];
391         const int kw = index % KW;
392         const int kh = (index / KW) % KH;
393         const int kd = (index / KW) / KH;
394 
395         const int id = od * SD - padF + kd;
396         const int ih = oh * SH - padT + kh;
397         const int iw = ow * SW - padL + kw;
398 
399         // If padding area could fit the kernel,
400         // then input displacement would be out of bounds.
401         // No need to back propagate there as padding is
402         // virtual in pooling_max case.
403         if (id < 0 || id >= ID) return;
404         if (ih < 0 || ih >= IH) return;
405         if (iw < 0 || iw >= IW) return;
406 
407         size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
408                 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
409                 + (size_t)ih * IW + (size_t)iw;
410         diff_src[diff_src_offset] += d[0];
411     };
412 
413     auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
414         auto id_start = apply_offset(od * SD, padF);
415         auto ih_start = apply_offset(oh * SH, padT);
416         auto iw_start = apply_offset(ow * SW, padL);
417         auto id_end = min(od * SD - padF + KD, ID);
418         auto ih_end = min(oh * SH - padT + KH, IH);
419         auto iw_end = min(ow * SW - padL + KW, IW);
420 
421         size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
422                 ? (size_t)KW * KH * KD
423                 : (size_t)(id_end - id_start) * (ih_end - ih_start)
424                         * (iw_end - iw_start);
425 
426         for_(int id = id_start; id < id_end; ++id)
427         for_(int ih = ih_start; ih < ih_end; ++ih)
428         for (int iw = iw_start; iw < iw_end; ++iw) {
429             size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
430                     + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
431                     + (size_t)ih * IW + (size_t)iw;
432             diff_src[diff_src_offset] += d[0] / num_summands;
433         }
434     };
435 
436     int ow_start = max(0, utils::div_up(padL - KW + 1, SW));
437     int ow_end = min(OW, 1 + (padL + IW - 1) / SW);
438 
439     int oh_start = max(0, utils::div_up(padT - KH + 1, SH));
440     int oh_end = min(OH, 1 + (padT + IH - 1) / SH);
441 
442     int od_start = max(0, utils::div_up(padF - KD + 1, SD));
443     int od_end = min(OD, 1 + (padF + ID - 1) / SD);
444 
445     if (alg == alg_kind::pooling_max) {
446         parallel_nd(MB, C, [&](int mb, int c) {
447             size_t diff_dst_offset_b
448                     = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
449             ker_zero(mb, c);
450             for_(int od = od_start; od < od_end; ++od)
451             for (int oh = oh_start; oh < oh_end; ++oh) {
452                 size_t diff_dst_offset = diff_dst_offset_b
453                         + (size_t)od * OH * OW + (size_t)oh * OW;
454                 for (int ow = ow_start; ow < ow_end; ++ow) {
455                     const data_t *d = &diff_dst[diff_dst_offset + ow];
456                     ker_max(d, mb, c, od, oh, ow);
457                 }
458             }
459         });
460     } else {
461         parallel_nd(MB, C, [&](int mb, int c) {
462             size_t diff_dst_offset_b
463                     = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
464             ker_zero(mb, c);
465             for_(int od = od_start; od < od_end; ++od)
466             for (int oh = oh_start; oh < oh_end; ++oh) {
467                 size_t diff_dst_offset = diff_dst_offset_b
468                         + (size_t)od * OH * OW + (size_t)oh * OW;
469                 for (int ow = ow_start; ow < ow_end; ++ow) {
470                     const data_t *d = &diff_dst[diff_dst_offset + ow];
471                     ker_avg(d, mb, c, od, oh, ow);
472                 }
473             }
474         });
475     }
476 
477     return status::success;
478 }
479 
480 template <>
execute_backward(const exec_ctx_t & ctx) const481 status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
482         const exec_ctx_t &ctx) const {
483 
484     auto alg = pd()->desc()->alg_kind;
485     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
486     const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
487 
488     auto diff_src = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_SRC);
489     auto diff_dst = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_DIFF_DST);
490     auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
491 
492     auto scratchpad = ctx.get_scratchpad_grantor();
493     float *bf16cvt_src = scratchpad.template get<float>(
494             memory_tracking::names::key_pool_src_bf16cvt);
495     float *bf16cvt_dst = scratchpad.template get<float>(
496             memory_tracking::names::key_pool_dst_bf16cvt);
497 
498     const memory_desc_wrapper ws_d(pd()->workspace_md());
499 
500     const int MB = pd()->MB();
501     const int C = pd()->IC();
502     const int OD = pd()->OD();
503     const int OH = pd()->OH();
504     const int OW = pd()->OW();
505     const int ID = pd()->ID();
506     const int IH = pd()->IH();
507     const int IW = pd()->IW();
508     const int KD = pd()->KD();
509     const int KH = pd()->KH();
510     const int KW = pd()->KW();
511     const int SD = pd()->KSD();
512     const int SH = pd()->KSH();
513     const int SW = pd()->KSW();
514     const int padF = pd()->padFront();
515     const int padT = pd()->padT();
516     const int padL = pd()->padL();
517 
518     const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW();
519     const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW();
520 
521     auto apply_offset = [=](int index, int offset) {
522         return (index > offset) ? index - offset : 0;
523     };
524 
525     auto ker_zero = [=](float *diff_src, int c_block_size) {
526         size_t diff_src_offset = 0;
527         for_(int c = 0; c < c_block_size; ++c)
528         for_(int id = 0; id < ID; ++id)
529         for_(int ih = 0; ih < IH; ++ih)
530         for (int iw = 0; iw < IW; ++iw) {
531             diff_src[diff_src_offset++] = 0.0f;
532         }
533     };
534 
535     auto ker_max = [=](const float *d, float *diff_src, int mb, int c, int od,
536                            int oh, int ow) {
537         auto b_c = ws_d.blocking_desc().inner_nblks == 0
538                 ? 1
539                 : ws_d.blocking_desc().inner_blks[0];
540         auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
541                                 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
542                                         : ws_d.blk_off(mb, c / b_c, ow))
543                 + c % b_c;
544 
545         const int index = ws_d.data_type() == data_type::u8
546                 ? (int)ws[ws_offset]
547                 : ((const int *)ws)[ws_offset];
548         const int kw = index % KW;
549         const int kh = (index / KW) % KH;
550         const int kd = (index / KW) / KH;
551 
552         const int id = od * SD - padF + kd;
553         const int ih = oh * SH - padT + kh;
554         const int iw = ow * SW - padL + kw;
555 
556         // If padding area could fit the kernel,
557         // then input displacement would be out of bounds.
558         // No need to back propagate there as padding is
559         // virtual in pooling_max case.
560         if (id < 0 || id >= ID) return;
561         if (ih < 0 || ih >= IH) return;
562         if (iw < 0 || iw >= IW) return;
563 
564         size_t diff_src_offset
565                 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
566         diff_src[diff_src_offset] += d[0];
567     };
568 
569     auto ker_avg = [=](const float *d, float *diff_src, int mb, int c, int od,
570                            int oh, int ow) {
571         auto id_start = apply_offset(od * SD, padF);
572         auto ih_start = apply_offset(oh * SH, padT);
573         auto iw_start = apply_offset(ow * SW, padL);
574         auto id_end = min(od * SD - padF + KD, ID);
575         auto ih_end = min(oh * SH - padT + KH, IH);
576         auto iw_end = min(ow * SW - padL + KW, IW);
577 
578         size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
579                 ? (size_t)KW * KH * KD
580                 : (size_t)(id_end - id_start) * (ih_end - ih_start)
581                         * (iw_end - iw_start);
582 
583         for_(int id = id_start; id < id_end; ++id)
584         for_(int ih = ih_start; ih < ih_end; ++ih)
585         for (int iw = iw_start; iw < iw_end; ++iw) {
586             size_t diff_src_offset
587                     = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
588             diff_src[diff_src_offset] += d[0] / num_summands;
589         }
590     };
591 
592     int ow_start = max(0, utils::div_up(padL - KW + 1, SW));
593     int ow_end = min(OW, 1 + (padL + IW - 1) / SW);
594 
595     int oh_start = max(0, utils::div_up(padT - KH + 1, SH));
596     int oh_end = min(OH, 1 + (padT + IH - 1) / SH);
597 
598     int od_start = max(0, utils::div_up(padF - KD + 1, SD));
599     int od_end = min(OD, 1 + (padF + ID - 1) / SD);
600 
601     dim_t c_blk = pd()->channel_block_size_;
602     int c_blk_tail = C % c_blk;
603     if (alg == alg_kind::pooling_max) {
604         parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
605                 [&](int ithr, int, int mb, int cb) {
606                     bool is_last_c_block
607                             = c_blk_tail > 0 && (cb + 1) * c_blk > C;
608                     int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
609                     size_t diff_dst_offset_b
610                             = ((size_t)mb * C + (size_t)cb * c_blk) * OD * OH
611                             * OW;
612                     size_t diff_src_offset
613                             = ((size_t)mb * C + (size_t)cb * c_blk) * ID * IH
614                             * IW;
615                     float *diff_dst_fp32
616                             = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
617                     float *diff_src_fp32
618                             = &bf16cvt_src[ithr * src_sp_size * c_blk];
619 
620                     ker_zero(diff_src_fp32, curr_c_block);
621 
622                     cvt_bfloat16_to_float(diff_dst_fp32,
623                             &diff_dst[diff_dst_offset_b],
624                             dst_sp_size * curr_c_block);
625 
626                     for_(int c = 0; c < curr_c_block; ++c)
627                     for_(int od = od_start; od < od_end; ++od)
628                     for (int oh = oh_start; oh < oh_end; ++oh) {
629                         size_t diff_dst_offset = (size_t)c * OD * OH * OW
630                                 + (size_t)od * OH * OW + (size_t)oh * OW;
631                         for (int ow = ow_start; ow < ow_end; ++ow) {
632                             const float *d
633                                     = &diff_dst_fp32[diff_dst_offset + ow];
634                             ker_max(d, &diff_src_fp32[c * ID * IH * IW], mb,
635                                     cb * c_blk + c, od, oh, ow);
636                         }
637                     }
638                     cvt_float_to_bfloat16(&diff_src[diff_src_offset],
639                             diff_src_fp32, src_sp_size * curr_c_block);
640                 });
641     } else {
642         parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
643                 [&](int ithr, int, int mb, int cb) {
644                     bool is_last_c_block
645                             = c_blk_tail > 0 && (cb + 1) * c_blk > C;
646                     int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
647                     size_t diff_dst_offset_b = (size_t)mb * C * OD * OH * OW
648                             + (size_t)cb * c_blk * OD * OH * OW;
649                     float *diff_dst_fp32
650                             = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
651                     size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
652                             + (size_t)cb * c_blk * ID * IH * IW;
653                     float *diff_src_fp32
654                             = &bf16cvt_src[ithr * src_sp_size * c_blk];
655 
656                     ker_zero(diff_src_fp32, curr_c_block);
657 
658                     cvt_bfloat16_to_float(diff_dst_fp32,
659                             &diff_dst[diff_dst_offset_b],
660                             dst_sp_size * curr_c_block);
661                     for_(int c = 0; c < curr_c_block; ++c)
662                     for_(int od = od_start; od < od_end; ++od)
663                     for (int oh = oh_start; oh < oh_end; ++oh) {
664                         size_t diff_dst_offset = (size_t)c * OD * OH * OW
665                                 + (size_t)od * OH * OW + (size_t)oh * OW;
666                         for (int ow = ow_start; ow < ow_end; ++ow) {
667                             const float *d
668                                     = &diff_dst_fp32[diff_dst_offset + ow];
669                             ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb,
670                                     cb * c_blk + c, od, oh, ow);
671                         }
672                     }
673                     cvt_float_to_bfloat16(&diff_src[diff_src_offset],
674                             diff_src_fp32, src_sp_size * curr_c_block);
675                 });
676     }
677 
678     return status::success;
679 }
680 template struct nchw_pooling_fwd_t<data_type::f32>;
681 template struct nchw_pooling_bwd_t<data_type::f32>;
682 template struct nchw_pooling_fwd_t<data_type::bf16>;
683 template struct nchw_pooling_bwd_t<data_type::bf16>;
684 } // namespace cpu
685 } // namespace impl
686 } // namespace dnnl
687 
688 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
689