1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <functional>
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_thread.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 
24 #include "cpu/cpu_primitive.hpp"
25 #include "cpu/simple_q10n.hpp"
26 
27 #include "cpu/ref_deconvolution.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 
33 namespace {
get_data_off(const memory_desc_wrapper & mdw,int ndims,dim_t mb,dim_t c,dim_t id,dim_t ih,dim_t iw)34 dim_t get_data_off(const memory_desc_wrapper &mdw, int ndims, dim_t mb, dim_t c,
35         dim_t id, dim_t ih, dim_t iw) {
36     switch (ndims) {
37         case 5: return mdw.off(mb, c, id, ih, iw);
38         case 4: return mdw.off(mb, c, ih, iw);
39         case 3: return mdw.off(mb, c, iw);
40         default: assert(!"unsupported ndims"); return dim_t(0);
41     }
42 }
43 
44 } // namespace
45 
46 template <data_type_t dst_type>
compute_fwd_bias_common(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const47 void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx,
48         typename prec_traits<dst_type>::type *dst,
49         const float *conv_output) const {
50     using dst_data_t = typename prec_traits<dst_type>::type;
51     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
52     const memory_desc_wrapper dst_d(pd()->dst_md());
53     const memory_desc_wrapper bias_d(pd()->weights_md(1));
54 
55     const auto G = pd()->G();
56     const auto MB = pd()->MB();
57     const auto OH = pd()->OH();
58     const auto OW = pd()->OW();
59     const auto OD = pd()->OD();
60     const auto OC = pd()->OC() / G;
61     const auto ndims = pd()->desc()->src_desc.ndims;
62 
63     parallel_nd(MB, G, OC, OD, OH, OW,
64             [&](dim_t mb, dim_t g, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
65                 const dim_t c = g * OC + oc;
66                 const dim_t off = get_data_off(dst_d, ndims, mb, c, od, oh, ow);
67                 float b = types::get_float_value(bias_d.data_type(), bias, c);
68                 float d = conv_output[off];
69                 dst[off] = cpu::saturate_and_round<dst_data_t>(d + b);
70             });
71 }
72 
73 template <data_type_t dst_type>
compute_fwd_bias_ncdhw(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const74 void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx,
75         typename prec_traits<dst_type>::type *dst,
76         const float *conv_output) const {
77     using dst_data_t = typename prec_traits<dst_type>::type;
78     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
79     const memory_desc_wrapper dst_d(pd()->dst_md());
80     const memory_desc_wrapper bias_d(pd()->weights_md(1));
81 
82     const auto MB = pd()->MB();
83     const auto OC = pd()->OC();
84     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
85 
86     parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
87         const dim_t off = (mb * OC + oc) * SP;
88         float b = types::get_float_value(bias_d.data_type(), bias, oc);
89         PRAGMA_OMP_SIMD()
90         for (dim_t sp = 0; sp < SP; ++sp) {
91             float d = conv_output[off + sp];
92             dst[off + sp] = cpu::saturate_and_round<dst_data_t>(d + b);
93         }
94     });
95 }
96 
97 template <data_type_t dst_type>
compute_fwd_bias_ndhwc(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const98 void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx,
99         typename prec_traits<dst_type>::type *dst,
100         const float *conv_output) const {
101     using dst_data_t = typename prec_traits<dst_type>::type;
102     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
103     const memory_desc_wrapper dst_d(pd()->dst_md());
104     const memory_desc_wrapper bias_d(pd()->weights_md(1));
105 
106     const auto MB = pd()->MB();
107     const auto OC = pd()->OC();
108     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
109 
110     parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
111         const dim_t off = (mb * SP + sp) * OC;
112         PRAGMA_OMP_SIMD()
113         for (dim_t oc = 0; oc < OC; ++oc) {
114             float b = types::get_float_value(bias_d.data_type(), bias, oc);
115             float d = conv_output[off + oc];
116             dst[off + oc] = cpu::saturate_and_round<dst_data_t>(d + b);
117         }
118     });
119 }
120 
121 template <data_type_t dst_type, dim_t blk_size>
compute_fwd_bias_nCdhwXc(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const122 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx,
123         typename prec_traits<dst_type>::type *dst,
124         const float *conv_output) const {
125     using dst_data_t = typename prec_traits<dst_type>::type;
126     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
127     const memory_desc_wrapper dst_d(pd()->dst_md());
128     const memory_desc_wrapper bias_d(pd()->weights_md(1));
129 
130     const auto MB = pd()->MB();
131     const auto OC = pd()->OC();
132     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
133     const auto stride_mb = dst_d.blocking_desc().strides[0];
134 
135     parallel_nd(MB, utils::div_up(OC, blk_size), SP,
136             [&](dim_t mb, dim_t oc_blk, dim_t sp) {
137                 const dim_t oc = oc_blk * blk_size;
138                 const dim_t off = mb * stride_mb + oc * SP + sp * blk_size;
139                 const dim_t blk = nstl::min(blk_size, OC - oc);
140 
141                 PRAGMA_OMP_SIMD()
142                 for (dim_t i = 0; i < blk_size; ++i) {
143                     float b = i < blk ? types::get_float_value(
144                                       bias_d.data_type(), bias, oc + i)
145                                       : 0;
146                     float d = conv_output[off + i];
147                     dst[off + i] = cpu::saturate_and_round<dst_data_t>(d + b);
148                 }
149             });
150 }
151 
152 template <data_type_t dst_type>
compute_fwd_bias(const exec_ctx_t & ctx,typename prec_traits<dst_type>::type * dst,const float * conv_output) const153 void ref_deconvolution_fwd_t::compute_fwd_bias(const exec_ctx_t &ctx,
154         typename prec_traits<dst_type>::type *dst,
155         const float *conv_output) const {
156     using namespace format_tag;
157     switch (pd()->dst_tag_) {
158         case ncdhw:
159         case nchw:
160         case ncw:
161             compute_fwd_bias_ncdhw<dst_type>(ctx, dst, conv_output);
162             break;
163         case ndhwc:
164         case nhwc:
165         case nwc:
166             compute_fwd_bias_ndhwc<dst_type>(ctx, dst, conv_output);
167             break;
168         case nCdhw8c:
169         case nChw8c:
170         case nCw8c:
171             assert(dst_type != data_type::bf16);
172             compute_fwd_bias_nCdhwXc<dst_type, 8>(ctx, dst, conv_output);
173             break;
174         case nCdhw16c:
175         case nChw16c:
176         case nCw16c:
177             compute_fwd_bias_nCdhwXc<dst_type, 16>(ctx, dst, conv_output);
178             break;
179         default:
180             compute_fwd_bias_common<dst_type>(ctx, dst, conv_output);
181             break;
182     }
183 }
184 
185 template <data_type_t dst_type>
compute_ref_attrs(const exec_ctx_t & ctx,const float * conv_output,typename prec_traits<dst_type>::type * original_dst) const186 status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx,
187         const float *conv_output,
188         typename prec_traits<dst_type>::type *original_dst) const {
189     using dst_data_t = typename prec_traits<dst_type>::type;
190     auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
191     DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
192     const bool is_dst_zp_common
193             = pd()->attr()->zero_points_.common(DNNL_ARG_DST);
194 
195     const memory_desc_wrapper dst_d(pd()->dst_md());
196 
197     const auto MB = pd()->MB();
198     const auto OH = pd()->OH();
199     const auto OW = pd()->OW();
200     const auto OD = pd()->OD();
201     const auto OC = pd()->OC();
202     const auto OCP = dst_d.padded_dims()[1];
203     const auto ndims = pd()->desc()->src_desc.ndims;
204 
205     const auto maybe_oscale = [=](float &d, dim_t oc) {
206         // scale_idx_mult = 1 for per_oc scales and 0, otherwise
207         const int scale_idx_mult
208                 = pd()->attr()->output_scales_.mask_ == (1 << 1);
209         const float *scales = pd()->attr()->output_scales_.scales_;
210         d *= scales[oc * scale_idx_mult];
211     };
212 
213     const auto maybe_dst_zero_point = [=](float &result, dim_t oc) {
214         if (is_dst_zp_common)
215             result += dst_zero_point[0];
216         else
217             result += dst_zero_point[oc];
218     };
219 
220     parallel_nd(MB, OCP, OD, OH, OW,
221             [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) {
222                 auto dst_off = get_data_off(dst_d, ndims, mb, ocp, od, oh, ow);
223                 float tmp_result = 0;
224 
225                 if (ocp < OC) {
226                     dim_t dst_l_off = (mb * OC + ocp) * OD * OH * OW
227                             + od * OH * OW + oh * OW + ow;
228                     tmp_result = conv_output[dst_off];
229                     maybe_oscale(tmp_result, ocp);
230 
231                     ref_post_ops_t::args_t args;
232                     if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1)
233                         args.dst_val = (float)original_dst[dst_off];
234                     args.ctx = &ctx;
235                     args.l_offset = dst_l_off;
236                     args.dst_md = pd()->dst_md();
237                     ref_post_ops->execute(tmp_result, args);
238                     maybe_dst_zero_point(tmp_result, ocp);
239                 }
240                 dst[dst_off] = cpu::saturate_and_round<dst_data_t>(tmp_result);
241             });
242 
243     return status_t::dnnl_success;
244 }
245 
get_weights_off(const memory_desc_wrapper & wei_d,bool with_groups,int ndims,dim_t g,dim_t oc,dim_t ic,dim_t kd,dim_t kh,dim_t kw)246 dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups,
247         int ndims, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
248     switch (ndims) {
249         case 5:
250             return with_groups ? wei_d.off(g, oc, ic, kd, kh, kw)
251                                : wei_d.off(oc, ic, kd, kh, kw);
252         case 4:
253             return with_groups ? wei_d.off(g, oc, ic, kh, kw)
254                                : wei_d.off(oc, ic, kh, kw);
255         case 3:
256             return with_groups ? wei_d.off(g, oc, ic, kw)
257                                : wei_d.off(oc, ic, kw);
258         default: assert(!"unsupported ndims"); return dim_t(0);
259     }
260 
261     return 0;
262 };
263 
264 template <data_type_t wei_type>
compute_src_zp_compensation(const exec_ctx_t & ctx,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * pd)265 static void compute_src_zp_compensation(const exec_ctx_t &ctx,
266         const int32_t *src_zero_point, const bool is_src_zp_common,
267         typename prec_traits<wei_type>::type *wei,
268         const cpu_deconvolution_fwd_pd_t *pd) {
269     using namespace memory_tracking::names;
270 
271     const auto scratchpad = ctx.get_scratchpad_grantor();
272     int32_t *zp_compensation = scratchpad.get<int32_t>(key_deconv_zp);
273     const auto G = pd->G();
274     const auto KH = pd->KH();
275     const auto KW = pd->KW();
276     const auto KD = pd->KD();
277     const auto OC = pd->OC() / G;
278     const auto IC = pd->IC() / G;
279     const memory_desc_wrapper wei_d(pd->weights_md());
280     const bool with_groups = pd->with_groups();
281     const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0);
282     const auto get_wei_off
283             = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
284                   return get_weights_off(
285                           wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
286               };
287 
288     parallel_nd(G, OC, [&](const dim_t g, const dim_t oc) {
289         const auto out_offset = g * OC + oc;
290         int32_t acc = 0;
291 
292         for_(dim_t kd = 0; kd < KD; ++kd)
293         for_(dim_t kh = 0; kh < KH; ++kh)
294         for (dim_t kw = 0; kw < KW; ++kw) {
295             for (dim_t ic = 0; ic < IC; ++ic) {
296                 const auto weights_offset = get_wei_off(g, oc, ic, kd, kh, kw);
297                 const int32_t wei32 = static_cast<int32_t>(wei[weights_offset]);
298 
299                 if (is_src_zp_common)
300                     acc += wei32;
301                 else
302                     acc += wei32 * src_zero_point[g * IC + ic];
303             }
304         }
305 
306         zp_compensation[out_offset] = acc * src_zero_point[0];
307     });
308 }
309 
310 template <data_type_t wei_type>
311 static std::function<int32_t(
312         const dim_t, const dim_t, const dim_t, const dim_t, const dim_t)>
prepare_zp_pad_comp_ker(const dim_t ndims,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * deconv_pd)313 prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point,
314         const bool is_src_zp_common, typename prec_traits<wei_type>::type *wei,
315         const cpu_deconvolution_fwd_pd_t *deconv_pd) {
316 
317     const auto KH = deconv_pd->KH();
318     const auto KW = deconv_pd->KW();
319     const auto KD = deconv_pd->KD();
320     const auto KSD = deconv_pd->KSD();
321     const auto KSH = deconv_pd->KSH();
322     const auto KSW = deconv_pd->KSW();
323     const auto KDD = deconv_pd->KDD() + 1;
324     const auto KDH = deconv_pd->KDH() + 1;
325     const auto KDW = deconv_pd->KDW() + 1;
326     const auto IC = deconv_pd->IC() / deconv_pd->G();
327     const auto IH = deconv_pd->IH();
328     const auto IW = deconv_pd->IW();
329     const auto ID = deconv_pd->ID();
330     const auto pad_front = deconv_pd->padFront();
331     const auto pad_top = deconv_pd->padT();
332     const auto pad_left = deconv_pd->padL();
333     const bool with_groups = deconv_pd->with_groups();
334     const memory_desc_wrapper wei_d(deconv_pd->weights_md());
335     const auto get_wei_off
336             = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
337                   return get_weights_off(
338                           wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
339               };
340 
341     return [=](const dim_t g, const dim_t oc, const dim_t od, const dim_t oh,
342                    const dim_t ow) {
343         int32_t zp_pad_compensation = 0;
344 
345         for (dim_t kd = 0; kd < KD; ++kd) {
346             const dim_t id = od - kd * KDD + pad_front;
347             const bool should_apply_pad_comp_d
348                     = id < 0 || id % KSD != 0 || (id / KSD) >= ID;
349 
350             for (dim_t kh = 0; kh < KH; ++kh) {
351                 const dim_t ih = oh - kh * KDH + pad_top;
352                 const bool should_apply_pad_comp_h
353                         = ih < 0 || ih % KSH != 0 || (ih / KSH) >= IH;
354 
355                 for (dim_t kw = 0; kw < KW; ++kw) {
356                     const dim_t iw = ow - kw * KDW + pad_left;
357                     const bool should_apply_pad_comp_w
358                             = iw < 0 || iw % KSW != 0 || (iw / KSW) >= IW;
359 
360                     if (should_apply_pad_comp_d || should_apply_pad_comp_h
361                             || should_apply_pad_comp_w) {
362 
363                         for (dim_t ic = 0; ic < IC; ic++) {
364                             const auto wei_off
365                                     = get_wei_off(g, oc, ic, kd, kh, kw);
366                             const int32_t wei32
367                                     = static_cast<int32_t>(wei[wei_off]);
368 
369                             if (is_src_zp_common)
370                                 zp_pad_compensation += wei32;
371                             else
372                                 zp_pad_compensation
373                                         += wei32 * src_zero_point[g * IC + ic];
374                         }
375                     }
376                 }
377             }
378         }
379 
380         if (is_src_zp_common && zp_pad_compensation)
381             zp_pad_compensation *= src_zero_point[0];
382 
383         return zp_pad_compensation;
384     };
385 }
386 
387 template <data_type_t wei_type>
apply_src_zero_point(const exec_ctx_t & ctx,const cpu_deconvolution_fwd_pd_t * deconv_pd,float * conv_output)388 static status_t apply_src_zero_point(const exec_ctx_t &ctx,
389         const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) {
390     using wei_data_t = typename prec_traits<wei_type>::type;
391     using namespace memory_tracking::names;
392     using namespace data_type;
393 
394     // required by DEFINE_ZERO_POINTS_BUFFER macro
395     const auto pd = [&]() { return deconv_pd; };
396     const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS);
397     DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
398     const bool is_src_zp_common
399             = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC);
400 
401     const auto scratchpad = ctx.get_scratchpad_grantor();
402     const int32_t *const zp_src_compensation
403             = scratchpad.get<int32_t>(key_deconv_zp);
404     const memory_desc_wrapper dst_d(pd()->dst_md());
405     const auto ndims = dst_d.ndims();
406 
407     const auto G = pd()->G();
408     const auto MB = pd()->MB();
409     const auto OH = pd()->OH();
410     const auto OW = pd()->OW();
411     const auto OD = pd()->OD();
412     const auto OC = pd()->OC() / G;
413 
414     compute_src_zp_compensation<wei_type>(
415             ctx, src_zero_point, is_src_zp_common, wei, deconv_pd);
416     const auto zp_pad_comp_ker = prepare_zp_pad_comp_ker<wei_type>(
417             ndims, src_zero_point, is_src_zp_common, wei, deconv_pd);
418 
419     parallel_nd(MB, G, OC, OD, OH, OW,
420             [&](const dim_t mb, const dim_t g, const dim_t oc, const dim_t od,
421                     const dim_t oh, const dim_t ow) {
422                 const auto oc_off = g * OC + oc;
423                 const auto dst_off
424                         = get_data_off(dst_d, ndims, mb, oc_off, od, oh, ow);
425                 int32_t conv_result
426                         = conv_output[dst_off] - zp_src_compensation[oc_off];
427 
428                 if (const auto zp_pad_compensation
429                         = zp_pad_comp_ker(g, oc, od, oh, ow)) {
430                     conv_result += zp_pad_compensation;
431                 }
432 
433                 conv_output[dst_off] = static_cast<float>(conv_result);
434             });
435 
436     return status::success;
437 }
438 
execute(const exec_ctx_t & ctx) const439 status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const {
440     using namespace memory_tracking::names;
441     const auto scratchpad = ctx.get_scratchpad_grantor();
442     const bool ref_bias = pd()->with_bias() && !pd()->conv_supports_bias_;
443     const bool non_default_attr = !pd()->attr()->has_default_values();
444 
445     const auto &args = ctx.args();
446     exec_args_t conv_args;
447     conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
448     conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
449     if (pd()->with_bias() && pd()->conv_supports_bias_)
450         conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
451 
452     // Create intermediate memory for f32 output if needed.
453     auto dst = args.at(DNNL_ARG_DST);
454     memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(),
455             scratchpad.get_memory_storage(key_deconv_bias));
456     memory_arg_t tmp_conv_output = {&tmp_memory, false};
457 
458     conv_args[DNNL_ARG_DIFF_SRC]
459             = ref_bias || non_default_attr ? tmp_conv_output : dst;
460 
461     // When sum post-op happens, we need to copy original destination memory
462     // prior call to external convolution happens.
463     if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) {
464         void *original_dst = scratchpad.get(key_deconv_sum);
465         const memory_desc_wrapper dst_d(pd()->dst_md());
466         void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
467         const auto dt_size = dst_d.data_type_size();
468 
469         parallel(0, [&](const int ithr, const int nthr) {
470             dim_t start {0}, end {0};
471             balance211(dst_d.nelems(true), nthr, ithr, start, end);
472             auto o_dst_start = (char *)original_dst + start * dt_size;
473             auto dst_start = (char *)dst + start * dt_size;
474             const auto size = (end - start) * dt_size;
475 
476             std::memcpy(o_dst_start, dst_start, size);
477         });
478     }
479 
480     exec_ctx_t conv_ctx(ctx, std::move(conv_args));
481 
482     nested_scratchpad_t ns(ctx, key_nested, conv_p_);
483     conv_ctx.set_scratchpad_grantor(ns.grantor());
484     auto status = conv_p_->execute(conv_ctx);
485     if (status != status::success) return status;
486 
487     conv_args[DNNL_ARG_DIFF_SRC]
488             = ref_bias || non_default_attr ? tmp_conv_output : dst;
489 
490     using namespace data_type;
491 
492     if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
493         float *conv_output = scratchpad.get<float>(key_deconv_bias);
494         const auto wei_dt = pd()->weights_md()->data_type;
495         switch (wei_dt) {
496             case s8: apply_src_zero_point<s8>(ctx, pd(), conv_output); break;
497             case u8: apply_src_zero_point<u8>(ctx, pd(), conv_output); break;
498             default: assert(!"unsupported data type");
499         }
500     }
501 
502     auto dst_type = pd()->dst_md()->data_type;
503 
504     if (ref_bias) {
505         float *conv_output = scratchpad.get<float>(key_deconv_bias);
506         if (non_default_attr) {
507             // Overwrite conv_output since further attr computations still need
508             // f32 output.
509             compute_fwd_bias<f32>(ctx, conv_output, conv_output);
510         } else {
511 
512 #define CASE(DT) \
513     case (DT): { \
514         using dst_data_t = typename prec_traits<DT>::type; \
515         dst_data_t *dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); \
516         compute_fwd_bias<DT>(ctx, dst, conv_output); \
517     } break
518 
519             switch (dst_type) {
520                 CASE(f32);
521                 CASE(bf16);
522                 CASE(s32);
523                 CASE(s8);
524                 CASE(u8);
525                 default: assert(!"unsupported data type");
526             }
527 #undef CASE
528         }
529     }
530 
531     if (non_default_attr) {
532         float *conv_output = scratchpad.get<float>(key_deconv_bias);
533 
534 #define CASE(DT) \
535     case (DT): { \
536         using dst_data_t = typename prec_traits<DT>::type; \
537         dst_data_t *original_dst = scratchpad.get<dst_data_t>(key_deconv_sum); \
538         compute_ref_attrs<DT>(ctx, conv_output, original_dst); \
539     } break
540 
541         switch (dst_type) {
542             CASE(f32);
543             CASE(bf16);
544             CASE(s32);
545             CASE(s8);
546             CASE(u8);
547             default: assert(!"unsupported data type");
548         }
549 #undef CASE
550     }
551 
552     return status::success;
553 }
554 
execute(const exec_ctx_t & ctx) const555 status_t ref_deconvolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
556     using namespace memory_tracking::names;
557     const auto &args = ctx.args();
558     exec_args_t conv_args;
559     conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
560     conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
561     conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
562     exec_ctx_t conv_ctx(ctx, std::move(conv_args));
563 
564     nested_scratchpad_t ns(ctx, key_nested, conv_p_);
565     conv_ctx.set_scratchpad_grantor(ns.grantor());
566     conv_p_->execute(conv_ctx);
567     return status::success;
568 }
569 
compute_bwd_bias(float * diff_bias,const float * diff_dst) const570 void ref_deconvolution_bwd_weights_t::compute_bwd_bias(
571         float *diff_bias, const float *diff_dst) const {
572     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
573 
574     const auto G = pd()->G();
575     const auto MB = pd()->MB();
576     const auto OH = pd()->OH();
577     const auto OW = pd()->OW();
578     const auto OC = pd()->OC() / G;
579     const auto OD = pd()->OD();
580     const auto ndims = pd()->desc()->src_desc.ndims;
581 
582     parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
583         float db = 0;
584         for_(dim_t mb = 0; mb < MB; ++mb)
585         for_(dim_t od = 0; od < OD; ++od)
586         for_(dim_t oh = 0; oh < OH; ++oh)
587         for (dim_t ow = 0; ow < OW; ++ow) {
588             const auto d_dst_off = get_data_off(
589                     diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
590             db += diff_dst[d_dst_off];
591         }
592         diff_bias[g * OC + oc] = db;
593     });
594 }
595 
596 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ncdhw(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const597 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
598         typename prec_traits<dbia_type>::type *diff_bias,
599         const typename prec_traits<ddst_type>::type *diff_dst) const {
600     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
601 
602     const auto OC = pd()->OC();
603     const auto MB = pd()->MB();
604     const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
605 
606     parallel_nd(OC, [&](dim_t oc) {
607         float db = 0;
608         for (dim_t mb = 0; mb < MB; ++mb) {
609             PRAGMA_OMP_SIMD(reduction(+ : db))
610             for (dim_t sp = 0; sp < SP; ++sp) {
611                 auto offset = (size_t)(mb * OC + oc) * SP + sp;
612                 db += diff_dst[offset];
613             }
614         }
615         diff_bias[oc] = db;
616     });
617 }
618 
619 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ndhwc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const620 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc(
621         typename prec_traits<dbia_type>::type *diff_bias,
622         const typename prec_traits<ddst_type>::type *diff_dst) const {
623     const auto MB = pd()->MB();
624     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
625     const auto OC = pd()->OC();
626 
627     parallel_nd(OC, [&](dim_t oc) {
628         float db = 0;
629         for (dim_t mb = 0; mb < MB; ++mb) {
630             PRAGMA_OMP_SIMD(reduction(+ : db))
631             for (dim_t sp = 0; sp < SP; ++sp) {
632                 const dim_t offset = (mb * SP + sp) * OC + oc;
633                 db += diff_dst[offset];
634             }
635         }
636         diff_bias[oc] = static_cast<typename prec_traits<dbia_type>::type>(db);
637     });
638 }
639 
640 template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
compute_bwd_bias_nCdhwXc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const641 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
642         typename prec_traits<dbia_type>::type *diff_bias,
643         const typename prec_traits<ddst_type>::type *diff_dst) const {
644     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
645 
646     const auto OC = pd()->OC();
647     const auto MB = pd()->MB();
648     const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
649 
650     const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
651 
652     parallel_nd(utils::div_up(OC, blksize), [&](dim_t ocb) {
653         float db[blksize] = {0};
654 
655         for (dim_t mb = 0; mb < MB; ++mb) {
656             for (dim_t sp = 0; sp < SP; ++sp) {
657                 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
658 
659                 PRAGMA_OMP_SIMD()
660                 for (dim_t i = 0; i < blksize; ++i)
661                     db[i] += diff_dst[offset + i];
662             }
663         }
664 
665         const dim_t blk = nstl::min(blksize, OC - ocb * blksize);
666 
667         PRAGMA_OMP_SIMD()
668         for (dim_t i = 0; i < blk; ++i)
669             diff_bias[ocb * blksize + i] = db[i];
670     });
671 }
672 
673 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bias(const exec_ctx_t & ctx) const674 void ref_deconvolution_bwd_weights_t::compute_bias(
675         const exec_ctx_t &ctx) const {
676     using dbia_data_t = typename prec_traits<dbia_type>::type;
677     using ddst_data_t = typename prec_traits<ddst_type>::type;
678 
679     auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS);
680     auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST);
681 
682     using namespace format_tag;
683     switch (pd()->dst_tag_) {
684         case ncdhw:
685         case nchw:
686         case ncw:
687             compute_bwd_bias_ncdhw<dbia_type, ddst_type>(diff_bias, diff_dst);
688             break;
689         case ndhwc:
690         case nhwc:
691         case nwc:
692             compute_bwd_bias_ndhwc<dbia_type, ddst_type>(diff_bias, diff_dst);
693             break;
694         case nCdhw8c:
695         case nChw8c:
696         case nCw8c:
697             assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
698             compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 8>(
699                     diff_bias, diff_dst);
700             break;
701         case nCdhw16c:
702         case nChw16c:
703         case nCw16c:
704             compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 16>(
705                     diff_bias, diff_dst);
706             break;
707         default:
708             assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
709             compute_bwd_bias((float *)diff_bias, (const float *)diff_dst);
710             break;
711     }
712 }
713 
execute(const exec_ctx_t & ctx) const714 status_t ref_deconvolution_bwd_weights_t::execute(const exec_ctx_t &ctx) const {
715     using namespace memory_tracking::names;
716     const auto &args = ctx.args();
717     exec_args_t conv_args;
718     conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
719     conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
720     conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS);
721     exec_ctx_t conv_ctx(ctx, std::move(conv_args));
722 
723     nested_scratchpad_t ns(ctx, key_nested, conv_p_);
724     conv_ctx.set_scratchpad_grantor(ns.grantor());
725     status_t status = conv_p_->execute(conv_ctx);
726     if (status != status::success) return status;
727 
728     if (pd()->with_bias()) {
729         using namespace data_type;
730 
731         auto dbia_type = pd()->diff_weights_md(1)->data_type;
732         auto ddst_type = pd()->diff_dst_md()->data_type;
733         if (utils::everyone_is(f32, dbia_type, ddst_type))
734             compute_bias<f32, f32>(ctx);
735         else if (utils::everyone_is(bf16, dbia_type, ddst_type))
736             compute_bias<bf16, bf16>(ctx);
737         else if (dbia_type == f32 && ddst_type == bf16)
738             compute_bias<f32, bf16>(ctx);
739         else {
740             assert(!"unsupported data type");
741             return status::runtime_error;
742         }
743     }
744     return status::success;
745 }
746 
747 using namespace data_type;
748 
749 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f32>(
750         const exec_ctx_t &ctx) const;
751 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, bf16>(
752         const exec_ctx_t &ctx) const;
753 template void ref_deconvolution_bwd_weights_t::compute_bias<bf16, bf16>(
754         const exec_ctx_t &ctx) const;
755 } // namespace cpu
756 } // namespace impl
757 } // namespace dnnl
758 
759 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
760