1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24
25 #include "cpu/simple_q10n.hpp"
26
27 #include "cpu/nchw_pooling.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32
33 using namespace nstl;
34
35 template <data_type_t d_type>
nchw_pooling_fwd_t(const pd_t * apd)36 nchw_pooling_fwd_t<d_type>::nchw_pooling_fwd_t(const pd_t *apd)
37 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
38
39 template <data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const40 status_t nchw_pooling_fwd_t<d_type>::execute_forward(
41 const exec_ctx_t &ctx) const {
42 const auto alg = pd()->desc()->alg_kind;
43 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
45 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
46
47 const memory_desc_wrapper ws_d(pd()->workspace_md());
48 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
49
50 const int MB = pd()->MB();
51 const int C = pd()->IC();
52 const int OD = pd()->OD();
53 const int OH = pd()->OH();
54 const int OW = pd()->OW();
55 const int ID = pd()->ID();
56 const int IH = pd()->IH();
57 const int IW = pd()->IW();
58 const int KD = pd()->KD();
59 const int KH = pd()->KH();
60 const int KW = pd()->KW();
61 const int SD = pd()->KSD();
62 const int SH = pd()->KSH();
63 const int SW = pd()->KSW();
64 const int padF = pd()->padFront();
65 const int padT = pd()->padT();
66 const int padL = pd()->padL();
67
68 const auto apply_offset = [](int index, int offset) {
69 return (index > offset) ? index - offset : 0;
70 };
71
72 const auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
73 if (ws) {
74 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
75 const size_t ws_offset = (size_t)OW * OH * OD * C * mb
76 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
77 + (size_t)OW * oh + (size_t)ow;
78 if (ws_dt == data_type::u8) {
79 assert(0 <= value
80 && value <= numeric_limits<typename prec_traits<
81 data_type::u8>::type>::max());
82 ws[ws_offset] = value;
83 } else
84 reinterpret_cast<int *>(ws)[ws_offset] = value;
85 }
86 };
87
88 const auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
89 for_(int kd = 0; kd < KD; ++kd)
90 for_(int kh = 0; kh < KH; ++kh)
91 for (int kw = 0; kw < KW; ++kw) {
92 const int id = od * SD - padF + kd;
93 const int ih = oh * SH - padT + kh;
94 const int iw = ow * SW - padL + kw;
95
96 if (id < 0 || id >= ID) continue;
97 if (ih < 0 || ih >= IH) continue;
98 if (iw < 0 || iw >= IW) continue;
99
100 const auto src_offset = (size_t)IW * IH * ID * C * mb
101 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
102 + (size_t)IW * ih + (size_t)iw;
103 const auto &s = src[src_offset];
104 if (s > d[0]) {
105 d[0] = s;
106 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
107 }
108 }
109 };
110
111 const auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
112 const auto id_start = apply_offset(od * SD, padF);
113 const auto ih_start = apply_offset(oh * SH, padT);
114 const auto iw_start = apply_offset(ow * SW, padL);
115 const auto id_end = min(od * SD - padF + KD, ID);
116 const auto ih_end = min(oh * SH - padT + KH, IH);
117 const auto iw_end = min(ow * SW - padL + KW, IW);
118
119 const auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
120 ? KD * KW * KH
121 : (id_end - id_start) * (ih_end - ih_start)
122 * (iw_end - iw_start);
123
124 float d_val = 0;
125 for_(int id = id_start; id < id_end; ++id)
126 for_(int ih = ih_start; ih < ih_end; ++ih)
127 for (int iw = iw_start; iw < iw_end; ++iw) {
128 const auto src_offset = (size_t)IW * IH * ID * C * mb
129 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
130 + (size_t)IW * ih + (size_t)iw;
131 d_val += src[src_offset];
132 }
133
134 return d_val / num_summands;
135 };
136
137 if (alg == alg_kind::pooling_max) {
138 parallel_nd(
139 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
140 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
141 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
142 + (size_t)OW * oh + (size_t)ow;
143 data_t *d = &dst[dst_offset];
144 d[0] = numeric_limits<data_t>::lowest();
145 set_ws(mb, c, od, oh, ow, 0);
146 ker_max(d, mb, c, od, oh, ow);
147
148 ref_post_ops_t::args_t args;
149 args.ctx = &ctx;
150 args.l_offset = dst_offset;
151 args.dst_md = pd()->dst_md();
152 ref_post_ops_.execute(dst[dst_offset], args);
153 dst[dst_offset]
154 = saturate_and_round<data_t>(dst[dst_offset]);
155 });
156 } else {
157 parallel_nd(
158 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
159 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
160 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
161 + (size_t)OW * oh + (size_t)ow;
162 data_t *d = &dst[dst_offset];
163 d[0] = 0;
164 auto res = ker_avg(d, mb, c, od, oh, ow);
165
166 ref_post_ops_t::args_t args;
167 args.ctx = &ctx;
168 args.l_offset = dst_offset;
169 args.dst_md = pd()->dst_md();
170 ref_post_ops_.execute(res, args);
171 d[0] = saturate_and_round<data_t>(res);
172 });
173 }
174
175 return status::success;
176 }
177
178 template <>
execute_forward(const exec_ctx_t & ctx) const179 status_t nchw_pooling_fwd_t<data_type::bf16>::execute_forward(
180 const exec_ctx_t &ctx) const {
181
182 auto alg = pd()->desc()->alg_kind;
183
184 auto src = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_SRC);
185 auto dst = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DST);
186 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
187 memory_desc_wrapper dst_d(pd()->dst_md());
188
189 auto scratchpad = ctx.get_scratchpad_grantor();
190 float *bf16cvt_wsp = scratchpad.template get<float>(
191 memory_tracking::names::key_pool_src_bf16cvt);
192
193 const memory_desc_wrapper ws_d(pd()->workspace_md());
194 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
195
196 const int MB = pd()->MB();
197 const int C = pd()->IC();
198 const int OD = pd()->OD();
199 const int OH = pd()->OH();
200 const int OW = pd()->OW();
201 const int ID = pd()->ID();
202 const int IH = pd()->IH();
203 const int IW = pd()->IW();
204 const int KD = pd()->KD();
205 const int KH = pd()->KH();
206 const int KW = pd()->KW();
207 const int SD = pd()->KSD();
208 const int SH = pd()->KSH();
209 const int SW = pd()->KSW();
210 const int padF = pd()->padFront();
211 const int padT = pd()->padT();
212 const int padL = pd()->padL();
213
214 const size_t simd_w = 16;
215 const size_t src_size = MB * C * ID * IH * IW;
216 const size_t blocked_size = src_size / simd_w;
217 const size_t tail_size = src_size % simd_w;
218
219 auto apply_offset = [=](int index, int offset) {
220 return (index > offset) ? index - offset : 0;
221 };
222
223 auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
224 if (ws) {
225 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
226 size_t ws_offset = (size_t)OW * OH * OD * C * mb
227 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
228 + (size_t)OW * oh + (size_t)ow;
229 if (ws_dt == data_type::u8) {
230 assert(0 <= value
231 && value <= numeric_limits<typename prec_traits<
232 data_type::u8>::type>::max());
233 ws[ws_offset] = value;
234 } else
235 reinterpret_cast<int *>(ws)[ws_offset] = value;
236 }
237 };
238
239 auto ker_max = [=](float *d, int mb, int c, int od, int oh, int ow) {
240 for_(int kd = 0; kd < KD; ++kd)
241 for_(int kh = 0; kh < KH; ++kh)
242 for (int kw = 0; kw < KW; ++kw) {
243 const int id = od * SD - padF + kd;
244 const int ih = oh * SH - padT + kh;
245 const int iw = ow * SW - padL + kw;
246
247 if (id < 0 || id >= ID) continue;
248 if (ih < 0 || ih >= IH) continue;
249 if (iw < 0 || iw >= IW) continue;
250
251 auto src_offset = (size_t)IW * IH * ID * C * mb
252 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
253 + (size_t)IW * ih + (size_t)iw;
254 auto &s = bf16cvt_wsp[src_offset];
255
256 if (s > d[0]) {
257 d[0] = s;
258 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
259 }
260 }
261 };
262
263 auto ker_avg = [=](float *d, int mb, int c, int od, int oh, int ow) {
264 auto id_start = apply_offset(od * SD, padF);
265 auto ih_start = apply_offset(oh * SH, padT);
266 auto iw_start = apply_offset(ow * SW, padL);
267 auto id_end = min(od * SD - padF + KD, ID);
268 auto ih_end = min(oh * SH - padT + KH, IH);
269 auto iw_end = min(ow * SW - padL + KW, IW);
270
271 auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
272 ? KD * KW * KH
273 : (id_end - id_start) * (ih_end - ih_start)
274 * (iw_end - iw_start);
275
276 for_(int id = id_start; id < id_end; ++id)
277 for_(int ih = ih_start; ih < ih_end; ++ih)
278 for (int iw = iw_start; iw < iw_end; ++iw) {
279 auto src_offset = (size_t)IW * IH * ID * C * mb
280 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
281 + (size_t)IW * ih + (size_t)iw;
282 d[0] += bf16cvt_wsp[src_offset];
283 }
284
285 d[0] = out_round<float>((float)d[0] / num_summands);
286 };
287 parallel_nd(blocked_size, [&](size_t i) {
288 cvt_bfloat16_to_float(
289 &bf16cvt_wsp[i * simd_w], &src[i * simd_w], simd_w);
290 });
291 if (tail_size)
292 cvt_bfloat16_to_float(&bf16cvt_wsp[blocked_size * simd_w],
293 &src[blocked_size * simd_w], tail_size);
294 if (alg == alg_kind::pooling_max) {
295 parallel_nd(
296 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
297 size_t dst_offset = (size_t)OW * OH * OD * C * mb
298 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
299 + (size_t)OW * oh + (size_t)ow;
300 float d_fp32 = numeric_limits<bfloat16_t>::lowest();
301
302 set_ws(mb, c, od, oh, ow, 0);
303
304 ker_max(&d_fp32, mb, c, od, oh, ow);
305
306 ref_post_ops_t::args_t args;
307 args.ctx = &ctx;
308 args.l_offset = dst_offset;
309 args.dst_md = pd()->dst_md();
310 ref_post_ops_.execute(d_fp32, args);
311
312 dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
313 });
314 } else {
315 parallel_nd(
316 MB, C, OD, OH, OW, [&](int mb, int c, int od, int oh, int ow) {
317 size_t dst_offset = (size_t)OW * OH * OD * C * mb
318 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
319 + (size_t)OW * oh + (size_t)ow;
320 float d_fp32 = 0.0f;
321 ker_avg(&d_fp32, mb, c, od, oh, ow);
322 ref_post_ops_t::args_t args;
323 args.ctx = &ctx;
324 args.l_offset = dst_offset;
325 args.dst_md = pd()->dst_md();
326 ref_post_ops_.execute(d_fp32, args);
327 dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
328 });
329 }
330
331 return status::success;
332 }
333
334 template <data_type_t d_type>
execute_backward(const exec_ctx_t & ctx) const335 status_t nchw_pooling_bwd_t<d_type>::execute_backward(
336 const exec_ctx_t &ctx) const {
337 auto alg = pd()->desc()->alg_kind;
338 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
339 const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
340
341 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
342 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
343 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
344
345 const memory_desc_wrapper ws_d(pd()->workspace_md());
346
347 const int MB = pd()->MB();
348 const int C = pd()->IC();
349 const int OD = pd()->OD();
350 const int OH = pd()->OH();
351 const int OW = pd()->OW();
352 const int ID = pd()->ID();
353 const int IH = pd()->IH();
354 const int IW = pd()->IW();
355 const int KD = pd()->KD();
356 const int KH = pd()->KH();
357 const int KW = pd()->KW();
358 const int SD = pd()->KSD();
359 const int SH = pd()->KSH();
360 const int SW = pd()->KSW();
361 const int padF = pd()->padFront();
362 const int padT = pd()->padT();
363 const int padL = pd()->padL();
364
365 auto apply_offset = [=](int index, int offset) {
366 return (index > offset) ? index - offset : 0;
367 };
368
369 auto ker_zero = [=](int mb, int c) {
370 size_t diff_src_offset
371 = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW;
372 for_(int id = 0; id < ID; ++id)
373 for_(int ih = 0; ih < IH; ++ih)
374 for (int iw = 0; iw < IW; ++iw) {
375 diff_src[diff_src_offset++] = 0;
376 }
377 };
378
379 auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
380 auto b_c = ws_d.blocking_desc().inner_nblks == 0
381 ? 1
382 : ws_d.blocking_desc().inner_blks[0];
383 auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
384 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
385 : ws_d.blk_off(mb, c / b_c, ow))
386 + c % b_c;
387
388 const int index = ws_d.data_type() == data_type::u8
389 ? (int)ws[ws_offset]
390 : ((const int *)ws)[ws_offset];
391 const int kw = index % KW;
392 const int kh = (index / KW) % KH;
393 const int kd = (index / KW) / KH;
394
395 const int id = od * SD - padF + kd;
396 const int ih = oh * SH - padT + kh;
397 const int iw = ow * SW - padL + kw;
398
399 // If padding area could fit the kernel,
400 // then input displacement would be out of bounds.
401 // No need to back propagate there as padding is
402 // virtual in pooling_max case.
403 if (id < 0 || id >= ID) return;
404 if (ih < 0 || ih >= IH) return;
405 if (iw < 0 || iw >= IW) return;
406
407 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
408 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
409 + (size_t)ih * IW + (size_t)iw;
410 diff_src[diff_src_offset] += d[0];
411 };
412
413 auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
414 auto id_start = apply_offset(od * SD, padF);
415 auto ih_start = apply_offset(oh * SH, padT);
416 auto iw_start = apply_offset(ow * SW, padL);
417 auto id_end = min(od * SD - padF + KD, ID);
418 auto ih_end = min(oh * SH - padT + KH, IH);
419 auto iw_end = min(ow * SW - padL + KW, IW);
420
421 size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
422 ? (size_t)KW * KH * KD
423 : (size_t)(id_end - id_start) * (ih_end - ih_start)
424 * (iw_end - iw_start);
425
426 for_(int id = id_start; id < id_end; ++id)
427 for_(int ih = ih_start; ih < ih_end; ++ih)
428 for (int iw = iw_start; iw < iw_end; ++iw) {
429 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
430 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
431 + (size_t)ih * IW + (size_t)iw;
432 diff_src[diff_src_offset] += d[0] / num_summands;
433 }
434 };
435
436 int ow_start = max(0, utils::div_up(padL - KW + 1, SW));
437 int ow_end = min(OW, 1 + (padL + IW - 1) / SW);
438
439 int oh_start = max(0, utils::div_up(padT - KH + 1, SH));
440 int oh_end = min(OH, 1 + (padT + IH - 1) / SH);
441
442 int od_start = max(0, utils::div_up(padF - KD + 1, SD));
443 int od_end = min(OD, 1 + (padF + ID - 1) / SD);
444
445 if (alg == alg_kind::pooling_max) {
446 parallel_nd(MB, C, [&](int mb, int c) {
447 size_t diff_dst_offset_b
448 = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
449 ker_zero(mb, c);
450 for_(int od = od_start; od < od_end; ++od)
451 for (int oh = oh_start; oh < oh_end; ++oh) {
452 size_t diff_dst_offset = diff_dst_offset_b
453 + (size_t)od * OH * OW + (size_t)oh * OW;
454 for (int ow = ow_start; ow < ow_end; ++ow) {
455 const data_t *d = &diff_dst[diff_dst_offset + ow];
456 ker_max(d, mb, c, od, oh, ow);
457 }
458 }
459 });
460 } else {
461 parallel_nd(MB, C, [&](int mb, int c) {
462 size_t diff_dst_offset_b
463 = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
464 ker_zero(mb, c);
465 for_(int od = od_start; od < od_end; ++od)
466 for (int oh = oh_start; oh < oh_end; ++oh) {
467 size_t diff_dst_offset = diff_dst_offset_b
468 + (size_t)od * OH * OW + (size_t)oh * OW;
469 for (int ow = ow_start; ow < ow_end; ++ow) {
470 const data_t *d = &diff_dst[diff_dst_offset + ow];
471 ker_avg(d, mb, c, od, oh, ow);
472 }
473 }
474 });
475 }
476
477 return status::success;
478 }
479
480 template <>
execute_backward(const exec_ctx_t & ctx) const481 status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
482 const exec_ctx_t &ctx) const {
483
484 auto alg = pd()->desc()->alg_kind;
485 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
486 const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
487
488 auto diff_src = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_SRC);
489 auto diff_dst = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_DIFF_DST);
490 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
491
492 auto scratchpad = ctx.get_scratchpad_grantor();
493 float *bf16cvt_src = scratchpad.template get<float>(
494 memory_tracking::names::key_pool_src_bf16cvt);
495 float *bf16cvt_dst = scratchpad.template get<float>(
496 memory_tracking::names::key_pool_dst_bf16cvt);
497
498 const memory_desc_wrapper ws_d(pd()->workspace_md());
499
500 const int MB = pd()->MB();
501 const int C = pd()->IC();
502 const int OD = pd()->OD();
503 const int OH = pd()->OH();
504 const int OW = pd()->OW();
505 const int ID = pd()->ID();
506 const int IH = pd()->IH();
507 const int IW = pd()->IW();
508 const int KD = pd()->KD();
509 const int KH = pd()->KH();
510 const int KW = pd()->KW();
511 const int SD = pd()->KSD();
512 const int SH = pd()->KSH();
513 const int SW = pd()->KSW();
514 const int padF = pd()->padFront();
515 const int padT = pd()->padT();
516 const int padL = pd()->padL();
517
518 const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW();
519 const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW();
520
521 auto apply_offset = [=](int index, int offset) {
522 return (index > offset) ? index - offset : 0;
523 };
524
525 auto ker_zero = [=](float *diff_src, int c_block_size) {
526 size_t diff_src_offset = 0;
527 for_(int c = 0; c < c_block_size; ++c)
528 for_(int id = 0; id < ID; ++id)
529 for_(int ih = 0; ih < IH; ++ih)
530 for (int iw = 0; iw < IW; ++iw) {
531 diff_src[diff_src_offset++] = 0.0f;
532 }
533 };
534
535 auto ker_max = [=](const float *d, float *diff_src, int mb, int c, int od,
536 int oh, int ow) {
537 auto b_c = ws_d.blocking_desc().inner_nblks == 0
538 ? 1
539 : ws_d.blocking_desc().inner_blks[0];
540 auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
541 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
542 : ws_d.blk_off(mb, c / b_c, ow))
543 + c % b_c;
544
545 const int index = ws_d.data_type() == data_type::u8
546 ? (int)ws[ws_offset]
547 : ((const int *)ws)[ws_offset];
548 const int kw = index % KW;
549 const int kh = (index / KW) % KH;
550 const int kd = (index / KW) / KH;
551
552 const int id = od * SD - padF + kd;
553 const int ih = oh * SH - padT + kh;
554 const int iw = ow * SW - padL + kw;
555
556 // If padding area could fit the kernel,
557 // then input displacement would be out of bounds.
558 // No need to back propagate there as padding is
559 // virtual in pooling_max case.
560 if (id < 0 || id >= ID) return;
561 if (ih < 0 || ih >= IH) return;
562 if (iw < 0 || iw >= IW) return;
563
564 size_t diff_src_offset
565 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
566 diff_src[diff_src_offset] += d[0];
567 };
568
569 auto ker_avg = [=](const float *d, float *diff_src, int mb, int c, int od,
570 int oh, int ow) {
571 auto id_start = apply_offset(od * SD, padF);
572 auto ih_start = apply_offset(oh * SH, padT);
573 auto iw_start = apply_offset(ow * SW, padL);
574 auto id_end = min(od * SD - padF + KD, ID);
575 auto ih_end = min(oh * SH - padT + KH, IH);
576 auto iw_end = min(ow * SW - padL + KW, IW);
577
578 size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
579 ? (size_t)KW * KH * KD
580 : (size_t)(id_end - id_start) * (ih_end - ih_start)
581 * (iw_end - iw_start);
582
583 for_(int id = id_start; id < id_end; ++id)
584 for_(int ih = ih_start; ih < ih_end; ++ih)
585 for (int iw = iw_start; iw < iw_end; ++iw) {
586 size_t diff_src_offset
587 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
588 diff_src[diff_src_offset] += d[0] / num_summands;
589 }
590 };
591
592 int ow_start = max(0, utils::div_up(padL - KW + 1, SW));
593 int ow_end = min(OW, 1 + (padL + IW - 1) / SW);
594
595 int oh_start = max(0, utils::div_up(padT - KH + 1, SH));
596 int oh_end = min(OH, 1 + (padT + IH - 1) / SH);
597
598 int od_start = max(0, utils::div_up(padF - KD + 1, SD));
599 int od_end = min(OD, 1 + (padF + ID - 1) / SD);
600
601 dim_t c_blk = pd()->channel_block_size_;
602 int c_blk_tail = C % c_blk;
603 if (alg == alg_kind::pooling_max) {
604 parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
605 [&](int ithr, int, int mb, int cb) {
606 bool is_last_c_block
607 = c_blk_tail > 0 && (cb + 1) * c_blk > C;
608 int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
609 size_t diff_dst_offset_b
610 = ((size_t)mb * C + (size_t)cb * c_blk) * OD * OH
611 * OW;
612 size_t diff_src_offset
613 = ((size_t)mb * C + (size_t)cb * c_blk) * ID * IH
614 * IW;
615 float *diff_dst_fp32
616 = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
617 float *diff_src_fp32
618 = &bf16cvt_src[ithr * src_sp_size * c_blk];
619
620 ker_zero(diff_src_fp32, curr_c_block);
621
622 cvt_bfloat16_to_float(diff_dst_fp32,
623 &diff_dst[diff_dst_offset_b],
624 dst_sp_size * curr_c_block);
625
626 for_(int c = 0; c < curr_c_block; ++c)
627 for_(int od = od_start; od < od_end; ++od)
628 for (int oh = oh_start; oh < oh_end; ++oh) {
629 size_t diff_dst_offset = (size_t)c * OD * OH * OW
630 + (size_t)od * OH * OW + (size_t)oh * OW;
631 for (int ow = ow_start; ow < ow_end; ++ow) {
632 const float *d
633 = &diff_dst_fp32[diff_dst_offset + ow];
634 ker_max(d, &diff_src_fp32[c * ID * IH * IW], mb,
635 cb * c_blk + c, od, oh, ow);
636 }
637 }
638 cvt_float_to_bfloat16(&diff_src[diff_src_offset],
639 diff_src_fp32, src_sp_size * curr_c_block);
640 });
641 } else {
642 parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
643 [&](int ithr, int, int mb, int cb) {
644 bool is_last_c_block
645 = c_blk_tail > 0 && (cb + 1) * c_blk > C;
646 int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
647 size_t diff_dst_offset_b = (size_t)mb * C * OD * OH * OW
648 + (size_t)cb * c_blk * OD * OH * OW;
649 float *diff_dst_fp32
650 = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
651 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
652 + (size_t)cb * c_blk * ID * IH * IW;
653 float *diff_src_fp32
654 = &bf16cvt_src[ithr * src_sp_size * c_blk];
655
656 ker_zero(diff_src_fp32, curr_c_block);
657
658 cvt_bfloat16_to_float(diff_dst_fp32,
659 &diff_dst[diff_dst_offset_b],
660 dst_sp_size * curr_c_block);
661 for_(int c = 0; c < curr_c_block; ++c)
662 for_(int od = od_start; od < od_end; ++od)
663 for (int oh = oh_start; oh < oh_end; ++oh) {
664 size_t diff_dst_offset = (size_t)c * OD * OH * OW
665 + (size_t)od * OH * OW + (size_t)oh * OW;
666 for (int ow = ow_start; ow < ow_end; ++ow) {
667 const float *d
668 = &diff_dst_fp32[diff_dst_offset + ow];
669 ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb,
670 cb * c_blk + c, od, oh, ow);
671 }
672 }
673 cvt_float_to_bfloat16(&diff_src[diff_src_offset],
674 diff_src_fp32, src_sp_size * curr_c_block);
675 });
676 }
677
678 return status::success;
679 }
680 template struct nchw_pooling_fwd_t<data_type::f32>;
681 template struct nchw_pooling_bwd_t<data_type::f32>;
682 template struct nchw_pooling_fwd_t<data_type::bf16>;
683 template struct nchw_pooling_bwd_t<data_type::bf16>;
684 } // namespace cpu
685 } // namespace impl
686 } // namespace dnnl
687
688 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
689