1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <functional>
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_thread.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23
24 #include "cpu/cpu_primitive.hpp"
25 #include "cpu/simple_q10n.hpp"
26
27 #include "cpu/ref_deconvolution.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32
33 namespace {
get_data_off(const memory_desc_wrapper & mdw,int ndims,dim_t mb,dim_t c,dim_t id,dim_t ih,dim_t iw)34 dim_t get_data_off(const memory_desc_wrapper &mdw, int ndims, dim_t mb, dim_t c,
35 dim_t id, dim_t ih, dim_t iw) {
36 switch (ndims) {
37 case 5: return mdw.off(mb, c, id, ih, iw);
38 case 4: return mdw.off(mb, c, ih, iw);
39 case 3: return mdw.off(mb, c, iw);
40 default: assert(!"unsupported ndims"); return dim_t(0);
41 }
42 }
43
44 } // namespace
45
46 template <data_type_t dst_type>
compute_fwd_bias_common(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const47 void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx,
48 typename prec_traits<dst_type>::type *dst,
49 const float *conv_output) const {
50 using dst_data_t = typename prec_traits<dst_type>::type;
51 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
52 const memory_desc_wrapper dst_d(pd()->dst_md());
53 const memory_desc_wrapper bias_d(pd()->weights_md(1));
54
55 const auto G = pd()->G();
56 const auto MB = pd()->MB();
57 const auto OH = pd()->OH();
58 const auto OW = pd()->OW();
59 const auto OD = pd()->OD();
60 const auto OC = pd()->OC() / G;
61 const auto ndims = pd()->desc()->src_desc.ndims;
62
63 parallel_nd(MB, G, OC, OD, OH, OW,
64 [&](dim_t mb, dim_t g, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
65 const dim_t c = g * OC + oc;
66 const dim_t off = get_data_off(dst_d, ndims, mb, c, od, oh, ow);
67 float b = types::get_float_value(bias_d.data_type(), bias, c);
68 float d = conv_output[off];
69 dst[off] = cpu::saturate_and_round<dst_data_t>(d + b);
70 });
71 }
72
73 template <data_type_t dst_type>
compute_fwd_bias_ncdhw(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const74 void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx,
75 typename prec_traits<dst_type>::type *dst,
76 const float *conv_output) const {
77 using dst_data_t = typename prec_traits<dst_type>::type;
78 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
79 const memory_desc_wrapper dst_d(pd()->dst_md());
80 const memory_desc_wrapper bias_d(pd()->weights_md(1));
81
82 const auto MB = pd()->MB();
83 const auto OC = pd()->OC();
84 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
85
86 parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
87 const dim_t off = (mb * OC + oc) * SP;
88 float b = types::get_float_value(bias_d.data_type(), bias, oc);
89 PRAGMA_OMP_SIMD()
90 for (dim_t sp = 0; sp < SP; ++sp) {
91 float d = conv_output[off + sp];
92 dst[off + sp] = cpu::saturate_and_round<dst_data_t>(d + b);
93 }
94 });
95 }
96
97 template <data_type_t dst_type>
compute_fwd_bias_ndhwc(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const98 void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx,
99 typename prec_traits<dst_type>::type *dst,
100 const float *conv_output) const {
101 using dst_data_t = typename prec_traits<dst_type>::type;
102 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
103 const memory_desc_wrapper dst_d(pd()->dst_md());
104 const memory_desc_wrapper bias_d(pd()->weights_md(1));
105
106 const auto MB = pd()->MB();
107 const auto OC = pd()->OC();
108 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
109
110 parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
111 const dim_t off = (mb * SP + sp) * OC;
112 PRAGMA_OMP_SIMD()
113 for (dim_t oc = 0; oc < OC; ++oc) {
114 float b = types::get_float_value(bias_d.data_type(), bias, oc);
115 float d = conv_output[off + oc];
116 dst[off + oc] = cpu::saturate_and_round<dst_data_t>(d + b);
117 }
118 });
119 }
120
121 template <data_type_t dst_type, dim_t blk_size>
compute_fwd_bias_nCdhwXc(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const122 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx,
123 typename prec_traits<dst_type>::type *dst,
124 const float *conv_output) const {
125 using dst_data_t = typename prec_traits<dst_type>::type;
126 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
127 const memory_desc_wrapper dst_d(pd()->dst_md());
128 const memory_desc_wrapper bias_d(pd()->weights_md(1));
129
130 const auto MB = pd()->MB();
131 const auto OC = pd()->OC();
132 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
133 const auto stride_mb = dst_d.blocking_desc().strides[0];
134
135 parallel_nd(MB, utils::div_up(OC, blk_size), SP,
136 [&](dim_t mb, dim_t oc_blk, dim_t sp) {
137 const dim_t oc = oc_blk * blk_size;
138 const dim_t off = mb * stride_mb + oc * SP + sp * blk_size;
139 const dim_t blk = nstl::min(blk_size, OC - oc);
140
141 PRAGMA_OMP_SIMD()
142 for (dim_t i = 0; i < blk_size; ++i) {
143 float b = i < blk ? types::get_float_value(
144 bias_d.data_type(), bias, oc + i)
145 : 0;
146 float d = conv_output[off + i];
147 dst[off + i] = cpu::saturate_and_round<dst_data_t>(d + b);
148 }
149 });
150 }
151
152 template <data_type_t dst_type>
compute_fwd_bias(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const153 void ref_deconvolution_fwd_t::compute_fwd_bias(const exec_ctx_t &ctx,
154 typename prec_traits<dst_type>::type *dst,
155 const float *conv_output) const {
156 using namespace format_tag;
157 switch (pd()->dst_tag_) {
158 case ncdhw:
159 case nchw:
160 case ncw:
161 compute_fwd_bias_ncdhw<dst_type>(ctx, dst, conv_output);
162 break;
163 case ndhwc:
164 case nhwc:
165 case nwc:
166 compute_fwd_bias_ndhwc<dst_type>(ctx, dst, conv_output);
167 break;
168 case nCdhw8c:
169 case nChw8c:
170 case nCw8c:
171 assert(dst_type != data_type::bf16);
172 compute_fwd_bias_nCdhwXc<dst_type, 8>(ctx, dst, conv_output);
173 break;
174 case nCdhw16c:
175 case nChw16c:
176 case nCw16c:
177 compute_fwd_bias_nCdhwXc<dst_type, 16>(ctx, dst, conv_output);
178 break;
179 default:
180 compute_fwd_bias_common<dst_type>(ctx, dst, conv_output);
181 break;
182 }
183 }
184
185 template <data_type_t dst_type>
compute_ref_attrs(const exec_ctx_t & ctx,const float * conv_output,typename prec_traits<dst_type>::type * original_dst) const186 status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx,
187 const float *conv_output,
188 typename prec_traits<dst_type>::type *original_dst) const {
189 using dst_data_t = typename prec_traits<dst_type>::type;
190 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
191 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
192 const bool is_dst_zp_common
193 = pd()->attr()->zero_points_.common(DNNL_ARG_DST);
194
195 const memory_desc_wrapper dst_d(pd()->dst_md());
196
197 const auto MB = pd()->MB();
198 const auto OH = pd()->OH();
199 const auto OW = pd()->OW();
200 const auto OD = pd()->OD();
201 const auto OC = pd()->OC();
202 const auto OCP = dst_d.padded_dims()[1];
203 const auto ndims = pd()->desc()->src_desc.ndims;
204
205 const auto maybe_oscale = [=](float &d, dim_t oc) {
206 // scale_idx_mult = 1 for per_oc scales and 0, otherwise
207 const int scale_idx_mult
208 = pd()->attr()->output_scales_.mask_ == (1 << 1);
209 const float *scales = pd()->attr()->output_scales_.scales_;
210 d *= scales[oc * scale_idx_mult];
211 };
212
213 const auto maybe_dst_zero_point = [=](float &result, dim_t oc) {
214 if (is_dst_zp_common)
215 result += dst_zero_point[0];
216 else
217 result += dst_zero_point[oc];
218 };
219
220 parallel_nd(MB, OCP, OD, OH, OW,
221 [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) {
222 auto dst_off = get_data_off(dst_d, ndims, mb, ocp, od, oh, ow);
223 float tmp_result = 0;
224
225 if (ocp < OC) {
226 dim_t dst_l_off = (mb * OC + ocp) * OD * OH * OW
227 + od * OH * OW + oh * OW + ow;
228 tmp_result = conv_output[dst_off];
229 maybe_oscale(tmp_result, ocp);
230
231 ref_post_ops_t::args_t args;
232 if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1)
233 args.dst_val = (float)original_dst[dst_off];
234 args.ctx = &ctx;
235 args.l_offset = dst_l_off;
236 args.dst_md = pd()->dst_md();
237 ref_post_ops->execute(tmp_result, args);
238 maybe_dst_zero_point(tmp_result, ocp);
239 }
240 dst[dst_off] = cpu::saturate_and_round<dst_data_t>(tmp_result);
241 });
242
243 return status_t::dnnl_success;
244 }
245
get_weights_off(const memory_desc_wrapper & wei_d,bool with_groups,int ndims,dim_t g,dim_t oc,dim_t ic,dim_t kd,dim_t kh,dim_t kw)246 dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups,
247 int ndims, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
248 switch (ndims) {
249 case 5:
250 return with_groups ? wei_d.off(g, oc, ic, kd, kh, kw)
251 : wei_d.off(oc, ic, kd, kh, kw);
252 case 4:
253 return with_groups ? wei_d.off(g, oc, ic, kh, kw)
254 : wei_d.off(oc, ic, kh, kw);
255 case 3:
256 return with_groups ? wei_d.off(g, oc, ic, kw)
257 : wei_d.off(oc, ic, kw);
258 default: assert(!"unsupported ndims"); return dim_t(0);
259 }
260
261 return 0;
262 };
263
264 template <data_type_t wei_type>
compute_src_zp_compensation(const exec_ctx_t & ctx,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * pd)265 static void compute_src_zp_compensation(const exec_ctx_t &ctx,
266 const int32_t *src_zero_point, const bool is_src_zp_common,
267 typename prec_traits<wei_type>::type *wei,
268 const cpu_deconvolution_fwd_pd_t *pd) {
269 using namespace memory_tracking::names;
270
271 const auto scratchpad = ctx.get_scratchpad_grantor();
272 int32_t *zp_compensation = scratchpad.get<int32_t>(key_deconv_zp);
273 const auto G = pd->G();
274 const auto KH = pd->KH();
275 const auto KW = pd->KW();
276 const auto KD = pd->KD();
277 const auto OC = pd->OC() / G;
278 const auto IC = pd->IC() / G;
279 const memory_desc_wrapper wei_d(pd->weights_md());
280 const bool with_groups = pd->with_groups();
281 const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0);
282 const auto get_wei_off
283 = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
284 return get_weights_off(
285 wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
286 };
287
288 parallel_nd(G, OC, [&](const dim_t g, const dim_t oc) {
289 const auto out_offset = g * OC + oc;
290 int32_t acc = 0;
291
292 for_(dim_t kd = 0; kd < KD; ++kd)
293 for_(dim_t kh = 0; kh < KH; ++kh)
294 for (dim_t kw = 0; kw < KW; ++kw) {
295 for (dim_t ic = 0; ic < IC; ++ic) {
296 const auto weights_offset = get_wei_off(g, oc, ic, kd, kh, kw);
297 const int32_t wei32 = static_cast<int32_t>(wei[weights_offset]);
298
299 if (is_src_zp_common)
300 acc += wei32;
301 else
302 acc += wei32 * src_zero_point[g * IC + ic];
303 }
304 }
305
306 zp_compensation[out_offset] = acc * src_zero_point[0];
307 });
308 }
309
310 template <data_type_t wei_type>
311 static std::function<int32_t(
312 const dim_t, const dim_t, const dim_t, const dim_t, const dim_t)>
prepare_zp_pad_comp_ker(const dim_t ndims,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * deconv_pd)313 prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point,
314 const bool is_src_zp_common, typename prec_traits<wei_type>::type *wei,
315 const cpu_deconvolution_fwd_pd_t *deconv_pd) {
316
317 const auto KH = deconv_pd->KH();
318 const auto KW = deconv_pd->KW();
319 const auto KD = deconv_pd->KD();
320 const auto KSD = deconv_pd->KSD();
321 const auto KSH = deconv_pd->KSH();
322 const auto KSW = deconv_pd->KSW();
323 const auto KDD = deconv_pd->KDD() + 1;
324 const auto KDH = deconv_pd->KDH() + 1;
325 const auto KDW = deconv_pd->KDW() + 1;
326 const auto IC = deconv_pd->IC() / deconv_pd->G();
327 const auto IH = deconv_pd->IH();
328 const auto IW = deconv_pd->IW();
329 const auto ID = deconv_pd->ID();
330 const auto pad_front = deconv_pd->padFront();
331 const auto pad_top = deconv_pd->padT();
332 const auto pad_left = deconv_pd->padL();
333 const bool with_groups = deconv_pd->with_groups();
334 const memory_desc_wrapper wei_d(deconv_pd->weights_md());
335 const auto get_wei_off
336 = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
337 return get_weights_off(
338 wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
339 };
340
341 return [=](const dim_t g, const dim_t oc, const dim_t od, const dim_t oh,
342 const dim_t ow) {
343 int32_t zp_pad_compensation = 0;
344
345 for (dim_t kd = 0; kd < KD; ++kd) {
346 const dim_t id = od - kd * KDD + pad_front;
347 const bool should_apply_pad_comp_d
348 = id < 0 || id % KSD != 0 || (id / KSD) >= ID;
349
350 for (dim_t kh = 0; kh < KH; ++kh) {
351 const dim_t ih = oh - kh * KDH + pad_top;
352 const bool should_apply_pad_comp_h
353 = ih < 0 || ih % KSH != 0 || (ih / KSH) >= IH;
354
355 for (dim_t kw = 0; kw < KW; ++kw) {
356 const dim_t iw = ow - kw * KDW + pad_left;
357 const bool should_apply_pad_comp_w
358 = iw < 0 || iw % KSW != 0 || (iw / KSW) >= IW;
359
360 if (should_apply_pad_comp_d || should_apply_pad_comp_h
361 || should_apply_pad_comp_w) {
362
363 for (dim_t ic = 0; ic < IC; ic++) {
364 const auto wei_off
365 = get_wei_off(g, oc, ic, kd, kh, kw);
366 const int32_t wei32
367 = static_cast<int32_t>(wei[wei_off]);
368
369 if (is_src_zp_common)
370 zp_pad_compensation += wei32;
371 else
372 zp_pad_compensation
373 += wei32 * src_zero_point[g * IC + ic];
374 }
375 }
376 }
377 }
378 }
379
380 if (is_src_zp_common && zp_pad_compensation)
381 zp_pad_compensation *= src_zero_point[0];
382
383 return zp_pad_compensation;
384 };
385 }
386
387 template <data_type_t wei_type>
apply_src_zero_point(const exec_ctx_t & ctx,const cpu_deconvolution_fwd_pd_t * deconv_pd,float * conv_output)388 static status_t apply_src_zero_point(const exec_ctx_t &ctx,
389 const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) {
390 using wei_data_t = typename prec_traits<wei_type>::type;
391 using namespace memory_tracking::names;
392 using namespace data_type;
393
394 // required by DEFINE_ZERO_POINTS_BUFFER macro
395 const auto pd = [&]() { return deconv_pd; };
396 const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS);
397 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
398 const bool is_src_zp_common
399 = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC);
400
401 const auto scratchpad = ctx.get_scratchpad_grantor();
402 const int32_t *const zp_src_compensation
403 = scratchpad.get<int32_t>(key_deconv_zp);
404 const memory_desc_wrapper dst_d(pd()->dst_md());
405 const auto ndims = dst_d.ndims();
406
407 const auto G = pd()->G();
408 const auto MB = pd()->MB();
409 const auto OH = pd()->OH();
410 const auto OW = pd()->OW();
411 const auto OD = pd()->OD();
412 const auto OC = pd()->OC() / G;
413
414 compute_src_zp_compensation<wei_type>(
415 ctx, src_zero_point, is_src_zp_common, wei, deconv_pd);
416 const auto zp_pad_comp_ker = prepare_zp_pad_comp_ker<wei_type>(
417 ndims, src_zero_point, is_src_zp_common, wei, deconv_pd);
418
419 parallel_nd(MB, G, OC, OD, OH, OW,
420 [&](const dim_t mb, const dim_t g, const dim_t oc, const dim_t od,
421 const dim_t oh, const dim_t ow) {
422 const auto oc_off = g * OC + oc;
423 const auto dst_off
424 = get_data_off(dst_d, ndims, mb, oc_off, od, oh, ow);
425 int32_t conv_result
426 = conv_output[dst_off] - zp_src_compensation[oc_off];
427
428 if (const auto zp_pad_compensation
429 = zp_pad_comp_ker(g, oc, od, oh, ow)) {
430 conv_result += zp_pad_compensation;
431 }
432
433 conv_output[dst_off] = static_cast<float>(conv_result);
434 });
435
436 return status::success;
437 }
438
execute(const exec_ctx_t & ctx) const439 status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const {
440 using namespace memory_tracking::names;
441 const auto scratchpad = ctx.get_scratchpad_grantor();
442 const bool ref_bias = pd()->with_bias() && !pd()->conv_supports_bias_;
443 const bool non_default_attr = !pd()->attr()->has_default_values();
444
445 const auto &args = ctx.args();
446 exec_args_t conv_args;
447 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
448 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
449 if (pd()->with_bias() && pd()->conv_supports_bias_)
450 conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
451
452 // Create intermediate memory for f32 output if needed.
453 auto dst = args.at(DNNL_ARG_DST);
454 memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(),
455 scratchpad.get_memory_storage(key_deconv_bias));
456 memory_arg_t tmp_conv_output = {&tmp_memory, false};
457
458 conv_args[DNNL_ARG_DIFF_SRC]
459 = ref_bias || non_default_attr ? tmp_conv_output : dst;
460
461 // When sum post-op happens, we need to copy original destination memory
462 // prior call to external convolution happens.
463 if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) {
464 void *original_dst = scratchpad.get(key_deconv_sum);
465 const memory_desc_wrapper dst_d(pd()->dst_md());
466 void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
467 const auto dt_size = dst_d.data_type_size();
468
469 parallel(0, [&](const int ithr, const int nthr) {
470 dim_t start {0}, end {0};
471 balance211(dst_d.nelems(true), nthr, ithr, start, end);
472 auto o_dst_start = (char *)original_dst + start * dt_size;
473 auto dst_start = (char *)dst + start * dt_size;
474 const auto size = (end - start) * dt_size;
475
476 std::memcpy(o_dst_start, dst_start, size);
477 });
478 }
479
480 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
481
482 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
483 conv_ctx.set_scratchpad_grantor(ns.grantor());
484 auto status = conv_p_->execute(conv_ctx);
485 if (status != status::success) return status;
486
487 conv_args[DNNL_ARG_DIFF_SRC]
488 = ref_bias || non_default_attr ? tmp_conv_output : dst;
489
490 using namespace data_type;
491
492 if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
493 float *conv_output = scratchpad.get<float>(key_deconv_bias);
494 const auto wei_dt = pd()->weights_md()->data_type;
495 switch (wei_dt) {
496 case s8: apply_src_zero_point<s8>(ctx, pd(), conv_output); break;
497 case u8: apply_src_zero_point<u8>(ctx, pd(), conv_output); break;
498 default: assert(!"unsupported data type");
499 }
500 }
501
502 auto dst_type = pd()->dst_md()->data_type;
503
504 if (ref_bias) {
505 float *conv_output = scratchpad.get<float>(key_deconv_bias);
506 if (non_default_attr) {
507 // Overwrite conv_output since further attr computations still need
508 // f32 output.
509 compute_fwd_bias<f32>(ctx, conv_output, conv_output);
510 } else {
511
512 #define CASE(DT) \
513 case (DT): { \
514 using dst_data_t = typename prec_traits<DT>::type; \
515 dst_data_t *dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); \
516 compute_fwd_bias<DT>(ctx, dst, conv_output); \
517 } break
518
519 switch (dst_type) {
520 CASE(f32);
521 CASE(bf16);
522 CASE(s32);
523 CASE(s8);
524 CASE(u8);
525 default: assert(!"unsupported data type");
526 }
527 #undef CASE
528 }
529 }
530
531 if (non_default_attr) {
532 float *conv_output = scratchpad.get<float>(key_deconv_bias);
533
534 #define CASE(DT) \
535 case (DT): { \
536 using dst_data_t = typename prec_traits<DT>::type; \
537 dst_data_t *original_dst = scratchpad.get<dst_data_t>(key_deconv_sum); \
538 compute_ref_attrs<DT>(ctx, conv_output, original_dst); \
539 } break
540
541 switch (dst_type) {
542 CASE(f32);
543 CASE(bf16);
544 CASE(s32);
545 CASE(s8);
546 CASE(u8);
547 default: assert(!"unsupported data type");
548 }
549 #undef CASE
550 }
551
552 return status::success;
553 }
554
execute(const exec_ctx_t & ctx) const555 status_t ref_deconvolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
556 using namespace memory_tracking::names;
557 const auto &args = ctx.args();
558 exec_args_t conv_args;
559 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
560 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
561 conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
562 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
563
564 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
565 conv_ctx.set_scratchpad_grantor(ns.grantor());
566 conv_p_->execute(conv_ctx);
567 return status::success;
568 }
569
compute_bwd_bias(float * diff_bias,const float * diff_dst) const570 void ref_deconvolution_bwd_weights_t::compute_bwd_bias(
571 float *diff_bias, const float *diff_dst) const {
572 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
573
574 const auto G = pd()->G();
575 const auto MB = pd()->MB();
576 const auto OH = pd()->OH();
577 const auto OW = pd()->OW();
578 const auto OC = pd()->OC() / G;
579 const auto OD = pd()->OD();
580 const auto ndims = pd()->desc()->src_desc.ndims;
581
582 parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
583 float db = 0;
584 for_(dim_t mb = 0; mb < MB; ++mb)
585 for_(dim_t od = 0; od < OD; ++od)
586 for_(dim_t oh = 0; oh < OH; ++oh)
587 for (dim_t ow = 0; ow < OW; ++ow) {
588 const auto d_dst_off = get_data_off(
589 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
590 db += diff_dst[d_dst_off];
591 }
592 diff_bias[g * OC + oc] = db;
593 });
594 }
595
596 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ncdhw(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const597 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
598 typename prec_traits<dbia_type>::type *diff_bias,
599 const typename prec_traits<ddst_type>::type *diff_dst) const {
600 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
601
602 const auto OC = pd()->OC();
603 const auto MB = pd()->MB();
604 const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
605
606 parallel_nd(OC, [&](dim_t oc) {
607 float db = 0;
608 for (dim_t mb = 0; mb < MB; ++mb) {
609 PRAGMA_OMP_SIMD(reduction(+ : db))
610 for (dim_t sp = 0; sp < SP; ++sp) {
611 auto offset = (size_t)(mb * OC + oc) * SP + sp;
612 db += diff_dst[offset];
613 }
614 }
615 diff_bias[oc] = db;
616 });
617 }
618
619 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ndhwc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const620 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc(
621 typename prec_traits<dbia_type>::type *diff_bias,
622 const typename prec_traits<ddst_type>::type *diff_dst) const {
623 const auto MB = pd()->MB();
624 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
625 const auto OC = pd()->OC();
626
627 parallel_nd(OC, [&](dim_t oc) {
628 float db = 0;
629 for (dim_t mb = 0; mb < MB; ++mb) {
630 PRAGMA_OMP_SIMD(reduction(+ : db))
631 for (dim_t sp = 0; sp < SP; ++sp) {
632 const dim_t offset = (mb * SP + sp) * OC + oc;
633 db += diff_dst[offset];
634 }
635 }
636 diff_bias[oc] = static_cast<typename prec_traits<dbia_type>::type>(db);
637 });
638 }
639
640 template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
compute_bwd_bias_nCdhwXc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const641 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
642 typename prec_traits<dbia_type>::type *diff_bias,
643 const typename prec_traits<ddst_type>::type *diff_dst) const {
644 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
645
646 const auto OC = pd()->OC();
647 const auto MB = pd()->MB();
648 const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
649
650 const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
651
652 parallel_nd(utils::div_up(OC, blksize), [&](dim_t ocb) {
653 float db[blksize] = {0};
654
655 for (dim_t mb = 0; mb < MB; ++mb) {
656 for (dim_t sp = 0; sp < SP; ++sp) {
657 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
658
659 PRAGMA_OMP_SIMD()
660 for (dim_t i = 0; i < blksize; ++i)
661 db[i] += diff_dst[offset + i];
662 }
663 }
664
665 const dim_t blk = nstl::min(blksize, OC - ocb * blksize);
666
667 PRAGMA_OMP_SIMD()
668 for (dim_t i = 0; i < blk; ++i)
669 diff_bias[ocb * blksize + i] = db[i];
670 });
671 }
672
673 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bias(const exec_ctx_t & ctx) const674 void ref_deconvolution_bwd_weights_t::compute_bias(
675 const exec_ctx_t &ctx) const {
676 using dbia_data_t = typename prec_traits<dbia_type>::type;
677 using ddst_data_t = typename prec_traits<ddst_type>::type;
678
679 auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS);
680 auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST);
681
682 using namespace format_tag;
683 switch (pd()->dst_tag_) {
684 case ncdhw:
685 case nchw:
686 case ncw:
687 compute_bwd_bias_ncdhw<dbia_type, ddst_type>(diff_bias, diff_dst);
688 break;
689 case ndhwc:
690 case nhwc:
691 case nwc:
692 compute_bwd_bias_ndhwc<dbia_type, ddst_type>(diff_bias, diff_dst);
693 break;
694 case nCdhw8c:
695 case nChw8c:
696 case nCw8c:
697 assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
698 compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 8>(
699 diff_bias, diff_dst);
700 break;
701 case nCdhw16c:
702 case nChw16c:
703 case nCw16c:
704 compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 16>(
705 diff_bias, diff_dst);
706 break;
707 default:
708 assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
709 compute_bwd_bias((float *)diff_bias, (const float *)diff_dst);
710 break;
711 }
712 }
713
execute(const exec_ctx_t & ctx) const714 status_t ref_deconvolution_bwd_weights_t::execute(const exec_ctx_t &ctx) const {
715 using namespace memory_tracking::names;
716 const auto &args = ctx.args();
717 exec_args_t conv_args;
718 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
719 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
720 conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS);
721 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
722
723 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
724 conv_ctx.set_scratchpad_grantor(ns.grantor());
725 status_t status = conv_p_->execute(conv_ctx);
726 if (status != status::success) return status;
727
728 if (pd()->with_bias()) {
729 using namespace data_type;
730
731 auto dbia_type = pd()->diff_weights_md(1)->data_type;
732 auto ddst_type = pd()->diff_dst_md()->data_type;
733 if (utils::everyone_is(f32, dbia_type, ddst_type))
734 compute_bias<f32, f32>(ctx);
735 else if (utils::everyone_is(bf16, dbia_type, ddst_type))
736 compute_bias<bf16, bf16>(ctx);
737 else if (dbia_type == f32 && ddst_type == bf16)
738 compute_bias<f32, bf16>(ctx);
739 else {
740 assert(!"unsupported data type");
741 return status::runtime_error;
742 }
743 }
744 return status::success;
745 }
746
747 using namespace data_type;
748
749 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f32>(
750 const exec_ctx_t &ctx) const;
751 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, bf16>(
752 const exec_ctx_t &ctx) const;
753 template void ref_deconvolution_bwd_weights_t::compute_bias<bf16, bf16>(
754 const exec_ctx_t &ctx) const;
755 } // namespace cpu
756 } // namespace impl
757 } // namespace dnnl
758
759 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
760