1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <atomic>
18 
19 #include "oneapi/dnnl/dnnl_types.h"
20 
21 #include "common/bfloat16.hpp"
22 #include "common/c_types_map.hpp"
23 #include "common/dnnl_thread.hpp"
24 #include "common/type_helpers.hpp"
25 #include "common/utils.hpp"
26 #include "cpu/x64/gemm_bf16_convolution.hpp"
27 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace x64 {
33 
34 using namespace dnnl::impl::status;
35 using namespace dnnl::impl::memory_tracking::names;
36 using namespace dnnl::impl::utils;
37 using namespace dnnl::impl::cpu::x64::bf16_support;
38 
39 // Below two stand-alone functions are moved out from execute_backward_data
40 // and execute_backward_weights to avoid warnings with gcc 6.x and 7.x compilers
41 // "declared with greater visibility than the type of its field"
42 // when one lambda function is delcared whithin the other one
store_bfloat16_in_parallel(bfloat16_t * output_data,const float * acc_data,size_t parallel_work,size_t parallel_work_size,bool do_in_parallel)43 void store_bfloat16_in_parallel(bfloat16_t *output_data, const float *acc_data,
44         size_t parallel_work, size_t parallel_work_size, bool do_in_parallel) {
45     parallel(do_in_parallel ? 0 : 1, [&](const int ithr, const int nthr) {
46         size_t start = 0, end = 0;
47         balance211(parallel_work, nthr, ithr, start, end);
48         if (start < end)
49             cvt_float_to_bfloat16(&output_data[start * parallel_work_size],
50                     &acc_data[start * parallel_work_size],
51                     (end - start) * parallel_work_size);
52     });
53 }
54 
cvt_acc_to_dst(const conv_gemm_conf_t & jcp,size_t g_start,size_t g_end,const float * acc_base,bfloat16_t * diff_weights)55 void cvt_acc_to_dst(const conv_gemm_conf_t &jcp, size_t g_start, size_t g_end,
56         const float *acc_base, bfloat16_t *diff_weights) {
57     const size_t parallel_work_size = jcp.ic * jcp.ks;
58     parallel(jcp.nthr == 1 ? 0 : 1, [&](const int ithr, const int nthr) {
59         size_t w_start = 0, w_end = 0;
60         balance211(parallel_work_size, nthr, ithr, w_start, w_end);
61         for_(auto w = w_start; w < w_end; ++w)
62         for (auto g = g_start; g < g_end; ++g) {
63             const float *__restrict acc_ptr
64                     = acc_base + (w * jcp.ngroups + g) * jcp.oc;
65             bfloat16_t *__restrict dw_ptr
66                     = diff_weights + (w * jcp.ngroups + g) * jcp.oc;
67             cvt_float_to_bfloat16(dw_ptr, acc_ptr, jcp.oc);
68         }
69     });
70 }
71 
72 template <data_type_t dst_data_type>
pp_ker_t(const pd_t * pd)73 gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
74     : jcp_(pd->jcp_)
75     , do_sum_(dst_data_type != data_type::f32 && jcp_.with_sum)
76     , max_data_reg_idx_(31)
77     , max_unroll_(12)
78     , compute_reg_step_(1)
79     , data_reg_base_idx_(0) {
80     using namespace types;
81     using namespace Xbyak;
82 
83     if (!mayiuse(avx512_core))
84         // bf16 is not supported
85         return;
86 
87     const auto &post_ops = jcp_.post_ops;
88     if (jcp_.with_eltwise || jcp_.with_binary) {
89 #define PARAM_OFF(field) offsetof(ker_args, field)
90         static constexpr bool preserve_gpr = true;
91         static constexpr bool preserve_vmm = true;
92         static constexpr size_t helper_vmm_idx = 31;
93         static constexpr size_t tail_size = 1;
94         static constexpr bool use_exact_tail_scalar_bcast = false;
95         const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
96                 helper_vmm_idx, reserved_eltwise_gpr, r14, preserve_gpr,
97                 preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
98                 memory_desc_wrapper(pd->dst_md()), tail_size, kreg_rem_mask,
99                 use_exact_tail_scalar_bcast};
100         const binary_injector::static_params_t binary_static_params {
101                 this->reg_param, rhs_arg_static_params};
102         static constexpr bool save_state = true;
103         const eltwise_injector::static_params_t eltwise_static_params {
104                 save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};
105 
106         postops_injector_ = utils::make_unique<
107                 injector::jit_uni_postops_injector_t<avx512_core>>(
108                 this, post_ops, binary_static_params, eltwise_static_params);
109 #undef PARAM_OFF
110     }
111 
112     if (do_sum_) {
113         compute_reg_step_ = 2;
114         vreg_sum_scale = Zmm(data_reg_base_idx_++);
115     }
116 
117     if (jcp_.with_bias) vreg_bias = Zmm(data_reg_base_idx_++);
118 
119     vlen_ = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
120 
121     isa_ = mayiuse(avx512_core_bf16) ? avx512_core_bf16
122                                      : bf16_emulation_t::get_isa();
123 
124     if (isa_ != avx512_core_bf16) {
125         max_data_reg_idx_ = 26;
126         bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
127                 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
128                 bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6);
129     }
130 
131     max_unroll_
132             = (max_data_reg_idx_ - data_reg_base_idx_ + 1) / compute_reg_step_;
133 }
134 
135 template <data_type_t dst_data_type>
apply_postops(const bool apply_mask,const int vmm_idx)136 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::apply_postops(
137         const bool apply_mask, const int vmm_idx) {
138 #define PARAM_OFF(x) offsetof(ker_args, x)
139     if (jcp_.with_eltwise || jcp_.with_binary) {
140         static constexpr int offset = 0;
141         if (jcp_.with_binary) {
142             const auto oc_off_oprnd = this->r12;
143             binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
144             rhs_arg_params.vmm_idx_to_oc_elem_off_addr.emplace(
145                     vmm_idx, ptr[reg_param + PARAM_OFF(g_oc_offset)]);
146             rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(vmm_idx, offset);
147             rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
148                     vmm_idx, oc_off_oprnd);
149             if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
150 
151             const injector_utils::register_preserve_guard_t register_guard(
152                     this, {oc_off_oprnd});
153             mov(oc_off_oprnd,
154                     ptr[rsp + reg_binary_post_op_acc_off
155                             + register_guard.stack_space_occupied()]);
156 
157             postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
158         } else
159             postops_injector_->compute_vector(vmm_idx);
160     }
161 #undef PARAM_OFF
162 }
163 
164 template <data_type_t dst_data_type>
generate()165 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
166     using namespace Xbyak;
167     using namespace utils;
168 
169     preamble();
170 
171 #ifdef _WIN32
172     mov(reg_param, rcx);
173 #endif
174 
175 #define PARAM_OFF(x) offsetof(ker_args, x)
176     mov(reg_dst_base, ptr[reg_param + PARAM_OFF(dst)]);
177     mov(reg_acc_base, ptr[reg_param + PARAM_OFF(acc)]);
178     if (jcp_.with_bias) mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
179     mov(reg_dst_str, ptr[reg_param + PARAM_OFF(dst_stride_in_bytes)]);
180     mov(reg_acc_str, ptr[reg_param + PARAM_OFF(acc_stride_in_bytes)]);
181     mov(reg_len, ptr[reg_param + PARAM_OFF(spatial_length)]);
182     mov(reg_oc_iter, ptr[reg_param + PARAM_OFF(oc_work)]);
183 
184     if (jcp_.with_binary) {
185         // zero initialize binary post_ops offset accumulator (store on stack)
186         const auto binary_post_op_acc_off_reg = reg_tmp;
187         xor_(binary_post_op_acc_off_reg, binary_post_op_acc_off_reg);
188         push(binary_post_op_acc_off_reg);
189     }
190 
191     if (do_sum_)
192         vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
193 #undef PARAM_OFF
194 
195     // Load accumulated value, apply sum (if any), bias (if any)
196     // and relu (if any); then convert to destination type and store
197     auto compute = [&](size_t offset, int idx, bool apply_mask) {
198         auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
199         auto vreg_dst_ = vreg_dst(idx);
200 
201         if (dst_data_type == data_type::bf16 && isa_ != avx512_core_bf16)
202             bf16_emu_->init_vcvtneps2bf16();
203 
204         if (apply_mask) vreg_dst_ = vreg_dst_ | kreg_rem_mask;
205         vmovups(vreg_dst_, acc_addr);
206 
207         if (jcp_.with_bias) vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias);
208 
209         auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
210         if (do_sum_) {
211             auto vreg_prev_dst_ = vreg_prev_dst(idx);
212             if (dst_data_type == data_type::f32) {
213                 if (apply_mask) vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask;
214 
215                 vmovups(vreg_prev_dst_, dst_addr);
216             } else if (dst_data_type == data_type::bf16) {
217                 auto vreg_prev_dst_ymm_ = vreg_prev_dst_ymm(idx);
218                 if (apply_mask)
219                     vreg_prev_dst_ymm_ = vreg_prev_dst_ymm_ | kreg_rem_mask;
220 
221                 vmovdqu16(vreg_prev_dst_ymm_, dst_addr);
222                 vpmovzxwd(vreg_prev_dst(idx), vreg_prev_dst_ymm_);
223                 vpslld(vreg_prev_dst(idx), vreg_prev_dst(idx), 0x10);
224             } else
225                 assert(!"unsupported data type");
226 
227             vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
228         }
229 
230         apply_postops(apply_mask, vreg_dst_idx(idx));
231 
232         if (dst_data_type == data_type::bf16) {
233             // TODO: implement store by zmm registers for bf16
234             auto vreg_dst_ymm_ = vreg_dst_ymm(idx);
235             if (isa_ == avx512_core_bf16)
236                 vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx));
237             else
238                 bf16_emu_->vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx));
239 
240             if (apply_mask) vreg_dst_ymm_ = vreg_dst_ymm_ | kreg_rem_mask;
241 
242             vmovdqu16(dst_addr, vreg_dst_ymm_);
243         } else if (dst_data_type == data_type::f32)
244             vmovups(dst_addr, vreg_dst_);
245         else
246             assert(!"unimplemented");
247     };
248 
249     // Advance all pointers by an immediate
250     auto advance_ptrs_imm = [&](size_t offset) {
251         add(reg_dst, offset * sizeof(dst_data_t));
252         add(reg_acc, offset * sizeof(acc_data_t));
253     };
254 
255     Xbyak::Label oc_loop, oc_loop_end;
256 
257     cmp(reg_oc_iter, 0);
258     jle(oc_loop_end, T_NEAR);
259 
260     L(oc_loop);
261 
262     mov(reg_len_iter, reg_len);
263     mov(reg_dst, reg_dst_base);
264     mov(reg_acc, reg_acc_base);
265 
266     if (jcp_.with_bias) vbroadcastss(vreg_bias, ptr[reg_bias]);
267 
268     constexpr int n_unroll = default_unroll_2_pow_; // unroll by powers of 2
269             // from 2^n to 2^0
270     assert((1 << n_unroll) <= max_unroll_);
271 
272     Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
273     for (int i = n_unroll; i >= 0; i--) {
274         const int unroll = 1 << i; // 4, 2, 1
275         L(l_simd_loop[i + 1]);
276         {
277             const int loop_len = unroll * vlen_;
278             cmp(reg_len_iter, loop_len);
279             jl(l_simd_loop[i], T_NEAR);
280             for (int j = 0; j < unroll; j++)
281                 compute(j * vlen_, j, false);
282 
283             advance_ptrs_imm(loop_len);
284             sub(reg_len_iter, loop_len);
285             jmp(l_simd_loop[i + 1], T_NEAR);
286         }
287     }
288     L(l_simd_loop[0]);
289 
290     mov(reg_tmp, reg_len_iter); // reg_tmp is rcx, and we need cl for the shift
291     mov(reg_rem_mask, 1);
292     shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen_ == 16
293     sub(reg_rem_mask, 1);
294     jz(l_simd_notail, T_NEAR);
295     kmovq(kreg_rem_mask, reg_rem_mask);
296     compute(0, 0, true);
297 
298     L(l_simd_notail);
299 
300     add(reg_dst_base, reg_dst_str);
301     add(reg_acc_base, reg_acc_str);
302     if (jcp_.with_bias) add(reg_bias, sizeof(acc_data_t));
303     if (jcp_.with_binary)
304         inc(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off));
305 
306     dec(reg_oc_iter);
307     jnz(oc_loop, T_NEAR); // oc_loop end
308 
309     L(oc_loop_end);
310 
311     if (jcp_.with_binary) add(rsp, stack_space_needed);
312 
313     postamble();
314 
315     if (jcp_.with_eltwise) postops_injector_->prepare_table();
316 }
317 
318 // operator () specialized for nspc format
319 template <data_type_t dst_data_type>
operator ()(dst_data_t * dst,const acc_data_t * acc,const acc_data_t * bias,float sum_scale,size_t oc_work,const void * post_ops_binary_rhs_arg_vec,const void * dst_orig,const size_t g_oc_offset)320 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
321         dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
322         float sum_scale, size_t oc_work,
323         const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
324         const size_t g_oc_offset) {
325 
326     ker_args args;
327     args.acc = acc;
328     args.dst = dst;
329     args.bias = bias;
330     args.sum_scale = sum_scale;
331     args.dst_stride_in_bytes = sizeof(dst_data_t);
332     args.acc_stride_in_bytes = sizeof(acc_data_t);
333     args.spatial_length = 1;
334     args.oc_work = oc_work;
335 
336     args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
337     args.dst_orig = dst_orig;
338     args.g_oc_offset = g_oc_offset;
339     jit_generator::operator()(&args);
340 }
341 
342 // operator () specialized for ncsp format
343 template <data_type_t dst_data_type>
operator ()(dst_data_t * dst,const acc_data_t * acc,const acc_data_t * bias,float sum_scale,size_t dst_stride_in_elements,size_t acc_stride_in_elements,size_t sp_len,size_t oc_len,const void * post_ops_binary_rhs_arg_vec,const void * dst_orig,const size_t g_oc_offset)344 void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
345         dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
346         float sum_scale, size_t dst_stride_in_elements,
347         size_t acc_stride_in_elements, size_t sp_len, size_t oc_len,
348         const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
349         const size_t g_oc_offset) {
350     if (sp_len == 0) return;
351 
352     ker_args args;
353     args.acc = acc;
354     args.dst = dst;
355     args.bias = bias;
356     args.sum_scale = sum_scale;
357     args.dst_stride_in_bytes = dst_stride_in_elements * sizeof(dst_data_t);
358     args.acc_stride_in_bytes = acc_stride_in_elements * sizeof(acc_data_t);
359     args.spatial_length = sp_len;
360     args.oc_work = oc_len;
361 
362     args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
363     args.dst_orig = dst_orig;
364     args.g_oc_offset = g_oc_offset;
365     jit_generator::operator()(&args);
366 }
367 
368 template <data_type_t dst_data_type>
execute_forward_nspc(const exec_ctx_t & ctx) const369 status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_nspc(
370         const exec_ctx_t &ctx) const {
371     auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
372     auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
373     auto dst_base = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
374     const auto post_ops_binary_rhs_arg_vec
375             = binary_injector::prepare_binary_args(
376                     this->pd()->attr()->post_ops_, ctx);
377 
378     auto scratchpad = ctx.get_scratchpad_grantor();
379     const conv_gemm_conf_t &jcp = pd()->jcp_;
380 
381     float *bia_base = nullptr;
382     if (jcp.with_bias) {
383         if (pd()->desc()->bias_desc.data_type == data_type::bf16) {
384             auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS);
385             bia_base = ctx.get_scratchpad_grantor().template get<float>(
386                     key_conv_bias_bf16_convert_wsp);
387             cvt_bfloat16_to_float(bia_base, bias_in, jcp.ngroups * jcp.oc);
388         } else {
389             auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
390             bia_base = const_cast<float *>(bias_in);
391         }
392     }
393     assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
394 
395     std::atomic<status_t> st(status::success);
396     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
397         status_t st_thr = execute_forward_thr_nspc(ithr, nthr, src_base,
398                 wei_base, bia_base, dst_base, scratchpad,
399                 post_ops_binary_rhs_arg_vec.data());
400         if (st_thr != status::success) st = st_thr;
401     });
402 
403     return st;
404 }
405 
406 template <data_type_t dst_data_type>
execute_forward_thr_nspc(const int ithr,const int nthr,const src_data_t * src_base,const wei_data_t * wei_base,const float * bia_base,dst_data_t * dst_base,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec) const407 status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_thr_nspc(
408         const int ithr, const int nthr, const src_data_t *src_base,
409         const wei_data_t *wei_base, const float *bia_base, dst_data_t *dst_base,
410         const memory_tracking::grantor_t &scratchpad,
411         const void *post_ops_binary_rhs_arg_vec) const {
412     const conv_gemm_conf_t &jcp = pd()->jcp_;
413 
414     // Src Format: mb-spatial-groups-input_channels
415     const size_t src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw
416             * jcp.ngroups * jcp.ic;
417     const size_t src_g_stride = jcp.ic;
418     // Wei Format: spatial-input_channels-groups-output_channels
419     const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
420 
421     // Dst Format: mb-spatial-groups-output_channels
422     const size_t dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh * jcp.ow
423             * jcp.ngroups * jcp.oc;
424     const size_t dst_g_stride = jcp.oc;
425     const size_t dst_os_stride = jcp.ngroups * jcp.oc;
426 
427     src_data_t *__restrict col = scratchpad.get<src_data_t>(key_conv_gemm_col)
428             + (ptrdiff_t)ithr * jcp.im2col_sz;
429     src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
430             + (ptrdiff_t)ithr * jcp.is * jcp.ic;
431     acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
432             + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
433 
434     const auto &post_ops = pd()->attr()->post_ops_;
435     const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
436     const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
437 
438     int g {0}, n {0}, ohb {0}, owb {0};
439     size_t start = 0, end = 0;
440 
441     const bool is_problem_3d = pd()->ndims() == 5;
442     assert(IMPLICATION(is_problem_3d,
443             jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
444                     && jcp.ic_block == jcp.ic));
445 
446     const int nb_oh = div_up(jcp.oh, jcp.oh_block);
447     const int nb_ow = div_up(jcp.ow, jcp.ow_block);
448     const size_t work_amount = (size_t)jcp.ngroups * jcp.mb * nb_oh * nb_ow;
449     balance211(work_amount, nthr, ithr, start, end);
450     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
451 
452     if (jcp.im2col_sz && is_problem_3d) {
453         // jit_gemm_convolution_utils::im2col_dt_3d() requires external
454         // data initialization by zeroes
455         // For performance reasons use uint16_t as a proxy for bfloat16_t
456         uint16_t *__restrict col_r
457                 = reinterpret_cast<uint16_t *__restrict>(col);
458         constexpr uint16_t zero_val = 0;
459 
460         PRAGMA_OMP_SIMD()
461         for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
462             col_r[i] = zero_val;
463     }
464     for (size_t iwork = start; iwork < end; ++iwork) {
465         int oh = ohb * jcp.oh_block;
466         int ow = owb * jcp.ow_block;
467         const src_data_t *__restrict src
468                 = src_base + n * src_mb_stride + g * src_g_stride;
469         const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
470 
471         const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
472         const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
473         if (jcp.im2col_sz && is_problem_3d)
474             jit_gemm_convolution_utils::transpose_dt(jcp, src, imtr);
475 
476         for (int od = 0; od < jcp.od; od++) {
477             dst_data_t *__restrict dst = dst_base + n * dst_mb_stride
478                     + g * dst_g_stride
479                     + ((od * jcp.oh + oh) * jcp.ow + ow) * dst_os_stride;
480             if (jcp.im2col_sz) {
481                 if (is_problem_3d)
482                     jit_gemm_convolution_utils::im2col_dt_3d<src_data_t,
483                             src_data_t>(jcp, imtr, col, od);
484                 else
485                     jit_gemm_convolution_utils::im2col_dt<src_data_t,
486                             src_data_t>(
487                             jcp, src, imtr, col, oh, h_step, ow, w_step);
488             }
489 
490             const dim_t M = jcp.oc;
491             const dim_t K = jcp.ks * jcp.ic;
492             const dim_t N = h_step * w_step;
493             const dim_t LDA = M * jcp.ngroups;
494             const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
495             const char *BT = jcp.im2col_sz ? "T" : "N";
496             const float onef = 1.f;
497             const float beta = this->beta_;
498             const src_data_t *__restrict src_od
499                     = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
500             const bool acc_needed = dst_data_type == data_type::bf16;
501             status_t st = gemm_bf16bf16f32("N", BT, &M, &N, &K, &onef, wei,
502                     &LDA, jcp.im2col_sz ? col : (src_data_t *)src_od, &LDB,
503                     &beta, acc_needed ? acc : (float *)dst,
504                     acc_needed ? &M : &LDA);
505             if (st != status::success) return st;
506 
507             const bool do_postprocess = pd()->is_postprocess_required();
508             if (do_postprocess) {
509                 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1, N,
510                         [&](size_t ithr, size_t nthr, size_t os) {
511                             const float *__restrict acc_arr = acc + os * jcp.oc;
512                             const float *__restrict bia_arr
513                                     = (bia_base == nullptr)
514                                     ? nullptr
515                                     : bia_base + g * jcp.oc;
516                             dst_data_t *__restrict dst_arr
517                                     = dst + os * dst_os_stride;
518 
519                             (*pp_ker_)(dst_arr,
520                                     acc_needed ? acc_arr : (float *)dst_arr,
521                                     bia_arr, sum_scale, jcp.oc,
522                                     post_ops_binary_rhs_arg_vec, dst,
523                                     g * jcp.oc);
524                         });
525             }
526         }
527         nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
528     }
529     return status::success;
530 }
531 
532 template <data_type_t dst_data_type>
execute_forward_ncsp(const exec_ctx_t & ctx) const533 status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_ncsp(
534         const exec_ctx_t &ctx) const {
535     auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
536     auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
537     auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
538     const auto post_ops_binary_rhs_arg_vec
539             = binary_injector::prepare_binary_args(
540                     this->pd()->attr()->post_ops_, ctx);
541 
542     bool is_bf16_dst = dst_data_type == data_type::bf16;
543 
544     auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
545             key_conv_gemm_col);
546     acc_data_t *acc_base = is_bf16_dst
547             ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
548                     key_conv_int_dat_in_acc_dt)
549             : nullptr;
550 
551     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
552 
553     float *bias = nullptr;
554     if (jcp.with_bias) {
555         if (pd()->desc()->bias_desc.data_type == data_type::bf16) {
556             auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS);
557             bias = ctx.get_scratchpad_grantor().template get<float>(
558                     key_conv_bias_bf16_convert_wsp);
559             cvt_bfloat16_to_float(bias, bias_in, jcp.ngroups * jcp.oc);
560         } else {
561             auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
562             bias = const_cast<float *>(bias_in);
563         }
564     }
565 
566     const auto &post_ops = pd()->attr()->post_ops_;
567     const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
568     const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
569 
570     const dim_t M = jcp.os * jcp.od;
571     const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
572     const size_t dst_step = (size_t)jcp.oc * M;
573     const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
574     const size_t weights_oc_size = jcp.ic * jcp.ks;
575 
576     const dim_t LDB = weights_oc_size;
577     const size_t work_amount
578             = (size_t)jcp.ngroups * jcp.mb * jcp.od * jcp.os_nb_block;
579     const bool is_problem_3d = pd()->ndims() == 5;
580     std::atomic<status_t> st(status::success);
581 
582     auto inner_ker = [&](const int ic, const int oc, const int groups,
583                              const int od, const int spatial,
584                              const src_data_t *src, const wei_data_t *weights,
585                              src_data_t *col, dst_data_t *dst, acc_data_t *acc,
586                              int ic_block, int oc_block) {
587         const dim_t os_block = nstl::min(
588                 (dim_t)jcp.os_block, (dim_t)jcp.os - spatial * jcp.os_block);
589 
590         if (jcp.im2col_sz) {
591             if (!is_problem_3d) {
592                 jit_gemm_convolution_utils::im2col<src_data_t>(jcp, src, col,
593                         spatial * jcp.os_block, os_block, ic, ic_block);
594             } else {
595                 assert(jcp.ic_block == jcp.ic);
596                 jit_gemm_convolution_utils::im2col_3d<src_data_t>(
597                         jcp, src, col, od, spatial * jcp.os_block, os_block);
598             }
599         }
600 
601         const acc_data_t one = 1.0;
602         const dim_t N = oc_block;
603         const dim_t K = ic_block * jcp.ks;
604         const dim_t m = os_block;
605         const dim_t LDA = jcp.im2col_sz ? m : M;
606         const dim_t LDC = is_bf16_dst ? m : M;
607         const float beta = (ic == 0) ? this->beta_ : one;
608         auto out_off = spatial * jcp.os_block + od * jcp.os;
609         dst_data_t *dst_local = dst + out_off;
610 
611         status_t st_thr = gemm_bf16bf16f32("N", "N", &m, &N, &K, &one,
612                 jcp.im2col_sz ? col : src + ic * M + out_off, &LDA, weights,
613                 &LDB, &beta, acc, &LDC);
614 
615         if (st_thr != status::success) {
616             st = st_thr;
617             return;
618         }
619 
620         if (this->pd()->is_postprocess_required() && ic + ic_block >= jcp.ic) {
621             size_t acc_str = LDC;
622             size_t dst_str = M;
623             float *bias_ptr = bias ? bias + groups * jcp.oc + oc : nullptr;
624             (*pp_ker_)(dst_local, acc, bias_ptr, sum_scale, dst_str, acc_str, m,
625                     oc_block, post_ops_binary_rhs_arg_vec.data(), dst,
626                     groups * jcp.oc + oc);
627         }
628     };
629 
630     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
631         src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
632         if (is_problem_3d) {
633             // jit_gemm_convolution_utils::im2col_3d() requires external
634             // data initialization by zeroes
635             for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
636                 _col[i] = (src_data_t)0;
637         }
638         int g {0}, n {0}, od {0}, nb_os {0};
639         size_t start = 0, end = 0;
640         size_t oc_start = 0, oc_end = 0;
641 
642         assert(jcp.loop_order == gemm_loop_lbr);
643         balance2D(nthr, ithr, work_amount, start, end, (size_t)jcp.oc, oc_start,
644                 oc_end, (size_t)jcp.nthr_oc);
645 
646         nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os,
647                 jcp.os_nb_block);
648         for (size_t iwork = start; iwork < end; ++iwork) {
649             for_(int oc = (int)oc_start; oc < (int)oc_end; oc += jcp.oc_block)
650             for (int ic = 0; ic < jcp.ic; ic += jcp.ic_block) {
651                 const src_data_t *_src = src + (n * jcp.ngroups + g) * src_step;
652                 const wei_data_t *_weights = weights + g * weights_g_size
653                         + oc * weights_oc_size + ic * jcp.ks;
654                 dst_data_t *_dst_im
655                         = dst + (n * jcp.ngroups + g) * dst_step + oc * M;
656                 auto out_off = nb_os * jcp.os_block + od * jcp.os;
657                 dst_data_t *dst_local = _dst_im + out_off;
658                 const int sizeof_cacheline_float = 16;
659                 acc_data_t *_acc = is_bf16_dst ? acc_base
660                                 + ithr
661                                         * rnd_up(jcp.oc_block * jcp.os_block,
662                                                 sizeof_cacheline_float)
663                                                : (acc_data_t *)dst_local;
664 
665                 const int ic_block = nstl::min(jcp.ic - ic, jcp.ic_block);
666                 const int oc_block = nstl::min(int(oc_end) - oc, jcp.oc_block);
667 
668                 inner_ker(ic, oc, g, od, nb_os, _src, _weights, _col, _dst_im,
669                         _acc, ic_block, oc_block);
670             }
671             nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os,
672                     jcp.os_nb_block);
673         }
674     });
675 
676     return st;
677 }
678 
679 template <data_type_t diff_src_data_type>
680 status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
execute_backward_data_nspc(const exec_ctx_t & ctx) const681         execute_backward_data_nspc(const exec_ctx_t &ctx) const {
682 
683     auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
684     auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
685     auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
686 
687     auto scratchpad = ctx.get_scratchpad_grantor();
688     const conv_gemm_conf_t &jcp = pd()->jcp_;
689 
690     std::atomic<status_t> st(status::success);
691     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
692         status_t st_thr = execute_backward_data_thr_nspc(
693                 ithr, nthr, diff_src_base, wei_base, diff_dst_base, scratchpad);
694         if (st_thr != status::success) st = st_thr;
695     });
696 
697     return st;
698 }
699 
700 template <data_type_t diff_src_data_type>
701 status_t gemm_bf16_convolution_bwd_data_t<
execute_backward_data_thr_nspc(const int ithr,const int nthr,diff_src_data_t * diff_src_base,const wei_data_t * wei_base,const diff_dst_data_t * diff_dst_base,const memory_tracking::grantor_t & scratchpad) const702         diff_src_data_type>::execute_backward_data_thr_nspc(const int ithr,
703         const int nthr, diff_src_data_t *diff_src_base,
704         const wei_data_t *wei_base, const diff_dst_data_t *diff_dst_base,
705         const memory_tracking::grantor_t &scratchpad) const {
706 
707     const conv_gemm_conf_t &jcp = pd()->jcp_;
708 
709     // Diff_dst Format: mb-spatial-groups-output_channels
710     const size_t diff_dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh
711             * jcp.ow * jcp.ngroups * jcp.oc;
712     const size_t diff_dst_g_stride = jcp.oc;
713 
714     // Wei Format: spatial-input_channels-groups-output_channels
715     const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
716 
717     // Diff_src Format: mb-spatial-groups-input_channels
718     const size_t diff_src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih
719             * jcp.iw * jcp.ngroups * jcp.ic;
720     const size_t diff_src_g_stride = jcp.ic;
721     const size_t diff_src_os_stride = jcp.ngroups * jcp.ic;
722 
723     // threads share work across mini-batch and groups
724     const size_t work_amount = jcp.ngroups * jcp.mb;
725 
726     acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
727             + (ptrdiff_t)ithr * jcp.im2col_sz;
728     acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
729             + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic;
730 
731     int n {0}, g {0};
732     size_t start = 0, end = 0;
733 
734     balance211(work_amount, nthr, ithr, start, end);
735     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
736 
737     for (size_t iwork = start; iwork < end; ++iwork) {
738         const diff_dst_data_t *__restrict diff_dst = diff_dst_base
739                 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
740         const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
741         diff_src_data_t *__restrict diff_src = diff_src_base
742                 + n * diff_src_mb_stride + g * diff_src_g_stride;
743 
744         const dim_t M = jcp.ks * jcp.ic;
745         const dim_t N = jcp.os * jcp.od;
746         const dim_t K = jcp.oc;
747 
748         const float onef = 1.0f, zerof = 0.0f;
749         const dim_t LD = K * jcp.ngroups;
750 
751         status_t st = gemm_bf16bf16f32("T", "N", &M, &N, &K, &onef, wei, &LD,
752                 diff_dst, &LD, &zerof, jcp.im2col_sz ? col : acc, &M);
753         if (st != status::success) return st;
754 
755         if (jcp.im2col_sz)
756             jit_gemm_convolution_utils::col2im_dt<acc_data_t>(jcp, col, acc);
757 
758         const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16;
759 
760         if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) {
761             cvt_float_to_bfloat16((bfloat16_t *)diff_src, (const float *)acc,
762                     static_cast<size_t>(jcp.is) * jcp.id * jcp.ic);
763         } else if (is_diff_src_bf16) {
764             parallel_nd_ext(jcp.nthr == 1 ? 0 : 1,
765                     static_cast<size_t>(jcp.is) * jcp.id,
766                     [&](size_t ithr, size_t nthr, size_t is) {
767                         diff_src_data_t *__restrict diff_src_loc
768                                 = diff_src + is * diff_src_os_stride;
769                         const acc_data_t *__restrict acc_loc
770                                 = acc + is * jcp.ic;
771                         cvt_float_to_bfloat16((bfloat16_t *)diff_src_loc,
772                                 (const float *)acc_loc, jcp.ic);
773                     });
774         } else {
775             assert(diff_src_data_type == data_type::f32);
776             parallel_nd_ext(jcp.nthr == 1 ? 0 : 1,
777                     static_cast<size_t>(jcp.is) * jcp.id,
778                     [&](size_t ithr, size_t nthr, size_t is) {
779                         diff_src_data_t *__restrict diff_src_loc
780                                 = diff_src + is * diff_src_os_stride;
781                         const acc_data_t *__restrict acc_loc
782                                 = acc + is * jcp.ic;
783                         PRAGMA_OMP_SIMD()
784                         for (int ic = 0; ic < jcp.ic; ++ic)
785                             diff_src_loc[ic] = acc_loc[ic];
786                     });
787         }
788         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
789     }
790     return status::success;
791 }
792 
793 template <data_type_t diff_src_data_type>
794 status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
execute_backward_data_ncsp(const exec_ctx_t & ctx) const795         execute_backward_data_ncsp(const exec_ctx_t &ctx) const {
796     auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
797     auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
798     auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
799 
800     auto col = ctx.get_scratchpad_grantor().template get<acc_data_t>(
801             key_conv_gemm_col);
802     acc_data_t *acc_base = diff_src_data_type == data_type::bf16
803             ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
804                     key_conv_int_dat_in_acc_dt)
805             : nullptr;
806 
807     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
808 
809     const dim_t M = jcp.os * jcp.od;
810     const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
811     const size_t dst_step = (size_t)jcp.oc * M;
812     const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
813 
814     const dim_t m = jcp.os_block;
815     const dim_t K = jcp.oc;
816     const dim_t N = jcp.ic * jcp.ks;
817 
818     const size_t work_amount = (size_t)jcp.ngroups * jcp.mb;
819     const bool is_problem_3d = pd()->ndims() == 5;
820 
821     std::atomic<status_t> st(status::success);
822 
823     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
824         acc_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
825 
826         int g {0}, n {0};
827         size_t start = 0, end = 0;
828         balance211(work_amount, nthr, ithr, start, end);
829         nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb);
830         for (size_t iwork = start; iwork < end; ++iwork) {
831 
832             diff_src_data_t *diff_src_local
833                     = diff_src + (n * jcp.ngroups + g) * src_step;
834             acc_data_t *acc = diff_src_data_type == data_type::bf16
835                     ? acc_base + ithr * rnd_up(src_step, 16)
836                     : (acc_data_t *)diff_src_local;
837 
838             if (is_problem_3d && jcp.im2col_sz > 0) {
839                 // jit_gemm_convolution_utils::col2im_3d() assumes that the
840                 // accumulator is initialized by zeroes
841                 for (size_t i = 0; i < src_step; i++)
842                     acc[i] = (acc_data_t)0;
843             }
844 
845             const wei_data_t *_weights = weights + g * weights_g_size;
846             for_(int od = 0; od < jcp.od; ++od)
847             for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
848                 auto out_off = os_nb * m + od * jcp.os;
849                 const diff_dst_data_t *_diff_dst
850                         = diff_dst + (n * jcp.ngroups + g) * dst_step + out_off;
851                 const dim_t os_block
852                         = nstl::min((dim_t)jcp.os_block, jcp.os - os_nb * m);
853                 const dim_t LDC = jcp.im2col_sz ? os_block : M;
854 
855                 const acc_data_t zero = 0.0, one = 1.0;
856                 status_t st_thr = gemm_bf16bf16f32("N", "T", &os_block, &N, &K,
857                         &one, _diff_dst, &M, _weights, &N, &zero,
858                         jcp.im2col_sz ? _col : acc + out_off, &LDC);
859 
860                 if (st_thr != status::success) {
861                     st = st_thr;
862                     return;
863                 }
864 
865                 if (jcp.im2col_sz) {
866                     if (!is_problem_3d)
867                         jit_gemm_convolution_utils::col2im(
868                                 jcp, _col, acc, os_nb * jcp.os_block, os_block);
869                     else
870                         jit_gemm_convolution_utils::col2im_3d(jcp, _col, acc,
871                                 od, os_nb * jcp.os_block, os_block);
872                 }
873             }
874             if (diff_src_data_type == data_type::bf16) {
875                 size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id;
876                 store_bfloat16_in_parallel((bfloat16_t *)diff_src_local,
877                         (const float *)acc, jcp.ic, spatial_size,
878                         jcp.nthr == 1);
879             }
880             nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
881         }
882     });
883 
884     return st;
885 }
886 
887 template <data_type_t diff_wei_data_type>
888 void gemm_bf16_convolution_bwd_weights_t<
bf16_bwd_weights_reduction_par_nspc(int ithr_mb,int nthr_mb,size_t g_start,size_t g_end,const conv_gemm_conf_t & jcp,const acc_data_t * weights_reduce_base,diff_wei_data_t * weights_base) const889         diff_wei_data_type>::bf16_bwd_weights_reduction_par_nspc(int ithr_mb,
890         int nthr_mb, size_t g_start, size_t g_end, const conv_gemm_conf_t &jcp,
891         const acc_data_t *weights_reduce_base,
892         diff_wei_data_t *weights_base) const {
893     assert(nthr_mb > 1); // no reduction for nthr_mb == 1
894 
895     const bool is_bf16_out = diff_wei_data_type == data_type::bf16;
896     const size_t weights_g_size = jcp.oc;
897     size_t weights_start {0}, weights_end {0};
898     balance211(size_t(jcp.ks) * jcp.ic, nthr_mb, ithr_mb, weights_start,
899             weights_end);
900 
901     for (auto tidx = 1; tidx < nthr_mb; ++tidx) {
902         const acc_data_t *ws_base
903                 = weights_reduce_base + tidx * weights_g_size * jcp.ks * jcp.ic;
904         for_(auto w = weights_start; w < weights_end; ++w)
905         for (auto g = g_start; g < g_end; ++g) {
906             const acc_data_t *ws_ptr = ws_base + w * jcp.oc;
907             float *wei_reduced = is_bf16_out
908                     ? (float *)weights_reduce_base + w * jcp.oc
909                     : (float *)weights_base + (w * jcp.ngroups + g) * jcp.oc;
910             if (is_bf16_out && tidx == nthr_mb - 1) {
911                 // the last iteration for bfloat16 requires conversion
912                 // and store to diff_weights array
913                 diff_wei_data_t *dwei_ptr
914                         = weights_base + (w * jcp.ngroups + g) * jcp.oc;
915                 add_floats_and_cvt_to_bfloat16(
916                         (bfloat16_t *)(dwei_ptr), wei_reduced, ws_ptr, jcp.oc);
917             } else {
918                 acc_ker_->accumulate(wei_reduced, ws_ptr, jcp.oc);
919             }
920         }
921     }
922 }
923 
924 template <data_type_t diff_wei_data_type>
925 void gemm_bf16_convolution_bwd_weights_t<
bf16_bwd_weights_reduction_par_ncsp(int ithr_mb,int nthr_mb,const conv_gemm_conf_t & jcp,const acc_data_t * weights_reduce_base,diff_wei_data_t * weights_base) const926         diff_wei_data_type>::bf16_bwd_weights_reduction_par_ncsp(int ithr_mb,
927         int nthr_mb, const conv_gemm_conf_t &jcp,
928         const acc_data_t *weights_reduce_base,
929         diff_wei_data_t *weights_base) const {
930     assert(nthr_mb > 1); // no reduction for nthr_mb == 1
931 
932     const bool is_bf16_out = diff_wei_data_type == data_type::bf16;
933     const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
934     size_t weights_start {0}, weights_end {0};
935     balance211(weights_g_size, nthr_mb, ithr_mb, weights_start, weights_end);
936 
937     if (weights_start >= weights_end) return; // nothing to do
938 
939     size_t acc_size = weights_end - weights_start;
940     float *wei_reduced = is_bf16_out
941             ? (float *)weights_reduce_base + weights_start
942             : (float *)weights_base + weights_start;
943     if (!is_bf16_out) {
944         // f32 diff_weights require initialization by weights_reduce
945         // for thr_mb = 0
946         for (size_t i = 0; i < acc_size; i++)
947             wei_reduced[i] = ((float *)weights_reduce_base + weights_start)[i];
948     }
949 
950     for (int thr_mb = 1; thr_mb < nthr_mb; ++thr_mb) {
951         float *wei_to_reduce = (float *)weights_reduce_base
952                 + thr_mb * weights_g_size + weights_start;
953 
954         if (is_bf16_out && thr_mb == nthr_mb - 1)
955             // the last iteration for bfloat16 requires conversion
956             // and store to diff_weights array
957             add_floats_and_cvt_to_bfloat16(
958                     (bfloat16_t *)(weights_base + weights_start), wei_reduced,
959                     wei_to_reduce, acc_size);
960         else
961             acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size);
962     }
963 }
964 
965 template <data_type_t diff_wei_data_type>
966 status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
execute_backward_weights_nspc(const exec_ctx_t & ctx) const967         execute_backward_weights_nspc(const exec_ctx_t &ctx) const {
968     auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
969     auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
970     auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
971 
972     auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
973             key_conv_gemm_col);
974     auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>(
975             key_conv_wei_reduction);
976     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
977 
978     acc_data_t *acc_base = diff_wei_data_type == data_type::bf16
979             ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
980                     key_conv_int_dat_in_acc_dt)
981             : (acc_data_t *)diff_weights;
982 
983     float *diff_bias = nullptr;
984     if (jcp.with_bias) {
985         if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16)
986             diff_bias = ctx.get_scratchpad_grantor().template get<float>(
987                     key_conv_bias_bf16_convert_wsp);
988         else
989             diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
990     }
991 
992     const dim_t K = jcp.os * static_cast<size_t>(jcp.od);
993     const size_t src_step
994             = static_cast<size_t>(jcp.ic) * jcp.ih * jcp.iw * jcp.id;
995     const size_t dst_step = jcp.oc * K;
996     const size_t weights_g_size = jcp.oc;
997 
998     const dim_t k = jcp.os;
999     const dim_t M = jcp.oc;
1000     const dim_t N = static_cast<dim_t>(jcp.ic) * jcp.ks;
1001     const dim_t LDB = jcp.ngroups * jcp.oc;
1002     const dim_t LDA = jcp.im2col_sz ? jcp.oh * jcp.ow : jcp.ngroups * jcp.ic;
1003     const bool is_problem_3d = pd()->ndims() == 5;
1004 
1005     std::atomic<status_t> st(status::success);
1006 
1007     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1008         int ithr_g, nthr_g, ithr_mb, nthr_mb;
1009         size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1010 
1011         const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1012         jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
1013                 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
1014 
1015         assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1016 
1017         const int need_reduction = nthr_mb != 1;
1018         src_data_t *__restrict imtr
1019                 = ctx.get_scratchpad_grantor().template get<src_data_t>(
1020                           key_conv_gemm_imtr)
1021                 + (ptrdiff_t)ithr * jcp.id * jcp.ic * jcp.is;
1022 
1023         if (ithr_g != -1 && ithr_mb != -1) {
1024             balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1025             balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1026 
1027             assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1028 
1029             src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
1030             if (is_problem_3d) {
1031                 // jit_gemm_convolution_utils::im2col_3d() requires external
1032                 // data initialization by zeroes
1033                 // For performance reasons use uint16_t as proxy for bfloat16_t
1034                 uint16_t *__restrict _col_r
1035                         = reinterpret_cast<uint16_t *__restrict>(_col);
1036                 constexpr uint16_t zero_val = 0;
1037 
1038                 PRAGMA_OMP_SIMD()
1039                 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
1040                     _col_r[i] = zero_val;
1041             }
1042 
1043             acc_data_t *weights_reduce_base = wei_reduction
1044                     + ithr_g * nthr_mb * weights_g_size * jcp.ks * jcp.ic;
1045             acc_data_t *weights_reduce = weights_reduce_base
1046                     + ithr_mb * weights_g_size * jcp.ks * jcp.ic;
1047 
1048             const bool use_diff_wei
1049                     = ithr_mb == 0 && diff_wei_data_type == data_type::f32;
1050             for (size_t g = g_start; g < g_end; ++g) {
1051                 acc_data_t *_diff_weights = use_diff_wei
1052                         ? (acc_data_t *)diff_weights + g * weights_g_size
1053                         : need_reduction ? weights_reduce
1054                                          : acc_base + g * weights_g_size;
1055                 const dim_t LDC = use_diff_wei
1056                         ? jcp.ngroups * jcp.oc
1057                         : need_reduction ? jcp.oc : jcp.ngroups * jcp.oc;
1058                 for (size_t mb = mb_start; mb < mb_end; ++mb) {
1059                     const src_data_t *_src
1060                             = src + mb * jcp.ngroups * src_step + g * jcp.ic;
1061                     if (jcp.im2col_sz && is_problem_3d)
1062                         jit_gemm_convolution_utils::transpose_dt(
1063                                 jcp, _src, imtr);
1064                     for (int od = 0; od < jcp.od; ++od) {
1065                         const diff_dst_data_t *_diff_dst = diff_dst
1066                                 + mb * jcp.ngroups * dst_step
1067                                 + od * k * jcp.ngroups * jcp.oc + g * jcp.oc;
1068 
1069                         if (jcp.im2col_sz) {
1070                             if (is_problem_3d)
1071                                 jit_gemm_convolution_utils::im2col_dt_3d<
1072                                         src_data_t, src_data_t>(
1073                                         jcp, imtr, _col, od);
1074                             else
1075                                 jit_gemm_convolution_utils::im2col_dt<
1076                                         src_data_t, src_data_t>(jcp, _src, imtr,
1077                                         _col, 0, jcp.oh, 0, jcp.ow);
1078                         }
1079                         const float zero = 0.0f, one = 1.0f;
1080                         status_t st_thr = gemm_bf16bf16f32("N",
1081                                 jcp.im2col_sz ? "N" : "T", &M, &N, &k, &one,
1082                                 _diff_dst, &LDB,
1083                                 jcp.im2col_sz
1084                                         ? _col
1085                                         : _src + od * k * jcp.ngroups * jcp.ic,
1086                                 &LDA, mb == mb_start && od == 0 ? &zero : &one,
1087                                 _diff_weights, &LDC);
1088                         if (st_thr != status::success) {
1089                             st = st_thr;
1090                             // Finish the loops early if failure occured.
1091                             g = g_end;
1092                             mb = mb_end;
1093                             od = jcp.od;
1094                         }
1095                     }
1096                 }
1097             }
1098             if (need_reduction && dnnl_thr_syncable()) {
1099                 dnnl_thr_barrier();
1100                 if (st != status::success) return;
1101                 bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start,
1102                         g_end, jcp, weights_reduce_base, diff_weights);
1103             } else if (diff_wei_data_type == data_type::bf16
1104                     && g_end > g_start) {
1105                 cvt_acc_to_dst(jcp, g_start, g_end, (const float *)acc_base,
1106                         (bfloat16_t *)diff_weights);
1107             }
1108         } else {
1109             if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
1110         }
1111     });
1112 
1113     if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
1114         parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1115             int ithr_g, nthr_g, ithr_mb, nthr_mb;
1116             size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1117 
1118             const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1119             jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
1120                     jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
1121                     nthr_mb);
1122 
1123             assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1124             const int need_reduction = nthr_mb != 1;
1125 
1126             if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
1127                 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1128                 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1129 
1130                 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1131 
1132                 acc_data_t *weights_reduce_base = wei_reduction
1133                         + ithr_g * nthr_mb * weights_g_size * jcp.ic * jcp.ks;
1134 
1135                 bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start,
1136                         g_end, jcp, weights_reduce_base, diff_weights);
1137             }
1138         });
1139     }
1140 
1141     if (jcp.with_bias) {
1142         parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
1143             acc_data_t db = 0;
1144             const size_t offset_base = g * jcp.oc + oc;
1145             for_(int mb = 0; mb < jcp.mb; ++mb)
1146             for_(int od = 0; od < jcp.od; ++od)
1147             for (int oh = 0; oh < jcp.oh; ++oh) {
1148                 const int width_stride = jcp.ngroups * jcp.oc;
1149                 const diff_dst_data_t *__restrict diff_dst_arr = diff_dst
1150                         + offset_base
1151                         + ((mb * jcp.od + od) * jcp.oh + oh) * jcp.ow
1152                                 * width_stride;
1153 
1154                 PRAGMA_OMP_SIMD(reduction(+ : db))
1155                 for (int ow = 0; ow < jcp.ow; ++ow) {
1156                     db += diff_dst_arr[ow * width_stride];
1157                 }
1158             }
1159             diff_bias[g * jcp.oc + oc] = db;
1160         });
1161 
1162         if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) {
1163             auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS);
1164             cvt_float_to_bfloat16(
1165                     diff_bias_in, diff_bias, jcp.ngroups * jcp.oc);
1166         }
1167     }
1168     return st;
1169 }
1170 
1171 template <data_type_t diff_wei_data_type>
1172 status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
execute_backward_weights_ncsp(const exec_ctx_t & ctx) const1173         execute_backward_weights_ncsp(const exec_ctx_t &ctx) const {
1174     auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
1175     auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
1176     auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
1177 
1178     auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
1179             key_conv_gemm_col);
1180     auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>(
1181             key_conv_wei_reduction);
1182 
1183     const conv_gemm_conf_t &jcp = this->pd()->jcp_;
1184 
1185     acc_data_t *acc_base = diff_wei_data_type == data_type::bf16
1186             ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
1187                     key_conv_int_dat_in_acc_dt)
1188             : (acc_data_t *)diff_weights;
1189 
1190     float *diff_bias = nullptr;
1191     if (jcp.with_bias) {
1192         if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16)
1193             diff_bias = ctx.get_scratchpad_grantor().template get<float>(
1194                     key_conv_bias_bf16_convert_wsp);
1195         else
1196             diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1197     }
1198 
1199     const dim_t K = jcp.os * jcp.od;
1200     const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
1201     const size_t dst_step = (size_t)jcp.oc * K;
1202     const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
1203 
1204     const dim_t k = jcp.os_block;
1205     const dim_t N = jcp.oc;
1206     const dim_t M = jcp.ic * jcp.ks;
1207     const bool is_problem_3d = pd()->ndims() == 5;
1208 
1209     std::atomic<status_t> st(status::success);
1210     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1211         int ithr_g, nthr_g, ithr_mb, nthr_mb;
1212         size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1213 
1214         const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1215         jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
1216                 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
1217 
1218         assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1219         const int need_reduction = nthr_mb != 1;
1220 
1221         if (ithr_g != -1 && ithr_mb != -1) {
1222             balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1223             balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1224 
1225             assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1226 
1227             src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
1228             // non-blocked jit_gemm_convolution_utils::im2col_3d() requires
1229             // external data initialization by zeroes
1230             const bool outer_padding = jcp.os_nb_block == 1;
1231             if (outer_padding && is_problem_3d) {
1232                 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
1233                     _col[i] = (src_data_t)0;
1234             }
1235 
1236             acc_data_t *weights_reduce_base
1237                     = wei_reduction + ithr_g * nthr_mb * weights_g_size;
1238             acc_data_t *weights_reduce
1239                     = weights_reduce_base + ithr_mb * weights_g_size;
1240 
1241             for (size_t g = g_start; g < g_end; ++g) {
1242                 acc_data_t *acc = need_reduction
1243                         ? weights_reduce
1244                         : (acc_base + g * weights_g_size);
1245                 for (size_t mb = mb_start; mb < mb_end; ++mb) {
1246                     const src_data_t *_src
1247                             = src + (mb * jcp.ngroups + g) * src_step;
1248                     for_(int od = 0; od < jcp.od; ++od)
1249                     for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
1250                         auto out_off = os_nb * k + od * jcp.os;
1251                         const dim_t os_block = nstl::min(
1252                                 (dim_t)jcp.os_block, jcp.os - os_nb * k);
1253                         const diff_dst_data_t *_diff_dst = diff_dst
1254                                 + (mb * jcp.ngroups + g) * dst_step + out_off;
1255 
1256                         if (jcp.im2col_sz) {
1257                             if (!is_problem_3d)
1258                                 jit_gemm_convolution_utils::im2col<src_data_t>(
1259                                         jcp, _src, _col, os_nb * jcp.os_block,
1260                                         os_block, 0, jcp.ic);
1261                             else
1262                                 jit_gemm_convolution_utils::im2col_3d<
1263                                         src_data_t>(jcp, _src, _col, od,
1264                                         os_nb * jcp.os_block, os_block);
1265                         }
1266 
1267                         const dim_t LDA = jcp.im2col_sz ? os_block : K;
1268                         const acc_data_t zero = 0.0, one = 1.0;
1269                         status_t st_thr = gemm_bf16bf16f32("T", "N", &M, &N,
1270                                 &os_block, &one,
1271                                 jcp.im2col_sz ? _col : _src + out_off, &LDA,
1272                                 _diff_dst, &K,
1273                                 mb == mb_start && os_nb == 0 && od == 0 ? &zero
1274                                                                         : &one,
1275                                 acc, &M);
1276 
1277                         if (st_thr != status::success) {
1278                             st = st_thr;
1279                             // Finish the loops early if failure occured.
1280                             g = g_end;
1281                             mb = mb_end;
1282                             od = jcp.od;
1283                             os_nb = jcp.os_nb_block;
1284                         }
1285                     }
1286                 }
1287             }
1288             if (need_reduction && dnnl_thr_syncable()) {
1289                 dnnl_thr_barrier();
1290                 if (st != status::success) return;
1291                 diff_wei_data_t *weights_base
1292                         = diff_weights + g_start * weights_g_size;
1293                 bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp,
1294                         weights_reduce_base, weights_base);
1295             } else if (diff_wei_data_type == data_type::bf16
1296                     && g_end > g_start) {
1297                 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
1298                 const size_t work_size = (g_end - g_start) * weights_g_size;
1299                 bfloat16_t *diff_weights_local
1300                         = (bfloat16_t *)diff_weights + g_start * weights_g_size;
1301                 const float *acc_local
1302                         = (const float *)acc_base + g_start * weights_g_size;
1303                 store_bfloat16_in_parallel(diff_weights_local, acc_local,
1304                         work_size, 1, jcp.nthr == 1);
1305             }
1306         } else {
1307             if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
1308         }
1309     });
1310 
1311     if (st != status::success) return st;
1312 
1313     if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
1314         parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1315             int ithr_g, nthr_g, ithr_mb, nthr_mb;
1316             size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1317 
1318             const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1319             jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
1320                     jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
1321                     nthr_mb);
1322 
1323             assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1324             const int need_reduction = nthr_mb != 1;
1325 
1326             if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
1327                 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1328                 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1329 
1330                 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1331 
1332                 acc_data_t *weights_reduce_base
1333                         = wei_reduction + ithr_g * nthr_mb * weights_g_size;
1334 
1335                 diff_wei_data_t *weights_base
1336                         = diff_weights + g_start * weights_g_size;
1337                 bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp,
1338                         weights_reduce_base, weights_base);
1339             }
1340         });
1341     }
1342 
1343     if (jcp.with_bias) {
1344         parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
1345             acc_data_t db = 0;
1346             size_t offset_ = (size_t)g * dst_step + (size_t)oc * K;
1347             for (int mb = 0; mb < jcp.mb; ++mb) {
1348                 size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step;
1349                 for_(int od = 0; od < jcp.od; ++od)
1350                 for (int oh = 0; oh < jcp.oh; ++oh)
1351                     PRAGMA_OMP_SIMD(reduction(+ : db))
1352                 for (int ow = 0; ow < jcp.ow; ++ow) {
1353                     db += diff_dst[offset];
1354                     offset++;
1355                 }
1356             }
1357             diff_bias[g * jcp.oc + oc] = db;
1358         });
1359 
1360         if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) {
1361             auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS);
1362             cvt_float_to_bfloat16(
1363                     diff_bias_in, diff_bias, jcp.ngroups * jcp.oc);
1364         }
1365     }
1366 
1367     return st;
1368 }
1369 
1370 template struct gemm_bf16_convolution_fwd_t<data_type::f32>;
1371 template struct gemm_bf16_convolution_fwd_t<data_type::bf16>;
1372 template struct gemm_bf16_convolution_bwd_data_t<data_type::f32>;
1373 template struct gemm_bf16_convolution_bwd_data_t<data_type::bf16>;
1374 template struct gemm_bf16_convolution_bwd_weights_t<data_type::f32>;
1375 template struct gemm_bf16_convolution_bwd_weights_t<data_type::bf16>;
1376 
1377 } // namespace x64
1378 } // namespace cpu
1379 } // namespace impl
1380 } // namespace dnnl
1381