1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <math.h>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24
25 #include "cpu/simple_q10n.hpp"
26
27 #include "cpu/nchw_pooling.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32
33 using namespace nstl;
34
35 template <data_type_t d_type>
nchw_pooling_fwd_t(const pd_t * apd)36 nchw_pooling_fwd_t<d_type>::nchw_pooling_fwd_t(const pd_t *apd)
37 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
38
39 template <data_type_t d_type>
execute_forward(const exec_ctx_t & ctx) const40 status_t nchw_pooling_fwd_t<d_type>::execute_forward(
41 const exec_ctx_t &ctx) const {
42 const auto alg = pd()->desc()->alg_kind;
43 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
45 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
46
47 const memory_desc_wrapper ws_d(pd()->workspace_md());
48 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
49
50 const dim_t MB = pd()->MB();
51 const dim_t C = pd()->OC();
52 const dim_t OD = pd()->OD();
53 const dim_t OH = pd()->OH();
54 const dim_t OW = pd()->OW();
55 const dim_t ID = pd()->ID();
56 const dim_t IH = pd()->IH();
57 const dim_t IW = pd()->IW();
58 const dim_t KD = pd()->KD();
59 const dim_t KH = pd()->KH();
60 const dim_t KW = pd()->KW();
61 const dim_t SD = pd()->KSD();
62 const dim_t SH = pd()->KSH();
63 const dim_t SW = pd()->KSW();
64 const dim_t padF = pd()->padFront();
65 const dim_t padT = pd()->padT();
66 const dim_t padL = pd()->padL();
67
68 const auto apply_offset = [](int index, int offset) {
69 return (index > offset) ? index - offset : 0;
70 };
71
72 const auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow,
73 dim_t value) {
74 if (ws) {
75 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
76 const size_t ws_offset = (size_t)OW * OH * OD * C * mb
77 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
78 + (size_t)OW * oh + (size_t)ow;
79 if (ws_dt == data_type::u8) {
80 assert(0 <= value
81 && value <= numeric_limits<typename prec_traits<
82 data_type::u8>::type>::max());
83 ws[ws_offset] = value;
84 } else
85 reinterpret_cast<int *>(ws)[ws_offset] = value;
86 }
87 };
88
89 const auto ker_max = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
90 dim_t ow) {
91 for_(dim_t kd = 0; kd < KD; ++kd)
92 for_(dim_t kh = 0; kh < KH; ++kh)
93 for (dim_t kw = 0; kw < KW; ++kw) {
94 const dim_t id = od * SD - padF + kd;
95 const dim_t ih = oh * SH - padT + kh;
96 const dim_t iw = ow * SW - padL + kw;
97
98 if (id < 0 || id >= ID) continue;
99 if (ih < 0 || ih >= IH) continue;
100 if (iw < 0 || iw >= IW) continue;
101
102 const auto src_offset = (size_t)IW * IH * ID * C * mb
103 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
104 + (size_t)IW * ih + (size_t)iw;
105 const auto &s = src[src_offset];
106 if (s > d[0]) {
107 d[0] = s;
108 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
109 }
110 }
111 };
112
113 const auto ker_avg = [=](data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
114 dim_t ow) {
115 const auto id_start = apply_offset(od * SD, padF);
116 const auto ih_start = apply_offset(oh * SH, padT);
117 const auto iw_start = apply_offset(ow * SW, padL);
118 const auto id_end = min(od * SD - padF + KD, ID);
119 const auto ih_end = min(oh * SH - padT + KH, IH);
120 const auto iw_end = min(ow * SW - padL + KW, IW);
121
122 const auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
123 ? KD * KW * KH
124 : (id_end - id_start) * (ih_end - ih_start)
125 * (iw_end - iw_start);
126
127 float d_val = 0;
128 for_(dim_t id = id_start; id < id_end; ++id)
129 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
130 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
131 const auto src_offset = (size_t)IW * IH * ID * C * mb
132 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
133 + (size_t)IW * ih + (size_t)iw;
134 d_val += src[src_offset];
135 }
136
137 return d_val / num_summands;
138 };
139
140 if (alg == alg_kind::pooling_max) {
141 parallel_nd(MB, C, OD, OH, OW,
142 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
143 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
144 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
145 + (size_t)OW * oh + (size_t)ow;
146 data_t *d = &dst[dst_offset];
147 d[0] = numeric_limits<data_t>::lowest();
148 set_ws(mb, c, od, oh, ow, 0);
149 ker_max(d, mb, c, od, oh, ow);
150
151 ref_post_ops_t::args_t args;
152 args.ctx = &ctx;
153 args.l_offset = dst_offset;
154 args.dst_md = pd()->dst_md();
155 ref_post_ops_.execute(dst[dst_offset], args);
156 dst[dst_offset]
157 = saturate_and_round<data_t>(dst[dst_offset]);
158 });
159 } else {
160 parallel_nd(MB, C, OD, OH, OW,
161 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
162 const size_t dst_offset = (size_t)OW * OH * OD * C * mb
163 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
164 + (size_t)OW * oh + (size_t)ow;
165 data_t *d = &dst[dst_offset];
166 d[0] = 0;
167 auto res = ker_avg(d, mb, c, od, oh, ow);
168
169 ref_post_ops_t::args_t args;
170 args.ctx = &ctx;
171 args.l_offset = dst_offset;
172 args.dst_md = pd()->dst_md();
173 ref_post_ops_.execute(res, args);
174 d[0] = saturate_and_round<data_t>(res);
175 });
176 }
177
178 return status::success;
179 }
180
181 template <>
execute_forward(const exec_ctx_t & ctx) const182 status_t nchw_pooling_fwd_t<data_type::bf16>::execute_forward(
183 const exec_ctx_t &ctx) const {
184
185 auto alg = pd()->desc()->alg_kind;
186
187 auto src = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_SRC);
188 auto dst = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DST);
189 auto ws = CTX_OUT_MEM(unsigned char *, DNNL_ARG_WORKSPACE);
190 memory_desc_wrapper dst_d(pd()->dst_md());
191
192 auto scratchpad = ctx.get_scratchpad_grantor();
193 float *bf16cvt_wsp = scratchpad.template get<float>(
194 memory_tracking::names::key_pool_src_bf16cvt);
195
196 const memory_desc_wrapper ws_d(pd()->workspace_md());
197 const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
198
199 const dim_t MB = pd()->MB();
200 const dim_t C = pd()->OC();
201 const dim_t OD = pd()->OD();
202 const dim_t OH = pd()->OH();
203 const dim_t OW = pd()->OW();
204 const dim_t ID = pd()->ID();
205 const dim_t IH = pd()->IH();
206 const dim_t IW = pd()->IW();
207 const dim_t KD = pd()->KD();
208 const dim_t KH = pd()->KH();
209 const dim_t KW = pd()->KW();
210 const dim_t SD = pd()->KSD();
211 const dim_t SH = pd()->KSH();
212 const dim_t SW = pd()->KSW();
213 const dim_t padF = pd()->padFront();
214 const dim_t padT = pd()->padT();
215 const dim_t padL = pd()->padL();
216
217 const size_t simd_w = 16;
218 const size_t src_size = MB * C * ID * IH * IW;
219 const size_t blocked_size = src_size / simd_w;
220 const size_t tail_size = src_size % simd_w;
221
222 auto apply_offset = [=](int index, int offset) {
223 return (index > offset) ? index - offset : 0;
224 };
225
226 auto set_ws = [=](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow,
227 dim_t value) {
228 if (ws) {
229 assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
230 size_t ws_offset = (size_t)OW * OH * OD * C * mb
231 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
232 + (size_t)OW * oh + (size_t)ow;
233 if (ws_dt == data_type::u8) {
234 assert(0 <= value
235 && value <= numeric_limits<typename prec_traits<
236 data_type::u8>::type>::max());
237 ws[ws_offset] = value;
238 } else
239 reinterpret_cast<int *>(ws)[ws_offset] = value;
240 }
241 };
242
243 auto ker_max = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
244 dim_t ow) {
245 for_(dim_t kd = 0; kd < KD; ++kd)
246 for_(dim_t kh = 0; kh < KH; ++kh)
247 for (dim_t kw = 0; kw < KW; ++kw) {
248 const dim_t id = od * SD - padF + kd;
249 const dim_t ih = oh * SH - padT + kh;
250 const dim_t iw = ow * SW - padL + kw;
251
252 if (id < 0 || id >= ID) continue;
253 if (ih < 0 || ih >= IH) continue;
254 if (iw < 0 || iw >= IW) continue;
255
256 auto src_offset = (size_t)IW * IH * ID * C * mb
257 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
258 + (size_t)IW * ih + (size_t)iw;
259 auto &s = bf16cvt_wsp[src_offset];
260
261 if (s > d[0]) {
262 d[0] = s;
263 set_ws(mb, c, od, oh, ow, kd * KH * KW + kh * KW + kw);
264 }
265 }
266 };
267
268 auto ker_avg = [=](float *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
269 dim_t ow) {
270 auto id_start = apply_offset(od * SD, padF);
271 auto ih_start = apply_offset(oh * SH, padT);
272 auto iw_start = apply_offset(ow * SW, padL);
273 auto id_end = min(od * SD - padF + KD, ID);
274 auto ih_end = min(oh * SH - padT + KH, IH);
275 auto iw_end = min(ow * SW - padL + KW, IW);
276
277 auto num_summands = (alg == alg_kind::pooling_avg_include_padding)
278 ? KD * KW * KH
279 : (id_end - id_start) * (ih_end - ih_start)
280 * (iw_end - iw_start);
281
282 for_(dim_t id = id_start; id < id_end; ++id)
283 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
284 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
285 auto src_offset = (size_t)IW * IH * ID * C * mb
286 + (size_t)IW * IH * ID * c + (size_t)IW * IH * id
287 + (size_t)IW * ih + (size_t)iw;
288 d[0] += bf16cvt_wsp[src_offset];
289 }
290
291 d[0] = out_round<float>((float)d[0] / num_summands);
292 };
293 parallel_nd(blocked_size, [&](size_t i) {
294 cvt_bfloat16_to_float(
295 &bf16cvt_wsp[i * simd_w], &src[i * simd_w], simd_w);
296 });
297 if (tail_size)
298 cvt_bfloat16_to_float(&bf16cvt_wsp[blocked_size * simd_w],
299 &src[blocked_size * simd_w], tail_size);
300 if (alg == alg_kind::pooling_max) {
301 parallel_nd(MB, C, OD, OH, OW,
302 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
303 size_t dst_offset = (size_t)OW * OH * OD * C * mb
304 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
305 + (size_t)OW * oh + (size_t)ow;
306 float d_fp32 = numeric_limits<bfloat16_t>::lowest();
307
308 set_ws(mb, c, od, oh, ow, 0);
309
310 ker_max(&d_fp32, mb, c, od, oh, ow);
311
312 ref_post_ops_t::args_t args;
313 args.ctx = &ctx;
314 args.l_offset = dst_offset;
315 args.dst_md = pd()->dst_md();
316 ref_post_ops_.execute(d_fp32, args);
317
318 dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
319 });
320 } else {
321 parallel_nd(MB, C, OD, OH, OW,
322 [&](dim_t mb, dim_t c, dim_t od, dim_t oh, dim_t ow) {
323 size_t dst_offset = (size_t)OW * OH * OD * C * mb
324 + (size_t)OW * OH * OD * c + (size_t)OW * OH * od
325 + (size_t)OW * oh + (size_t)ow;
326 float d_fp32 = 0.0f;
327 ker_avg(&d_fp32, mb, c, od, oh, ow);
328 ref_post_ops_t::args_t args;
329 args.ctx = &ctx;
330 args.l_offset = dst_offset;
331 args.dst_md = pd()->dst_md();
332 ref_post_ops_.execute(d_fp32, args);
333 dst[dst_offset] = static_cast<bfloat16_t>(d_fp32);
334 });
335 }
336
337 return status::success;
338 }
339
340 template <data_type_t d_type>
execute_backward(const exec_ctx_t & ctx) const341 status_t nchw_pooling_bwd_t<d_type>::execute_backward(
342 const exec_ctx_t &ctx) const {
343 auto alg = pd()->desc()->alg_kind;
344 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
345 const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
346
347 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
348 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
349 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
350
351 const memory_desc_wrapper ws_d(pd()->workspace_md());
352
353 const dim_t MB = pd()->MB();
354 const dim_t C = pd()->OC();
355 const dim_t OD = pd()->OD();
356 const dim_t OH = pd()->OH();
357 const dim_t OW = pd()->OW();
358 const dim_t ID = pd()->ID();
359 const dim_t IH = pd()->IH();
360 const dim_t IW = pd()->IW();
361 const dim_t KD = pd()->KD();
362 const dim_t KH = pd()->KH();
363 const dim_t KW = pd()->KW();
364 const dim_t SD = pd()->KSD();
365 const dim_t SH = pd()->KSH();
366 const dim_t SW = pd()->KSW();
367 const dim_t padF = pd()->padFront();
368 const dim_t padT = pd()->padT();
369 const dim_t padL = pd()->padL();
370
371 auto apply_offset = [=](int index, int offset) {
372 return (index > offset) ? index - offset : 0;
373 };
374
375 auto ker_zero = [=](dim_t mb, dim_t c) {
376 size_t diff_src_offset
377 = (size_t)mb * C * ID * IH * IW + (size_t)c * ID * IH * IW;
378 for_(dim_t id = 0; id < ID; ++id)
379 for_(dim_t ih = 0; ih < IH; ++ih)
380 for (dim_t iw = 0; iw < IW; ++iw) {
381 diff_src[diff_src_offset++] = 0;
382 }
383 };
384
385 auto ker_max = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
386 dim_t ow) {
387 auto b_c = ws_d.blocking_desc().inner_nblks == 0
388 ? 1
389 : ws_d.blocking_desc().inner_blks[0];
390 auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
391 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
392 : ws_d.blk_off(mb, c / b_c, ow))
393 + c % b_c;
394
395 const int index = ws_d.data_type() == data_type::u8
396 ? (int)ws[ws_offset]
397 : ((const int *)ws)[ws_offset];
398 const dim_t kw = index % KW;
399 const dim_t kh = (index / KW) % KH;
400 const dim_t kd = (index / KW) / KH;
401
402 const dim_t id = od * SD - padF + kd;
403 const dim_t ih = oh * SH - padT + kh;
404 const dim_t iw = ow * SW - padL + kw;
405
406 // If padding area could fit the kernel,
407 // then input displacement would be out of bounds.
408 // No need to back propagate there as padding is
409 // virtual in pooling_max case.
410 if (id < 0 || id >= ID) return;
411 if (ih < 0 || ih >= IH) return;
412 if (iw < 0 || iw >= IW) return;
413
414 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
415 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
416 + (size_t)ih * IW + (size_t)iw;
417 diff_src[diff_src_offset] += d[0];
418 };
419
420 auto ker_avg = [=](const data_t *d, dim_t mb, dim_t c, dim_t od, dim_t oh,
421 dim_t ow) {
422 dim_t id_start = apply_offset(od * SD, padF);
423 dim_t ih_start = apply_offset(oh * SH, padT);
424 dim_t iw_start = apply_offset(ow * SW, padL);
425 dim_t id_end = min(od * SD - padF + KD, ID);
426 dim_t ih_end = min(oh * SH - padT + KH, IH);
427 dim_t iw_end = min(ow * SW - padL + KW, IW);
428
429 size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
430 ? (size_t)KW * KH * KD
431 : (size_t)(id_end - id_start) * (ih_end - ih_start)
432 * (iw_end - iw_start);
433
434 for_(dim_t id = id_start; id < id_end; ++id)
435 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
436 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
437 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
438 + (size_t)c * ID * IH * IW + (size_t)id * IH * IW
439 + (size_t)ih * IW + (size_t)iw;
440 diff_src[diff_src_offset] += d[0] / num_summands;
441 }
442 };
443
444 dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW));
445 dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
446
447 dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH));
448 dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
449
450 dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD));
451 dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
452
453 if (alg == alg_kind::pooling_max) {
454 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
455 size_t diff_dst_offset_b
456 = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
457 ker_zero(mb, c);
458 for_(dim_t od = od_start; od < od_end; ++od)
459 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
460 size_t diff_dst_offset = diff_dst_offset_b
461 + (size_t)od * OH * OW + (size_t)oh * OW;
462 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
463 const data_t *d = &diff_dst[diff_dst_offset + ow];
464 ker_max(d, mb, c, od, oh, ow);
465 }
466 }
467 });
468 } else {
469 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
470 size_t diff_dst_offset_b
471 = (size_t)mb * C * OD * OH * OW + (size_t)c * OD * OH * OW;
472 ker_zero(mb, c);
473 for_(dim_t od = od_start; od < od_end; ++od)
474 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
475 size_t diff_dst_offset = diff_dst_offset_b
476 + (size_t)od * OH * OW + (size_t)oh * OW;
477 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
478 const data_t *d = &diff_dst[diff_dst_offset + ow];
479 ker_avg(d, mb, c, od, oh, ow);
480 }
481 }
482 });
483 }
484
485 return status::success;
486 }
487
488 template <>
execute_backward(const exec_ctx_t & ctx) const489 status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
490 const exec_ctx_t &ctx) const {
491
492 auto alg = pd()->desc()->alg_kind;
493 const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
494 const bool is_2d = pd()->desc()->diff_src_desc.ndims == 4;
495
496 auto diff_src = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_SRC);
497 auto diff_dst = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_DIFF_DST);
498 auto ws = CTX_IN_MEM(const unsigned char *, DNNL_ARG_WORKSPACE);
499
500 auto scratchpad = ctx.get_scratchpad_grantor();
501 float *bf16cvt_src = scratchpad.template get<float>(
502 memory_tracking::names::key_pool_src_bf16cvt);
503 float *bf16cvt_dst = scratchpad.template get<float>(
504 memory_tracking::names::key_pool_dst_bf16cvt);
505
506 const memory_desc_wrapper ws_d(pd()->workspace_md());
507
508 const dim_t MB = pd()->MB();
509 const dim_t C = pd()->OC();
510 const dim_t OD = pd()->OD();
511 const dim_t OH = pd()->OH();
512 const dim_t OW = pd()->OW();
513 const dim_t ID = pd()->ID();
514 const dim_t IH = pd()->IH();
515 const dim_t IW = pd()->IW();
516 const dim_t KD = pd()->KD();
517 const dim_t KH = pd()->KH();
518 const dim_t KW = pd()->KW();
519 const dim_t SD = pd()->KSD();
520 const dim_t SH = pd()->KSH();
521 const dim_t SW = pd()->KSW();
522 const dim_t padF = pd()->padFront();
523 const dim_t padT = pd()->padT();
524 const dim_t padL = pd()->padL();
525
526 const size_t dst_sp_size = pd()->OD() * pd()->OH() * pd()->OW();
527 const size_t src_sp_size = pd()->ID() * pd()->IH() * pd()->IW();
528
529 auto apply_offset = [=](int index, int offset) {
530 return (index > offset) ? index - offset : 0;
531 };
532
533 auto ker_zero = [=](float *diff_src, dim_t c_block_size) {
534 size_t diff_src_offset = 0;
535 for_(dim_t c = 0; c < c_block_size; ++c)
536 for_(dim_t id = 0; id < ID; ++id)
537 for_(dim_t ih = 0; ih < IH; ++ih)
538 for (dim_t iw = 0; iw < IW; ++iw) {
539 diff_src[diff_src_offset++] = 0.0f;
540 }
541 };
542
543 auto ker_max = [=](const float *d, float *diff_src, dim_t mb, dim_t c,
544 dim_t od, dim_t oh, dim_t ow) {
545 auto b_c = ws_d.blocking_desc().inner_nblks == 0
546 ? 1
547 : ws_d.blocking_desc().inner_blks[0];
548 auto ws_offset = (is_3d ? ws_d.blk_off(mb, c / b_c, od, oh, ow)
549 : is_2d ? ws_d.blk_off(mb, c / b_c, oh, ow)
550 : ws_d.blk_off(mb, c / b_c, ow))
551 + c % b_c;
552
553 const int index = ws_d.data_type() == data_type::u8
554 ? (int)ws[ws_offset]
555 : ((const int *)ws)[ws_offset];
556 const dim_t kw = index % KW;
557 const dim_t kh = (index / KW) % KH;
558 const dim_t kd = (index / KW) / KH;
559
560 const dim_t id = od * SD - padF + kd;
561 const dim_t ih = oh * SH - padT + kh;
562 const dim_t iw = ow * SW - padL + kw;
563
564 // If padding area could fit the kernel,
565 // then input displacement would be out of bounds.
566 // No need to back propagate there as padding is
567 // virtual in pooling_max case.
568 if (id < 0 || id >= ID) return;
569 if (ih < 0 || ih >= IH) return;
570 if (iw < 0 || iw >= IW) return;
571
572 size_t diff_src_offset
573 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
574 diff_src[diff_src_offset] += d[0];
575 };
576
577 auto ker_avg = [=](const float *d, float *diff_src, dim_t mb, dim_t c,
578 dim_t od, dim_t oh, dim_t ow) {
579 auto id_start = apply_offset(od * SD, padF);
580 auto ih_start = apply_offset(oh * SH, padT);
581 auto iw_start = apply_offset(ow * SW, padL);
582 auto id_end = min(od * SD - padF + KD, ID);
583 auto ih_end = min(oh * SH - padT + KH, IH);
584 auto iw_end = min(ow * SW - padL + KW, IW);
585
586 size_t num_summands = (alg == alg_kind::pooling_avg_include_padding)
587 ? (size_t)KW * KH * KD
588 : (size_t)(id_end - id_start) * (ih_end - ih_start)
589 * (iw_end - iw_start);
590
591 for_(dim_t id = id_start; id < id_end; ++id)
592 for_(dim_t ih = ih_start; ih < ih_end; ++ih)
593 for (dim_t iw = iw_start; iw < iw_end; ++iw) {
594 size_t diff_src_offset
595 = (size_t)id * IH * IW + (size_t)ih * IW + (size_t)iw;
596 diff_src[diff_src_offset] += d[0] / num_summands;
597 }
598 };
599
600 dim_t ow_start = max(dim_t(0), utils::div_up(padL - KW + 1, SW));
601 dim_t ow_end = min(OW, 1 + (padL + IW - 1) / SW);
602
603 dim_t oh_start = max(dim_t(0), utils::div_up(padT - KH + 1, SH));
604 dim_t oh_end = min(OH, 1 + (padT + IH - 1) / SH);
605
606 dim_t od_start = max(dim_t(0), utils::div_up(padF - KD + 1, SD));
607 dim_t od_end = min(OD, 1 + (padF + ID - 1) / SD);
608
609 dim_t c_blk = pd()->channel_block_size_;
610 dim_t c_blk_tail = C % c_blk;
611 const int nthr = pd()->nthr_;
612
613 if (alg == alg_kind::pooling_max) {
614 parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
615 [&](int ithr, int, dim_t mb, dim_t cb) {
616 bool is_last_c_block
617 = c_blk_tail > 0 && (cb + 1) * c_blk > C;
618 dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
619 size_t diff_dst_offset_b
620 = ((size_t)mb * C + (size_t)cb * c_blk) * OD * OH
621 * OW;
622 size_t diff_src_offset
623 = ((size_t)mb * C + (size_t)cb * c_blk) * ID * IH
624 * IW;
625 float *diff_dst_fp32
626 = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
627 float *diff_src_fp32
628 = &bf16cvt_src[ithr * src_sp_size * c_blk];
629
630 ker_zero(diff_src_fp32, curr_c_block);
631
632 cvt_bfloat16_to_float(diff_dst_fp32,
633 &diff_dst[diff_dst_offset_b],
634 dst_sp_size * curr_c_block);
635
636 for_(dim_t c = 0; c < curr_c_block; ++c)
637 for_(dim_t od = od_start; od < od_end; ++od)
638 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
639 size_t diff_dst_offset = (size_t)c * OD * OH * OW
640 + (size_t)od * OH * OW + (size_t)oh * OW;
641 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
642 const float *d
643 = &diff_dst_fp32[diff_dst_offset + ow];
644 ker_max(d, &diff_src_fp32[c * ID * IH * IW], mb,
645 cb * c_blk + c, od, oh, ow);
646 }
647 }
648 cvt_float_to_bfloat16(&diff_src[diff_src_offset],
649 diff_src_fp32, src_sp_size * curr_c_block);
650 });
651 } else {
652 parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
653 [&](int ithr, int, dim_t mb, dim_t cb) {
654 bool is_last_c_block
655 = c_blk_tail > 0 && (cb + 1) * c_blk > C;
656 dim_t curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
657 size_t diff_dst_offset_b = (size_t)mb * C * OD * OH * OW
658 + (size_t)cb * c_blk * OD * OH * OW;
659 float *diff_dst_fp32
660 = &bf16cvt_dst[ithr * dst_sp_size * c_blk];
661 size_t diff_src_offset = (size_t)mb * C * ID * IH * IW
662 + (size_t)cb * c_blk * ID * IH * IW;
663 float *diff_src_fp32
664 = &bf16cvt_src[ithr * src_sp_size * c_blk];
665
666 ker_zero(diff_src_fp32, curr_c_block);
667
668 cvt_bfloat16_to_float(diff_dst_fp32,
669 &diff_dst[diff_dst_offset_b],
670 dst_sp_size * curr_c_block);
671 for_(dim_t c = 0; c < curr_c_block; ++c)
672 for_(dim_t od = od_start; od < od_end; ++od)
673 for (dim_t oh = oh_start; oh < oh_end; ++oh) {
674 size_t diff_dst_offset = (size_t)c * OD * OH * OW
675 + (size_t)od * OH * OW + (size_t)oh * OW;
676 for (dim_t ow = ow_start; ow < ow_end; ++ow) {
677 const float *d
678 = &diff_dst_fp32[diff_dst_offset + ow];
679 ker_avg(d, &diff_src_fp32[c * ID * IH * IW], mb,
680 cb * c_blk + c, od, oh, ow);
681 }
682 }
683 cvt_float_to_bfloat16(&diff_src[diff_src_offset],
684 diff_src_fp32, src_sp_size * curr_c_block);
685 });
686 }
687
688 return status::success;
689 }
690 template struct nchw_pooling_fwd_t<data_type::f32>;
691 template struct nchw_pooling_bwd_t<data_type::f32>;
692 template struct nchw_pooling_fwd_t<data_type::bf16>;
693 template struct nchw_pooling_bwd_t<data_type::bf16>;
694 } // namespace cpu
695 } // namespace impl
696 } // namespace dnnl
697
698 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
699