1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <atomic>
18
19 #include "oneapi/dnnl/dnnl_types.h"
20
21 #include "common/bfloat16.hpp"
22 #include "common/c_types_map.hpp"
23 #include "common/dnnl_thread.hpp"
24 #include "common/type_helpers.hpp"
25 #include "common/utils.hpp"
26 #include "cpu/x64/gemm_bf16_convolution.hpp"
27 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace x64 {
33
34 using namespace dnnl::impl::status;
35 using namespace dnnl::impl::memory_tracking::names;
36 using namespace dnnl::impl::utils;
37 using namespace dnnl::impl::cpu::x64::bf16_support;
38
39 // Below two stand-alone functions are moved out from execute_backward_data
40 // and execute_backward_weights to avoid warnings with gcc 6.x and 7.x compilers
41 // "declared with greater visibility than the type of its field"
42 // when one lambda function is delcared whithin the other one
store_bfloat16_in_parallel(bfloat16_t * output_data,const float * acc_data,size_t parallel_work,size_t parallel_work_size,bool do_in_parallel)43 void store_bfloat16_in_parallel(bfloat16_t *output_data, const float *acc_data,
44 size_t parallel_work, size_t parallel_work_size, bool do_in_parallel) {
45 parallel(do_in_parallel ? 0 : 1, [&](const int ithr, const int nthr) {
46 size_t start = 0, end = 0;
47 balance211(parallel_work, nthr, ithr, start, end);
48 if (start < end)
49 cvt_float_to_bfloat16(&output_data[start * parallel_work_size],
50 &acc_data[start * parallel_work_size],
51 (end - start) * parallel_work_size);
52 });
53 }
54
cvt_acc_to_dst(const conv_gemm_conf_t & jcp,size_t g_start,size_t g_end,const float * acc_base,bfloat16_t * diff_weights)55 void cvt_acc_to_dst(const conv_gemm_conf_t &jcp, size_t g_start, size_t g_end,
56 const float *acc_base, bfloat16_t *diff_weights) {
57 const size_t parallel_work_size = jcp.ic * jcp.ks;
58 parallel(jcp.nthr == 1 ? 0 : 1, [&](const int ithr, const int nthr) {
59 size_t w_start = 0, w_end = 0;
60 balance211(parallel_work_size, nthr, ithr, w_start, w_end);
61 for_(auto w = w_start; w < w_end; ++w)
62 for (auto g = g_start; g < g_end; ++g) {
63 const float *__restrict acc_ptr
64 = acc_base + (w * jcp.ngroups + g) * jcp.oc;
65 bfloat16_t *__restrict dw_ptr
66 = diff_weights + (w * jcp.ngroups + g) * jcp.oc;
67 cvt_float_to_bfloat16(dw_ptr, acc_ptr, jcp.oc);
68 }
69 });
70 }
71
72 template <data_type_t dst_data_type>
pp_ker_t(const pd_t * pd)73 gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
74 : jcp_(pd->jcp_)
75 , do_sum_(dst_data_type != data_type::f32 && jcp_.with_sum)
76 , max_data_reg_idx_(31)
77 , max_unroll_(12)
78 , compute_reg_step_(1)
79 , data_reg_base_idx_(0) {
80 using namespace types;
81 using namespace Xbyak;
82
83 if (!mayiuse(avx512_core))
84 // bf16 is not supported
85 return;
86
87 const auto &post_ops = jcp_.post_ops;
88 if (jcp_.with_eltwise || jcp_.with_binary) {
89 #define PARAM_OFF(field) offsetof(ker_args, field)
90 static constexpr bool preserve_gpr = true;
91 static constexpr bool preserve_vmm = true;
92 static constexpr size_t helper_vmm_idx = 31;
93 static constexpr size_t tail_size = 1;
94 static constexpr bool use_exact_tail_scalar_bcast = false;
95 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
96 helper_vmm_idx, reserved_eltwise_gpr, r14, preserve_gpr,
97 preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
98 memory_desc_wrapper(pd->dst_md()), tail_size, kreg_rem_mask,
99 use_exact_tail_scalar_bcast};
100 const binary_injector::static_params_t binary_static_params {
101 this->reg_param, rhs_arg_static_params};
102 static constexpr bool save_state = true;
103 const eltwise_injector::static_params_t eltwise_static_params {
104 save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};
105
106 postops_injector_ = utils::make_unique<
107 injector::jit_uni_postops_injector_t<avx512_core>>(
108 this, post_ops, binary_static_params, eltwise_static_params);
109 #undef PARAM_OFF
110 }
111
112 if (do_sum_) {
113 compute_reg_step_ = 2;
114 vreg_sum_scale = Zmm(data_reg_base_idx_++);
115 }
116
117 if (jcp_.with_bias) vreg_bias = Zmm(data_reg_base_idx_++);
118
119 vlen_ = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
120
121 isa_ = mayiuse(avx512_core_bf16) ? avx512_core_bf16
122 : bf16_emulation_t::get_isa();
123
124 if (isa_ != avx512_core_bf16) {
125 max_data_reg_idx_ = 26;
126 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
127 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
128 bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6);
129 }
130
131 max_unroll_
132 = (max_data_reg_idx_ - data_reg_base_idx_ + 1) / compute_reg_step_;
133 }
134
135 template <data_type_t dst_data_type>
apply_postops(const bool apply_mask,const int vmm_idx)136 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::apply_postops(
137 const bool apply_mask, const int vmm_idx) {
138 #define PARAM_OFF(x) offsetof(ker_args, x)
139 if (jcp_.with_eltwise || jcp_.with_binary) {
140 static constexpr int offset = 0;
141 if (jcp_.with_binary) {
142 const auto oc_off_oprnd = this->r12;
143 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
144 rhs_arg_params.vmm_idx_to_oc_elem_off_addr.emplace(
145 vmm_idx, ptr[reg_param + PARAM_OFF(g_oc_offset)]);
146 rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(vmm_idx, offset);
147 rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
148 vmm_idx, oc_off_oprnd);
149 if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
150
151 const injector_utils::register_preserve_guard_t register_guard(
152 this, {oc_off_oprnd});
153 mov(oc_off_oprnd,
154 ptr[rsp + reg_binary_post_op_acc_off
155 + register_guard.stack_space_occupied()]);
156
157 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
158 } else
159 postops_injector_->compute_vector(vmm_idx);
160 }
161 #undef PARAM_OFF
162 }
163
164 template <data_type_t dst_data_type>
generate()165 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
166 using namespace Xbyak;
167 using namespace utils;
168
169 preamble();
170
171 #ifdef _WIN32
172 mov(reg_param, rcx);
173 #endif
174
175 #define PARAM_OFF(x) offsetof(ker_args, x)
176 mov(reg_dst_base, ptr[reg_param + PARAM_OFF(dst)]);
177 mov(reg_acc_base, ptr[reg_param + PARAM_OFF(acc)]);
178 if (jcp_.with_bias) mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
179 mov(reg_dst_str, ptr[reg_param + PARAM_OFF(dst_stride_in_bytes)]);
180 mov(reg_acc_str, ptr[reg_param + PARAM_OFF(acc_stride_in_bytes)]);
181 mov(reg_len, ptr[reg_param + PARAM_OFF(spatial_length)]);
182 mov(reg_oc_iter, ptr[reg_param + PARAM_OFF(oc_work)]);
183
184 if (jcp_.with_binary) {
185 // zero initialize binary post_ops offset accumulator (store on stack)
186 const auto binary_post_op_acc_off_reg = reg_tmp;
187 xor_(binary_post_op_acc_off_reg, binary_post_op_acc_off_reg);
188 push(binary_post_op_acc_off_reg);
189 }
190
191 if (do_sum_)
192 vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
193 #undef PARAM_OFF
194
195 // Load accumulated value, apply sum (if any), bias (if any)
196 // and relu (if any); then convert to destination type and store
197 auto compute = [&](size_t offset, int idx, bool apply_mask) {
198 auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
199 auto vreg_dst_ = vreg_dst(idx);
200
201 if (dst_data_type == data_type::bf16 && isa_ != avx512_core_bf16)
202 bf16_emu_->init_vcvtneps2bf16();
203
204 if (apply_mask) vreg_dst_ = vreg_dst_ | kreg_rem_mask;
205 vmovups(vreg_dst_, acc_addr);
206
207 if (jcp_.with_bias) vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias);
208
209 auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
210 if (do_sum_) {
211 auto vreg_prev_dst_ = vreg_prev_dst(idx);
212 if (dst_data_type == data_type::f32) {
213 if (apply_mask) vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask;
214
215 vmovups(vreg_prev_dst_, dst_addr);
216 } else if (dst_data_type == data_type::bf16) {
217 auto vreg_prev_dst_ymm_ = vreg_prev_dst_ymm(idx);
218 if (apply_mask)
219 vreg_prev_dst_ymm_ = vreg_prev_dst_ymm_ | kreg_rem_mask;
220
221 vmovdqu16(vreg_prev_dst_ymm_, dst_addr);
222 vpmovzxwd(vreg_prev_dst(idx), vreg_prev_dst_ymm_);
223 vpslld(vreg_prev_dst(idx), vreg_prev_dst(idx), 0x10);
224 } else
225 assert(!"unsupported data type");
226
227 vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
228 }
229
230 apply_postops(apply_mask, vreg_dst_idx(idx));
231
232 if (dst_data_type == data_type::bf16) {
233 // TODO: implement store by zmm registers for bf16
234 auto vreg_dst_ymm_ = vreg_dst_ymm(idx);
235 if (isa_ == avx512_core_bf16)
236 vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx));
237 else
238 bf16_emu_->vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx));
239
240 if (apply_mask) vreg_dst_ymm_ = vreg_dst_ymm_ | kreg_rem_mask;
241
242 vmovdqu16(dst_addr, vreg_dst_ymm_);
243 } else if (dst_data_type == data_type::f32)
244 vmovups(dst_addr, vreg_dst_);
245 else
246 assert(!"unimplemented");
247 };
248
249 // Advance all pointers by an immediate
250 auto advance_ptrs_imm = [&](size_t offset) {
251 add(reg_dst, offset * sizeof(dst_data_t));
252 add(reg_acc, offset * sizeof(acc_data_t));
253 };
254
255 Xbyak::Label oc_loop, oc_loop_end;
256
257 cmp(reg_oc_iter, 0);
258 jle(oc_loop_end, T_NEAR);
259
260 L(oc_loop);
261
262 mov(reg_len_iter, reg_len);
263 mov(reg_dst, reg_dst_base);
264 mov(reg_acc, reg_acc_base);
265
266 if (jcp_.with_bias) vbroadcastss(vreg_bias, ptr[reg_bias]);
267
268 constexpr int n_unroll = default_unroll_2_pow_; // unroll by powers of 2
269 // from 2^n to 2^0
270 assert((1 << n_unroll) <= max_unroll_);
271
272 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
273 for (int i = n_unroll; i >= 0; i--) {
274 const int unroll = 1 << i; // 4, 2, 1
275 L(l_simd_loop[i + 1]);
276 {
277 const int loop_len = unroll * vlen_;
278 cmp(reg_len_iter, loop_len);
279 jl(l_simd_loop[i], T_NEAR);
280 for (int j = 0; j < unroll; j++)
281 compute(j * vlen_, j, false);
282
283 advance_ptrs_imm(loop_len);
284 sub(reg_len_iter, loop_len);
285 jmp(l_simd_loop[i + 1], T_NEAR);
286 }
287 }
288 L(l_simd_loop[0]);
289
290 mov(reg_tmp, reg_len_iter); // reg_tmp is rcx, and we need cl for the shift
291 mov(reg_rem_mask, 1);
292 shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen_ == 16
293 sub(reg_rem_mask, 1);
294 jz(l_simd_notail, T_NEAR);
295 kmovq(kreg_rem_mask, reg_rem_mask);
296 compute(0, 0, true);
297
298 L(l_simd_notail);
299
300 add(reg_dst_base, reg_dst_str);
301 add(reg_acc_base, reg_acc_str);
302 if (jcp_.with_bias) add(reg_bias, sizeof(acc_data_t));
303 if (jcp_.with_binary)
304 inc(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off));
305
306 dec(reg_oc_iter);
307 jnz(oc_loop, T_NEAR); // oc_loop end
308
309 L(oc_loop_end);
310
311 if (jcp_.with_binary) add(rsp, stack_space_needed);
312
313 postamble();
314
315 if (jcp_.with_eltwise) postops_injector_->prepare_table();
316 }
317
318 // operator () specialized for nspc format
319 template <data_type_t dst_data_type>
operator ()(dst_data_t * dst,const acc_data_t * acc,const acc_data_t * bias,float sum_scale,size_t oc_work,const void * post_ops_binary_rhs_arg_vec,const void * dst_orig,const size_t g_oc_offset)320 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
321 dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
322 float sum_scale, size_t oc_work,
323 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
324 const size_t g_oc_offset) {
325
326 ker_args args;
327 args.acc = acc;
328 args.dst = dst;
329 args.bias = bias;
330 args.sum_scale = sum_scale;
331 args.dst_stride_in_bytes = sizeof(dst_data_t);
332 args.acc_stride_in_bytes = sizeof(acc_data_t);
333 args.spatial_length = 1;
334 args.oc_work = oc_work;
335
336 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
337 args.dst_orig = dst_orig;
338 args.g_oc_offset = g_oc_offset;
339 jit_generator::operator()(&args);
340 }
341
342 // operator () specialized for ncsp format
343 template <data_type_t dst_data_type>
operator ()(dst_data_t * dst,const acc_data_t * acc,const acc_data_t * bias,float sum_scale,size_t dst_stride_in_elements,size_t acc_stride_in_elements,size_t sp_len,size_t oc_len,const void * post_ops_binary_rhs_arg_vec,const void * dst_orig,const size_t g_oc_offset)344 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
345 dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
346 float sum_scale, size_t dst_stride_in_elements,
347 size_t acc_stride_in_elements, size_t sp_len, size_t oc_len,
348 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
349 const size_t g_oc_offset) {
350 if (sp_len == 0) return;
351
352 ker_args args;
353 args.acc = acc;
354 args.dst = dst;
355 args.bias = bias;
356 args.sum_scale = sum_scale;
357 args.dst_stride_in_bytes = dst_stride_in_elements * sizeof(dst_data_t);
358 args.acc_stride_in_bytes = acc_stride_in_elements * sizeof(acc_data_t);
359 args.spatial_length = sp_len;
360 args.oc_work = oc_len;
361
362 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
363 args.dst_orig = dst_orig;
364 args.g_oc_offset = g_oc_offset;
365 jit_generator::operator()(&args);
366 }
367
368 template <data_type_t dst_data_type>
execute_forward_nspc(const exec_ctx_t & ctx) const369 status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_nspc(
370 const exec_ctx_t &ctx) const {
371 auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
372 auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
373 auto dst_base = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
374 const auto post_ops_binary_rhs_arg_vec
375 = binary_injector::prepare_binary_args(
376 this->pd()->attr()->post_ops_, ctx);
377
378 auto scratchpad = ctx.get_scratchpad_grantor();
379 const conv_gemm_conf_t &jcp = pd()->jcp_;
380
381 float *bia_base = nullptr;
382 if (jcp.with_bias) {
383 if (pd()->desc()->bias_desc.data_type == data_type::bf16) {
384 auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS);
385 bia_base = ctx.get_scratchpad_grantor().template get<float>(
386 key_conv_bias_bf16_convert_wsp);
387 cvt_bfloat16_to_float(bia_base, bias_in, jcp.ngroups * jcp.oc);
388 } else {
389 auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
390 bia_base = const_cast<float *>(bias_in);
391 }
392 }
393 assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
394
395 std::atomic<status_t> st(status::success);
396 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
397 status_t st_thr = execute_forward_thr_nspc(ithr, nthr, src_base,
398 wei_base, bia_base, dst_base, scratchpad,
399 post_ops_binary_rhs_arg_vec.data());
400 if (st_thr != status::success) st = st_thr;
401 });
402
403 return st;
404 }
405
406 template <data_type_t dst_data_type>
execute_forward_thr_nspc(const int ithr,const int nthr,const src_data_t * src_base,const wei_data_t * wei_base,const float * bia_base,dst_data_t * dst_base,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec) const407 status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_thr_nspc(
408 const int ithr, const int nthr, const src_data_t *src_base,
409 const wei_data_t *wei_base, const float *bia_base, dst_data_t *dst_base,
410 const memory_tracking::grantor_t &scratchpad,
411 const void *post_ops_binary_rhs_arg_vec) const {
412 const conv_gemm_conf_t &jcp = pd()->jcp_;
413
414 // Src Format: mb-spatial-groups-input_channels
415 const size_t src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw
416 * jcp.ngroups * jcp.ic;
417 const size_t src_g_stride = jcp.ic;
418 // Wei Format: spatial-input_channels-groups-output_channels
419 const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
420
421 // Dst Format: mb-spatial-groups-output_channels
422 const size_t dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh * jcp.ow
423 * jcp.ngroups * jcp.oc;
424 const size_t dst_g_stride = jcp.oc;
425 const size_t dst_os_stride = jcp.ngroups * jcp.oc;
426
427 src_data_t *__restrict col = scratchpad.get<src_data_t>(key_conv_gemm_col)
428 + (ptrdiff_t)ithr * jcp.im2col_sz;
429 src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
430 + (ptrdiff_t)ithr * jcp.is * jcp.ic;
431 acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
432 + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
433
434 const auto &post_ops = pd()->attr()->post_ops_;
435 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
436 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
437
438 int g {0}, n {0}, ohb {0}, owb {0};
439 size_t start = 0, end = 0;
440
441 const bool is_problem_3d = pd()->ndims() == 5;
442 assert(IMPLICATION(is_problem_3d,
443 jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
444 && jcp.ic_block == jcp.ic));
445
446 const int nb_oh = div_up(jcp.oh, jcp.oh_block);
447 const int nb_ow = div_up(jcp.ow, jcp.ow_block);
448 const size_t work_amount = (size_t)jcp.ngroups * jcp.mb * nb_oh * nb_ow;
449 balance211(work_amount, nthr, ithr, start, end);
450 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
451
452 if (jcp.im2col_sz && is_problem_3d) {
453 // jit_gemm_convolution_utils::im2col_dt_3d() requires external
454 // data initialization by zeroes
455 // For performance reasons use uint16_t as a proxy for bfloat16_t
456 uint16_t *__restrict col_r
457 = reinterpret_cast<uint16_t *__restrict>(col);
458 constexpr uint16_t zero_val = 0;
459
460 PRAGMA_OMP_SIMD()
461 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
462 col_r[i] = zero_val;
463 }
464 for (size_t iwork = start; iwork < end; ++iwork) {
465 int oh = ohb * jcp.oh_block;
466 int ow = owb * jcp.ow_block;
467 const src_data_t *__restrict src
468 = src_base + n * src_mb_stride + g * src_g_stride;
469 const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
470
471 const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
472 const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
473 if (jcp.im2col_sz && is_problem_3d)
474 jit_gemm_convolution_utils::transpose_dt(jcp, src, imtr);
475
476 for (int od = 0; od < jcp.od; od++) {
477 dst_data_t *__restrict dst = dst_base + n * dst_mb_stride
478 + g * dst_g_stride
479 + ((od * jcp.oh + oh) * jcp.ow + ow) * dst_os_stride;
480 if (jcp.im2col_sz) {
481 if (is_problem_3d)
482 jit_gemm_convolution_utils::im2col_dt_3d<src_data_t,
483 src_data_t>(jcp, imtr, col, od);
484 else
485 jit_gemm_convolution_utils::im2col_dt<src_data_t,
486 src_data_t>(
487 jcp, src, imtr, col, oh, h_step, ow, w_step);
488 }
489
490 const dim_t M = jcp.oc;
491 const dim_t K = jcp.ks * jcp.ic;
492 const dim_t N = h_step * w_step;
493 const dim_t LDA = M * jcp.ngroups;
494 const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
495 const char *BT = jcp.im2col_sz ? "T" : "N";
496 const float onef = 1.f;
497 const float beta = this->beta_;
498 const src_data_t *__restrict src_od
499 = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
500 const bool acc_needed = dst_data_type == data_type::bf16;
501 status_t st = gemm_bf16bf16f32("N", BT, &M, &N, &K, &onef, wei,
502 &LDA, jcp.im2col_sz ? col : (src_data_t *)src_od, &LDB,
503 &beta, acc_needed ? acc : (float *)dst,
504 acc_needed ? &M : &LDA);
505 if (st != status::success) return st;
506
507 const bool do_postprocess = pd()->is_postprocess_required();
508 if (do_postprocess) {
509 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1, N,
510 [&](size_t ithr, size_t nthr, size_t os) {
511 const float *__restrict acc_arr = acc + os * jcp.oc;
512 const float *__restrict bia_arr
513 = (bia_base == nullptr)
514 ? nullptr
515 : bia_base + g * jcp.oc;
516 dst_data_t *__restrict dst_arr
517 = dst + os * dst_os_stride;
518
519 (*pp_ker_)(dst_arr,
520 acc_needed ? acc_arr : (float *)dst_arr,
521 bia_arr, sum_scale, jcp.oc,
522 post_ops_binary_rhs_arg_vec, dst,
523 g * jcp.oc);
524 });
525 }
526 }
527 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
528 }
529 return status::success;
530 }
531
532 template <data_type_t dst_data_type>
execute_forward_ncsp(const exec_ctx_t & ctx) const533 status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_ncsp(
534 const exec_ctx_t &ctx) const {
535 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
536 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
537 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
538 const auto post_ops_binary_rhs_arg_vec
539 = binary_injector::prepare_binary_args(
540 this->pd()->attr()->post_ops_, ctx);
541
542 bool is_bf16_dst = dst_data_type == data_type::bf16;
543
544 auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
545 key_conv_gemm_col);
546 acc_data_t *acc_base = is_bf16_dst
547 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
548 key_conv_int_dat_in_acc_dt)
549 : nullptr;
550
551 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
552
553 float *bias = nullptr;
554 if (jcp.with_bias) {
555 if (pd()->desc()->bias_desc.data_type == data_type::bf16) {
556 auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS);
557 bias = ctx.get_scratchpad_grantor().template get<float>(
558 key_conv_bias_bf16_convert_wsp);
559 cvt_bfloat16_to_float(bias, bias_in, jcp.ngroups * jcp.oc);
560 } else {
561 auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
562 bias = const_cast<float *>(bias_in);
563 }
564 }
565
566 const auto &post_ops = pd()->attr()->post_ops_;
567 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
568 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
569
570 const dim_t M = jcp.os * jcp.od;
571 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
572 const size_t dst_step = (size_t)jcp.oc * M;
573 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
574 const size_t weights_oc_size = jcp.ic * jcp.ks;
575
576 const dim_t LDB = weights_oc_size;
577 const size_t work_amount
578 = (size_t)jcp.ngroups * jcp.mb * jcp.od * jcp.os_nb_block;
579 const bool is_problem_3d = pd()->ndims() == 5;
580 std::atomic<status_t> st(status::success);
581
582 auto inner_ker = [&](const int ic, const int oc, const int groups,
583 const int od, const int spatial,
584 const src_data_t *src, const wei_data_t *weights,
585 src_data_t *col, dst_data_t *dst, acc_data_t *acc,
586 int ic_block, int oc_block) {
587 const dim_t os_block = nstl::min(
588 (dim_t)jcp.os_block, (dim_t)jcp.os - spatial * jcp.os_block);
589
590 if (jcp.im2col_sz) {
591 if (!is_problem_3d) {
592 jit_gemm_convolution_utils::im2col<src_data_t>(jcp, src, col,
593 spatial * jcp.os_block, os_block, ic, ic_block);
594 } else {
595 assert(jcp.ic_block == jcp.ic);
596 jit_gemm_convolution_utils::im2col_3d<src_data_t>(
597 jcp, src, col, od, spatial * jcp.os_block, os_block);
598 }
599 }
600
601 const acc_data_t one = 1.0;
602 const dim_t N = oc_block;
603 const dim_t K = ic_block * jcp.ks;
604 const dim_t m = os_block;
605 const dim_t LDA = jcp.im2col_sz ? m : M;
606 const dim_t LDC = is_bf16_dst ? m : M;
607 const float beta = (ic == 0) ? this->beta_ : one;
608 auto out_off = spatial * jcp.os_block + od * jcp.os;
609 dst_data_t *dst_local = dst + out_off;
610
611 status_t st_thr = gemm_bf16bf16f32("N", "N", &m, &N, &K, &one,
612 jcp.im2col_sz ? col : src + ic * M + out_off, &LDA, weights,
613 &LDB, &beta, acc, &LDC);
614
615 if (st_thr != status::success) {
616 st = st_thr;
617 return;
618 }
619
620 if (this->pd()->is_postprocess_required() && ic + ic_block >= jcp.ic) {
621 size_t acc_str = LDC;
622 size_t dst_str = M;
623 float *bias_ptr = bias ? bias + groups * jcp.oc + oc : nullptr;
624 (*pp_ker_)(dst_local, acc, bias_ptr, sum_scale, dst_str, acc_str, m,
625 oc_block, post_ops_binary_rhs_arg_vec.data(), dst,
626 groups * jcp.oc + oc);
627 }
628 };
629
630 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
631 src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
632 if (is_problem_3d) {
633 // jit_gemm_convolution_utils::im2col_3d() requires external
634 // data initialization by zeroes
635 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
636 _col[i] = (src_data_t)0;
637 }
638 int g {0}, n {0}, od {0}, nb_os {0};
639 size_t start = 0, end = 0;
640 size_t oc_start = 0, oc_end = 0;
641
642 assert(jcp.loop_order == gemm_loop_lbr);
643 balance2D(nthr, ithr, work_amount, start, end, (size_t)jcp.oc, oc_start,
644 oc_end, (size_t)jcp.nthr_oc);
645
646 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os,
647 jcp.os_nb_block);
648 for (size_t iwork = start; iwork < end; ++iwork) {
649 for_(int oc = (int)oc_start; oc < (int)oc_end; oc += jcp.oc_block)
650 for (int ic = 0; ic < jcp.ic; ic += jcp.ic_block) {
651 const src_data_t *_src = src + (n * jcp.ngroups + g) * src_step;
652 const wei_data_t *_weights = weights + g * weights_g_size
653 + oc * weights_oc_size + ic * jcp.ks;
654 dst_data_t *_dst_im
655 = dst + (n * jcp.ngroups + g) * dst_step + oc * M;
656 auto out_off = nb_os * jcp.os_block + od * jcp.os;
657 dst_data_t *dst_local = _dst_im + out_off;
658 const int sizeof_cacheline_float = 16;
659 acc_data_t *_acc = is_bf16_dst ? acc_base
660 + ithr
661 * rnd_up(jcp.oc_block * jcp.os_block,
662 sizeof_cacheline_float)
663 : (acc_data_t *)dst_local;
664
665 const int ic_block = nstl::min(jcp.ic - ic, jcp.ic_block);
666 const int oc_block = nstl::min(int(oc_end) - oc, jcp.oc_block);
667
668 inner_ker(ic, oc, g, od, nb_os, _src, _weights, _col, _dst_im,
669 _acc, ic_block, oc_block);
670 }
671 nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os,
672 jcp.os_nb_block);
673 }
674 });
675
676 return st;
677 }
678
679 template <data_type_t diff_src_data_type>
680 status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
execute_backward_data_nspc(const exec_ctx_t & ctx) const681 execute_backward_data_nspc(const exec_ctx_t &ctx) const {
682
683 auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
684 auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
685 auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
686
687 auto scratchpad = ctx.get_scratchpad_grantor();
688 const conv_gemm_conf_t &jcp = pd()->jcp_;
689
690 std::atomic<status_t> st(status::success);
691 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
692 status_t st_thr = execute_backward_data_thr_nspc(
693 ithr, nthr, diff_src_base, wei_base, diff_dst_base, scratchpad);
694 if (st_thr != status::success) st = st_thr;
695 });
696
697 return st;
698 }
699
700 template <data_type_t diff_src_data_type>
701 status_t gemm_bf16_convolution_bwd_data_t<
execute_backward_data_thr_nspc(const int ithr,const int nthr,diff_src_data_t * diff_src_base,const wei_data_t * wei_base,const diff_dst_data_t * diff_dst_base,const memory_tracking::grantor_t & scratchpad) const702 diff_src_data_type>::execute_backward_data_thr_nspc(const int ithr,
703 const int nthr, diff_src_data_t *diff_src_base,
704 const wei_data_t *wei_base, const diff_dst_data_t *diff_dst_base,
705 const memory_tracking::grantor_t &scratchpad) const {
706
707 const conv_gemm_conf_t &jcp = pd()->jcp_;
708
709 // Diff_dst Format: mb-spatial-groups-output_channels
710 const size_t diff_dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh
711 * jcp.ow * jcp.ngroups * jcp.oc;
712 const size_t diff_dst_g_stride = jcp.oc;
713
714 // Wei Format: spatial-input_channels-groups-output_channels
715 const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
716
717 // Diff_src Format: mb-spatial-groups-input_channels
718 const size_t diff_src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih
719 * jcp.iw * jcp.ngroups * jcp.ic;
720 const size_t diff_src_g_stride = jcp.ic;
721 const size_t diff_src_os_stride = jcp.ngroups * jcp.ic;
722
723 // threads share work across mini-batch and groups
724 const size_t work_amount = jcp.ngroups * jcp.mb;
725
726 acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
727 + (ptrdiff_t)ithr * jcp.im2col_sz;
728 acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
729 + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic;
730
731 int n {0}, g {0};
732 size_t start = 0, end = 0;
733
734 balance211(work_amount, nthr, ithr, start, end);
735 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
736
737 for (size_t iwork = start; iwork < end; ++iwork) {
738 const diff_dst_data_t *__restrict diff_dst = diff_dst_base
739 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
740 const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
741 diff_src_data_t *__restrict diff_src = diff_src_base
742 + n * diff_src_mb_stride + g * diff_src_g_stride;
743
744 const dim_t M = jcp.ks * jcp.ic;
745 const dim_t N = jcp.os * jcp.od;
746 const dim_t K = jcp.oc;
747
748 const float onef = 1.0f, zerof = 0.0f;
749 const dim_t LD = K * jcp.ngroups;
750
751 status_t st = gemm_bf16bf16f32("T", "N", &M, &N, &K, &onef, wei, &LD,
752 diff_dst, &LD, &zerof, jcp.im2col_sz ? col : acc, &M);
753 if (st != status::success) return st;
754
755 if (jcp.im2col_sz)
756 jit_gemm_convolution_utils::col2im_dt<acc_data_t>(jcp, col, acc);
757
758 const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16;
759
760 if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) {
761 cvt_float_to_bfloat16((bfloat16_t *)diff_src, (const float *)acc,
762 static_cast<size_t>(jcp.is) * jcp.id * jcp.ic);
763 } else if (is_diff_src_bf16) {
764 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1,
765 static_cast<size_t>(jcp.is) * jcp.id,
766 [&](size_t ithr, size_t nthr, size_t is) {
767 diff_src_data_t *__restrict diff_src_loc
768 = diff_src + is * diff_src_os_stride;
769 const acc_data_t *__restrict acc_loc
770 = acc + is * jcp.ic;
771 cvt_float_to_bfloat16((bfloat16_t *)diff_src_loc,
772 (const float *)acc_loc, jcp.ic);
773 });
774 } else {
775 assert(diff_src_data_type == data_type::f32);
776 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1,
777 static_cast<size_t>(jcp.is) * jcp.id,
778 [&](size_t ithr, size_t nthr, size_t is) {
779 diff_src_data_t *__restrict diff_src_loc
780 = diff_src + is * diff_src_os_stride;
781 const acc_data_t *__restrict acc_loc
782 = acc + is * jcp.ic;
783 PRAGMA_OMP_SIMD()
784 for (int ic = 0; ic < jcp.ic; ++ic)
785 diff_src_loc[ic] = acc_loc[ic];
786 });
787 }
788 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
789 }
790 return status::success;
791 }
792
793 template <data_type_t diff_src_data_type>
794 status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
execute_backward_data_ncsp(const exec_ctx_t & ctx) const795 execute_backward_data_ncsp(const exec_ctx_t &ctx) const {
796 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
797 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
798 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
799
800 auto col = ctx.get_scratchpad_grantor().template get<acc_data_t>(
801 key_conv_gemm_col);
802 acc_data_t *acc_base = diff_src_data_type == data_type::bf16
803 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
804 key_conv_int_dat_in_acc_dt)
805 : nullptr;
806
807 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
808
809 const dim_t M = jcp.os * jcp.od;
810 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
811 const size_t dst_step = (size_t)jcp.oc * M;
812 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
813
814 const dim_t m = jcp.os_block;
815 const dim_t K = jcp.oc;
816 const dim_t N = jcp.ic * jcp.ks;
817
818 const size_t work_amount = (size_t)jcp.ngroups * jcp.mb;
819 const bool is_problem_3d = pd()->ndims() == 5;
820
821 std::atomic<status_t> st(status::success);
822
823 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
824 acc_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
825
826 int g {0}, n {0};
827 size_t start = 0, end = 0;
828 balance211(work_amount, nthr, ithr, start, end);
829 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb);
830 for (size_t iwork = start; iwork < end; ++iwork) {
831
832 diff_src_data_t *diff_src_local
833 = diff_src + (n * jcp.ngroups + g) * src_step;
834 acc_data_t *acc = diff_src_data_type == data_type::bf16
835 ? acc_base + ithr * rnd_up(src_step, 16)
836 : (acc_data_t *)diff_src_local;
837
838 if (is_problem_3d && jcp.im2col_sz > 0) {
839 // jit_gemm_convolution_utils::col2im_3d() assumes that the
840 // accumulator is initialized by zeroes
841 for (size_t i = 0; i < src_step; i++)
842 acc[i] = (acc_data_t)0;
843 }
844
845 const wei_data_t *_weights = weights + g * weights_g_size;
846 for_(int od = 0; od < jcp.od; ++od)
847 for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
848 auto out_off = os_nb * m + od * jcp.os;
849 const diff_dst_data_t *_diff_dst
850 = diff_dst + (n * jcp.ngroups + g) * dst_step + out_off;
851 const dim_t os_block
852 = nstl::min((dim_t)jcp.os_block, jcp.os - os_nb * m);
853 const dim_t LDC = jcp.im2col_sz ? os_block : M;
854
855 const acc_data_t zero = 0.0, one = 1.0;
856 status_t st_thr = gemm_bf16bf16f32("N", "T", &os_block, &N, &K,
857 &one, _diff_dst, &M, _weights, &N, &zero,
858 jcp.im2col_sz ? _col : acc + out_off, &LDC);
859
860 if (st_thr != status::success) {
861 st = st_thr;
862 return;
863 }
864
865 if (jcp.im2col_sz) {
866 if (!is_problem_3d)
867 jit_gemm_convolution_utils::col2im(
868 jcp, _col, acc, os_nb * jcp.os_block, os_block);
869 else
870 jit_gemm_convolution_utils::col2im_3d(jcp, _col, acc,
871 od, os_nb * jcp.os_block, os_block);
872 }
873 }
874 if (diff_src_data_type == data_type::bf16) {
875 size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id;
876 store_bfloat16_in_parallel((bfloat16_t *)diff_src_local,
877 (const float *)acc, jcp.ic, spatial_size,
878 jcp.nthr == 1);
879 }
880 nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
881 }
882 });
883
884 return st;
885 }
886
887 template <data_type_t diff_wei_data_type>
888 void gemm_bf16_convolution_bwd_weights_t<
bf16_bwd_weights_reduction_par_nspc(int ithr_mb,int nthr_mb,size_t g_start,size_t g_end,const conv_gemm_conf_t & jcp,const acc_data_t * weights_reduce_base,diff_wei_data_t * weights_base) const889 diff_wei_data_type>::bf16_bwd_weights_reduction_par_nspc(int ithr_mb,
890 int nthr_mb, size_t g_start, size_t g_end, const conv_gemm_conf_t &jcp,
891 const acc_data_t *weights_reduce_base,
892 diff_wei_data_t *weights_base) const {
893 assert(nthr_mb > 1); // no reduction for nthr_mb == 1
894
895 const bool is_bf16_out = diff_wei_data_type == data_type::bf16;
896 const size_t weights_g_size = jcp.oc;
897 size_t weights_start {0}, weights_end {0};
898 balance211(size_t(jcp.ks) * jcp.ic, nthr_mb, ithr_mb, weights_start,
899 weights_end);
900
901 for (auto tidx = 1; tidx < nthr_mb; ++tidx) {
902 const acc_data_t *ws_base
903 = weights_reduce_base + tidx * weights_g_size * jcp.ks * jcp.ic;
904 for_(auto w = weights_start; w < weights_end; ++w)
905 for (auto g = g_start; g < g_end; ++g) {
906 const acc_data_t *ws_ptr = ws_base + w * jcp.oc;
907 float *wei_reduced = is_bf16_out
908 ? (float *)weights_reduce_base + w * jcp.oc
909 : (float *)weights_base + (w * jcp.ngroups + g) * jcp.oc;
910 if (is_bf16_out && tidx == nthr_mb - 1) {
911 // the last iteration for bfloat16 requires conversion
912 // and store to diff_weights array
913 diff_wei_data_t *dwei_ptr
914 = weights_base + (w * jcp.ngroups + g) * jcp.oc;
915 add_floats_and_cvt_to_bfloat16(
916 (bfloat16_t *)(dwei_ptr), wei_reduced, ws_ptr, jcp.oc);
917 } else {
918 acc_ker_->accumulate(wei_reduced, ws_ptr, jcp.oc);
919 }
920 }
921 }
922 }
923
924 template <data_type_t diff_wei_data_type>
925 void gemm_bf16_convolution_bwd_weights_t<
bf16_bwd_weights_reduction_par_ncsp(int ithr_mb,int nthr_mb,const conv_gemm_conf_t & jcp,const acc_data_t * weights_reduce_base,diff_wei_data_t * weights_base) const926 diff_wei_data_type>::bf16_bwd_weights_reduction_par_ncsp(int ithr_mb,
927 int nthr_mb, const conv_gemm_conf_t &jcp,
928 const acc_data_t *weights_reduce_base,
929 diff_wei_data_t *weights_base) const {
930 assert(nthr_mb > 1); // no reduction for nthr_mb == 1
931
932 const bool is_bf16_out = diff_wei_data_type == data_type::bf16;
933 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
934 size_t weights_start {0}, weights_end {0};
935 balance211(weights_g_size, nthr_mb, ithr_mb, weights_start, weights_end);
936
937 if (weights_start >= weights_end) return; // nothing to do
938
939 size_t acc_size = weights_end - weights_start;
940 float *wei_reduced = is_bf16_out
941 ? (float *)weights_reduce_base + weights_start
942 : (float *)weights_base + weights_start;
943 if (!is_bf16_out) {
944 // f32 diff_weights require initialization by weights_reduce
945 // for thr_mb = 0
946 for (size_t i = 0; i < acc_size; i++)
947 wei_reduced[i] = ((float *)weights_reduce_base + weights_start)[i];
948 }
949
950 for (int thr_mb = 1; thr_mb < nthr_mb; ++thr_mb) {
951 float *wei_to_reduce = (float *)weights_reduce_base
952 + thr_mb * weights_g_size + weights_start;
953
954 if (is_bf16_out && thr_mb == nthr_mb - 1)
955 // the last iteration for bfloat16 requires conversion
956 // and store to diff_weights array
957 add_floats_and_cvt_to_bfloat16(
958 (bfloat16_t *)(weights_base + weights_start), wei_reduced,
959 wei_to_reduce, acc_size);
960 else
961 acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size);
962 }
963 }
964
965 template <data_type_t diff_wei_data_type>
966 status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
execute_backward_weights_nspc(const exec_ctx_t & ctx) const967 execute_backward_weights_nspc(const exec_ctx_t &ctx) const {
968 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
969 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
970 auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
971
972 auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
973 key_conv_gemm_col);
974 auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>(
975 key_conv_wei_reduction);
976 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
977
978 acc_data_t *acc_base = diff_wei_data_type == data_type::bf16
979 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
980 key_conv_int_dat_in_acc_dt)
981 : (acc_data_t *)diff_weights;
982
983 float *diff_bias = nullptr;
984 if (jcp.with_bias) {
985 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16)
986 diff_bias = ctx.get_scratchpad_grantor().template get<float>(
987 key_conv_bias_bf16_convert_wsp);
988 else
989 diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
990 }
991
992 const dim_t K = jcp.os * static_cast<size_t>(jcp.od);
993 const size_t src_step
994 = static_cast<size_t>(jcp.ic) * jcp.ih * jcp.iw * jcp.id;
995 const size_t dst_step = jcp.oc * K;
996 const size_t weights_g_size = jcp.oc;
997
998 const dim_t k = jcp.os;
999 const dim_t M = jcp.oc;
1000 const dim_t N = static_cast<dim_t>(jcp.ic) * jcp.ks;
1001 const dim_t LDB = jcp.ngroups * jcp.oc;
1002 const dim_t LDA = jcp.im2col_sz ? jcp.oh * jcp.ow : jcp.ngroups * jcp.ic;
1003 const bool is_problem_3d = pd()->ndims() == 5;
1004
1005 std::atomic<status_t> st(status::success);
1006
1007 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1008 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1009 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1010
1011 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1012 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
1013 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
1014
1015 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1016
1017 const int need_reduction = nthr_mb != 1;
1018 src_data_t *__restrict imtr
1019 = ctx.get_scratchpad_grantor().template get<src_data_t>(
1020 key_conv_gemm_imtr)
1021 + (ptrdiff_t)ithr * jcp.id * jcp.ic * jcp.is;
1022
1023 if (ithr_g != -1 && ithr_mb != -1) {
1024 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1025 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1026
1027 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1028
1029 src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
1030 if (is_problem_3d) {
1031 // jit_gemm_convolution_utils::im2col_3d() requires external
1032 // data initialization by zeroes
1033 // For performance reasons use uint16_t as proxy for bfloat16_t
1034 uint16_t *__restrict _col_r
1035 = reinterpret_cast<uint16_t *__restrict>(_col);
1036 constexpr uint16_t zero_val = 0;
1037
1038 PRAGMA_OMP_SIMD()
1039 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
1040 _col_r[i] = zero_val;
1041 }
1042
1043 acc_data_t *weights_reduce_base = wei_reduction
1044 + ithr_g * nthr_mb * weights_g_size * jcp.ks * jcp.ic;
1045 acc_data_t *weights_reduce = weights_reduce_base
1046 + ithr_mb * weights_g_size * jcp.ks * jcp.ic;
1047
1048 const bool use_diff_wei
1049 = ithr_mb == 0 && diff_wei_data_type == data_type::f32;
1050 for (size_t g = g_start; g < g_end; ++g) {
1051 acc_data_t *_diff_weights = use_diff_wei
1052 ? (acc_data_t *)diff_weights + g * weights_g_size
1053 : need_reduction ? weights_reduce
1054 : acc_base + g * weights_g_size;
1055 const dim_t LDC = use_diff_wei
1056 ? jcp.ngroups * jcp.oc
1057 : need_reduction ? jcp.oc : jcp.ngroups * jcp.oc;
1058 for (size_t mb = mb_start; mb < mb_end; ++mb) {
1059 const src_data_t *_src
1060 = src + mb * jcp.ngroups * src_step + g * jcp.ic;
1061 if (jcp.im2col_sz && is_problem_3d)
1062 jit_gemm_convolution_utils::transpose_dt(
1063 jcp, _src, imtr);
1064 for (int od = 0; od < jcp.od; ++od) {
1065 const diff_dst_data_t *_diff_dst = diff_dst
1066 + mb * jcp.ngroups * dst_step
1067 + od * k * jcp.ngroups * jcp.oc + g * jcp.oc;
1068
1069 if (jcp.im2col_sz) {
1070 if (is_problem_3d)
1071 jit_gemm_convolution_utils::im2col_dt_3d<
1072 src_data_t, src_data_t>(
1073 jcp, imtr, _col, od);
1074 else
1075 jit_gemm_convolution_utils::im2col_dt<
1076 src_data_t, src_data_t>(jcp, _src, imtr,
1077 _col, 0, jcp.oh, 0, jcp.ow);
1078 }
1079 const float zero = 0.0f, one = 1.0f;
1080 status_t st_thr = gemm_bf16bf16f32("N",
1081 jcp.im2col_sz ? "N" : "T", &M, &N, &k, &one,
1082 _diff_dst, &LDB,
1083 jcp.im2col_sz
1084 ? _col
1085 : _src + od * k * jcp.ngroups * jcp.ic,
1086 &LDA, mb == mb_start && od == 0 ? &zero : &one,
1087 _diff_weights, &LDC);
1088 if (st_thr != status::success) {
1089 st = st_thr;
1090 // Finish the loops early if failure occured.
1091 g = g_end;
1092 mb = mb_end;
1093 od = jcp.od;
1094 }
1095 }
1096 }
1097 }
1098 if (need_reduction && dnnl_thr_syncable()) {
1099 dnnl_thr_barrier();
1100 if (st != status::success) return;
1101 bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start,
1102 g_end, jcp, weights_reduce_base, diff_weights);
1103 } else if (diff_wei_data_type == data_type::bf16
1104 && g_end > g_start) {
1105 cvt_acc_to_dst(jcp, g_start, g_end, (const float *)acc_base,
1106 (bfloat16_t *)diff_weights);
1107 }
1108 } else {
1109 if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
1110 }
1111 });
1112
1113 if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
1114 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1115 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1116 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1117
1118 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1119 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
1120 jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
1121 nthr_mb);
1122
1123 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1124 const int need_reduction = nthr_mb != 1;
1125
1126 if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
1127 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1128 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1129
1130 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1131
1132 acc_data_t *weights_reduce_base = wei_reduction
1133 + ithr_g * nthr_mb * weights_g_size * jcp.ic * jcp.ks;
1134
1135 bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start,
1136 g_end, jcp, weights_reduce_base, diff_weights);
1137 }
1138 });
1139 }
1140
1141 if (jcp.with_bias) {
1142 parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
1143 acc_data_t db = 0;
1144 const size_t offset_base = g * jcp.oc + oc;
1145 for_(int mb = 0; mb < jcp.mb; ++mb)
1146 for_(int od = 0; od < jcp.od; ++od)
1147 for (int oh = 0; oh < jcp.oh; ++oh) {
1148 const int width_stride = jcp.ngroups * jcp.oc;
1149 const diff_dst_data_t *__restrict diff_dst_arr = diff_dst
1150 + offset_base
1151 + ((mb * jcp.od + od) * jcp.oh + oh) * jcp.ow
1152 * width_stride;
1153
1154 PRAGMA_OMP_SIMD(reduction(+ : db))
1155 for (int ow = 0; ow < jcp.ow; ++ow) {
1156 db += diff_dst_arr[ow * width_stride];
1157 }
1158 }
1159 diff_bias[g * jcp.oc + oc] = db;
1160 });
1161
1162 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) {
1163 auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS);
1164 cvt_float_to_bfloat16(
1165 diff_bias_in, diff_bias, jcp.ngroups * jcp.oc);
1166 }
1167 }
1168 return st;
1169 }
1170
1171 template <data_type_t diff_wei_data_type>
1172 status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
execute_backward_weights_ncsp(const exec_ctx_t & ctx) const1173 execute_backward_weights_ncsp(const exec_ctx_t &ctx) const {
1174 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
1175 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
1176 auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
1177
1178 auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
1179 key_conv_gemm_col);
1180 auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>(
1181 key_conv_wei_reduction);
1182
1183 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
1184
1185 acc_data_t *acc_base = diff_wei_data_type == data_type::bf16
1186 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
1187 key_conv_int_dat_in_acc_dt)
1188 : (acc_data_t *)diff_weights;
1189
1190 float *diff_bias = nullptr;
1191 if (jcp.with_bias) {
1192 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16)
1193 diff_bias = ctx.get_scratchpad_grantor().template get<float>(
1194 key_conv_bias_bf16_convert_wsp);
1195 else
1196 diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1197 }
1198
1199 const dim_t K = jcp.os * jcp.od;
1200 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
1201 const size_t dst_step = (size_t)jcp.oc * K;
1202 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
1203
1204 const dim_t k = jcp.os_block;
1205 const dim_t N = jcp.oc;
1206 const dim_t M = jcp.ic * jcp.ks;
1207 const bool is_problem_3d = pd()->ndims() == 5;
1208
1209 std::atomic<status_t> st(status::success);
1210 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1211 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1212 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1213
1214 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1215 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
1216 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
1217
1218 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1219 const int need_reduction = nthr_mb != 1;
1220
1221 if (ithr_g != -1 && ithr_mb != -1) {
1222 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1223 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1224
1225 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1226
1227 src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
1228 // non-blocked jit_gemm_convolution_utils::im2col_3d() requires
1229 // external data initialization by zeroes
1230 const bool outer_padding = jcp.os_nb_block == 1;
1231 if (outer_padding && is_problem_3d) {
1232 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
1233 _col[i] = (src_data_t)0;
1234 }
1235
1236 acc_data_t *weights_reduce_base
1237 = wei_reduction + ithr_g * nthr_mb * weights_g_size;
1238 acc_data_t *weights_reduce
1239 = weights_reduce_base + ithr_mb * weights_g_size;
1240
1241 for (size_t g = g_start; g < g_end; ++g) {
1242 acc_data_t *acc = need_reduction
1243 ? weights_reduce
1244 : (acc_base + g * weights_g_size);
1245 for (size_t mb = mb_start; mb < mb_end; ++mb) {
1246 const src_data_t *_src
1247 = src + (mb * jcp.ngroups + g) * src_step;
1248 for_(int od = 0; od < jcp.od; ++od)
1249 for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
1250 auto out_off = os_nb * k + od * jcp.os;
1251 const dim_t os_block = nstl::min(
1252 (dim_t)jcp.os_block, jcp.os - os_nb * k);
1253 const diff_dst_data_t *_diff_dst = diff_dst
1254 + (mb * jcp.ngroups + g) * dst_step + out_off;
1255
1256 if (jcp.im2col_sz) {
1257 if (!is_problem_3d)
1258 jit_gemm_convolution_utils::im2col<src_data_t>(
1259 jcp, _src, _col, os_nb * jcp.os_block,
1260 os_block, 0, jcp.ic);
1261 else
1262 jit_gemm_convolution_utils::im2col_3d<
1263 src_data_t>(jcp, _src, _col, od,
1264 os_nb * jcp.os_block, os_block);
1265 }
1266
1267 const dim_t LDA = jcp.im2col_sz ? os_block : K;
1268 const acc_data_t zero = 0.0, one = 1.0;
1269 status_t st_thr = gemm_bf16bf16f32("T", "N", &M, &N,
1270 &os_block, &one,
1271 jcp.im2col_sz ? _col : _src + out_off, &LDA,
1272 _diff_dst, &K,
1273 mb == mb_start && os_nb == 0 && od == 0 ? &zero
1274 : &one,
1275 acc, &M);
1276
1277 if (st_thr != status::success) {
1278 st = st_thr;
1279 // Finish the loops early if failure occured.
1280 g = g_end;
1281 mb = mb_end;
1282 od = jcp.od;
1283 os_nb = jcp.os_nb_block;
1284 }
1285 }
1286 }
1287 }
1288 if (need_reduction && dnnl_thr_syncable()) {
1289 dnnl_thr_barrier();
1290 if (st != status::success) return;
1291 diff_wei_data_t *weights_base
1292 = diff_weights + g_start * weights_g_size;
1293 bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp,
1294 weights_reduce_base, weights_base);
1295 } else if (diff_wei_data_type == data_type::bf16
1296 && g_end > g_start) {
1297 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
1298 const size_t work_size = (g_end - g_start) * weights_g_size;
1299 bfloat16_t *diff_weights_local
1300 = (bfloat16_t *)diff_weights + g_start * weights_g_size;
1301 const float *acc_local
1302 = (const float *)acc_base + g_start * weights_g_size;
1303 store_bfloat16_in_parallel(diff_weights_local, acc_local,
1304 work_size, 1, jcp.nthr == 1);
1305 }
1306 } else {
1307 if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
1308 }
1309 });
1310
1311 if (st != status::success) return st;
1312
1313 if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
1314 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1315 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1316 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1317
1318 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1319 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
1320 jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
1321 nthr_mb);
1322
1323 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1324 const int need_reduction = nthr_mb != 1;
1325
1326 if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
1327 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1328 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1329
1330 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1331
1332 acc_data_t *weights_reduce_base
1333 = wei_reduction + ithr_g * nthr_mb * weights_g_size;
1334
1335 diff_wei_data_t *weights_base
1336 = diff_weights + g_start * weights_g_size;
1337 bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp,
1338 weights_reduce_base, weights_base);
1339 }
1340 });
1341 }
1342
1343 if (jcp.with_bias) {
1344 parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
1345 acc_data_t db = 0;
1346 size_t offset_ = (size_t)g * dst_step + (size_t)oc * K;
1347 for (int mb = 0; mb < jcp.mb; ++mb) {
1348 size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step;
1349 for_(int od = 0; od < jcp.od; ++od)
1350 for (int oh = 0; oh < jcp.oh; ++oh)
1351 PRAGMA_OMP_SIMD(reduction(+ : db))
1352 for (int ow = 0; ow < jcp.ow; ++ow) {
1353 db += diff_dst[offset];
1354 offset++;
1355 }
1356 }
1357 diff_bias[g * jcp.oc + oc] = db;
1358 });
1359
1360 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) {
1361 auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS);
1362 cvt_float_to_bfloat16(
1363 diff_bias_in, diff_bias, jcp.ngroups * jcp.oc);
1364 }
1365 }
1366
1367 return st;
1368 }
1369
1370 template struct gemm_bf16_convolution_fwd_t<data_type::f32>;
1371 template struct gemm_bf16_convolution_fwd_t<data_type::bf16>;
1372 template struct gemm_bf16_convolution_bwd_data_t<data_type::f32>;
1373 template struct gemm_bf16_convolution_bwd_data_t<data_type::bf16>;
1374 template struct gemm_bf16_convolution_bwd_weights_t<data_type::f32>;
1375 template struct gemm_bf16_convolution_bwd_weights_t<data_type::bf16>;
1376
1377 } // namespace x64
1378 } // namespace cpu
1379 } // namespace impl
1380 } // namespace dnnl
1381