1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "common/c_types_map.hpp"
21 #include "common/compiler_workarounds.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/math_utils.hpp"
24 #include "common/nstl.hpp"
25 #include "common/type_helpers.hpp"
26
27 #include "cpu/simple_q10n.hpp"
28
29 #include "cpu/nhwc_pooling.hpp"
30
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34
35 // Intel's LLVM-based compiler on Windows generates incorrect code with
36 // PRAGMA_OMP_SIMD in some particular cases.
37 // TODO: The issue above seems to be an additional one to the issue mentioned
38 // in `CLANG_WA_01_SAFE_TO_USE_OMP_SIMD`. Once the later is resolved,
39 // check specifically the former one, maybe it will go away as well.
40 #if ((defined _WIN32) && (defined __INTEL_CLANG_COMPILER))
41 #define SAFE_TO_USE_OMP_SIMD (0 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD)
42 #else
43 #define SAFE_TO_USE_OMP_SIMD (1 && CLANG_WA_01_SAFE_TO_USE_OMP_SIMD)
44 #endif
45
46 #define MEM_D(name) name##_d
47
48 #define DECLARE_READ_STRIDES(name) \
49 const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \
50 const size_t name##_d_stride \
51 = is_3d ? MEM_D(name).blocking_desc().strides[ndims - 3] : 0; \
52 const size_t name##_h_stride \
53 = is_1d ? 0 : MEM_D(name).blocking_desc().strides[ndims - 2]; \
54 const size_t name##_w_stride \
55 = MEM_D(name).blocking_desc().strides[ndims - 1];
56
57 namespace nhwc_pooling {
strided_offset(const int _n,const size_t _sn,const int _d,const size_t _sd,const int _h,const size_t _sh,const int _w,const size_t _sw)58 size_t strided_offset(const int _n, const size_t _sn, const int _d,
59 const size_t _sd, const int _h, const size_t _sh, const int _w,
60 const size_t _sw) {
61 return _n * _sn + _d * _sd + _h * _sh + _w * _sw;
62 }
63 } // namespace nhwc_pooling
64
65 template <data_type_t d_type>
nhwc_pooling_fwd_t(const pd_t * apd)66 nhwc_pooling_fwd_t<d_type>::nhwc_pooling_fwd_t(const pd_t *apd)
67 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
68
69 template <data_type_t d_type>
array_div_by_const(const int n,const ker_data_t * src,const size_t num,ker_data_t * dst) const70 void nhwc_pooling_fwd_t<d_type>::array_div_by_const(const int n,
71 const ker_data_t *src, const size_t num, ker_data_t *dst) const {
72 for (int i = 0; i < n; ++i) {
73 const float ftmp = ((float)src[i]) / num;
74 dst[i] = out_round<ker_data_t>(ftmp);
75 }
76 }
77
78 template <data_type_t d_type>
array_add(const int n,const ker_data_t * src,ker_data_t * dst) const79 void nhwc_pooling_fwd_t<d_type>::array_add(
80 const int n, const ker_data_t *src, ker_data_t *dst) const {
81 for (int i = 0; i < n; ++i) {
82 dst[i] += src[i];
83 }
84 }
85
86 template <data_type_t d_type>
array_nhwc_max(const int n,ker_data_t * dst,const ker_data_t * src,unsigned char * ws,const size_t ws_offset,const data_type_t ws_dt,const int index) const87 void nhwc_pooling_fwd_t<d_type>::array_nhwc_max(const int n, ker_data_t *dst,
88 const ker_data_t *src, unsigned char *ws, const size_t ws_offset,
89 const data_type_t ws_dt, const int index) const {
90 assert(ws);
91 #if SAFE_TO_USE_OMP_SIMD
92 PRAGMA_OMP_SIMD()
93 #endif
94 for (int oc = 0; oc < n; ++oc) {
95 const auto s = src[oc];
96 ker_data_t mv = dst[oc];
97
98 // update index of maximum
99 #if defined __INTEL_COMPILER
100 if (s > mv) {
101 // if (ws && (s > mv)) {
102 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
103 if (ws_dt == data_type::u8) {
104 assert(0 <= index && index <= 255);
105 ws[ws_offset + oc] = index;
106 } else
107 reinterpret_cast<int *>(ws)[ws_offset + oc] = index;
108 }
109 #else
110 // Need to add explicit predicates for GCC to vectorize this.
111 // And although the resulting code is ugly, it is still 4 times
112 // faster than scalar
113 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
114
115 if (ws_dt == data_type::u8) {
116 assert(0 <= index && index <= 255);
117 const unsigned char predicate = (s > mv) ? 0xff : 0;
118 unsigned char current_value = ws[ws_offset + oc];
119 current_value = (predicate & (unsigned char)index)
120 | ((~predicate) & current_value);
121 ws[ws_offset + oc] = current_value;
122 } else {
123 auto wint = reinterpret_cast<int *>(ws);
124 const unsigned int predicate = (s > mv) ? 0xffffffff : 0;
125 unsigned int current_value = wint[ws_offset + oc];
126 current_value = (predicate & (unsigned int)index)
127 | ((~predicate) & current_value);
128 wint[ws_offset + oc] = current_value;
129 }
130 #endif
131 // update maximum
132 dst[oc] = nstl::max(s, mv);
133 }
134 }
135
136 template <data_type_t d_type>
array_nhwc_initialize(const int n,ker_data_t * dst,unsigned char * ws,const size_t ws_offset,const data_type_t ws_dt) const137 void nhwc_pooling_fwd_t<d_type>::array_nhwc_initialize(const int n,
138 ker_data_t *dst, unsigned char *ws, const size_t ws_offset,
139 const data_type_t ws_dt) const {
140 assert(ws && (ws_dt == data_type::u8 || ws_dt == data_type::s32));
141 #if SAFE_TO_USE_OMP_SIMD
142 PRAGMA_OMP_SIMD()
143 #endif
144 for (int oc = 0; oc < n; ++oc) {
145 if (ws_dt == data_type::u8)
146 ws[ws_offset + oc] = 0;
147 else
148 reinterpret_cast<int *>(ws)[ws_offset + oc] = 0;
149 dst[oc] = nstl::numeric_limits<data_t>::lowest();
150 }
151 }
152
153 using namespace nstl;
154 using namespace nhwc_pooling;
155
156 template <data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const157 status_t nhwc_pooling_fwd_t<d_type>::execute_forward(
158 const exec_ctx_t &ctx) const {
159
160 const auto alg = pd()->desc()->alg_kind;
161
162 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
163 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
164 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
165
166 const memory_desc_wrapper MEM_D(src)(pd()->src_md());
167 const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
168 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
169
170 const int MB = pd()->MB();
171 const int OD = pd()->OD();
172 const int OC = pd()->OC();
173 const int OH = pd()->OH();
174 const int OW = pd()->OW();
175 const int ID = pd()->ID();
176 const int IH = pd()->IH();
177 const int IW = pd()->IW();
178 const int KD = pd()->KD();
179 const int KH = pd()->KH();
180 const int KW = pd()->KW();
181 const int SD = pd()->KSD();
182 const int SH = pd()->KSH();
183 const int SW = pd()->KSW();
184 const int padF = pd()->padFront();
185 const int padT = pd()->padT();
186 const int padL = pd()->padL();
187
188 const bool is_1d = pd()->desc()->src_desc.ndims == 3;
189 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
190 const int ndims = pd()->ndims();
191 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
192
193 DECLARE_READ_STRIDES(src);
194 DECLARE_READ_STRIDES(dst);
195
196 const auto apply_offset = [](int index, int offset) {
197 return (index > offset) ? index - offset : 0;
198 };
199
200 const dim_t SP = OW * OH;
201 const dim_t OSP = SP * OD;
202
203 const auto get_logical_offset
204 = [&](int mb, int oc, int od, int oh, int ow) -> dim_t {
205 return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
206 };
207 const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
208
209 parallel_nd(MB, OD, OH, OW, [&](int mb, int od, int oh, int ow) {
210 const size_t dst_offset_init = strided_offset(mb, dst_n_stride, od,
211 dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
212 if (alg == alg_kind::pooling_max) {
213 size_t ws_offset_init = 0;
214 if (ws) {
215 DECLARE_READ_STRIDES(ws);
216 ws_offset_init = strided_offset(mb, ws_n_stride, od,
217 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
218 }
219 // Note: GCC 4.8.5 won't vectorize below
220 // simple loops unless they are singled out
221 // into separate helper routines:
222 // array_nhwc_initialize, array_nhwc_max
223 if (!ws) {
224 auto *const d = dst + dst_offset_init;
225 PRAGMA_OMP_SIMD()
226 for (int oc = 0; oc < OC; ++oc) {
227 d[oc] = nstl::numeric_limits<data_t>::lowest();
228 }
229 } else {
230 array_nhwc_initialize(
231 OC, dst + dst_offset_init, ws, ws_offset_init, ws_dt);
232 }
233
234 for_(int kd = 0; kd < KD; ++kd)
235 for_(int kh = 0; kh < KH; ++kh)
236 for (int kw = 0; kw < KW; ++kw) {
237 const int id = od * SD - padF + kd;
238 const int ih = oh * SH - padT + kh;
239 const int iw = ow * SW - padL + kw;
240
241 if (id < 0 || id >= ID) continue;
242 if (ih < 0 || ih >= IH) continue;
243 if (iw < 0 || iw >= IW) continue;
244
245 const size_t src_offset_init = strided_offset(mb, src_n_stride,
246 id, src_d_stride, ih, src_h_stride, iw, src_w_stride);
247
248 if (!ws) {
249 auto *const s = src + src_offset_init;
250 auto *const d = dst + dst_offset_init;
251 PRAGMA_OMP_SIMD()
252 for (int oc = 0; oc < OC; ++oc) {
253 d[oc] = nstl::max(s[oc], d[oc]);
254 }
255 } else {
256 array_nhwc_max(OC, dst + dst_offset_init,
257 src + src_offset_init, ws, ws_offset_init, ws_dt,
258 kd * KH * KW + kh * KW + kw);
259 }
260 }
261 } else {
262 // pooling_avg
263 const auto d = dst + dst_offset_init;
264
265 utils::array_set(d, 0, OC);
266
267 const auto id_start = apply_offset(od * SD, padF);
268 const auto ih_start = apply_offset(oh * SH, padT);
269 const auto iw_start = apply_offset(ow * SW, padL);
270 const auto id_end = min(od * SD - padF + KD, ID);
271 const auto ih_end = min(oh * SH - padT + KH, IH);
272 const auto iw_end = min(ow * SW - padL + KW, IW);
273
274 // it is cheaper to actually count this in a loop
275 // as the typical kernel is small
276 size_t num_summands = 0;
277
278 for_(int id = id_start; id < id_end; ++id)
279 for_(int ih = ih_start; ih < ih_end; ++ih)
280 for (int iw = iw_start; iw < iw_end; ++iw) {
281 const size_t src_offset_init = strided_offset(mb, src_n_stride,
282 id, src_d_stride, ih, src_h_stride, iw, src_w_stride);
283 const auto s = src + src_offset_init;
284
285 // need to move the loop to separate function
286 // for GCC 4.8.5 to vectorize
287 array_add(OC, s, d);
288
289 num_summands++;
290 }
291
292 num_summands = (alg == alg_kind::pooling_avg_include_padding)
293 ? KW * KH * KD
294 : num_summands;
295
296 // need to move the loop to separate function
297 // for GCC 4.8.5 to vectorize
298 array_div_by_const(OC, d, num_summands, d);
299 }
300
301 if (are_postops_set) {
302 auto *const d = dst + dst_offset_init;
303 ref_post_ops_t::args_t args;
304 args.ctx = &ctx;
305 args.l_offset = get_logical_offset(mb, 0, od, oh, ow);
306 args.dst_md = pd()->dst_md();
307
308 for (int oc = 0; oc < OC; ++oc) {
309 ref_post_ops_.execute(d[oc], args);
310 args.l_offset += OSP;
311 }
312 }
313 });
314 return status::success;
315 }
316
317 template <>
execute_forward(const exec_ctx_t & ctx) const318 status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward(
319 const exec_ctx_t &ctx) const {
320
321 const auto alg = pd()->desc()->alg_kind;
322
323 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
324 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
325 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
326
327 auto scratchpad = ctx.get_scratchpad_grantor();
328 float *const bf16cvt_src_wsp = scratchpad.template get<float>(
329 memory_tracking::names::key_pool_src_bf16cvt);
330 float *const bf16cvt_dst_wsp = scratchpad.template get<float>(
331 memory_tracking::names::key_pool_dst_bf16cvt);
332
333 const memory_desc_wrapper MEM_D(src)(pd()->src_md());
334 const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
335 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
336
337 const int MB = pd()->MB();
338 const int OD = pd()->OD();
339 const int OC = pd()->OC();
340 const int OH = pd()->OH();
341 const int OW = pd()->OW();
342 const int ID = pd()->ID();
343 const int IH = pd()->IH();
344 const int IW = pd()->IW();
345 const int KD = pd()->KD();
346 const int KH = pd()->KH();
347 const int KW = pd()->KW();
348 const int SD = pd()->KSD();
349 const int SH = pd()->KSH();
350 const int SW = pd()->KSW();
351 const int padF = pd()->padFront();
352 const int padT = pd()->padT();
353 const int padL = pd()->padL();
354
355 const bool is_1d = pd()->desc()->src_desc.ndims == 3;
356 const bool is_3d = pd()->desc()->src_desc.ndims == 5;
357 const int ndims = pd()->ndims();
358 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
359
360 DECLARE_READ_STRIDES(src);
361 DECLARE_READ_STRIDES(dst);
362
363 const auto apply_offset = [&](int index, int offset) {
364 return (index > offset) ? index - offset : 0;
365 };
366
367 const dim_t SP = OW * OH;
368 const dim_t OSP = SP * OD;
369
370 const auto get_logical_offset
371 = [&](int mb, int oc, int od, int oh, int ow) -> dim_t {
372 return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
373 };
374 const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
375
376 parallel_nd_ext(0, MB, OD, OH, OW,
377 [&](int ithr, int, int mb, int od, int oh, int ow) {
378 const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
379 od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
380 float *const dst_f32 = &bf16cvt_dst_wsp[ithr * OC];
381 float *const src_f32 = &bf16cvt_src_wsp[ithr * OC];
382
383 if (alg == alg_kind::pooling_max) {
384 size_t ws_offset_init = 0;
385 if (ws) {
386 DECLARE_READ_STRIDES(ws);
387 ws_offset_init = strided_offset(mb, ws_n_stride, od,
388 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
389 };
390 // Note: GCC 4.8.5 won't vectorize below
391 // simple loops unless they are singled out
392 // into separate helper routines:
393 // array_nhwc_initialize, array_nhwc_max
394 if (!ws) {
395 PRAGMA_OMP_SIMD()
396 for (int oc = 0; oc < OC; ++oc) {
397 dst_f32[oc]
398 = nstl::numeric_limits<data_t>::lowest();
399 }
400 } else {
401 array_nhwc_initialize(
402 OC, dst_f32, ws, ws_offset_init, ws_dt);
403 }
404
405 for_(int kd = 0; kd < KD; ++kd)
406 for_(int kh = 0; kh < KH; ++kh)
407 for (int kw = 0; kw < KW; ++kw) {
408 const int id = od * SD - padF + kd;
409 const int ih = oh * SH - padT + kh;
410 const int iw = ow * SW - padL + kw;
411
412 if (id < 0 || id >= ID) continue;
413 if (ih < 0 || ih >= IH) continue;
414 if (iw < 0 || iw >= IW) continue;
415
416 const size_t src_offset_init = strided_offset(mb,
417 src_n_stride, id, src_d_stride, ih,
418 src_h_stride, iw, src_w_stride);
419
420 cvt_bfloat16_to_float(
421 src_f32, &src[src_offset_init], OC);
422
423 if (!ws) {
424 PRAGMA_OMP_SIMD()
425 for (int oc = 0; oc < OC; ++oc) {
426 dst_f32[oc]
427 = nstl::max(src_f32[oc], dst_f32[oc]);
428 }
429 } else {
430 array_nhwc_max(OC, dst_f32, src_f32, ws,
431 ws_offset_init, ws_dt,
432 kd * KH * KW + kh * KW + kw);
433 }
434 }
435 } else {
436 // pooling_avg
437 utils::array_set(dst_f32, 0, OC);
438
439 const auto id_start = apply_offset(od * SD, padF);
440 const auto ih_start = apply_offset(oh * SH, padT);
441 const auto iw_start = apply_offset(ow * SW, padL);
442 const auto id_end = min(od * SD - padF + KD, ID);
443 const auto ih_end = min(oh * SH - padT + KH, IH);
444 const auto iw_end = min(ow * SW - padL + KW, IW);
445
446 // it is cheaper to actually count this in a loop
447 // as the typical kernel is small
448 size_t num_summands = 0;
449
450 for_(int id = id_start; id < id_end; ++id)
451 for_(int ih = ih_start; ih < ih_end; ++ih)
452 for (int iw = iw_start; iw < iw_end; ++iw) {
453 size_t src_offset_init = strided_offset(mb,
454 src_n_stride, id, src_d_stride, ih,
455 src_h_stride, iw, src_w_stride);
456 cvt_bfloat16_to_float(
457 src_f32, &src[src_offset_init], OC);
458
459 // need to move the loop to separate function
460 // for GCC 4.8.5 to vectorize
461 array_add(OC, src_f32, dst_f32);
462 num_summands++;
463 }
464
465 num_summands
466 = (alg == alg_kind::pooling_avg_include_padding)
467 ? KW * KH * KD
468 : num_summands;
469
470 // need to move the loop to separate function
471 // for GCC 4.8.5 to vectorize
472 array_div_by_const(OC, dst_f32, num_summands, dst_f32);
473 }
474
475 if (are_postops_set) {
476 ref_post_ops_t::args_t args;
477 args.ctx = &ctx;
478 args.l_offset = get_logical_offset(mb, 0, od, oh, ow);
479 args.dst_md = pd()->dst_md();
480
481 for (int oc = 0; oc < OC; ++oc) {
482 ref_post_ops_.execute(dst_f32[oc], args);
483 args.l_offset += OSP;
484 }
485 }
486 cvt_float_to_bfloat16(dst + dst_offset_init, dst_f32, OC);
487 });
488 return status::success;
489 }
490
491 template <data_type_t d_type>
execute_backward(const exec_ctx_t & ctx) const492 status_t nhwc_pooling_bwd_t<d_type>::execute_backward(
493 const exec_ctx_t &ctx) const {
494 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
495 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
496 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
497
498 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
499 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
500 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
501
502 const int MB = pd()->MB();
503 const int ID = pd()->ID();
504 const int IH = pd()->IH();
505 const int IW = pd()->IW();
506 const int KD = pd()->KD();
507 const int KH = pd()->KH();
508 const int KW = pd()->KW();
509 const int SD = pd()->KSD();
510 const int SH = pd()->KSH();
511 const int SW = pd()->KSW();
512 const int OC = pd()->OC();
513 const int padF = pd()->padFront();
514 const int padT = pd()->padT();
515 const int padL = pd()->padL();
516 const int OD = pd()->OD();
517 const int OH = pd()->OH();
518 const int OW = pd()->OW();
519
520 const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3;
521 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
522 const int ndims = pd()->ndims();
523 auto alg = pd()->desc()->alg_kind;
524
525 DECLARE_READ_STRIDES(diff_src);
526 DECLARE_READ_STRIDES(diff_dst);
527
528 auto apply_offset = [=](int index, int offset) {
529 return (index > offset) ? index - offset : 0;
530 };
531
532 parallel_nd(MB, ID, IH, IW, [&](int mb, int id, int ih, int iw) {
533 size_t src_offset_init
534 = strided_offset(mb, diff_src_n_stride, id, diff_src_d_stride,
535 ih, diff_src_h_stride, iw, diff_src_w_stride);
536
537 for (int oc = 0; oc < OC; ++oc)
538 diff_src[src_offset_init + oc] = data_type_t(0);
539
540 // Find out which output cells may correspond to current
541 // input position. Current input postition divided by
542 // stride, with integer divide rounding down, is the
543 // right-most output.
544 // Left-most output may be computed if we decrement input
545 // by (kernel_size - 1) and then do the same division by
546 // stride.
547 int od_left = max((id + padF - KD + 1) / SD, 0);
548 int oh_left = max((ih + padT - KH + 1) / SH, 0);
549 int ow_left = max((iw + padL - KW + 1) / SW, 0);
550 // Notice +1 here to preserve the C loop "less than"
551 // condition for continuing the for loop.
552 int od_right = min((id + padF) / SD + 1, OD);
553 int oh_right = min((ih + padT) / SH + 1, OH);
554 int ow_right = min((iw + padL) / SW + 1, OW);
555
556 for_(int od = od_left; od < od_right; ++od)
557 for_(int oh = oh_left; oh < oh_right; ++oh)
558 for (int ow = ow_left; ow < ow_right; ++ow) {
559 const int kd = id - od * SD + padF;
560 const int kh = ih - oh * SH + padT;
561 const int kw = iw - ow * SW + padL;
562
563 if (kd < 0 || kd >= KD) continue;
564 if (kh < 0 || kh >= KH) continue;
565 if (kw < 0 || kw >= KW) continue;
566
567 size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, od,
568 diff_dst_d_stride, oh, diff_dst_h_stride, ow,
569 diff_dst_w_stride);
570
571 if (alg == alg_kind::pooling_max) {
572 DECLARE_READ_STRIDES(ws);
573 size_t ws_offset_init = strided_offset(mb, ws_n_stride, od,
574 ws_d_stride, oh, ws_h_stride, ow, ws_w_stride);
575 const int index = kd * KH * KW + kh * KW + kw;
576 const unsigned char *ws_ = ws + ws_offset_init;
577 const int *intws_ = (int *)ws + ws_offset_init;
578 const bool ws_is_u8 = MEM_D(ws).data_type() == data_type::u8;
579
580 #if SAFE_TO_USE_OMP_SIMD
581 PRAGMA_OMP_SIMD()
582 #endif
583 for (int oc = 0; oc < OC; ++oc) {
584 const int index_from_ws = ws_is_u8 ? ws_[oc] : intws_[oc];
585 const data_t d = diff_dst[dst_offset_init + oc];
586
587 // Check if kernel windows are disjoint, in this case
588 // there's no update needed and we just write there once
589 // otherwise we add value to the contents.
590 auto value = (index_from_ws == index) ? d : data_type_t(0);
591 if (!(KD == SD && KH == SH && KW == SW))
592 diff_src[src_offset_init + oc] += value;
593 else
594 diff_src[src_offset_init + oc] = value;
595 }
596 } else {
597 // pooling_avg
598 auto id_start = apply_offset(od * SD, padF);
599 auto ih_start = apply_offset(oh * SH, padT);
600 auto iw_start = apply_offset(ow * SW, padL);
601 auto id_end = min(od * SD - padF + KD, ID);
602 auto ih_end = min(oh * SH - padT + KH, IH);
603 auto iw_end = min(ow * SW - padL + KW, IW);
604
605 auto num_summands
606 = (alg == alg_kind::pooling_avg_include_padding)
607 ? KW * KH * KD
608 : (ih_end - ih_start) * (iw_end - iw_start)
609 * (id_end - id_start);
610
611 PRAGMA_OMP_SIMD()
612 for (int oc = 0; oc < OC; ++oc) {
613 const data_t d = diff_dst[dst_offset_init + oc];
614 // Check if kernel windows are disjoint, in this case
615 // there's no update needed and we just write there once
616 // otherwise we add value to the contents.
617 if (!(KD == SD && KH == SH && KW == SW))
618 diff_src[src_offset_init + oc] += d / num_summands;
619 else
620 diff_src[src_offset_init + oc] = d / num_summands;
621 }
622 }
623 }
624 });
625 return status::success;
626 }
627
628 template <>
execute_backward(const exec_ctx_t & ctx) const629 status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward(
630 const exec_ctx_t &ctx) const {
631
632 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
633 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
634 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
635
636 auto scratchpad = ctx.get_scratchpad_grantor();
637 float *bf16cvt_dsrc = scratchpad.template get<float>(
638 memory_tracking::names::key_pool_src_bf16cvt);
639 float *bf16cvt_ddst = scratchpad.template get<float>(
640 memory_tracking::names::key_pool_dst_bf16cvt);
641
642 const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
643 const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
644 const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
645
646 const int MB = pd()->MB();
647 const int ID = pd()->ID();
648 const int IH = pd()->IH();
649 const int IW = pd()->IW();
650 const int KD = pd()->KD();
651 const int KH = pd()->KH();
652 const int KW = pd()->KW();
653 const int SD = pd()->KSD();
654 const int SH = pd()->KSH();
655 const int SW = pd()->KSW();
656 const int OC = pd()->OC();
657 const int padF = pd()->padFront();
658 const int padT = pd()->padT();
659 const int padL = pd()->padL();
660 const int OD = pd()->OD();
661 const int OH = pd()->OH();
662 const int OW = pd()->OW();
663
664 const bool is_1d = pd()->desc()->diff_src_desc.ndims == 3;
665 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
666 const int ndims = pd()->ndims();
667 auto alg = pd()->desc()->alg_kind;
668
669 DECLARE_READ_STRIDES(diff_src);
670 DECLARE_READ_STRIDES(diff_dst);
671
672 auto apply_offset = [=](int index, int offset) {
673 return (index > offset) ? index - offset : 0;
674 };
675
676 parallel_nd_ext(0, MB, ID, IH, IW,
677 [&](int ithr, int, int mb, int id, int ih, int iw) {
678 size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
679 id, diff_src_d_stride, ih, diff_src_h_stride, iw,
680 diff_src_w_stride);
681
682 float *diff_dst_fp32 = &bf16cvt_ddst[ithr * OC];
683 float *diff_src_fp32 = &bf16cvt_dsrc[ithr * OC];
684
685 for (int oc = 0; oc < OC; ++oc) {
686 diff_src_fp32[oc] = 0.f;
687 diff_src[src_offset_init + oc] = (bfloat16_t)0.f;
688 }
689
690 // Find out which output cells may correspond to current
691 // input position. Current input postition divided by
692 // stride, with integer divide rounding down, is the
693 // right-most output.
694 // Left-most output may be computed if we decrement input
695 // by (kernel_size - 1) and then do the same division by
696 // stride.
697 int od_left = max((id + padF - KD + 1) / SD, 0);
698 int oh_left = max((ih + padT - KH + 1) / SH, 0);
699 int ow_left = max((iw + padL - KW + 1) / SW, 0);
700 // Notice +1 here to preserve the C loop "less than"
701 // condition for continuing the for loop.
702 int od_right = min((id + padF) / SD + 1, OD);
703 int oh_right = min((ih + padT) / SH + 1, OH);
704 int ow_right = min((iw + padL) / SW + 1, OW);
705
706 for_(int od = od_left; od < od_right; ++od)
707 for_(int oh = oh_left; oh < oh_right; ++oh)
708 for (int ow = ow_left; ow < ow_right; ++ow) {
709 const int kd = id - od * SD + padF;
710 const int kh = ih - oh * SH + padT;
711 const int kw = iw - ow * SW + padL;
712
713 if (kd < 0 || kd >= KD) continue;
714 if (kh < 0 || kh >= KH) continue;
715 if (kw < 0 || kw >= KW) continue;
716
717 size_t dst_offset_init = strided_offset(mb,
718 diff_dst_n_stride, od, diff_dst_d_stride, oh,
719 diff_dst_h_stride, ow, diff_dst_w_stride);
720 cvt_bfloat16_to_float(
721 diff_dst_fp32, &diff_dst[dst_offset_init], OC);
722
723 if (alg == alg_kind::pooling_max) {
724 DECLARE_READ_STRIDES(ws);
725 size_t ws_offset_init = strided_offset(mb, ws_n_stride,
726 od, ws_d_stride, oh, ws_h_stride, ow,
727 ws_w_stride);
728 const int index = kd * KH * KW + kh * KW + kw;
729 const unsigned char *ws_ = ws + ws_offset_init;
730 const int *intws_ = (int *)ws + ws_offset_init;
731 const bool ws_is_u8
732 = MEM_D(ws).data_type() == data_type::u8;
733
734 #if SAFE_TO_USE_OMP_SIMD
735 PRAGMA_OMP_SIMD()
736 #endif
737 for (int oc = 0; oc < OC; ++oc) {
738 const int index_from_ws
739 = ws_is_u8 ? ws_[oc] : intws_[oc];
740
741 // Check if kernel windows are disjoint, in this case
742 // there's no update needed and we just write there once
743 // otherwise we add value to the contents.
744 float value = (index_from_ws == index)
745 ? diff_dst_fp32[oc]
746 : 0.0f;
747 if (!(KD == SD && KH == SH && KW == SW))
748 diff_src_fp32[oc] += value;
749 else
750 diff_src_fp32[oc] = value;
751 }
752 } else {
753 // pooling_avg
754 auto id_start = apply_offset(od * SD, padF);
755 auto ih_start = apply_offset(oh * SH, padT);
756 auto iw_start = apply_offset(ow * SW, padL);
757 auto id_end = min(od * SD - padF + KD, ID);
758 auto ih_end = min(oh * SH - padT + KH, IH);
759 auto iw_end = min(ow * SW - padL + KW, IW);
760
761 auto num_summands
762 = (alg == alg_kind::pooling_avg_include_padding)
763 ? KW * KH * KD
764 : (ih_end - ih_start) * (iw_end - iw_start)
765 * (id_end - id_start);
766
767 PRAGMA_OMP_SIMD()
768 for (int oc = 0; oc < OC; ++oc) {
769 // Check if kernel windows are disjoint, in this case
770 // there's no update needed and we just write there once
771 // otherwise we add value to the contents.
772 if (!(KD == SD && KH == SH && KW == SW))
773 diff_src_fp32[oc]
774 += diff_dst_fp32[oc] / num_summands;
775 else
776 diff_src_fp32[oc]
777 = diff_dst_fp32[oc] / num_summands;
778 }
779 }
780 cvt_float_to_bfloat16(
781 &diff_src[src_offset_init], diff_src_fp32, OC);
782 }
783 });
784 return status::success;
785 }
786
787 template struct nhwc_pooling_fwd_t<data_type::f32>;
788 template struct nhwc_pooling_bwd_t<data_type::f32>;
789 template struct nhwc_pooling_fwd_t<data_type::bf16>;
790 template struct nhwc_pooling_bwd_t<data_type::bf16>;
791
792 } // namespace cpu
793 } // namespace impl
794 } // namespace dnnl
795
796 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
797