1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <functional>
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_thread.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23
24 #include "cpu/cpu_primitive.hpp"
25 #include "cpu/ref_io_helper.hpp"
26
27 #include "cpu/ref_convolution_utils.hpp"
28 #include "cpu/ref_deconvolution.hpp"
29
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33
compute_fwd_bias_common(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const34 void ref_deconvolution_fwd_t::compute_fwd_bias_common(const exec_ctx_t &ctx,
35 void *dst, const float *conv_output, bool non_default_attr) const {
36 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
37 const memory_desc_wrapper dst_d(pd()->dst_md());
38 const memory_desc_wrapper bias_d(pd()->weights_md(1));
39
40 const auto G = pd()->G();
41 const auto MB = pd()->MB();
42 const auto OH = pd()->OH();
43 const auto OW = pd()->OW();
44 const auto OD = pd()->OD();
45 const auto OC = pd()->OC() / G;
46 const auto ndims = pd()->desc()->src_desc.ndims;
47
48 parallel_nd(MB, G, OC, OD, OH, OW,
49 [&](dim_t mb, dim_t g, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
50 const dim_t c = g * OC + oc;
51 const dim_t off = ref_conv_utils::get_data_off(
52 dst_d, ndims, mb, c, od, oh, ow);
53 float b = io::load_float_value(bias_d.data_type(), bias, c);
54 float d = conv_output[off];
55 // Use f32 if attributes happen after bias to get precise answer
56 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
57 io::store_float_value(dt, d + b, dst, off);
58 });
59 }
60
compute_fwd_bias_ncdhw(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const61 void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const exec_ctx_t &ctx,
62 void *dst, const float *conv_output, bool non_default_attr) const {
63 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
64 const memory_desc_wrapper dst_d(pd()->dst_md());
65 const memory_desc_wrapper bias_d(pd()->weights_md(1));
66
67 const auto MB = pd()->MB();
68 const auto OC = pd()->OC();
69 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
70
71 parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
72 const dim_t off = (mb * OC + oc) * SP;
73 float b = io::load_float_value(bias_d.data_type(), bias, oc);
74 PRAGMA_OMP_SIMD()
75 for (dim_t sp = 0; sp < SP; ++sp) {
76 float d = conv_output[off + sp];
77 // Use f32 if attributes happen after bias to get precise answer.
78 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
79 io::store_float_value(dt, d + b, dst, off + sp);
80 }
81 });
82 }
83
compute_fwd_bias_ndhwc(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const84 void ref_deconvolution_fwd_t::compute_fwd_bias_ndhwc(const exec_ctx_t &ctx,
85 void *dst, const float *conv_output, bool non_default_attr) const {
86 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
87 const memory_desc_wrapper dst_d(pd()->dst_md());
88 const memory_desc_wrapper bias_d(pd()->weights_md(1));
89
90 const auto MB = pd()->MB();
91 const auto OC = pd()->OC();
92 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
93
94 parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
95 const dim_t off = (mb * SP + sp) * OC;
96 PRAGMA_OMP_SIMD()
97 for (dim_t oc = 0; oc < OC; ++oc) {
98 float b = io::load_float_value(bias_d.data_type(), bias, oc);
99 float d = conv_output[off + oc];
100 // Use f32 if attributes happen after bias to get precise answer.
101 auto dt = non_default_attr ? data_type::f32 : dst_d.data_type();
102 io::store_float_value(dt, d + b, dst, off + oc);
103 }
104 });
105 }
106
107 template <dim_t blk_size>
compute_fwd_bias_nCdhwXc(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const108 void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx,
109 void *dst, const float *conv_output, bool non_default_attr) const {
110 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
111 const memory_desc_wrapper dst_d(pd()->dst_md());
112 const memory_desc_wrapper bias_d(pd()->weights_md(1));
113
114 const auto MB = pd()->MB();
115 const auto OC = pd()->OC();
116 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
117 const auto stride_mb = dst_d.blocking_desc().strides[0];
118
119 parallel_nd(MB, utils::div_up(OC, blk_size), SP,
120 [&](dim_t mb, dim_t oc_blk, dim_t sp) {
121 const dim_t oc = oc_blk * blk_size;
122 const dim_t off = mb * stride_mb + oc * SP + sp * blk_size;
123 const dim_t blk = nstl::min(blk_size, OC - oc);
124
125 PRAGMA_OMP_SIMD()
126 for (dim_t i = 0; i < blk_size; ++i) {
127 float b = i < blk ? io::load_float_value(
128 bias_d.data_type(), bias, oc + i)
129 : 0;
130 float d = conv_output[off + i];
131 // Use f32 if attributes happen after bias to get precise
132 // answer.
133 auto dt = non_default_attr ? data_type::f32
134 : dst_d.data_type();
135 io::store_float_value(dt, d + b, dst, off + i);
136 }
137 });
138 }
139
compute_fwd_bias(const exec_ctx_t & ctx,void * dst,const float * conv_output,bool non_default_attr) const140 void ref_deconvolution_fwd_t::compute_fwd_bias(const exec_ctx_t &ctx, void *dst,
141 const float *conv_output, bool non_default_attr) const {
142 using namespace format_tag;
143 switch (pd()->dst_tag_) {
144 case ncdhw:
145 case nchw:
146 case ncw:
147 compute_fwd_bias_ncdhw(ctx, dst, conv_output, non_default_attr);
148 break;
149 case ndhwc:
150 case nhwc:
151 case nwc:
152 compute_fwd_bias_ndhwc(ctx, dst, conv_output, non_default_attr);
153 break;
154 case nCdhw8c:
155 case nChw8c:
156 case nCw8c:
157 compute_fwd_bias_nCdhwXc<8>(
158 ctx, dst, conv_output, non_default_attr);
159 break;
160 case nCdhw16c:
161 case nChw16c:
162 case nCw16c:
163 compute_fwd_bias_nCdhwXc<16>(
164 ctx, dst, conv_output, non_default_attr);
165 break;
166 default:
167 compute_fwd_bias_common(ctx, dst, conv_output, non_default_attr);
168 break;
169 }
170 }
171
compute_ref_attrs(const exec_ctx_t & ctx,const float * conv_output,void * original_dst) const172 status_t ref_deconvolution_fwd_t::compute_ref_attrs(const exec_ctx_t &ctx,
173 const float *conv_output, void *original_dst) const {
174 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
175 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
176 const bool is_dst_zp_common
177 = pd()->attr()->zero_points_.common(DNNL_ARG_DST);
178
179 const memory_desc_wrapper dst_d(pd()->dst_md());
180
181 const auto MB = pd()->MB();
182 const auto OH = pd()->OH();
183 const auto OW = pd()->OW();
184 const auto OD = pd()->OD();
185 const auto OC = pd()->OC();
186 const auto OCP = dst_d.padded_dims()[1];
187 const auto ndims = pd()->desc()->src_desc.ndims;
188
189 const auto maybe_oscale = [=](float &d, dim_t oc) {
190 // scale_idx_mult = 1 for per_oc scales and 0, otherwise
191 const int scale_idx_mult
192 = pd()->attr()->output_scales_.mask_ == (1 << 1);
193 const float *scales = pd()->attr()->output_scales_.scales_;
194 d *= scales[oc * scale_idx_mult];
195 };
196
197 const auto maybe_dst_zero_point = [=](float &result, dim_t oc) {
198 if (is_dst_zp_common)
199 result += dst_zero_point[0];
200 else
201 result += dst_zero_point[oc];
202 };
203
204 parallel_nd(MB, OCP, OD, OH, OW,
205 [&](dim_t mb, int ocp, dim_t od, dim_t oh, dim_t ow) {
206 auto dst_off = ref_conv_utils::get_data_off(
207 dst_d, ndims, mb, ocp, od, oh, ow);
208 float tmp_result = 0;
209
210 if (ocp < OC) {
211 dim_t dst_l_off = (mb * OC + ocp) * OD * OH * OW
212 + od * OH * OW + oh * OW + ow;
213 tmp_result = conv_output[dst_off];
214 maybe_oscale(tmp_result, ocp);
215
216 ref_post_ops_t::args_t args;
217 if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1)
218 args.dst_val = io::load_float_value(
219 dst_d.data_type(), original_dst, dst_off);
220 args.ctx = &ctx;
221 args.l_offset = dst_l_off;
222 args.dst_md = pd()->dst_md();
223 ref_post_ops->execute(tmp_result, args);
224 maybe_dst_zero_point(tmp_result, ocp);
225 }
226 io::store_float_value(
227 dst_d.data_type(), tmp_result, dst, dst_off);
228 });
229
230 return status_t::dnnl_success;
231 }
232
get_weights_off(const memory_desc_wrapper & wei_d,bool with_groups,int ndims,dim_t g,dim_t oc,dim_t ic,dim_t kd,dim_t kh,dim_t kw)233 dim_t get_weights_off(const memory_desc_wrapper &wei_d, bool with_groups,
234 int ndims, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
235 switch (ndims) {
236 case 5:
237 return with_groups ? wei_d.off(g, oc, ic, kd, kh, kw)
238 : wei_d.off(oc, ic, kd, kh, kw);
239 case 4:
240 return with_groups ? wei_d.off(g, oc, ic, kh, kw)
241 : wei_d.off(oc, ic, kh, kw);
242 case 3:
243 return with_groups ? wei_d.off(g, oc, ic, kw)
244 : wei_d.off(oc, ic, kw);
245 default: assert(!"unsupported ndims"); return dim_t(0);
246 }
247
248 return 0;
249 };
250
251 template <data_type_t wei_type>
compute_src_zp_compensation(const exec_ctx_t & ctx,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * pd)252 static void compute_src_zp_compensation(const exec_ctx_t &ctx,
253 const int32_t *src_zero_point, const bool is_src_zp_common,
254 typename prec_traits<wei_type>::type *wei,
255 const cpu_deconvolution_fwd_pd_t *pd) {
256 using namespace memory_tracking::names;
257
258 const auto scratchpad = ctx.get_scratchpad_grantor();
259 int32_t *zp_compensation = scratchpad.get<int32_t>(key_deconv_zp);
260 const auto G = pd->G();
261 const auto KH = pd->KH();
262 const auto KW = pd->KW();
263 const auto KD = pd->KD();
264 const auto OC = pd->OC() / G;
265 const auto IC = pd->IC() / G;
266 const memory_desc_wrapper wei_d(pd->weights_md());
267 const bool with_groups = pd->with_groups();
268 const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0);
269 const auto get_wei_off
270 = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
271 return get_weights_off(
272 wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
273 };
274
275 parallel_nd(G, OC, [&](const dim_t g, const dim_t oc) {
276 const auto out_offset = g * OC + oc;
277 int32_t acc = 0;
278
279 for_(dim_t kd = 0; kd < KD; ++kd)
280 for_(dim_t kh = 0; kh < KH; ++kh)
281 for (dim_t kw = 0; kw < KW; ++kw) {
282 for (dim_t ic = 0; ic < IC; ++ic) {
283 const auto weights_offset = get_wei_off(g, oc, ic, kd, kh, kw);
284 const int32_t wei32 = static_cast<int32_t>(wei[weights_offset]);
285
286 if (is_src_zp_common)
287 acc += wei32;
288 else
289 acc += wei32 * src_zero_point[g * IC + ic];
290 }
291 }
292
293 zp_compensation[out_offset] = acc * src_zero_point[0];
294 });
295 }
296
297 template <data_type_t wei_type>
298 static std::function<int32_t(
299 const dim_t, const dim_t, const dim_t, const dim_t, const dim_t)>
prepare_zp_pad_comp_ker(const dim_t ndims,const int32_t * src_zero_point,const bool is_src_zp_common,typename prec_traits<wei_type>::type * wei,const cpu_deconvolution_fwd_pd_t * deconv_pd)300 prepare_zp_pad_comp_ker(const dim_t ndims, const int32_t *src_zero_point,
301 const bool is_src_zp_common, typename prec_traits<wei_type>::type *wei,
302 const cpu_deconvolution_fwd_pd_t *deconv_pd) {
303
304 const auto KH = deconv_pd->KH();
305 const auto KW = deconv_pd->KW();
306 const auto KD = deconv_pd->KD();
307 const auto KSD = deconv_pd->KSD();
308 const auto KSH = deconv_pd->KSH();
309 const auto KSW = deconv_pd->KSW();
310 const auto KDD = deconv_pd->KDD() + 1;
311 const auto KDH = deconv_pd->KDH() + 1;
312 const auto KDW = deconv_pd->KDW() + 1;
313 const auto IC = deconv_pd->IC() / deconv_pd->G();
314 const auto IH = deconv_pd->IH();
315 const auto IW = deconv_pd->IW();
316 const auto ID = deconv_pd->ID();
317 const auto pad_front = deconv_pd->padFront();
318 const auto pad_top = deconv_pd->padT();
319 const auto pad_left = deconv_pd->padL();
320 const bool with_groups = deconv_pd->with_groups();
321 const memory_desc_wrapper wei_d(deconv_pd->weights_md());
322 const auto get_wei_off
323 = [=](dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, dim_t kw) {
324 return get_weights_off(
325 wei_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
326 };
327
328 return [=](const dim_t g, const dim_t oc, const dim_t od, const dim_t oh,
329 const dim_t ow) {
330 int32_t zp_pad_compensation = 0;
331
332 for (dim_t kd = 0; kd < KD; ++kd) {
333 const dim_t id = od - kd * KDD + pad_front;
334 const bool should_apply_pad_comp_d
335 = id < 0 || id % KSD != 0 || (id / KSD) >= ID;
336
337 for (dim_t kh = 0; kh < KH; ++kh) {
338 const dim_t ih = oh - kh * KDH + pad_top;
339 const bool should_apply_pad_comp_h
340 = ih < 0 || ih % KSH != 0 || (ih / KSH) >= IH;
341
342 for (dim_t kw = 0; kw < KW; ++kw) {
343 const dim_t iw = ow - kw * KDW + pad_left;
344 const bool should_apply_pad_comp_w
345 = iw < 0 || iw % KSW != 0 || (iw / KSW) >= IW;
346
347 if (should_apply_pad_comp_d || should_apply_pad_comp_h
348 || should_apply_pad_comp_w) {
349
350 for (dim_t ic = 0; ic < IC; ic++) {
351 const auto wei_off
352 = get_wei_off(g, oc, ic, kd, kh, kw);
353 const int32_t wei32
354 = static_cast<int32_t>(wei[wei_off]);
355
356 if (is_src_zp_common)
357 zp_pad_compensation += wei32;
358 else
359 zp_pad_compensation
360 += wei32 * src_zero_point[g * IC + ic];
361 }
362 }
363 }
364 }
365 }
366
367 if (is_src_zp_common && zp_pad_compensation)
368 zp_pad_compensation *= src_zero_point[0];
369
370 return zp_pad_compensation;
371 };
372 }
373
374 template <data_type_t wei_type>
apply_src_zero_point(const exec_ctx_t & ctx,const cpu_deconvolution_fwd_pd_t * deconv_pd,float * conv_output)375 static status_t apply_src_zero_point(const exec_ctx_t &ctx,
376 const cpu_deconvolution_fwd_pd_t *deconv_pd, float *conv_output) {
377 using wei_data_t = typename prec_traits<wei_type>::type;
378 using namespace memory_tracking::names;
379 using namespace data_type;
380
381 // required by DEFINE_ZERO_POINTS_BUFFER macro
382 const auto pd = [&]() { return deconv_pd; };
383 const auto wei = CTX_OUT_MEM(wei_data_t *, DNNL_ARG_WEIGHTS);
384 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
385 const bool is_src_zp_common
386 = deconv_pd->attr()->zero_points_.common(DNNL_ARG_SRC);
387
388 const auto scratchpad = ctx.get_scratchpad_grantor();
389 const int32_t *const zp_src_compensation
390 = scratchpad.get<int32_t>(key_deconv_zp);
391 const memory_desc_wrapper dst_d(pd()->dst_md());
392 const auto ndims = dst_d.ndims();
393
394 const auto G = pd()->G();
395 const auto MB = pd()->MB();
396 const auto OH = pd()->OH();
397 const auto OW = pd()->OW();
398 const auto OD = pd()->OD();
399 const auto OC = pd()->OC() / G;
400
401 compute_src_zp_compensation<wei_type>(
402 ctx, src_zero_point, is_src_zp_common, wei, deconv_pd);
403 const auto zp_pad_comp_ker = prepare_zp_pad_comp_ker<wei_type>(
404 ndims, src_zero_point, is_src_zp_common, wei, deconv_pd);
405
406 parallel_nd(MB, G, OC, OD, OH, OW,
407 [&](const dim_t mb, const dim_t g, const dim_t oc, const dim_t od,
408 const dim_t oh, const dim_t ow) {
409 const auto oc_off = g * OC + oc;
410 const auto dst_off = ref_conv_utils::get_data_off(
411 dst_d, ndims, mb, oc_off, od, oh, ow);
412 int32_t conv_result
413 = conv_output[dst_off] - zp_src_compensation[oc_off];
414
415 if (const auto zp_pad_compensation
416 = zp_pad_comp_ker(g, oc, od, oh, ow)) {
417 conv_result += zp_pad_compensation;
418 }
419
420 conv_output[dst_off] = static_cast<float>(conv_result);
421 });
422
423 return status::success;
424 }
425
execute(const exec_ctx_t & ctx) const426 status_t ref_deconvolution_fwd_t::execute(const exec_ctx_t &ctx) const {
427 using namespace memory_tracking::names;
428 const auto scratchpad = ctx.get_scratchpad_grantor();
429 const bool ref_bias = pd()->with_bias() && !pd()->conv_supports_bias_;
430 const bool non_default_attr = !pd()->attr()->has_default_values();
431
432 const auto &args = ctx.args();
433 exec_args_t conv_args;
434 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
435 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
436 if (pd()->with_bias() && pd()->conv_supports_bias_)
437 conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS);
438
439 // Create intermediate memory for f32 output if needed.
440 auto dst = args.at(DNNL_ARG_DST);
441 memory_t tmp_memory(dst.mem->engine(), pd()->conv_pd_->diff_src_md(),
442 scratchpad.get_memory_storage(key_deconv_bias));
443 memory_arg_t tmp_conv_output = {&tmp_memory, false};
444
445 conv_args[DNNL_ARG_DIFF_SRC]
446 = ref_bias || non_default_attr ? tmp_conv_output : dst;
447
448 // When sum post-op happens, we need to copy original destination memory
449 // prior call to external convolution happens.
450 if (pd()->attr()->post_ops_.find(primitive_kind::sum) != -1) {
451 void *original_dst = scratchpad.get(key_deconv_sum);
452 const memory_desc_wrapper dst_d(pd()->dst_md());
453 void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
454 const auto dt_size = dst_d.data_type_size();
455
456 parallel(0, [&](const int ithr, const int nthr) {
457 dim_t start {0}, end {0};
458 balance211(dst_d.nelems(true), nthr, ithr, start, end);
459 auto o_dst_start = (char *)original_dst + start * dt_size;
460 auto dst_start = (char *)dst + start * dt_size;
461 const auto size = (end - start) * dt_size;
462
463 std::memcpy(o_dst_start, dst_start, size);
464 });
465 }
466
467 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
468
469 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
470 conv_ctx.set_scratchpad_grantor(ns.grantor());
471 auto status = conv_p_->execute(conv_ctx);
472 if (status != status::success) return status;
473
474 using namespace data_type;
475
476 if (!pd()->attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) {
477 float *conv_output = scratchpad.get<float>(key_deconv_bias);
478 const auto wei_dt = pd()->weights_md()->data_type;
479 switch (wei_dt) {
480 case s8: apply_src_zero_point<s8>(ctx, pd(), conv_output); break;
481 case u8: apply_src_zero_point<u8>(ctx, pd(), conv_output); break;
482 default: assert(!"unsupported data type");
483 }
484 }
485
486 float *conv_output = scratchpad.get<float>(key_deconv_bias);
487
488 if (ref_bias) {
489 void *dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
490 void *tmp_output = non_default_attr ? conv_output : dst;
491 compute_fwd_bias(ctx, tmp_output, conv_output, non_default_attr);
492 }
493
494 if (non_default_attr) {
495 void *original_dst = scratchpad.get<void>(key_deconv_sum);
496 compute_ref_attrs(ctx, conv_output, original_dst);
497 }
498
499 return status::success;
500 }
501
execute(const exec_ctx_t & ctx) const502 status_t ref_deconvolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
503 using namespace memory_tracking::names;
504 const auto &args = ctx.args();
505 exec_args_t conv_args;
506 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
507 conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS);
508 conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC);
509 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
510
511 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
512 conv_ctx.set_scratchpad_grantor(ns.grantor());
513 conv_p_->execute(conv_ctx);
514 return status::success;
515 }
516
compute_bwd_bias(float * diff_bias,const float * diff_dst) const517 void ref_deconvolution_bwd_weights_t::compute_bwd_bias(
518 float *diff_bias, const float *diff_dst) const {
519 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
520
521 const auto G = pd()->G();
522 const auto MB = pd()->MB();
523 const auto OH = pd()->OH();
524 const auto OW = pd()->OW();
525 const auto OC = pd()->OC() / G;
526 const auto OD = pd()->OD();
527 const auto ndims = pd()->desc()->src_desc.ndims;
528
529 parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
530 float db = 0;
531 for_(dim_t mb = 0; mb < MB; ++mb)
532 for_(dim_t od = 0; od < OD; ++od)
533 for_(dim_t oh = 0; oh < OH; ++oh)
534 for (dim_t ow = 0; ow < OW; ++ow) {
535 const auto d_dst_off = ref_conv_utils::get_data_off(
536 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
537 db += diff_dst[d_dst_off];
538 }
539 diff_bias[g * OC + oc] = db;
540 });
541 }
542
543 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ncdhw(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const544 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
545 typename prec_traits<dbia_type>::type *diff_bias,
546 const typename prec_traits<ddst_type>::type *diff_dst) const {
547 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
548
549 const auto OC = pd()->OC();
550 const auto MB = pd()->MB();
551 const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
552
553 parallel_nd(OC, [&](dim_t oc) {
554 float db = 0;
555 for (dim_t mb = 0; mb < MB; ++mb) {
556 PRAGMA_OMP_SIMD(reduction(+ : db))
557 for (dim_t sp = 0; sp < SP; ++sp) {
558 auto offset = (size_t)(mb * OC + oc) * SP + sp;
559 db += diff_dst[offset];
560 }
561 }
562 diff_bias[oc] = db;
563 });
564 }
565
566 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bwd_bias_ndhwc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const567 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ndhwc(
568 typename prec_traits<dbia_type>::type *diff_bias,
569 const typename prec_traits<ddst_type>::type *diff_dst) const {
570 const auto MB = pd()->MB();
571 const auto SP = pd()->OW() * pd()->OH() * pd()->OD();
572 const auto OC = pd()->OC();
573
574 parallel_nd(OC, [&](dim_t oc) {
575 float db = 0;
576 for (dim_t mb = 0; mb < MB; ++mb) {
577 PRAGMA_OMP_SIMD(reduction(+ : db))
578 for (dim_t sp = 0; sp < SP; ++sp) {
579 const dim_t offset = (mb * SP + sp) * OC + oc;
580 db += diff_dst[offset];
581 }
582 }
583 diff_bias[oc] = static_cast<typename prec_traits<dbia_type>::type>(db);
584 });
585 }
586
587 template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize>
compute_bwd_bias_nCdhwXc(typename prec_traits<dbia_type>::type * diff_bias,const typename prec_traits<ddst_type>::type * diff_dst) const588 void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
589 typename prec_traits<dbia_type>::type *diff_bias,
590 const typename prec_traits<ddst_type>::type *diff_dst) const {
591 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
592
593 const auto OC = pd()->OC();
594 const auto MB = pd()->MB();
595 const auto SP = pd()->OH() * pd()->OW() * pd()->OD();
596
597 const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
598
599 parallel_nd(utils::div_up(OC, blksize), [&](dim_t ocb) {
600 float db[blksize] = {0};
601
602 for (dim_t mb = 0; mb < MB; ++mb) {
603 for (dim_t sp = 0; sp < SP; ++sp) {
604 auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
605
606 PRAGMA_OMP_SIMD()
607 for (dim_t i = 0; i < blksize; ++i)
608 db[i] += diff_dst[offset + i];
609 }
610 }
611
612 const dim_t blk = nstl::min(blksize, OC - ocb * blksize);
613
614 PRAGMA_OMP_SIMD()
615 for (dim_t i = 0; i < blk; ++i)
616 diff_bias[ocb * blksize + i] = db[i];
617 });
618 }
619
620 template <data_type_t dbia_type, data_type_t ddst_type>
compute_bias(const exec_ctx_t & ctx) const621 void ref_deconvolution_bwd_weights_t::compute_bias(
622 const exec_ctx_t &ctx) const {
623 using dbia_data_t = typename prec_traits<dbia_type>::type;
624 using ddst_data_t = typename prec_traits<ddst_type>::type;
625
626 auto diff_bias = CTX_OUT_MEM(dbia_data_t *, DNNL_ARG_DIFF_BIAS);
627 auto diff_dst = CTX_IN_MEM(const ddst_data_t *, DNNL_ARG_DIFF_DST);
628
629 using namespace format_tag;
630 switch (pd()->dst_tag_) {
631 case ncdhw:
632 case nchw:
633 case ncw:
634 compute_bwd_bias_ncdhw<dbia_type, ddst_type>(diff_bias, diff_dst);
635 break;
636 case ndhwc:
637 case nhwc:
638 case nwc:
639 compute_bwd_bias_ndhwc<dbia_type, ddst_type>(diff_bias, diff_dst);
640 break;
641 case nCdhw8c:
642 case nChw8c:
643 case nCw8c:
644 assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
645 compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 8>(
646 diff_bias, diff_dst);
647 break;
648 case nCdhw16c:
649 case nChw16c:
650 case nCw16c:
651 compute_bwd_bias_nCdhwXc<dbia_type, ddst_type, 16>(
652 diff_bias, diff_dst);
653 break;
654 default:
655 assert(!utils::one_of(data_type::bf16, dbia_type, ddst_type));
656 compute_bwd_bias((float *)diff_bias, (const float *)diff_dst);
657 break;
658 }
659 }
660
execute(const exec_ctx_t & ctx) const661 status_t ref_deconvolution_bwd_weights_t::execute(const exec_ctx_t &ctx) const {
662 using namespace memory_tracking::names;
663 const auto &args = ctx.args();
664 exec_args_t conv_args;
665 conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC);
666 conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST);
667 conv_args[DNNL_ARG_DIFF_WEIGHTS] = args.at(DNNL_ARG_DIFF_WEIGHTS);
668 exec_ctx_t conv_ctx(ctx, std::move(conv_args));
669
670 nested_scratchpad_t ns(ctx, key_nested, conv_p_);
671 conv_ctx.set_scratchpad_grantor(ns.grantor());
672 status_t status = conv_p_->execute(conv_ctx);
673 if (status != status::success) return status;
674
675 if (pd()->with_bias()) {
676 using namespace data_type;
677
678 auto dbia_type = pd()->diff_weights_md(1)->data_type;
679 auto ddst_type = pd()->diff_dst_md()->data_type;
680 if (utils::everyone_is(f32, dbia_type, ddst_type))
681 compute_bias<f32, f32>(ctx);
682 else if (utils::everyone_is(bf16, dbia_type, ddst_type))
683 compute_bias<bf16, bf16>(ctx);
684 else if (dbia_type == f32 && ddst_type == bf16)
685 compute_bias<f32, bf16>(ctx);
686 else {
687 assert(!"unsupported data type");
688 return status::runtime_error;
689 }
690 }
691 return status::success;
692 }
693
694 using namespace data_type;
695
696 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, f32>(
697 const exec_ctx_t &ctx) const;
698 template void ref_deconvolution_bwd_weights_t::compute_bias<f32, bf16>(
699 const exec_ctx_t &ctx) const;
700 template void ref_deconvolution_bwd_weights_t::compute_bias<bf16, bf16>(
701 const exec_ctx_t &ctx) const;
702 } // namespace cpu
703 } // namespace impl
704 } // namespace dnnl
705
706 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
707