1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include <math.h>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/compiler_workarounds.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/math_utils.hpp"
24 #include "common/nstl.hpp"
25 #include "common/type_helpers.hpp"
26 
27 #include "cpu/simple_q10n.hpp"
28 
29 #include "cpu/nhwc_pooling.hpp"
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 
35 // Intel's LLVM-based compiler on Windows generates incorrect code with
36 // PRAGMA_OMP_SIMD in some particular cases.
37 // TODO: The issue above seems to be an additional one to the issue mentioned
38 //       in `CLANG_WA_01_SAFE_TO_USE_OMP_SIMD`. Once the later is resolved,
39 //       check specifically the former one, maybe it will go away as well.
40 #if ((defined _WIN32) && (defined __INTEL_CLANG_COMPILER))
41 #define SAFE_TO_USE_OMP_SIMD (0 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD)
42 #else
43 #define SAFE_TO_USE_OMP_SIMD (1 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD)
44 #endif
45 
46 #define MEM_D(name) name##_d
47 
48 #define DECLARE_READ_STRIDES(name) \
49     const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \
50     const size_t name##_d_stride \
51             = is_3d ? MEM_D(name).blocking_desc().strides[ndims - 3] : 0; \
52     const size_t name##_h_stride \
53             = is_1d ? 0 : MEM_D(name).blocking_desc().strides[ndims - 2]; \
54     const size_t name##_w_stride \
55             = MEM_D(name).blocking_desc().strides[ndims - 1];
56 
57 namespace nhwc_pooling {
strided_offset(const int _n,const size_t _sn,const int _d,const size_t _sd,const int _h,const size_t _sh,const int _w,const size_t _sw)58 size_t strided_offset(const int _n, const size_t _sn, const int _d,
59         const size_t _sd, const int _h, const size_t _sh, const int _w,
60         const size_t _sw) {
61     return _n * _sn + _d * _sd + _h * _sh + _w * _sw;
62 }
63 } // namespace nhwc_pooling
64 
65 template <data_type_t d_type>
nhwc_pooling_fwd_t(const pd_t * apd)66 nhwc_pooling_fwd_t<d_type>::nhwc_pooling_fwd_t(const pd_t *apd)
67     : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
68 
69 template <data_type_t d_type>
array_div_by_const(const int n,const ker_data_t * src,const size_t num,ker_data_t * dst) const70 void nhwc_pooling_fwd_t<d_type>::array_div_by_const(const int n,
71         const ker_data_t *src, const size_t num, ker_data_t *dst) const {
72     for (int i = 0; i < n; ++i) {
73         const float ftmp = ((float)src[i]) / num;
74         dst[i] = out_round<ker_data_t>(ftmp);
75     }
76 }
77 
78 template <data_type_t d_type>
array_add(const int n,const ker_data_t * src,ker_data_t * dst) const79 void nhwc_pooling_fwd_t<d_type>::array_add(
80         const int n, const ker_data_t *src, ker_data_t *dst) const {
81     for (int i = 0; i < n; ++i) {
82         dst[i] += src[i];
83     }
84 }
85 
86 template <data_type_t d_type>
array_nhwc_max(const int n,ker_data_t * dst,const ker_data_t * src,unsigned char * ws,const size_t ws_offset,const data_type_t ws_dt,const int index) const87 void nhwc_pooling_fwd_t<d_type>::array_nhwc_max(const int n, ker_data_t *dst,
88         const ker_data_t *src, unsigned char *ws, const size_t ws_offset,
89         const data_type_t ws_dt, const int index) const {
90     assert(ws);
91 #if SAFE_TO_USE_OMP_SIMD
92     PRAGMA_OMP_SIMD()
93 #endif
94     for (int oc = 0; oc < n; ++oc) {
95         const auto s = src[oc];
96         ker_data_t mv = dst[oc];
97 
98         // update index of maximum
99 #if defined __INTEL_COMPILER
100         if (s > mv) {
101             // if (ws && (s > mv)) {
102             assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
103             if (ws_dt == data_type::u8) {
104                 assert(0 <= index && index <= 255);
105                 ws[ws_offset + oc] = index;
106             } else
107                 reinterpret_cast<int *>(ws)[ws_offset + oc] = index;
108         }
109 #else
110         // Need to add explicit predicates for GCC to vectorize this.
111         // And although the resulting code is ugly, it is still 4 times
112         // faster than scalar
113         assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
114 
115         if (ws_dt == data_type::u8) {
116             assert(0 <= index && index <= 255);
117             const unsigned char predicate = (s > mv) ? 0xff : 0;
118             unsigned char current_value = ws[ws_offset + oc];
119             current_value = (predicate & (unsigned char)index)
120                     | ((~predicate) & current_value);
121             ws[ws_offset + oc] = current_value;
122         } else {
123             auto wint = reinterpret_cast<int *>(ws);
124             const unsigned int predicate = (s > mv) ? 0xffffffff : 0;
125             unsigned int current_value = wint[ws_offset + oc];
126             current_value = (predicate & (unsigned int)index)
127                     | ((~predicate) & current_value);
128             wint[ws_offset + oc] = current_value;
129         }
130 #endif
131         // update maximum
132         dst[oc] = nstl::max(s, mv);
133     }
134 }
135 
136 template <data_type_t d_type>
array_nhwc_initialize(const int n,ker_data_t * dst,unsigned char * ws,const size_t ws_offset,const data_type_t ws_dt) const137 void nhwc_pooling_fwd_t<d_type>::array_nhwc_initialize(const int n,
138         ker_data_t *dst, unsigned char *ws, const size_t ws_offset,
139         const data_type_t ws_dt) const {
140     assert(ws && (ws_dt == data_type::u8 || ws_dt == data_type::s32));
141 #if SAFE_TO_USE_OMP_SIMD
142     PRAGMA_OMP_SIMD()
143 #endif
144     for (int oc = 0; oc < n; ++oc) {
145         if (ws_dt == data_type::u8)
146             ws[ws_offset + oc] = 0;
147         else
148             reinterpret_cast<int *>(ws)[ws_offset + oc] = 0;
149         dst[oc] = nstl::numeric_limits<data_t>::lowest();
150     }
151 }
152 
153 using namespace nstl;
154 using namespace nhwc_pooling;
155 
156 template <data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const157 status_t nhwc_pooling_fwd_t<d_type>::execute_forward(
158         const exec_ctx_t &ctx) const {
159 
160     const auto alg = pd()->desc()->alg_kind;
161 
162     const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
163     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
164     auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
165 
166     const memory_desc_wrapper MEM_D(src)(pd()->src_md());
167     const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
168     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
169 
170     const int MB = pd()->MB();
171     const int OD = pd()->OD();
172     const int OC = pd()->OC();
173     const int OH = pd()->OH();
174     const int OW = pd()->OW();
175     const int ID = pd()->ID();
176     const int IH = pd()->IH();
177     const int IW = pd()->IW();
178     const int KD = pd()->KD();
179     const int KH = pd()->KH();
180     const int KW = pd()->KW();
181     const int SD = pd()->KSD();
182     const int SH = pd()->KSH();
183     const int SW = pd()->KSW();
184     const int padF = pd()->padFront();
185     const int padT = pd()->padT();
186     const int padL = pd()->padL();
187 
188     const bool is_1d = pd()->desc()->src_desc.ndims == 3;
189     const bool is_3d = pd()->desc()->src_desc.ndims == 5;
190     const int ndims = pd()->ndims();
191     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
192 
193     DECLARE_READ_STRIDES(src);
194     DECLARE_READ_STRIDES(dst);
195 
196     const auto apply_offset = [](int index, int offset) {
197         return (index > offset) ? index - offset : 0;
198     };
199 
200     const dim_t SP = OW * OH;
201     const dim_t OSP = SP * OD;
202 
203     const auto get_logical_offset
204             = [&](int mb, int oc, int od, int oh, int ow) -> dim_t {
205         return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
206     };
207     const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
208 
209     parallel_nd(MB, OD, OH, OW, [&](int mb, int od, int oh, int ow) {
210         const size_t dst_offset_init = strided_offset(mb, dst_n_stride, od,
211                 dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
212         if (alg == alg_kind::pooling_max) {
213             size_t ws_offset_init = 0;
214             if (ws) {
215                 DECLARE_READ_STRIDES(ws);
216                 ws_offset_init = strided_offset(mb, ws_n_stride, od,
217                         ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
218             }
219             // Note: GCC 4.8.5 won't vectorize below
220             // simple loops unless they are singled out
221             // into separate helper routines:
222             //    array_nhwc_initialize, array_nhwc_max
223             if (!ws) {
224                 auto *const d = dst + dst_offset_init;
225                 PRAGMA_OMP_SIMD()
226                 for (int oc = 0; oc < OC; ++oc) {
227                     d[oc] = nstl::numeric_limits<data_t>::lowest();
228                 }
229             } else {
230                 array_nhwc_initialize(
231                         OC, dst + dst_offset_init, ws, ws_offset_init, ws_dt);
232             }
233 
234             for_(int kd = 0; kd < KD; ++kd)
235             for_(int kh = 0; kh < KH; ++kh)
236             for (int kw = 0; kw < KW; ++kw) {
237                 const int id = od * SD - padF + kd;
238                 const int ih = oh * SH - padT + kh;
239                 const int iw = ow * SW - padL + kw;
240 
241                 if (id < 0 || id >= ID) continue;
242                 if (ih < 0 || ih >= IH) continue;
243                 if (iw < 0 || iw >= IW) continue;
244 
245                 const size_t src_offset_init = strided_offset(mb, src_n_stride,
246                         id, src_d_stride, ih, src_h_stride, iw, src_w_stride);
247 
248                 if (!ws) {
249                     auto *const s = src + src_offset_init;
250                     auto *const d = dst + dst_offset_init;
251                     PRAGMA_OMP_SIMD()
252                     for (int oc = 0; oc < OC; ++oc) {
253                         d[oc] = nstl::max(s[oc], d[oc]);
254                     }
255                 } else {
256                     array_nhwc_max(OC, dst + dst_offset_init,
257                             src + src_offset_init, ws, ws_offset_init, ws_dt,
258                             kd * KH * KW + kh * KW + kw);
259                 }
260             }
261         } else {
262             // pooling_avg
263             const auto d = dst + dst_offset_init;
264 
265             utils::array_set(d, 0, OC);
266 
267             const auto id_start = apply_offset(od * SD, padF);
268             const auto ih_start = apply_offset(oh * SH, padT);
269             const auto iw_start = apply_offset(ow * SW, padL);
270             const auto id_end = min(od * SD - padF + KD, ID);
271             const auto ih_end = min(oh * SH - padT + KH, IH);
272             const auto iw_end = min(ow * SW - padL + KW, IW);
273 
274             // it is cheaper to actually count this in a loop
275             // as the typical kernel is small
276             size_t num_summands = 0;
277 
278             for_(int id = id_start; id < id_end; ++id)
279             for_(int ih = ih_start; ih < ih_end; ++ih)
280             for (int iw = iw_start; iw < iw_end; ++iw) {
281                 const size_t src_offset_init = strided_offset(mb, src_n_stride,
282                         id, src_d_stride, ih, src_h_stride, iw, src_w_stride);
283                 const auto s = src + src_offset_init;
284 
285                 // need to move the loop to separate function
286                 // for GCC 4.8.5 to vectorize
287                 array_add(OC, s, d);
288 
289                 num_summands++;
290             }
291 
292             num_summands = (alg == alg_kind::pooling_avg_include_padding)
293                     ? KW * KH * KD
294                     : num_summands;
295 
296             // need to move the loop to separate function
297             // for GCC 4.8.5 to vectorize
298             array_div_by_const(OC, d, num_summands, d);
299         }
300 
301         if (are_postops_set) {
302             auto *const d = dst + dst_offset_init;
303             ref_post_ops_t::args_t args;
304             args.ctx = &ctx;
305             args.l_offset = get_logical_offset(mb, 0, od, oh, ow);
306             args.dst_md = pd()->dst_md();
307 
308             for (int oc = 0; oc < OC; ++oc) {
309                 ref_post_ops_.execute(d[oc], args);
310                 args.l_offset += OSP;
311             }
312         }
313     });
314     return status::success;
315 }
316 
317 template <>
execute_forward(const exec_ctx_t & ctx) const318 status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward(
319         const exec_ctx_t &ctx) const {
320 
321     const auto alg = pd()->desc()->alg_kind;
322 
323     const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
324     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
325     auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
326 
327     auto scratchpad = ctx.get_scratchpad_grantor();
328     float *const bf16cvt_src_wsp = scratchpad.template get<float>(
329             memory_tracking::names::key_pool_src_bf16cvt);
330     float *const bf16cvt_dst_wsp = scratchpad.template get<float>(
331             memory_tracking::names::key_pool_dst_bf16cvt);
332 
333     const memory_desc_wrapper MEM_D(src)(pd()->src_md());
334     const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
335     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
336 
337     const int MB = pd()->MB();
338     const int OD = pd()->OD();
339     const int OC = pd()->OC();
340     const int OH = pd()->OH();
341     const int OW = pd()->OW();
342     const int ID = pd()->ID();
343     const int IH = pd()->IH();
344     const int IW = pd()->IW();
345     const int KD = pd()->KD();
346     const int KH = pd()->KH();
347     const int KW = pd()->KW();
348     const int SD = pd()->KSD();
349     const int SH = pd()->KSH();
350     const int SW = pd()->KSW();
351     const int padF = pd()->padFront();
352     const int padT = pd()->padT();
353     const int padL = pd()->padL();
354 
355     const bool is_1d = pd()->desc()->src_desc.ndims == 3;
356     const bool is_3d = pd()->desc()->src_desc.ndims == 5;
357     const int ndims = pd()->ndims();
358     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
359 
360     DECLARE_READ_STRIDES(src);
361     DECLARE_READ_STRIDES(dst);
362 
363     const auto apply_offset = [&](int index, int offset) {
364         return (index > offset) ? index - offset : 0;
365     };
366 
367     const dim_t SP = OW * OH;
368     const dim_t OSP = SP * OD;
369 
370     const auto get_logical_offset
371             = [&](int mb, int oc, int od, int oh, int ow) -> dim_t {
372         return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
373     };
374     const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
375 
376     parallel_nd_ext(0, MB, OD, OH, OW,
377             [&](int ithr, int, int mb, int od, int oh, int ow) {
378                 const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
379                         od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
380                 float *const dst_f32 = &bf16cvt_dst_wsp[ithr * OC];
381                 float *const src_f32 = &bf16cvt_src_wsp[ithr * OC];
382 
383                 if (alg == alg_kind::pooling_max) {
384                     size_t ws_offset_init = 0;
385                     if (ws) {
386                         DECLARE_READ_STRIDES(ws);
387                         ws_offset_init = strided_offset(mb, ws_n_stride, od,
388                                 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
389                     };
390                     // Note: GCC 4.8.5 won't vectorize below
391                     // simple loops unless they are singled out
392                     // into separate helper routines:
393                     //    array_nhwc_initialize, array_nhwc_max
394                     if (!ws) {
395                         PRAGMA_OMP_SIMD()
396                         for (int oc = 0; oc < OC; ++oc) {
397                             dst_f32[oc]
398                                     = nstl::numeric_limits<data_t>::lowest();
399                         }
400                     } else {
401                         array_nhwc_initialize(
402                                 OC, dst_f32, ws, ws_offset_init, ws_dt);
403                     }
404 
405                     for_(int kd = 0; kd < KD; ++kd)
406                     for_(int kh = 0; kh < KH; ++kh)
407                     for (int kw = 0; kw < KW; ++kw) {
408                         const int id = od * SD - padF + kd;
409                         const int ih = oh * SH - padT + kh;
410                         const int iw = ow * SW - padL + kw;
411 
412                         if (id < 0 || id >= ID) continue;
413                         if (ih < 0 || ih >= IH) continue;
414                         if (iw < 0 || iw >= IW) continue;
415 
416                         const size_t src_offset_init = strided_offset(mb,
417                                 src_n_stride, id, src_d_stride, ih,
418                                 src_h_stride, iw, src_w_stride);
419 
420                         cvt_bfloat16_to_float(
421                                 src_f32, &src[src_offset_init], OC);
422 
423                         if (!ws) {
424                             PRAGMA_OMP_SIMD()
425                             for (int oc = 0; oc < OC; ++oc) {
426                                 dst_f32[oc]
427                                         = nstl::max(src_f32[oc], dst_f32[oc]);
428                             }
429                         } else {
430                             array_nhwc_max(OC, dst_f32, src_f32, ws,
431                                     ws_offset_init, ws_dt,
432                                     kd * KH * KW + kh * KW + kw);
433                         }
434                     }
435                 } else {
436                     // pooling_avg
437                     utils::array_set(dst_f32, 0, OC);
438 
439                     const auto id_start = apply_offset(od * SD, padF);
440                     const auto ih_start = apply_offset(oh * SH, padT);
441                     const auto iw_start = apply_offset(ow * SW, padL);
442                     const auto id_end = min(od * SD - padF + KD, ID);
443                     const auto ih_end = min(oh * SH - padT + KH, IH);
444                     const auto iw_end = min(ow * SW - padL + KW, IW);
445 
446                     // it is cheaper to actually count this in a loop
447                     // as the typical kernel is small
448                     size_t num_summands = 0;
449 
450                     for_(int id = id_start; id < id_end; ++id)
451                     for_(int ih = ih_start; ih < ih_end; ++ih)
452                     for (int iw = iw_start; iw < iw_end; ++iw) {
453                         size_t src_offset_init = strided_offset(mb,
454                                 src_n_stride, id, src_d_stride, ih,
455                                 src_h_stride, iw, src_w_stride);
456                         cvt_bfloat16_to_float(
457                                 src_f32, &src[src_offset_init], OC);
458 
459                         // need to move the loop to separate function
460                         // for GCC 4.8.5 to vectorize
461                         array_add(OC, src_f32, dst_f32);
462                         num_summands++;
463                     }
464 
465                     num_summands
466                             = (alg == alg_kind::pooling_avg_include_padding)
467                             ? KW * KH * KD
468                             : num_summands;
469 
470                     // need to move the loop to separate function
471                     // for GCC 4.8.5 to vectorize
472                     array_div_by_const(OC, dst_f32, num_summands, dst_f32);
473                 }
474 
475                 if (are_postops_set) {
476                     ref_post_ops_t::args_t args;
477                     args.ctx = &ctx;
478                     args.l_offset = get_logical_offset(mb, 0, od, oh, ow);
479                     args.dst_md = pd()->dst_md();
480 
481                     for (int oc = 0; oc < OC; ++oc) {
482                         ref_post_ops_.execute(dst_f32[oc], args);
483                         args.l_offset += OSP;
484                     }
485                 }
486                 cvt_float_to_bfloat16(dst + dst_offset_init, dst_f32, OC);
487             });
488     return status::success;
489 }
490 
491 template <data_type_t d_type>
execute_backward(const exec_ctx_t & ctx) const492 status_t nhwc_pooling_bwd_t<d_type>::execute_backward(
493         const exec_ctx_t &ctx) const {
494     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
495     auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
496     auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
497 
498     const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
499     const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
500     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
501 
502     const int MB = pd()->MB();
503     const int ID = pd()->ID();
504     const int IH = pd()->IH();
505     const int IW = pd()->IW();
506     const int KD = pd()->KD();
507     const int KH = pd()->KH();
508     const int KW = pd()->KW();
509     const int SD = pd()->KSD();
510     const int SH = pd()->KSH();
511     const int SW = pd()->KSW();
512     const int OC = pd()->OC();
513     const int padF = pd()->padFront();
514     const int padT = pd()->padT();
515     const int padL = pd()->padL();
516     const int OD = pd()->OD();
517     const int OH = pd()->OH();
518     const int OW = pd()->OW();
519 
520     const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3;
521     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
522     const int ndims = pd()->ndims();
523     auto alg = pd()->desc()->alg_kind;
524 
525     DECLARE_READ_STRIDES(diff_src);
526     DECLARE_READ_STRIDES(diff_dst);
527 
528     auto apply_offset = [=](int index, int offset) {
529         return (index > offset) ? index - offset : 0;
530     };
531 
532     parallel_nd(MB, ID, IH, IW, [&](int mb, int id, int ih, int iw) {
533         size_t src_offset_init
534                 = strided_offset(mb, diff_src_n_stride, id, diff_src_d_stride,
535                         ih, diff_src_h_stride, iw, diff_src_w_stride);
536 
537         for (int oc = 0; oc < OC; ++oc)
538             diff_src[src_offset_init + oc] = data_type_t(0);
539 
540         // Find out which output cells may correspond to current
541         // input position. Current input postition divided by
542         // stride, with integer divide rounding down, is the
543         // right-most output.
544         // Left-most output may be computed if we decrement input
545         // by (kernel_size - 1) and then do the same division by
546         // stride.
547         int od_left = max((id + padF - KD + 1) / SD, 0);
548         int oh_left = max((ih + padT - KH + 1) / SH, 0);
549         int ow_left = max((iw + padL - KW + 1) / SW, 0);
550         // Notice +1 here to preserve the C loop "less than"
551         // condition for continuing the for loop.
552         int od_right = min((id + padF) / SD + 1, OD);
553         int oh_right = min((ih + padT) / SH + 1, OH);
554         int ow_right = min((iw + padL) / SW + 1, OW);
555 
556         for_(int od = od_left; od < od_right; ++od)
557         for_(int oh = oh_left; oh < oh_right; ++oh)
558         for (int ow = ow_left; ow < ow_right; ++ow) {
559             const int kd = id - od * SD + padF;
560             const int kh = ih - oh * SH + padT;
561             const int kw = iw - ow * SW + padL;
562 
563             if (kd < 0 || kd >= KD) continue;
564             if (kh < 0 || kh >= KH) continue;
565             if (kw < 0 || kw >= KW) continue;
566 
567             size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, od,
568                     diff_dst_d_stride, oh, diff_dst_h_stride, ow,
569                     diff_dst_w_stride);
570 
571             if (alg == alg_kind::pooling_max) {
572                 DECLARE_READ_STRIDES(ws);
573                 size_t ws_offset_init = strided_offset(mb, ws_n_stride, od,
574                         ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
575                 const int index = kd * KH * KW + kh * KW + kw;
576                 const unsigned char *ws_ = ws + ws_offset_init;
577                 const int *intws_ = (int *)ws + ws_offset_init;
578                 const bool ws_is_u8 = MEM_D(ws).data_type() == data_type::u8;
579 
580 #if SAFE_TO_USE_OMP_SIMD
581                 PRAGMA_OMP_SIMD()
582 #endif
583                 for (int oc = 0; oc < OC; ++oc) {
584                     const int index_from_ws = ws_is_u8 ? ws_[oc] : intws_[oc];
585                     const data_t d = diff_dst[dst_offset_init + oc];
586 
587                     // Check if kernel windows are disjoint, in this case
588                     // there's no update needed and we just write there once
589                     // otherwise we add value to the contents.
590                     auto value = (index_from_ws == index) ? d : data_type_t(0);
591                     if (!(KD == SD && KH == SH && KW == SW))
592                         diff_src[src_offset_init + oc] += value;
593                     else
594                         diff_src[src_offset_init + oc] = value;
595                 }
596             } else {
597                 // pooling_avg
598                 auto id_start = apply_offset(od * SD, padF);
599                 auto ih_start = apply_offset(oh * SH, padT);
600                 auto iw_start = apply_offset(ow * SW, padL);
601                 auto id_end = min(od * SD - padF + KD, ID);
602                 auto ih_end = min(oh * SH - padT + KH, IH);
603                 auto iw_end = min(ow * SW - padL + KW, IW);
604 
605                 auto num_summands
606                         = (alg == alg_kind::pooling_avg_include_padding)
607                         ? KW * KH * KD
608                         : (ih_end - ih_start) * (iw_end - iw_start)
609                                 * (id_end - id_start);
610 
611                 PRAGMA_OMP_SIMD()
612                 for (int oc = 0; oc < OC; ++oc) {
613                     const data_t d = diff_dst[dst_offset_init + oc];
614                     // Check if kernel windows are disjoint, in this case
615                     // there's no update needed and we just write there once
616                     // otherwise we add value to the contents.
617                     if (!(KD == SD && KH == SH && KW == SW))
618                         diff_src[src_offset_init + oc] += d / num_summands;
619                     else
620                         diff_src[src_offset_init + oc] = d / num_summands;
621                 }
622             }
623         }
624     });
625     return status::success;
626 }
627 
628 template <>
execute_backward(const exec_ctx_t & ctx) const629 status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward(
630         const exec_ctx_t &ctx) const {
631 
632     auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
633     auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
634     auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
635 
636     auto scratchpad = ctx.get_scratchpad_grantor();
637     float *bf16cvt_dsrc = scratchpad.template get<float>(
638             memory_tracking::names::key_pool_src_bf16cvt);
639     float *bf16cvt_ddst = scratchpad.template get<float>(
640             memory_tracking::names::key_pool_dst_bf16cvt);
641 
642     const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
643     const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
644     const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
645 
646     const int MB = pd()->MB();
647     const int ID = pd()->ID();
648     const int IH = pd()->IH();
649     const int IW = pd()->IW();
650     const int KD = pd()->KD();
651     const int KH = pd()->KH();
652     const int KW = pd()->KW();
653     const int SD = pd()->KSD();
654     const int SH = pd()->KSH();
655     const int SW = pd()->KSW();
656     const int OC = pd()->OC();
657     const int padF = pd()->padFront();
658     const int padT = pd()->padT();
659     const int padL = pd()->padL();
660     const int OD = pd()->OD();
661     const int OH = pd()->OH();
662     const int OW = pd()->OW();
663 
664     const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3;
665     const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
666     const int ndims = pd()->ndims();
667     auto alg = pd()->desc()->alg_kind;
668 
669     DECLARE_READ_STRIDES(diff_src);
670     DECLARE_READ_STRIDES(diff_dst);
671 
672     auto apply_offset = [=](int index, int offset) {
673         return (index > offset) ? index - offset : 0;
674     };
675 
676     parallel_nd_ext(0, MB, ID, IH, IW,
677             [&](int ithr, int, int mb, int id, int ih, int iw) {
678                 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
679                         id, diff_src_d_stride, ih, diff_src_h_stride, iw,
680                         diff_src_w_stride);
681 
682                 float *diff_dst_fp32 = &bf16cvt_ddst[ithr * OC];
683                 float *diff_src_fp32 = &bf16cvt_dsrc[ithr * OC];
684 
685                 for (int oc = 0; oc < OC; ++oc) {
686                     diff_src_fp32[oc] = 0.f;
687                     diff_src[src_offset_init + oc] = (bfloat16_t)0.f;
688                 }
689 
690                 // Find out which output cells may correspond to current
691                 // input position. Current input postition divided by
692                 // stride, with integer divide rounding down, is the
693                 // right-most output.
694                 // Left-most output may be computed if we decrement input
695                 // by (kernel_size - 1) and then do the same division by
696                 // stride.
697                 int od_left = max((id + padF - KD + 1) / SD, 0);
698                 int oh_left = max((ih + padT - KH + 1) / SH, 0);
699                 int ow_left = max((iw + padL - KW + 1) / SW, 0);
700                 // Notice +1 here to preserve the C loop "less than"
701                 // condition for continuing the for loop.
702                 int od_right = min((id + padF) / SD + 1, OD);
703                 int oh_right = min((ih + padT) / SH + 1, OH);
704                 int ow_right = min((iw + padL) / SW + 1, OW);
705 
706                 for_(int od = od_left; od < od_right; ++od)
707                 for_(int oh = oh_left; oh < oh_right; ++oh)
708                 for (int ow = ow_left; ow < ow_right; ++ow) {
709                     const int kd = id - od * SD + padF;
710                     const int kh = ih - oh * SH + padT;
711                     const int kw = iw - ow * SW + padL;
712 
713                     if (kd < 0 || kd >= KD) continue;
714                     if (kh < 0 || kh >= KH) continue;
715                     if (kw < 0 || kw >= KW) continue;
716 
717                     size_t dst_offset_init = strided_offset(mb,
718                             diff_dst_n_stride, od, diff_dst_d_stride, oh,
719                             diff_dst_h_stride, ow, diff_dst_w_stride);
720                     cvt_bfloat16_to_float(
721                             diff_dst_fp32, &diff_dst[dst_offset_init], OC);
722 
723                     if (alg == alg_kind::pooling_max) {
724                         DECLARE_READ_STRIDES(ws);
725                         size_t ws_offset_init = strided_offset(mb, ws_n_stride,
726                                 od, ws_d_stride, oh, ws_h_stride, ow,
727                                 ws_w_stride);
728                         const int index = kd * KH * KW + kh * KW + kw;
729                         const unsigned char *ws_ = ws + ws_offset_init;
730                         const int *intws_ = (int *)ws + ws_offset_init;
731                         const bool ws_is_u8
732                                 = MEM_D(ws).data_type() == data_type::u8;
733 
734 #if SAFE_TO_USE_OMP_SIMD
735                         PRAGMA_OMP_SIMD()
736 #endif
737                         for (int oc = 0; oc < OC; ++oc) {
738                             const int index_from_ws
739                                     = ws_is_u8 ? ws_[oc] : intws_[oc];
740 
741                             // Check if kernel windows are disjoint, in this case
742                             // there's no update needed and we just write there once
743                             // otherwise we add value to the contents.
744                             float value = (index_from_ws == index)
745                                     ? diff_dst_fp32[oc]
746                                     : 0.0f;
747                             if (!(KD == SD && KH == SH && KW == SW))
748                                 diff_src_fp32[oc] += value;
749                             else
750                                 diff_src_fp32[oc] = value;
751                         }
752                     } else {
753                         // pooling_avg
754                         auto id_start = apply_offset(od * SD, padF);
755                         auto ih_start = apply_offset(oh * SH, padT);
756                         auto iw_start = apply_offset(ow * SW, padL);
757                         auto id_end = min(od * SD - padF + KD, ID);
758                         auto ih_end = min(oh * SH - padT + KH, IH);
759                         auto iw_end = min(ow * SW - padL + KW, IW);
760 
761                         auto num_summands
762                                 = (alg == alg_kind::pooling_avg_include_padding)
763                                 ? KW * KH * KD
764                                 : (ih_end - ih_start) * (iw_end - iw_start)
765                                         * (id_end - id_start);
766 
767                         PRAGMA_OMP_SIMD()
768                         for (int oc = 0; oc < OC; ++oc) {
769                             // Check if kernel windows are disjoint, in this case
770                             // there's no update needed and we just write there once
771                             // otherwise we add value to the contents.
772                             if (!(KD == SD && KH == SH && KW == SW))
773                                 diff_src_fp32[oc]
774                                         += diff_dst_fp32[oc] / num_summands;
775                             else
776                                 diff_src_fp32[oc]
777                                         = diff_dst_fp32[oc] / num_summands;
778                         }
779                     }
780                     cvt_float_to_bfloat16(
781                             &diff_src[src_offset_init], diff_src_fp32, OC);
782                 }
783             });
784     return status::success;
785 }
786 
787 template struct nhwc_pooling_fwd_t<data_type::f32>;
788 template struct nhwc_pooling_bwd_t<data_type::f32>;
789 template struct nhwc_pooling_fwd_t<data_type::bf16>;
790 template struct nhwc_pooling_bwd_t<data_type::bf16>;
791 
792 } // namespace cpu
793 } // namespace impl
794 } // namespace dnnl
795 
796 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
797