1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <functional>
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_thread.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 
24 #include "cpu/cpu_primitive.hpp"
25 #include "cpu/ref_io_helper.hpp"
26 
27 #include "cpu/ref_convolution_utils.hpp"
28 #include "cpu/ref_deconvolution.hpp"
29 
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 
compute_fwd_bias_common(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const34 void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx,
35         void *dst, const float *conv_output, bool non_default_attr) const {
36     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
37     const memory_desc_wrapper dst_d(pd()->dst_md());
38     const memory_desc_wrapper bias_d(pd()->weights_md(1));
39 
40     const auto G = pd()->G();
41     const auto MB = pd()->MB();
42     const auto OH = pd()->OH();
43     const auto OW = pd()->OW();
44     const auto OD = pd()->OD();
45     const auto OC = pd()->OC() / G;
46     const auto ndims = pd()->desc()->src_desc.ndims;
47 
48     parallel_nd(MB, G, OC, OD, OH, OW,
49             [&](dim_t mb, dim_t g, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
50                 const dim_t c = g * OC + oc;
51                 const dim_t off = ref_conv_utils::get_data_off(
52                         dst_d, ndims, mb, c, od, oh, ow);
53                 float b = io::load_float_value(bias_d.data_type(), bias, c);
54                 float d = conv_output[off];
55                 // Use f32 if attributes happen after bias to get precise answer
56                 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
57                 io::store_float_value(dt, d + b, dst, off);
58             });
59 }
60 
compute_fwd_bias_ncdhw(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const61 void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx,
62         void *dst, const float *conv_output, bool non_default_attr) const {
63     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
64     const memory_desc_wrapper dst_d(pd()->dst_md());
65     const memory_desc_wrapper bias_d(pd()->weights_md(1));
66 
67     const auto MB = pd()->MB();
68     const auto OC = pd()->OC();
69     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
70 
71     parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
72         const dim_t off = (mb * OC + oc) * SP;
73         float b = io::load_float_value(bias_d.data_type(), bias, oc);
74         PRAGMA_OMP_SIMD()
75         for (dim_t sp = 0; sp < SP; ++sp) {
76             float d = conv_output[off + sp];
77             // Use f32 if attributes happen after bias to get precise answer.
78             auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
79             io::store_float_value(dt, d + b, dst, off + sp);
80         }
81     });
82 }
83 
compute_fwd_bias_ndhwc(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const84 void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx,
85         void *dst, const float *conv_output, bool non_default_attr) const {
86     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
87     const memory_desc_wrapper dst_d(pd()->dst_md());
88     const memory_desc_wrapper bias_d(pd()->weights_md(1));
89 
90     const auto MB = pd()->MB();
91     const auto OC = pd()->OC();
92     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
93 
94     parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
95         const dim_t off = (mb * SP + sp) * OC;
96         PRAGMA_OMP_SIMD()
97         for (dim_t oc = 0; oc < OC; ++oc) {
98             float b = io::load_float_value(bias_d.data_type(), bias, oc);
99             float d = conv_output[off + oc];
100             // Use f32 if attributes happen after bias to get precise answer.
101             auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
102             io::store_float_value(dt, d + b, dst, off + oc);
103         }
104     });
105 }
106 
107 template <dim_t blk_size>
compute_fwd_bias_nCdhwXc(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const108 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx,
109         void *dst, const float *conv_output, bool non_default_attr) const {
110     const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
111     const memory_desc_wrapper dst_d(pd()->dst_md());
112     const memory_desc_wrapper bias_d(pd()->weights_md(1));
113 
114     const auto MB = pd()->MB();
115     const auto OC = pd()->OC();
116     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
117     const auto stride_mb = dst_d.blocking_desc().strides[0];
118 
119     parallel_nd(MB, utils::div_up(OC, blk_size), SP,
120             [&](dim_t mb, dim_t oc_blk, dim_t sp) {
121                 const dim_t oc = oc_blk * blk_size;
122                 const dim_t off = mb * stride_mb + oc * SP + sp * blk_size;
123                 const dim_t blk = nstl::min(blk_size, OC - oc);
124 
125                 PRAGMA_OMP_SIMD()
126                 for (dim_t i = 0; i < blk_size; ++i) {
127                     float b = i < blk ? io::load_float_value(
128                                       bias_d.data_type(), bias, oc + i)
129                                       : 0;
130                     float d = conv_output[off + i];
131                     // Use f32 if attributes happen after bias to get precise
132                     // answer.
133                     auto dt = non_default_attr ? data_type::f32
134                                                : dst_d.data_type();
135                     io::store_float_value(dt, d + b, dst, off + i);
136                 }
137             });
138 }
139 
compute_fwd_bias(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const140 void ref_deconvolution_fwd_t::compute_fwd_bias(const exec_ctx_t &ctx, void *dst,
141         const float *conv_output, bool non_default_attr) const {
142     using namespace format_tag;
143     switch (pd()->dst_tag_) {
144         case ncdhw:
145         case nchw:
146         case ncw:
147             compute_fwd_bias_ncdhw(ctx, dst, conv_output, non_default_attr);
148             break;
149         case ndhwc:
150         case nhwc:
151         case nwc:
152             compute_fwd_bias_ndhwc(ctx, dst, conv_output, non_default_attr);
153             break;
154         case nCdhw8c:
155         case nChw8c:
156         case nCw8c:
157             compute_fwd_bias_nCdhwXc<8>(
158                     ctx, dst, conv_output, non_default_attr);
159             break;
160         case nCdhw16c:
161         case nChw16c:
162         case nCw16c:
163             compute_fwd_bias_nCdhwXc<16>(
164                     ctx, dst, conv_output, non_default_attr);
165             break;
166         default:
167             compute_fwd_bias_common(ctx, dst, conv_output, non_default_attr);
168             break;
169     }
170 }
171 
compute_ref_attrs(const exec_ctx_t & ctx,const float * conv_output,void * original_dst) const172 status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx,
173         const float *conv_output, void *original_dst) const {
174     auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
175     DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
176     const bool is_dst_zp_common
177             = pd()->attr()->zero_points_.common(DNNL_ARG_DST);
178 
179     const memory_desc_wrapper dst_d(pd()->dst_md());
180 
181     const auto MB = pd()->MB();
182     const auto OH = pd()->OH();
183     const auto OW = pd()->OW();
184     const auto OD = pd()->OD();
185     const auto OC = pd()->OC();
186     const auto OCP = dst_d.padded_dims()[1];
187     const auto ndims = pd()->desc()->src_desc.ndims;
188 
189     const auto maybe_oscale = [=](float &d, dim_t oc) {
190         // scale_idx_mult = 1 for per_oc scales and 0, otherwise
191         const int scale_idx_mult
192                 = pd()->attr()->output_scales_.mask_ == (1 << 1);
193         const float *scales = pd()->attr()->output_scales_.scales_;
194         d *= scales[oc * scale_idx_mult];
195     };
196 
197     const auto maybe_dst_zero_point = [=](float &result, dim_t oc) {
198         if (is_dst_zp_common)
199             result += dst_zero_point[0];
200         else
201             result += dst_zero_point[oc];
202     };
203 
204     parallel_nd(MB, OCP, OD, OH, OW,
205             [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) {
206                 auto dst_off = ref_conv_utils::get_data_off(
207                         dst_d, ndims, mb, ocp, od, oh, ow);
208                 float tmp_result = 0;
209 
210                 if (ocp < OC) {
211                     dim_t dst_l_off = (mb * OC + ocp) * OD * OH * OW
212                             + od * OH * OW + oh * OW + ow;
213                     tmp_result = conv_output[dst_off];
214                     maybe_oscale(tmp_result, ocp);
215 
216                     ref_post_ops_t::args_t args;
217                     if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1)
218                         args.dst_val = io::load_float_value(
219                                 dst_d.data_type(), original_dst, dst_off);
220                     args.ctx = &ctx;
221                     args.l_offset = dst_l_off;
222                     args.dst_md = pd()->dst_md();
223                     ref_post_ops->execute(tmp_result, args);
224                     maybe_dst_zero_point(tmp_result, ocp);
225                 }
226                 io::store_float_value(
227                         dst_d.data_type(), tmp_result, dst, dst_off);
228             });
229 
230     return status_t::dnnl_success;
231 }
232 
get_weights_off(const memory_desc_wrapper & wei_d,bool with_groups,int ndims,dim_t g,dim_t oc,dim_t ic,dim_t kd,dim_t kh,dim_t kw)233 dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups,
234         int ndims, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
235     switch (ndims) {
236         case 5:
237             return with_groups ? wei_d.off(g, oc, ic, kd, kh, kw)
238                                : wei_d.off(oc, ic, kd, kh, kw);
239         case 4:
240             return with_groups ? wei_d.off(g, oc, ic, kh, kw)
241                                : wei_d.off(oc, ic, kh, kw);
242         case 3:
243             return with_groups ? wei_d.off(g, oc, ic, kw)
244                                : wei_d.off(oc, ic, kw);
245         default: assert(!"unsupported ndims"); return dim_t(0);
246     }
247 
248     return 0;
249 };
250 
251 template <data_type_t wei_type>
compute_src_zp_compensation(const exec_ctx_t & ctx,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * pd)252 static void compute_src_zp_compensation(const exec_ctx_t &ctx,
253         const int32_t *src_zero_point, const bool is_src_zp_common,
254         typename prec_traits<wei_type>::type *wei,
255         const cpu_deconvolution_fwd_pd_t *pd) {
256     using namespace memory_tracking::names;
257 
258     const auto scratchpad = ctx.get_scratchpad_grantor();
259     int32_t *zp_compensation = scratchpad.get<int32_t>(key_deconv_zp);
260     const auto G = pd->G();
261     const auto KH = pd->KH();
262     const auto KW = pd->KW();
263     const auto KD = pd->KD();
264     const auto OC = pd->OC() / G;
265     const auto IC = pd->IC() / G;
266     const memory_desc_wrapper wei_d(pd->weights_md());
267     const bool with_groups = pd->with_groups();
268     const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0);
269     const auto get_wei_off
270             = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
271                   return get_weights_off(
272                           wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
273               };
274 
275     parallel_nd(G, OC, [&](const dim_t g, const dim_t oc) {
276         const auto out_offset = g * OC + oc;
277         int32_t acc = 0;
278 
279         for_(dim_t kd = 0; kd < KD; ++kd)
280         for_(dim_t kh = 0; kh < KH; ++kh)
281         for (dim_t kw = 0; kw < KW; ++kw) {
282             for (dim_t ic = 0; ic < IC; ++ic) {
283                 const auto weights_offset = get_wei_off(g, oc, ic, kd, kh, kw);
284                 const int32_t wei32 = static_cast<int32_t>(wei[weights_offset]);
285 
286                 if (is_src_zp_common)
287                     acc += wei32;
288                 else
289                     acc += wei32 * src_zero_point[g * IC + ic];
290             }
291         }
292 
293         zp_compensation[out_offset] = acc * src_zero_point[0];
294     });
295 }
296 
297 template <data_type_t wei_type>
298 static std::function<int32_t(
299         const dim_t, const dim_t, const dim_t, const dim_t, const dim_t)>
prepare_zp_pad_comp_ker(const dim_t ndims,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * deconv_pd)300 prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point,
301         const bool is_src_zp_common, typename prec_traits<wei_type>::type *wei,
302         const cpu_deconvolution_fwd_pd_t *deconv_pd) {
303 
304     const auto KH = deconv_pd->KH();
305     const auto KW = deconv_pd->KW();
306     const auto KD = deconv_pd->KD();
307     const auto KSD = deconv_pd->KSD();
308     const auto KSH = deconv_pd->KSH();
309     const auto KSW = deconv_pd->KSW();
310     const auto KDD = deconv_pd->KDD() + 1;
311     const auto KDH = deconv_pd->KDH() + 1;
312     const auto KDW = deconv_pd->KDW() + 1;
313     const auto IC = deconv_pd->IC() / deconv_pd->G();
314     const auto IH = deconv_pd->IH();
315     const auto IW = deconv_pd->IW();
316     const auto ID = deconv_pd->ID();
317     const auto pad_front = deconv_pd->padFront();
318     const auto pad_top = deconv_pd->padT();
319     const auto pad_left = deconv_pd->padL();
320     const bool with_groups = deconv_pd->with_groups();
321     const memory_desc_wrapper wei_d(deconv_pd->weights_md());
322     const auto get_wei_off
323             = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
324                   return get_weights_off(
325                           wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
326               };
327 
328     return [=](const dim_t g, const dim_t oc, const dim_t od, const dim_t oh,
329                    const dim_t ow) {
330         int32_t zp_pad_compensation = 0;
331 
332         for (dim_t kd = 0; kd < KD; ++kd) {
333             const dim_t id = od - kd * KDD + pad_front;
334             const bool should_apply_pad_comp_d
335                     = id < 0 || id % KSD != 0 || (id / KSD) >= ID;
336 
337             for (dim_t kh = 0; kh < KH; ++kh) {
338                 const dim_t ih = oh - kh * KDH + pad_top;
339                 const bool should_apply_pad_comp_h
340                         = ih < 0 || ih % KSH != 0 || (ih / KSH) >= IH;
341 
342                 for (dim_t kw = 0; kw < KW; ++kw) {
343                     const dim_t iw = ow - kw * KDW + pad_left;
344                     const bool should_apply_pad_comp_w
345                             = iw < 0 || iw % KSW != 0 || (iw / KSW) >= IW;
346 
347                     if (should_apply_pad_comp_d || should_apply_pad_comp_h
348                             || should_apply_pad_comp_w) {
349 
350                         for (dim_t ic = 0; ic < IC; ic++) {
351                             const auto wei_off
352                                     = get_wei_off(g, oc, ic, kd, kh, kw);
353                             const int32_t wei32
354                                     = static_cast<int32_t>(wei[wei_off]);
355 
356                             if (is_src_zp_common)
357                                 zp_pad_compensation += wei32;
358                             else
359                                 zp_pad_compensation
360                                         += wei32 * src_zero_point[g * IC + ic];
361                         }
362                     }
363                 }
364             }
365         }
366 
367         if (is_src_zp_common && zp_pad_compensation)
368             zp_pad_compensation *= src_zero_point[0];
369 
370         return zp_pad_compensation;
371     };
372 }
373 
374 template <data_type_t wei_type>
apply_src_zero_point(const exec_ctx_t & ctx,const cpu_deconvolution_fwd_pd_t * deconv_pd,float * conv_output)375 static status_t apply_src_zero_point(const exec_ctx_t &ctx,
376         const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) {
377     using wei_data_t = typename prec_traits<wei_type>::type;
378     using namespace memory_tracking::names;
379     using namespace data_type;
380 
381     // required by DEFINE_ZERO_POINTS_BUFFER macro
382     const auto pd = [&]() { return deconv_pd; };
383     const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS);
384     DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
385     const bool is_src_zp_common
386             = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC);
387 
388     const auto scratchpad = ctx.get_scratchpad_grantor();
389     const int32_t *const zp_src_compensation
390             = scratchpad.get<int32_t>(key_deconv_zp);
391     const memory_desc_wrapper dst_d(pd()->dst_md());
392     const auto ndims = dst_d.ndims();
393 
394     const auto G = pd()->G();
395     const auto MB = pd()->MB();
396     const auto OH = pd()->OH();
397     const auto OW = pd()->OW();
398     const auto OD = pd()->OD();
399     const auto OC = pd()->OC() / G;
400 
401     compute_src_zp_compensation<wei_type>(
402             ctx, src_zero_point, is_src_zp_common, wei, deconv_pd);
403     const auto zp_pad_comp_ker = prepare_zp_pad_comp_ker<wei_type>(
404             ndims, src_zero_point, is_src_zp_common, wei, deconv_pd);
405 
406     parallel_nd(MB, G, OC, OD, OH, OW,
407             [&](const dim_t mb, const dim_t g, const dim_t oc, const dim_t od,
408                     const dim_t oh, const dim_t ow) {
409                 const auto oc_off = g * OC + oc;
410                 const auto dst_off = ref_conv_utils::get_data_off(
411                         dst_d, ndims, mb, oc_off, od, oh, ow);
412                 int32_t conv_result
413                         = conv_output[dst_off] - zp_src_compensation[oc_off];
414 
415                 if (const auto zp_pad_compensation
416                         = zp_pad_comp_ker(g, oc, od, oh, ow)) {
417                     conv_result += zp_pad_compensation;
418                 }
419 
420                 conv_output[dst_off] = static_cast<float>(conv_result);
421             });
422 
423     return status::success;
424 }
425 
execute(const exec_ctx_t & ctx) const426 status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const {
427     using namespace memory_tracking::names;
428     const auto scratchpad = ctx.get_scratchpad_grantor();
429     const bool ref_bias = pd()->with_bias() && !pd()->conv_supports_bias_;
430     const bool non_default_attr = !pd()->attr()->has_default_values();
431 
432     const auto &args = ctx.args();
433     exec_args_t conv_args;
434     conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
435     conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
436     if (pd()->with_bias() && pd()->conv_supports_bias_)
437         conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
438 
439     // Create intermediate memory for f32 output if needed.
440     auto dst = args.at(DNNL_ARG_DST);
441     memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(),
442             scratchpad.get_memory_storage(key_deconv_bias));
443     memory_arg_t tmp_conv_output = {&tmp_memory, false};
444 
445     conv_args[DNNL_ARG_DIFF_SRC]
446             = ref_bias || non_default_attr ? tmp_conv_output : dst;
447 
448     // When sum post-op happens, we need to copy original destination memory
449     // prior call to external convolution happens.
450     if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) {
451         void *original_dst = scratchpad.get(key_deconv_sum);
452         const memory_desc_wrapper dst_d(pd()->dst_md());
453         void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
454         const auto dt_size = dst_d.data_type_size();
455 
456         parallel(0, [&](const int ithr, const int nthr) {
457             dim_t start {0}, end {0};
458             balance211(dst_d.nelems(true), nthr, ithr, start, end);
459             auto o_dst_start = (char *)original_dst + start * dt_size;
460             auto dst_start = (char *)dst + start * dt_size;
461             const auto size = (end - start) * dt_size;
462 
463             std::memcpy(o_dst_start, dst_start, size);
464         });
465     }
466 
467     exec_ctx_t conv_ctx(ctx, std::move(conv_args));
468 
469     nested_scratchpad_t ns(ctx, key_nested, conv_p_);
470     conv_ctx.set_scratchpad_grantor(ns.grantor());
471     auto status = conv_p_->execute(conv_ctx);
472     if (status != status::success) return status;
473 
474     using namespace data_type;
475 
476     if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
477         float *conv_output = scratchpad.get<float>(key_deconv_bias);
478         const auto wei_dt = pd()->weights_md()->data_type;
479         switch (wei_dt) {
480             case s8: apply_src_zero_point<s8>(ctx, pd(), conv_output); break;
481             case u8: apply_src_zero_point<u8>(ctx, pd(), conv_output); break;
482             default: assert(!"unsupported data type");
483         }
484     }
485 
486     float *conv_output = scratchpad.get<float>(key_deconv_bias);
487 
488     if (ref_bias) {
489         void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
490         void *tmp_output = non_default_attr ? conv_output : dst;
491         compute_fwd_bias(ctx, tmp_output, conv_output, non_default_attr);
492     }
493 
494     if (non_default_attr) {
495         void *original_dst = scratchpad.get<void>(key_deconv_sum);
496         compute_ref_attrs(ctx, conv_output, original_dst);
497     }
498 
499     return status::success;
500 }
501 
execute(const exec_ctx_t & ctx) const502 status_t ref_deconvolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
503     using namespace memory_tracking::names;
504     const auto &args = ctx.args();
505     exec_args_t conv_args;
506     conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
507     conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
508     conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
509     exec_ctx_t conv_ctx(ctx, std::move(conv_args));
510 
511     nested_scratchpad_t ns(ctx, key_nested, conv_p_);
512     conv_ctx.set_scratchpad_grantor(ns.grantor());
513     conv_p_->execute(conv_ctx);
514     return status::success;
515 }
516 
compute_bwd_bias(float * diff_bias,const float * diff_dst) const517 void ref_deconvolution_bwd_weights_t::compute_bwd_bias(
518         float *diff_bias, const float *diff_dst) const {
519     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
520 
521     const auto G = pd()->G();
522     const auto MB = pd()->MB();
523     const auto OH = pd()->OH();
524     const auto OW = pd()->OW();
525     const auto OC = pd()->OC() / G;
526     const auto OD = pd()->OD();
527     const auto ndims = pd()->desc()->src_desc.ndims;
528 
529     parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
530         float db = 0;
531         for_(dim_t mb = 0; mb < MB; ++mb)
532         for_(dim_t od = 0; od < OD; ++od)
533         for_(dim_t oh = 0; oh < OH; ++oh)
534         for (dim_t ow = 0; ow < OW; ++ow) {
535             const auto d_dst_off = ref_conv_utils::get_data_off(
536                     diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
537             db += diff_dst[d_dst_off];
538         }
539         diff_bias[g * OC + oc] = db;
540     });
541 }
542 
543 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ncdhw(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const544 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
545         typename prec_traits<dbia_type>::type *diff_bias,
546         const typename prec_traits<ddst_type>::type *diff_dst) const {
547     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
548 
549     const auto OC = pd()->OC();
550     const auto MB = pd()->MB();
551     const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
552 
553     parallel_nd(OC, [&](dim_t oc) {
554         float db = 0;
555         for (dim_t mb = 0; mb < MB; ++mb) {
556             PRAGMA_OMP_SIMD(reduction(+ : db))
557             for (dim_t sp = 0; sp < SP; ++sp) {
558                 auto offset = (size_t)(mb * OC + oc) * SP + sp;
559                 db += diff_dst[offset];
560             }
561         }
562         diff_bias[oc] = db;
563     });
564 }
565 
566 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ndhwc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const567 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc(
568         typename prec_traits<dbia_type>::type *diff_bias,
569         const typename prec_traits<ddst_type>::type *diff_dst) const {
570     const auto MB = pd()->MB();
571     const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
572     const auto OC = pd()->OC();
573 
574     parallel_nd(OC, [&](dim_t oc) {
575         float db = 0;
576         for (dim_t mb = 0; mb < MB; ++mb) {
577             PRAGMA_OMP_SIMD(reduction(+ : db))
578             for (dim_t sp = 0; sp < SP; ++sp) {
579                 const dim_t offset = (mb * SP + sp) * OC + oc;
580                 db += diff_dst[offset];
581             }
582         }
583         diff_bias[oc] = static_cast<typename prec_traits<dbia_type>::type>(db);
584     });
585 }
586 
587 template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
compute_bwd_bias_nCdhwXc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const588 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
589         typename prec_traits<dbia_type>::type *diff_bias,
590         const typename prec_traits<ddst_type>::type *diff_dst) const {
591     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
592 
593     const auto OC = pd()->OC();
594     const auto MB = pd()->MB();
595     const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
596 
597     const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
598 
599     parallel_nd(utils::div_up(OC, blksize), [&](dim_t ocb) {
600         float db[blksize] = {0};
601 
602         for (dim_t mb = 0; mb < MB; ++mb) {
603             for (dim_t sp = 0; sp < SP; ++sp) {
604                 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
605 
606                 PRAGMA_OMP_SIMD()
607                 for (dim_t i = 0; i < blksize; ++i)
608                     db[i] += diff_dst[offset + i];
609             }
610         }
611 
612         const dim_t blk = nstl::min(blksize, OC - ocb * blksize);
613 
614         PRAGMA_OMP_SIMD()
615         for (dim_t i = 0; i < blk; ++i)
616             diff_bias[ocb * blksize + i] = db[i];
617     });
618 }
619 
620 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bias(const exec_ctx_t & ctx) const621 void ref_deconvolution_bwd_weights_t::compute_bias(
622         const exec_ctx_t &ctx) const {
623     using dbia_data_t = typename prec_traits<dbia_type>::type;
624     using ddst_data_t = typename prec_traits<ddst_type>::type;
625 
626     auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS);
627     auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST);
628 
629     using namespace format_tag;
630     switch (pd()->dst_tag_) {
631         case ncdhw:
632         case nchw:
633         case ncw:
634             compute_bwd_bias_ncdhw<dbia_type, ddst_type>(diff_bias, diff_dst);
635             break;
636         case ndhwc:
637         case nhwc:
638         case nwc:
639             compute_bwd_bias_ndhwc<dbia_type, ddst_type>(diff_bias, diff_dst);
640             break;
641         case nCdhw8c:
642         case nChw8c:
643         case nCw8c:
644             assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
645             compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 8>(
646                     diff_bias, diff_dst);
647             break;
648         case nCdhw16c:
649         case nChw16c:
650         case nCw16c:
651             compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 16>(
652                     diff_bias, diff_dst);
653             break;
654         default:
655             assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
656             compute_bwd_bias((float *)diff_bias, (const float *)diff_dst);
657             break;
658     }
659 }
660 
execute(const exec_ctx_t & ctx) const661 status_t ref_deconvolution_bwd_weights_t::execute(const exec_ctx_t &ctx) const {
662     using namespace memory_tracking::names;
663     const auto &args = ctx.args();
664     exec_args_t conv_args;
665     conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
666     conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
667     conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS);
668     exec_ctx_t conv_ctx(ctx, std::move(conv_args));
669 
670     nested_scratchpad_t ns(ctx, key_nested, conv_p_);
671     conv_ctx.set_scratchpad_grantor(ns.grantor());
672     status_t status = conv_p_->execute(conv_ctx);
673     if (status != status::success) return status;
674 
675     if (pd()->with_bias()) {
676         using namespace data_type;
677 
678         auto dbia_type = pd()->diff_weights_md(1)->data_type;
679         auto ddst_type = pd()->diff_dst_md()->data_type;
680         if (utils::everyone_is(f32, dbia_type, ddst_type))
681             compute_bias<f32, f32>(ctx);
682         else if (utils::everyone_is(bf16, dbia_type, ddst_type))
683             compute_bias<bf16, bf16>(ctx);
684         else if (dbia_type == f32 && ddst_type == bf16)
685             compute_bias<f32, bf16>(ctx);
686         else {
687             assert(!"unsupported data type");
688             return status::runtime_error;
689         }
690     }
691     return status::success;
692 }
693 
694 using namespace data_type;
695 
696 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f32>(
697         const exec_ctx_t &ctx) const;
698 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, bf16>(
699         const exec_ctx_t &ctx) const;
700 template void ref_deconvolution_bwd_weights_t::compute_bias<bf16, bf16>(
701         const exec_ctx_t &ctx) const;
702 } // namespace cpu
703 } // namespace impl
704 } // namespace dnnl
705 
706 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
707