1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/bfloat16.hpp"
18 #include "common/c_types_map.hpp"
19 #include "common/math_utils.hpp"
20 #include "common/nstl.hpp"
21 #include "common/type_helpers.hpp"
22 #include "common/utils.hpp"
23 
24 #include "cpu/platform.hpp"
25 #include "cpu/x64/cpu_barrier.hpp"
26 
27 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28 #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
29 #include "cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp"
30 
31 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
32 
33 namespace dnnl {
34 namespace impl {
35 namespace cpu {
36 namespace x64 {
37 
38 using namespace format_tag;
39 using namespace dnnl::impl::memory_tracking::names;
40 using namespace dnnl::impl::utils;
41 using namespace Xbyak;
42 
43 namespace {
44 
45 constexpr auto small_spatial = 14;
46 
pick_loop_order(jit_conv_conf_t & jcp)47 inline void pick_loop_order(jit_conv_conf_t &jcp) {
48     using namespace prop_kind;
49     assert(one_of(
50             jcp.prop_kind, forward_training, forward_inference, backward_data));
51     auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
52     auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
53 
54     if (utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
55                 format_tag::nwc)
56             && jcp.ngroups > 1 && jcp.oc < 16) {
57         jcp.loop_order = loop_nhwcg;
58     } else if (jcp.prop_kind == backward_data) {
59         // ow-threading is currently implemented for forward only
60         // TODO: single code for fwd and bwd after ow-thr for bwd
61         // meaningless switch was removed
62         if (jcp.ndims < 5)
63             jcp.loop_order = (w <= small_spatial && h <= small_spatial)
64                     ? loop_cwgn
65                     : loop_gncw;
66         else
67             jcp.loop_order = (w <= small_spatial && h <= small_spatial)
68                     ? loop_cgn
69                     : loop_gnc;
70     } else {
71         jcp.loop_order = (w <= small_spatial && h <= small_spatial) ? loop_cwgn
72                                                                     : loop_gncw;
73     }
74 }
is_ow_threading_available(const jit_conv_conf_t & jcp)75 inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
76     /*is 1D conv */
77     return (jcp.id == 1 && jcp.ih == 1 && jcp.kd == 1 && jcp.kh == 1);
78 }
is_ow_threading_on(const jit_conv_conf_t & jcp)79 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
80     return (jcp.nb_ow > 1);
81 }
is_iw_threading_available(const jit_conv_conf_t & jcp)82 inline bool is_iw_threading_available(const jit_conv_conf_t &jcp) {
83     return one_of(jcp.ndims, 3, 4);
84 }
is_iw_threading_on(const jit_conv_conf_t & jcp)85 inline bool is_iw_threading_on(const jit_conv_conf_t &jcp) {
86     return (jcp.nb_iw > 1);
87 }
is_1stconv(const jit_conv_conf_t & jcp)88 inline bool is_1stconv(const jit_conv_conf_t &jcp) {
89     const bool no_big_offt = nstl::max<size_t>(jcp.ic, jcp.oc)
90                     * nstl::max(jcp.typesize_in, jcp.typesize_out) * jcp.id
91                     * jcp.ih * jcp.iw
92             < INT_MAX;
93     return jcp.ic < 16 && jcp.ngroups == 1 && no_big_offt;
94 }
95 } // namespace
96 
97 template <typename Vmm>
_jit_avx512_core_bf16_fwd_kernel(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_t & dst_md)98 _jit_avx512_core_bf16_fwd_kernel<Vmm>::_jit_avx512_core_bf16_fwd_kernel(
99         const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
100         const memory_desc_t &dst_md)
101     : jit_generator(nullptr, ker_code_size, true, avx512_core_bf16)
102     , jcp(ajcp)
103     , attr_(attr) {
104     if (jcp.with_eltwise || jcp.with_binary) {
105         using namespace binary_injector;
106         static constexpr bool preserve_gpr = true;
107         static constexpr bool preserve_vmm = false;
108         static constexpr size_t helper_vmm_idx = 31;
109         const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
110         const size_t tail_size = oc_block_tail
111                 ? oc_block_tail
112                 : jcp.oc_without_padding % isa_simd_width_;
113         static constexpr bool use_exact_tail_scalar_bcast = true;
114 
115         const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
116                 r14, r15, preserve_gpr, preserve_vmm,
117                 GET_OFF(post_ops_binary_rhs_arg_vec),
118                 memory_desc_wrapper(dst_md), tail_size, postops_mask,
119                 use_exact_tail_scalar_bcast};
120         const static_params_t static_params {
121                 this->param1, rhs_arg_static_params};
122 
123         postops_injector_ = utils::make_unique<
124                 injector::jit_uni_postops_injector_t<avx512_core>>(
125                 this, jcp.post_ops, static_params);
126     }
127     if (!isa_has_bf16(jcp.isa))
128         bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
129                 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
130                 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5);
131 }
132 
133 template <typename Vmm>
prepare_dst(int ur_w)134 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::prepare_dst(int ur_w) {
135     for (int k = 0; k < jcp.nb_oc_blocking; k++)
136         for (int j = 0; j < ur_w; j++) {
137             Vmm vmm = vmm_dst(j, k);
138             vpxord(vmm, vmm, vmm);
139         }
140 }
141 
142 template <typename Vmm>
vmm_dst_idx(int i_ur,int i_oc) const143 int _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst_idx(
144         int i_ur, int i_oc) const {
145     const int idx = i_ur * jcp.nb_oc_blocking + i_oc;
146     assert(idx < ker_reg_base_idx);
147     return idx;
148 }
149 
150 template <typename Vmm>
vmm_dst(int i_ur,int i_oc) const151 Vmm _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst(int i_ur, int i_oc) const {
152     return Vmm(vmm_dst_idx(i_ur, i_oc));
153 }
154 
155 template <typename F>
iterate(const int nb_oc_block,const int ur_w,const bool mask_tail,const bool force_masking,const F & f)156 static void iterate(const int nb_oc_block, const int ur_w, const bool mask_tail,
157         const bool force_masking, const F &f) {
158     for (int k = 0; k < nb_oc_block; k++) {
159         const bool mask_flag
160                 = force_masking || (mask_tail && k + 1 == nb_oc_block);
161         for (int j = 0; j < ur_w; j++)
162             f(mask_flag, k, j);
163     }
164 }
165 template <typename F>
iterate(const int nb_oc_block,const int ur_w,const F & f)166 static void iterate(const int nb_oc_block, const int ur_w, const F &f) {
167     iterate(nb_oc_block, ur_w, false, false, f);
168 }
169 
170 template <typename Vmm>
apply_postops(int ur_w)171 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::apply_postops(int ur_w) {
172     if (jcp.with_eltwise || jcp.with_binary) {
173         injector_utils::vmm_index_set_t vmm_idxs;
174         if (jcp.with_binary) {
175             binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
176                     rhs_arg_params_tail;
177             const auto temp_offset_reg = this->r12;
178             const auto mask_tail = jcp.oc_without_padding % jcp.simd_w;
179             const bool oc_blk_is_smaller_than_vmm
180                     = jcp.oc_block < isa_simd_width_;
181             iterate(jcp.nb_oc_blocking, ur_w, mask_tail,
182                     oc_blk_is_smaller_than_vmm,
183                     [&](const bool mask_flag, const int k, const int j) {
184                         const int aux_output_l_off
185                                 = get_dst_offset(j, k) / jcp.typesize_out;
186                         const auto vmm_idx = vmm_dst_idx(j, k);
187                         vmm_idxs.emplace(vmm_idx);
188 
189                         rhs_arg_params_tail.vmm_idx_to_oc_elem_off_addr.emplace(
190                                 vmm_idx, ptr[param1 + GET_OFF(oc_l_off)]);
191                         rhs_arg_params_tail.vmm_idx_to_oc_elem_off_val.emplace(
192                                 vmm_idx, k * jcp.oc_block);
193                         rhs_arg_params_tail.vmm_idx_to_out_off_oprnd.emplace(
194                                 vmm_idx, temp_offset_reg);
195                         rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
196                                 vmm_idx, aux_output_l_off);
197                         if (mask_flag)
198                             rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
199                     });
200             rhs_arg_params = rhs_arg_params_tail;
201             rhs_arg_params.vmm_tail_idx_.clear();
202 
203             const injector_utils::register_preserve_guard_t register_guard(
204                     this, {temp_offset_reg});
205             mov(temp_offset_reg, reg_dst);
206             sub(temp_offset_reg, ptr[param1 + GET_OFF(dst_orig)]);
207             shr(temp_offset_reg, std::log2(sizeof(float)));
208 
209             Label postops_done;
210             if (mask_tail || oc_blk_is_smaller_than_vmm) {
211                 Label postops_no_tail;
212                 if (mask_tail) {
213                     test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1);
214                     jz(postops_no_tail, T_NEAR);
215                 }
216                 postops_injector_->compute_vector_range(
217                         vmm_idxs, rhs_arg_params_tail);
218                 jmp(postops_done, T_NEAR);
219                 L(postops_no_tail);
220             }
221             postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
222             L(postops_done);
223 
224         } else {
225             iterate(jcp.nb_oc_blocking, ur_w,
226                     [&](const bool, const int k, const int j) {
227                         vmm_idxs.emplace(vmm_dst_idx(j, k));
228                     });
229             postops_injector_->compute_vector_range(vmm_idxs);
230         }
231     }
232 }
233 
234 template <typename Vmm>
store_dst(int ur_w)235 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::store_dst(int ur_w) {
236     Label store_label;
237     const int oc_tail = jcp.oc_tail;
238     if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
239 
240     if (jcp.with_sum) {
241         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
242             for (int j = 0; j < ur_w; j++) {
243                 // mask only needed for last oc_block
244                 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking;
245                 Vmm vmm = vmm_dst(j, k);
246                 size_t aux_dst_offset = get_dst_offset(j, k);
247                 if (jcp.dst_dt == data_type::bf16) {
248                     vpmovzxwd(may_be_mask_vmm(vmm_prev_dst, mask_flag, true),
249                             make_safe_addr(
250                                     reg_dst, aux_dst_offset, reg_long_offt));
251                     vpslld(vmm_prev_dst, vmm_prev_dst, 16);
252                     vaddps(vmm, vmm_prev_dst);
253                 } else {
254                     vaddps(may_be_mask_vmm(vmm, mask_flag, true),
255                             make_safe_addr(
256                                     reg_dst, aux_dst_offset, reg_long_offt));
257                 }
258             }
259         }
260     }
261 
262     if (jcp.with_bias) {
263         mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
264         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
265             int bias_offset = jcp.typesize_bia * k * jcp.oc_block;
266             for (int j = 0; j < ur_w; j++) {
267                 // mask only needed for last oc_block
268                 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking;
269                 Vmm vmm = vmm_dst(j, k);
270                 if (jcp.bia_dt == data_type::bf16) {
271                     vpmovzxwd(may_be_mask_vmm(vmm_bias, mask_flag, true),
272                             EVEX_compress_addr(reg_bias, bias_offset));
273                     vpslld(vmm_bias, vmm_bias, 16);
274                     vaddps(vmm, vmm_bias);
275                 } else
276                     vaddps(may_be_mask_vmm(vmm, mask_flag, true),
277                             EVEX_compress_addr(reg_bias, bias_offset));
278             }
279         }
280     }
281 
282     apply_postops(ur_w);
283 
284     L(store_label);
285     if (jcp.dst_dt == data_type::f32) {
286         for (int k = 0; k < jcp.nb_oc_blocking; k++)
287             for (int j = 0; j < ur_w; j++) {
288                 Vmm vmm = vmm_dst(j, k);
289                 size_t aux_dst_offset = get_dst_offset(j, k);
290                 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
291                 // mask only needed for last oc_block
292                 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking
293                         && is_dst_layout_nxc();
294                 vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false));
295             }
296     } else if (jcp.dst_dt == data_type::bf16) {
297         if (isa_has_bf16(jcp.isa) && is_dst_layout_nxc()) {
298             // Optimization: use single store instruction for pair of the
299             // nearest vectors along OC dimension
300             for (int j = 0; j < ur_w; j++) {
301                 int k = 0;
302                 for (; k < rnd_dn(jcp.nb_oc_blocking, 2); k += 2) {
303                     Vmm vmm = vmm_dst(j, k);
304                     Vmm vmm_next = vmm_dst(j, k + 1);
305                     size_t aux_dst_offset = get_dst_offset(j, k);
306                     auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
307                     vcvtne2ps2bf16(vmm, vmm_next, vmm);
308                     // mask only needed for last oc_block
309                     bool mask_flag = oc_tail && k + 2 == jcp.nb_oc_blocking;
310                     vmovdqu16(
311                             addr, may_be_mask_vmm(vmm, mask_flag, false, true));
312                 }
313                 if (jcp.nb_oc_blocking % 2 != 0) {
314                     Vmm vmm = vmm_dst(j, k);
315                     auto vmm_down = Vmm_down_t(vmm.getIdx());
316                     size_t aux_dst_offset = get_dst_offset(j, k);
317                     auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
318                     vcvtneps2bf16(vmm_down, vmm);
319                     // for xmm, upper half is zero after conversion to
320                     // bf16, so mask always & mask for tails
321                     bool mask_flag = jcp.simd_w == 4 || oc_tail;
322                     vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
323                 }
324             }
325         } else if (isa_has_bf16(jcp.isa) /* !is_dst_layout_nxc() */) {
326             // Optimization: use single store instruction for pair of the
327             // nearest vectors along WIDTH dimension
328             for (int k = 0; k < jcp.nb_oc_blocking; k++) {
329                 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
330                 for (j = 0; j < n_2bf2ps; j += 2) {
331                     size_t aux_dst_offset = get_dst_offset(j, k);
332                     auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
333 
334                     auto vmm_str = vmm_src(j, jcp.nb_oc_blocking);
335                     vcvtne2ps2bf16(vmm_str, vmm_dst(j + 1, k), vmm_dst(j, k));
336                     vmovups(addr, vmm_str);
337                 }
338                 if (j < ur_w) {
339                     size_t aux_dst_offset = get_dst_offset(j, k);
340 
341                     auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
342                     auto vmm_down_str = vmm_src_down(j, jcp.nb_oc_blocking);
343                     vcvtneps2bf16(vmm_down_str, vmm_dst(j, k));
344                     // for xmm, upper half is zero after conversion to
345                     // bf16, so mask always.
346                     const bool mask_flag = jcp.simd_w == 4;
347                     vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag));
348                 }
349             }
350         } else {
351             for (int k = 0; k < jcp.nb_oc_blocking; k++)
352                 for (int j = 0; j < ur_w; j++) {
353                     Vmm vmm = vmm_dst(j, k);
354                     size_t aux_dst_offset = get_dst_offset(j, k);
355                     auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
356                     auto vmm_down = vmm_src_down(0, jcp.nb_oc_blocking);
357                     bf16_emu_->vcvtneps2bf16(
358                             Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx()));
359                     bool mask_flag = (oc_tail && k + 1 == jcp.nb_oc_blocking
360                                              && is_dst_layout_nxc())
361                             // for xmm, upper half is zero after conversion to
362                             // bf16, so mask always & mask for tails
363                             || jcp.simd_w == 4;
364                     vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
365                 }
366         }
367     } else
368         assert(!"unsupported destination type");
369 }
370 
371 template <typename Vmm>
compute_loop(int ur_w,int pad_l,int pad_r)372 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::compute_loop(
373         int ur_w, int pad_l, int pad_r) {
374     Label kh_label, kd_label;
375     const int ic_tail = jcp.ic_tail;
376     const int ic_step = 2;
377 
378     /* max_src_offset is explicitly used in the 1st convolution.
379      * Set its value so that accessing the double-word memory
380      * referenced by ptr[src_base + offset] is safe whenever
381      *     0 <= offset < max_src_offset
382      *
383      * Note: Since the arguments pad_l, pad_r might not exactly match
384      * with jcp.l_pad and jcp.r_pad respectively so this value needs to be
385      * computed separately for each invocation of the compute_loop.
386      */
387     dim_t max_src_offset = 0;
388     if (jcp.is_1stconv || ic_tail) {
389         for (int ki = 0; ki < jcp.kw; ki++) {
390             int ow_fst = get_ow_start(ki, pad_l);
391             int ow_last = get_ow_end(ur_w, ki, pad_r) - 1;
392             if (ow_fst > ow_last) continue;
393             int ic_last = rnd_up(nstl::min(jcp.ic_block,
394                                          nstl::max(jcp.ic, ic_tail)),
395                                   ic_step)
396                     - ic_step;
397 
398             dim_t src_offset = get_src_offset(
399                     ic_last, filter_w_to_src(ki, ow_last, pad_l));
400             if (src_offset > max_src_offset) max_src_offset = src_offset;
401         }
402     }
403 
404     prepare_dst(ur_w);
405 
406     Label skip_compute_loop;
407     if (jcp.ndims == 5) {
408         mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
409         if ((jcp.dilate_d >= jcp.id)
410                 || (jcp.kd - 1) * (jcp.dilate_d + 1)
411                         < nstl::max(jcp.f_pad, jcp.back_pad)) {
412             cmp(reg_kj, 0);
413             je(skip_compute_loop, T_NEAR);
414         }
415     }
416     mov(reg_kj, reg_kh);
417     if ((jcp.dilate_h >= jcp.ih)
418             || (jcp.kh - 1) * (jcp.dilate_h + 1)
419                     < nstl::max(jcp.t_pad, jcp.b_pad)) {
420         cmp(reg_kj, 0);
421         je(skip_compute_loop, T_NEAR);
422     }
423 
424     // IC loop
425     Label icb_label;
426     mov(reg_ic, jcp.ic);
427     L(icb_label);
428 
429     if (jcp.ndims == 5) {
430         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
431         mov(ptr[rsp + off_reg_ker_], reg_ker);
432         mov(ptr[rsp + off_reg_src_], reg_src);
433 
434         L(kd_label);
435     }
436 
437     mov(aux_reg_src, reg_src);
438     mov(aux_reg_ker, reg_ker);
439 
440     mov(reg_kj, reg_kh);
441 
442     std::vector<Label> ic_tail_jmp(jcp.kw);
443     L(kh_label);
444     {
445         for (int ki = 0; ki < jcp.kw; ki++) {
446             int ow_start = get_ow_start(ki, pad_l);
447             int ow_end = get_ow_end(ur_w, ki, pad_r);
448             for (int ic = 0;
449                     ic < rnd_up(nstl::min(jcp.ic_block, jcp.ic), ic_step);
450                     ic += ic_step) {
451                 if (ic_tail && ic == rnd_up(ic_tail, ic_step)) {
452                     // insert this check at most once per icb, no more.
453                     cmp(reg_ic, ic_tail);
454                     je(ic_tail_jmp[ki], T_NEAR);
455                 }
456                 for (int oi = ow_start; oi < ow_end; oi++) {
457                     dim_t src_offset = get_src_offset(
458                             ic, filter_w_to_src(ki, oi, pad_l));
459                     auto vmm_in = vmm_src(oi, jcp.nb_oc_blocking);
460                     const auto addr_base = EVEX_compress_addr_safe(
461                             aux_reg_src, src_offset, reg_long_offt);
462                     const bool tail_load
463                             = ic_tail && ic == rnd_dn(ic_tail, ic_step);
464                     if (jcp.is_1stconv || tail_load) {
465                         const bool need_single_load
466                                 = (ic + 1 == jcp.ic || ic + 1 == ic_tail);
467                         const bool safe_overstep = (src_offset < max_src_offset)
468                                 && !is_src_layout_nxc();
469 
470                         /* For the comment below, let us define three words
471                          * x_b = ptr[addr_base] and x_s = ptr[addr_strided]
472                          * x_g = ptr[addr_base + 2]
473                          *
474                          * For single load case:
475                          * Without overstep zmm_in register is loaded as
476                          *     [0, x_b, ..., 0, x_b, 0, x_b]
477                          * On the other hand, "with overstep" zmm_in register
478                          * is loaded as
479                          *     [x_g, x_b, ..., x_g, x_b, x_g, x_b]
480                          * where x_g is a garbage word.
481                          *
482                          * Note:
483                          * 1. In single load case with safe_overstep enabled,
484                          * it is implicitly assumed that the element in zmm_wei
485                          * register corresponding to the "garbage value x_g" in
486                          * zmm_in register is zero.
487                          * 2. One can have potential problem when x_g is
488                          * either Inf or NaN since it is multiplied by zero
489                          * in accumulation. But as x_g is a "valid input"
490                          * for different offset so one might assume that x_g is
491                          * neither Inf nor Nan.
492                          *
493                          * For non single load case:
494                          * zmm_in register is loaded as
495                          *     [x_s, x_b, ...., x_s, x_b, x_s, x_b]
496                          */
497                         if (tail_load) {
498                             if (need_single_load) {
499                                 Label mask_load, load_done;
500                                 cmp(reg_ic, ic + ic_step);
501                                 jl(mask_load, T_NEAR);
502                                 vpbroadcastd(vmm_in, addr_base);
503                                 jmp(load_done, T_NEAR);
504                                 L(mask_load);
505                                 vpbroadcastw(vmm_in | odd_load_mask | T_z,
506                                         addr_base);
507                                 L(load_done);
508                             } else {
509                                 vpbroadcastd(vmm_in, addr_base);
510                             }
511                         } else if (need_single_load && !safe_overstep)
512                             vpbroadcastw(
513                                     vmm_in | odd_load_mask | T_z, addr_base);
514                         else if (IMPLICATION(!is_src_layout_nxc(),
515                                          need_single_load && safe_overstep))
516                             vpbroadcastd(vmm_in, addr_base);
517                         else {
518                             const auto addr_strided
519                                     = EVEX_compress_addr_safe(aux_reg_src,
520                                             src_offset + get_src_offset(1, 0),
521                                             reg_long_offt);
522                             vpbroadcastd(vmm_in, addr_base);
523                             vpbroadcastw(vmm_in | even_load_mask, addr_strided);
524                         }
525                     } else {
526                         vpbroadcastd(vmm_in, addr_base);
527                     }
528                 }
529                 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
530                     auto wei_off = get_kernel_offset(kk, ic, ki);
531                     vmovups(vmm_wei,
532                             EVEX_compress_addr_safe(
533                                     aux_reg_ker, wei_off, reg_long_offt));
534                     for (int oi = ow_start; oi < ow_end; oi++) {
535                         auto acc = vmm_dst(oi, kk);
536                         auto src = vmm_src(oi, jcp.nb_oc_blocking);
537                         if (isa_has_bf16(jcp.isa)) {
538                             vdpbf16ps(acc, vmm_wei, src);
539                         } else
540                             bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()),
541                                     Zmm(vmm_wei.getIdx()), Zmm(src.getIdx()));
542                     }
543                 }
544             }
545             L(ic_tail_jmp[ki]);
546         }
547         safe_add(aux_reg_ker, get_kernel_offset(0, 0, 0, 1), reg_long_offt);
548         safe_add(aux_reg_src, get_src_offset(0, filter_h_to_src(1)),
549                 reg_long_offt);
550 
551         dec(reg_kj);
552         cmp(reg_kj, 0);
553         jg(kh_label, T_NEAR);
554     }
555 
556     if (jcp.ndims == 5) {
557         safe_add(reg_src, get_src_offset(0, filter_d_to_src(1)), reg_long_offt);
558         safe_add(reg_ker, get_kernel_offset(0, 0, 0, 0, 1), reg_long_offt);
559         dec(reg_ki);
560         cmp(reg_ki, 0);
561         jg(kd_label, T_NEAR);
562 
563         mov(reg_ker, ptr[rsp + off_reg_ker_]);
564         mov(reg_src, ptr[rsp + off_reg_src_]);
565     }
566 
567     // End of IC Loop
568     dim_t src_step = get_src_offset(jcp.ic_block, 0);
569     const size_t ker_step = get_kernel_offset(0, jcp.ic_block, 0);
570     safe_add(reg_src, src_step, reg_long_offt);
571     safe_add(reg_ker, ker_step, reg_long_offt);
572 
573     sub(reg_ic, jcp.ic_block);
574     cmp(reg_ic, 0);
575     jg(icb_label, T_NEAR);
576 
577     safe_sub(reg_src, src_step * jcp.nb_ic, reg_long_offt);
578     safe_sub(reg_ker, ker_step * jcp.nb_ic, reg_long_offt);
579 
580     L(skip_compute_loop);
581     store_dst(ur_w);
582 }
583 
584 template <typename Vmm>
generate()585 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::generate() {
586     int iw = jcp.iw;
587     int ow = jcp.ow;
588     int ow_block = jcp.ow_block;
589     int nb_ow = jcp.nb_ow;
590     int kw = jcp.kw;
591     int l_pad = jcp.l_pad;
592     int ur_w = jcp.ur_w;
593     int ur_w_tail = jcp.ur_w_tail;
594     int stride_w = jcp.stride_w;
595 
596     auto src_shift = get_src_offset(0, filter_w_to_src(0, ur_w));
597     auto dst_shift = get_dst_offset(ur_w, 0);
598 
599     auto src_shift_pad = get_src_offset(0, filter_w_to_src(0, ur_w, l_pad));
600     auto src_shift_pad_second_block
601             = get_src_offset(0, filter_w_to_src(0, 0, l_pad));
602 
603     preamble();
604     if (jcp.ndims == 5) sub(rsp, stack_space_needed_);
605 
606     if (jcp.is_1stconv || jcp.ic_tail) {
607         Xbyak::Reg64 reg_alt_mask = r8;
608         const auto odd_mask = size_t {0x5555555555555555};
609         const auto even_mask = size_t {0xaaaaaaaaaaaaaaaa};
610         mov(reg_alt_mask, odd_mask);
611         kmovq(odd_load_mask, reg_alt_mask);
612         mov(reg_alt_mask, even_mask);
613         kmovq(even_load_mask, reg_alt_mask);
614     }
615 
616     if (jcp.simd_w == 4) {
617         auto reg_tail_32 = reg_oc.cvt32();
618         mov(reg_tail_32, (1 << jcp.simd_w) - 1);
619         kmovb(k_oc_tail_mask, reg_tail_32);
620     }
621 
622     if (jcp.oc_tail) {
623         Label done;
624         // dummy mask all 1's
625         if (jcp.simd_w != 4) { // simd_w == 4, has its dummy mask set already
626             kxnord(k_oc_tail_mask, k_oc_tail_mask, k_oc_tail_mask);
627         }
628         // To account for special store optimization, where two oc_blocks are
629         // combined with one single write, extend the mask for 32bits (32 bf16s)
630         const bool need_extended_mask = jcp.dst_dt == data_type::bf16
631                 && isa_has_bf16(jcp.isa) && jcp.nb_oc_blocking > 1;
632         if (need_extended_mask)
633             kxnord(k_oc_tail_mask_extended, k_oc_tail_mask_extended,
634                     k_oc_tail_mask_extended);
635 
636         test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1);
637         jz(done, T_NEAR);
638         auto reg_tail_32 = reg_oc.cvt32();
639         mov(reg_tail_32, (1 << jcp.oc_tail) - 1);
640         kmovd(k_oc_tail_mask, reg_tail_32);
641         kmovd(postops_mask, reg_tail_32);
642         if (need_extended_mask) {
643             mov(reg_tail_32, (1 << (jcp.oc_tail + jcp.simd_w)) - 1);
644             kmovd(k_oc_tail_mask_extended, reg_tail_32);
645         }
646         L(done);
647     } else if (jcp.with_binary)
648         if (jcp.oc_block != isa_simd_width_) {
649             const int mask = (1 << jcp.oc_block) - 1;
650             const Reg32 regw_tmp = reg_oi.cvt32();
651             mov(regw_tmp, mask);
652             kmovd(postops_mask, regw_tmp);
653         }
654 
655     mov(reg_src, ptr[param1 + GET_OFF(src)]);
656     mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
657     mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
658     mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
659 
660     int r_pad = nstl::max(0, jcp.r_pad);
661     int n_oi = ow / ur_w;
662     int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w,
663             calculate_extended_filter_size(kw, jcp.dilate_w));
664 
665     if (!is_ow_threading_on(jcp)) {
666         // ow is being processed as a whole - with left and right paddings
667         if (r_pad1 > 0) n_oi--;
668 
669         xor_(reg_oi, reg_oi);
670         if (ow == ur_w) {
671             compute_loop(ur_w, l_pad, r_pad);
672         } else {
673             if (n_oi == 0) {
674                 compute_loop(ur_w, l_pad, r_pad1);
675                 add(reg_src, src_shift_pad);
676                 add(reg_dst, dst_shift);
677                 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
678             } else {
679                 if (l_pad > 0) {
680                     compute_loop(ur_w, l_pad, 0);
681                     add(reg_src, src_shift_pad);
682                     add(reg_dst, dst_shift);
683                     inc(reg_oi);
684                 }
685                 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
686                     Label ow_loop_label;
687                     L(ow_loop_label);
688                     {
689                         compute_loop(ur_w, 0, 0);
690                         add(reg_src, src_shift);
691                         add(reg_dst, dst_shift);
692 
693                         inc(reg_oi);
694                         cmp(reg_oi, n_oi);
695                         jl(ow_loop_label, T_NEAR);
696                     }
697                 }
698                 if (r_pad1 > 0) {
699                     compute_loop(ur_w, 0, r_pad1);
700                     add(reg_src, src_shift);
701                     add(reg_dst, dst_shift);
702                 }
703                 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
704             }
705         }
706     } else {
707         // ow block is only processed.
708         // Number of block is passed as parameter owb,
709         // and padding processing depends on this number.
710 
711         Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
712         Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
713 
714         assert(ow_block % ur_w == 0);
715         int n_oi_not_last_ow_block = ow_block / ur_w;
716         // to simplify code (and general regs usage),
717         // size of ow block must be >= 2 * ur_w
718         assert(n_oi_not_last_ow_block > 1);
719         int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
720         int n_oi_first_ow_block = n_oi_not_last_ow_block;
721 
722         int n_oi_last_ow_block = (ow - ow_block * (nb_ow - 1)) / ur_w;
723 
724         // prepare right padding
725         bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
726         bool first_ow_block_padded
727                 = next_last_ow_block_padded && jcp.nb_ow == 2;
728         bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
729 
730         if (last_ow_block_padded)
731             n_oi_last_ow_block--;
732         else if (first_ow_block_padded)
733             n_oi_first_ow_block--;
734         else if (next_last_ow_block_padded)
735             n_oi_next_last_ow_block--;
736 
737         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
738         cmp(reg_owb, 0); // is that the first ow-block ?
739         jg(middle_ow_blocks_label, T_NEAR);
740 
741         // the first ow block, compute left padding
742 
743         mov(reg_oi, n_oi_first_ow_block);
744         if (l_pad > 0) {
745             compute_loop(ur_w, l_pad, 0);
746             add(reg_src, src_shift_pad);
747             add(reg_dst, dst_shift);
748             dec(reg_oi);
749         }
750         jmp(oi_loop_label, T_NEAR);
751 
752         // middle or last ow block entry
753 
754         L(middle_ow_blocks_label);
755 
756         if (l_pad > 0) {
757             // just to consider left padding, not compute
758             add(reg_src, src_shift_pad_second_block);
759         }
760 
761         // set number of iteration for oi-loop
762         cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
763         mov(reg_oi, n_oi_last_ow_block);
764         je(oi_loop_label, T_NEAR);
765         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
766         mov(reg_oi, n_oi_next_last_ow_block);
767         je(oi_loop_label, T_NEAR);
768         mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
769 
770         // oi loop w/o padding
771         L(oi_loop_label);
772         L(oi_loop_start_label);
773         cmp(reg_oi, 0);
774         jle(oi_loop_end_label, T_NEAR);
775 
776         compute_loop(ur_w, 0, 0);
777         add(reg_src, src_shift);
778         add(reg_dst, dst_shift);
779         dec(reg_oi);
780         jmp(oi_loop_start_label, T_NEAR);
781         L(oi_loop_end_label);
782 
783         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
784 
785         cmp(reg_owb, 0); // first ow-block ?
786         if (first_ow_block_padded) {
787             je(last_oi_label, T_NEAR);
788         } else {
789             je(end_label, T_NEAR);
790         }
791         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
792         jl(end_label, T_NEAR);
793         if (next_last_ow_block_padded) {
794             je(last_oi_label, T_NEAR);
795         } else {
796             je(end_label, T_NEAR);
797         }
798         // that is last block
799         if (!last_ow_block_padded) { jmp(tail_label, T_NEAR); }
800 
801         // last oi block with right padding
802         L(last_oi_label);
803         compute_loop(ur_w, 0, r_pad1);
804         add(reg_src, src_shift);
805         add(reg_dst, dst_shift);
806 
807         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
808         cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
809         jl(end_label, T_NEAR);
810 
811         L(tail_label);
812         if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
813         L(end_label);
814     }
815 
816     if (jcp.ndims == 5) add(rsp, stack_space_needed_);
817     postamble();
818 
819     if (jcp.with_eltwise) postops_injector_->prepare_table();
820 }
821 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp)822 void jit_avx512_core_bf16_fwd_kernel::init_scratchpad(
823         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
824     using namespace memory_tracking::names;
825     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
826         assert(jcp.ngroups == 1);
827         scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
828     }
829 }
830 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr,int nthreads)831 status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
832         const convolution_desc_t &cd, memory_desc_t &src_md,
833         memory_desc_t &weights_md, memory_desc_t &dst_md,
834         memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) {
835 
836     using namespace prop_kind;
837 
838     const memory_desc_wrapper src_d(&src_md);
839     const memory_desc_wrapper weights_d(&weights_md);
840     const memory_desc_wrapper dst_d(&dst_md);
841     const memory_desc_wrapper bias_d(&bias_md);
842 
843     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
844     int ndims = src_d.ndims();
845 
846     jcp = zero<decltype(jcp)>();
847     jcp.nthr = nthreads;
848     jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
849                                         : bf16_emulation_t::get_isa();
850     jcp.ver = ver_vnni;
851     jcp.ndims = ndims;
852     jcp.prop_kind = cd.prop_kind;
853     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
854     jcp.mb = src_d.dims()[0];
855     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
856     jcp.oc_without_padding = jcp.oc;
857     jcp.ic = src_d.dims()[1] / jcp.ngroups;
858     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
859     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
860     jcp.iw = src_d.dims()[ndims - 1];
861     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
862     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
863     jcp.ow = dst_d.dims()[ndims - 1];
864     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
865     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
866     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
867     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
868     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
869     jcp.l_pad = cd.padding[0][ndims - 3];
870     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
871     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
872     jcp.stride_w = cd.strides[ndims - 3];
873     jcp.dst_dt = dst_d.data_type();
874 
875     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
876     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
877     jcp.dilate_w = cd.dilates[ndims - 3];
878 
879     jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
880 
881     jcp.typesize_in = types::data_type_size(src_d.data_type());
882     jcp.typesize_out = types::data_type_size(dst_d.data_type());
883 
884     jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
885     jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
886 
887     int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
888     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
889     int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
890     jcp.r_pad = calculate_end_padding(
891             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
892     jcp.b_pad = calculate_end_padding(
893             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
894     jcp.back_pad = calculate_end_padding(
895             jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
896     bool kernel_outside_src = false || ext_kw <= jcp.l_pad
897             || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
898             || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
899     if (kernel_outside_src) return status::unimplemented;
900 
901     const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
902     const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
903     const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
904     const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
905     const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
906     auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c,
907             dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx);
908     auto curr_dst_tag = dst_d.matches_one_of_tag(
909             dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
910     bool is_data_layout_nxc
911             = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
912     jcp.is_1stconv = is_1stconv(jcp);
913 
914     const int regs = isa_has_bf16(jcp.isa) ? 31 /* expl_bcast case */ : 26;
915     const bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc;
916 
917     jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
918 
919     const bool ok_to_try_lower_zmm = true
920             && IMPLICATION(is_data_layout_nxc,
921                     jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w
922                             && jcp.ngroups > 1)
923             && !jcp.is_1stconv && !ok_to_pad_channels
924             && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
925 
926     if (ok_to_try_lower_zmm) {
927         for (auto simd : {8, 4}) {
928             if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
929                 jcp.simd_w = simd;
930                 break;
931             }
932         }
933     }
934 
935     jcp.oc_block = jcp.simd_w;
936     jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
937 
938     if (ok_to_pad_channels) {
939         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
940         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
941     }
942 
943     if (!IMPLICATION(!is_data_layout_nxc,
944                 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
945         return status::unimplemented;
946 
947     format_tag_t src_tag, dst_tag, wei_tag;
948 
949     if (jcp.simd_w == 8) {
950         assert(with_groups);
951         dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
952         wei_tag = pick(ndims - 3, gOIw4i8o2i, gOIhw4i8o2i, gOIdhw4i8o2i);
953     } else if (jcp.simd_w == 4) {
954         assert(with_groups);
955         dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
956         wei_tag = pick(ndims - 3, gOIw2i4o2i, gOIhw2i4o2i, gOIdhw2i4o2i);
957     } else if (jcp.is_1stconv) {
958         dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
959         src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx;
960         wei_tag = pick(2 * ndims - 6 + with_groups, OwI16o2i, gOwI16o2i,
961                 OhwI16o2i, gOhwI16o2i, OdhwI16o2i, gOdhwI16o2i);
962     } else {
963         dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
964         wei_tag = pick(2 * ndims - 6 + with_groups, OIw8i16o2i, gOIw8i16o2i,
965                 OIhw8i16o2i, gOIhw8i16o2i, OIdhw8i16o2i, gOIdhw8i16o2i);
966     }
967 
968     if (src_md.format_kind == format_kind::any)
969         CHECK(memory_desc_init_by_tag(src_md, src_tag));
970     else if (curr_src_tag != src_tag)
971         return status::unimplemented;
972     jcp.src_tag = src_tag;
973 
974     if (dst_md.format_kind == format_kind::any)
975         CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
976     else if (curr_dst_tag != dst_tag)
977         return status::unimplemented;
978     jcp.dst_tag = dst_tag;
979 
980     if (weights_md.format_kind == format_kind::any) {
981         CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
982         jcp.wei_tag = wei_tag;
983     } else {
984         jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
985         if (jcp.wei_tag != wei_tag) return status::unimplemented;
986     }
987 
988     if (jcp.with_bias) {
989         if (bias_d.format_kind() == format_kind::any)
990             CHECK(memory_desc_init_by_tag(bias_md, x));
991     }
992 
993     jcp.aligned_threads = 0;
994 
995     bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
996             && jcp.oc <= dst_d.padded_dims()[1]
997             && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
998             && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
999     if (!args_ok) return status::unimplemented;
1000 
1001     const auto &post_ops = attr.post_ops_;
1002     jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
1003     const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
1004     jcp.with_eltwise = eltwise_ind != -1;
1005     if (jcp.with_eltwise) {
1006         jcp.eltwise = post_ops.entry_[eltwise_ind].eltwise;
1007         if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1008     }
1009     const int binary_ind = post_ops.find(primitive_kind::binary);
1010     jcp.with_binary = binary_ind != -1;
1011 
1012     jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1013     if (is_data_layout_nxc)
1014         jcp.oc_tail = jcp.oc % jcp.simd_w;
1015     else
1016         jcp.oc_tail = jcp.with_binary ? jcp.oc_without_padding % jcp.simd_w : 0;
1017 
1018     jcp.post_ops = post_ops;
1019 
1020     using namespace injector;
1021     static constexpr bool sum_at_pos_0_only = true;
1022     static constexpr bool sum_requires_scale_one = true;
1023     const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
1024             jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
1025             {broadcasting_strategy_t::scalar,
1026                     broadcasting_strategy_t::per_oc}});
1027     if (!post_ops_ok_) return status::unimplemented;
1028 
1029     jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
1030     jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
1031     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1032 
1033     jcp.kernel_kind = expl_bcast;
1034     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
1035     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
1036         int ur_w = regs / (jcp.nb_oc_blocking + 1);
1037         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
1038                 && (jcp.l_pad <= ur_w
1039                         && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
1040             break;
1041     }
1042 
1043     jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
1044     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1045     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1046 
1047     jcp.ow_block = jcp.ow;
1048     if (is_ow_threading_available(jcp)) {
1049         const int L1_part = platform::get_per_core_cache_size(1) * 5 / 8;
1050         int size_src_chunk = jcp.typesize_in * jcp.ic_block * jcp.ur_w;
1051         int size_dst_chunk = jcp.typesize_out * jcp.oc_block
1052                 * jcp.nb_oc_blocking * jcp.ur_w;
1053         int size_wei_chunk = jcp.typesize_in * jcp.oc_block * jcp.ic_block
1054                 * jcp.nb_oc_blocking * jcp.kw;
1055         int nurw = (L1_part - size_wei_chunk)
1056                 / (size_dst_chunk + size_src_chunk);
1057         // current design of generate() requires ow_block >= 2 * ur_w
1058         jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
1059     }
1060     jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1061 
1062     int r_pad_no_tail = nstl::max(0,
1063             calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
1064                     jcp.stride_w, ext_kw));
1065     if (jcp.l_pad > jcp.ur_w || r_pad_no_tail > jcp.ur_w)
1066         return status::unimplemented;
1067 
1068     /* adjust the thread decomposition
1069      * to improve the perf for small problem size
1070      * the threshold L1_cache_size/factor and the factor is empirical
1071      * simply set the thread to 4 for now
1072      * TODO: Add get_thr_eff func to get optimal thread number */
1073 
1074     size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh
1075             * jcp.kw * jcp.kd;
1076     size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1077             * jcp.iw * jcp.id;
1078     size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1079             * jcp.ow * jcp.od;
1080     size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1081     const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1082 
1083     // The factor for 1d=1, 2d=2, 3d=4;
1084     int factor = nstl::max(1, (2 * (ndims - 3)));
1085     if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) {
1086         jcp.nthr = nstl::min(jcp.nthr, 4);
1087     }
1088 
1089     pick_loop_order(jcp);
1090 
1091     return status::success;
1092 }
1093 
1094 template <typename Vmm>
prepare_output(int ur_w)1095 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::prepare_output(int ur_w) {
1096     for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1097         for (int j = 0; j < ur_w; j++) {
1098             Vmm vmm = vmm_dsrc(j, k);
1099             vpxord(vmm, vmm, vmm);
1100         }
1101     }
1102 }
1103 
1104 template <typename Vmm>
store_output(int ur_w)1105 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::store_output(int ur_w) {
1106     if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
1107     const int ic_tail = jcp.ic_tail;
1108 
1109     if (jcp.dst_dt == data_type::f32) {
1110         for (int k = 0; k < jcp.nb_ic_blocking; k++)
1111             for (int j = 0; j < ur_w; j++) {
1112                 Vmm vmm = vmm_dsrc(j, k);
1113                 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1114                 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1115                 // mask only needed for last ic_block
1116                 bool mask_flag = ic_tail && k + 1 == jcp.nb_ic_blocking
1117                         && is_dsrc_layout_nxc();
1118                 vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false));
1119             }
1120     } else if (jcp.dst_dt == data_type::bf16) {
1121         if (isa_has_bf16(jcp.isa) && is_ddst_layout_nxc()) {
1122             // Optimization: use single store instruction for pair of the
1123             // nearest vectors along IC dimension
1124             for (int j = 0; j < ur_w; j++) {
1125                 int k = 0;
1126                 for (; k < rnd_dn(jcp.nb_ic_blocking, 2); k += 2) {
1127                     Vmm vmm = vmm_dsrc(j, k);
1128                     Vmm vmm_next = vmm_dsrc(j, k + 1);
1129                     size_t aux_dsrc_offset = get_diff_src_offset(j, k);
1130                     auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset);
1131                     vcvtne2ps2bf16(vmm, vmm_next, vmm);
1132                     bool mask_flag = ic_tail && k + 2 == jcp.nb_ic_blocking;
1133                     vmovdqu16(
1134                             addr, may_be_mask_vmm(vmm, mask_flag, false, true));
1135                 }
1136                 if (jcp.nb_ic_blocking % 2 != 0) {
1137                     Vmm vmm = vmm_dsrc(j, k);
1138                     auto vmm_down = Vmm_down_t(vmm.getIdx());
1139                     size_t aux_dsrc_offset = get_diff_src_offset(j, k);
1140                     auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset);
1141                     vcvtneps2bf16(vmm_down, vmm);
1142                     // for xmm, upper half is zero after conversion to
1143                     // bf16, so mask always & mask for tails
1144                     bool mask_flag = jcp.simd_w == 4 || ic_tail;
1145                     vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
1146                 }
1147             }
1148         } else if (isa_has_bf16(jcp.isa) /* && !is_ddst_layout_nxc() */) {
1149             // Optimization: use single store instruction for pair of the
1150             // nearest vectors along WIDTH dimension
1151             int store_idx = 0;
1152             const int max_regs = 32;
1153             const int free_regs_start_idx = jcp.ur_w * jcp.nb_ic_blocking;
1154             const int num_regs_available = max_regs - free_regs_start_idx;
1155             int reg_idx = 0;
1156             for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1157                 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
1158                 for (j = 0; j < n_2bf2ps; j += 2) {
1159                     reg_idx = free_regs_start_idx
1160                             + store_idx % num_regs_available;
1161                     assert(reg_idx < max_regs);
1162                     size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1163                     auto addr
1164                             = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1165 
1166                     auto vmm_str = Vmm(reg_idx);
1167                     vcvtne2ps2bf16(vmm_str, vmm_dsrc(j + 1, k), vmm_dsrc(j, k));
1168                     vmovups(addr, vmm_str);
1169                     store_idx++;
1170                 }
1171                 if (j < ur_w) {
1172                     reg_idx = free_regs_start_idx
1173                             + store_idx % num_regs_available;
1174                     assert(reg_idx < max_regs);
1175 
1176                     size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1177                     auto addr
1178                             = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1179                     auto vmm_down_str = Vmm_down_t(reg_idx);
1180                     vcvtneps2bf16(vmm_down_str, vmm_dsrc(j, k));
1181                     // for xmm, upper half is zero after conversion to
1182                     // bf16, so mask always.
1183                     bool mask_flag = jcp.simd_w == 4;
1184                     vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag));
1185                     store_idx++;
1186                 }
1187             }
1188         } else {
1189             for (int k = 0; k < jcp.nb_ic_blocking; k++)
1190                 for (int j = 0; j < ur_w; j++) {
1191                     Vmm vmm = vmm_dsrc(j, k);
1192                     size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1193                     auto addr
1194                             = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1195                     auto vmm_down = vmm_ddst_down(0);
1196                     bf16_emu_->vcvtneps2bf16(
1197                             Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx()));
1198                     bool mask_flag = (ic_tail && k + 1 == jcp.nb_ic_blocking
1199                                              && is_dsrc_layout_nxc())
1200                             // for xmm, upper half is zero after conversion to
1201                             // bf16, so mask always & mask for tails
1202                             || jcp.simd_w == 4;
1203                     vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
1204                 }
1205         }
1206     } else
1207         assert(!"unsupported diff_src type");
1208 }
1209 
1210 template <typename Vmm>
compute_loop(int ur_w,int l_overflow,int r_overflow)1211 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::compute_loop(
1212         int ur_w, int l_overflow, int r_overflow) {
1213     int kw = jcp.kw;
1214     int dilate_w = jcp.dilate_w + 1;
1215     int stride_w = jcp.stride_w;
1216     int stride_h = jcp.stride_h;
1217     const int oc_tail = jcp.oc_tail;
1218     Label kh_label, skip_compute_label;
1219 
1220     prepare_output(ur_w);
1221 
1222     if (jcp.ndims == 5) {
1223         mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1224         cmp(reg_ki, 0);
1225         jle(skip_compute_label, T_NEAR);
1226     }
1227 
1228     cmp(reg_kh, 0);
1229     jle(skip_compute_label, T_NEAR);
1230 
1231     // OC loop
1232     Label ocb_label;
1233     mov(reg_oc, jcp.oc);
1234     L(ocb_label);
1235 
1236     if (jcp.ndims < 5) {
1237         mov(aux_reg_dst, reg_dst);
1238         mov(aux_reg_ker, reg_ker);
1239     }
1240     Label kd_label;
1241     if (jcp.ndims == 5) {
1242         mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1243         mov(aux_reg_dst_d, reg_dst);
1244         mov(aux_reg_ker_d, reg_ker);
1245 
1246         L(kd_label);
1247         mov(aux_reg_dst, aux_reg_dst_d);
1248         mov(aux_reg_ker, aux_reg_ker_d);
1249     }
1250 
1251     std::vector<Label> oc_tail_jmp(jcp.kw);
1252     mov(reg_kj, reg_kh);
1253     L(kh_label);
1254     {
1255         for (int ki = 0; ki < kw; ki++) {
1256             int jj_start = get_iw_start(ki, l_overflow);
1257             int jj_end = get_iw_end(ur_w, ki, r_overflow);
1258             const int ref_jj_start
1259                     = nstl::max(0, l_overflow - (kw - 1 - ki) * dilate_w);
1260             const int ref_jj_end
1261                     = ur_w - nstl::max(0, r_overflow - ki * dilate_w);
1262             assert(IMPLICATION(stride_w == 1,
1263                     jj_start == ref_jj_start && jj_end == ref_jj_end));
1264             UNUSED(ref_jj_start);
1265             UNUSED(ref_jj_end);
1266             const int oc_step = 2;
1267             for (int oc = 0;
1268                     oc < rnd_up(nstl::min(jcp.oc_block, jcp.oc), oc_step);
1269                     oc += oc_step) {
1270                 if (oc_tail && oc == rnd_up(oc_tail, oc_step)) {
1271                     cmp(reg_oc, oc_tail);
1272                     je(oc_tail_jmp[ki], T_NEAR);
1273                 }
1274                 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1275                     assert((jj + jcp.l_pad - ki * dilate_w) % stride_w == 0);
1276                     int ow_idx = (jj + jcp.l_pad - ki * dilate_w) / stride_w;
1277                     auto aux_ddst_offset = get_diff_dst_offset(ow_idx, oc);
1278                     auto ddst = vmm_ddst(jj / stride_w);
1279                     const bool tail_load = oc_tail && oc == rnd_dn(oc_tail, 2);
1280                     const bool need_single_load = oc + 1 == oc_tail;
1281 
1282                     if (tail_load && need_single_load) {
1283                         Label mask_load, load_done;
1284                         cmp(reg_oc, oc + 2);
1285                         jl(mask_load, T_NEAR);
1286                         vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1287                         jmp(load_done, T_NEAR);
1288                         L(mask_load);
1289                         // We broadcast w here. As the weights are zero-padded
1290                         // at oc + 1, vdpbf16ps({0, w}, {dst, dst}) is okay.
1291                         vpbroadcastw(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1292                         L(load_done);
1293                     } else {
1294                         vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1295                     }
1296                 }
1297                 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1298                     size_t aux_kernel_offset = get_kernel_offset(kk, oc, ki);
1299                     vmovups(vmm_wei,
1300                             EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1301 
1302                     for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1303                         auto ddst = vmm_ddst(jj / stride_w);
1304                         auto acc = vmm_dsrc(jj, kk);
1305 
1306                         if (isa_has_bf16(jcp.isa)) {
1307                             vdpbf16ps(acc, vmm_wei, ddst);
1308                         } else
1309                             bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()),
1310                                     Zmm(vmm_wei.getIdx()), Zmm(ddst.getIdx()));
1311                     }
1312                 }
1313             }
1314             L(oc_tail_jmp[ki]);
1315         }
1316 
1317         add(aux_reg_ker, get_kernel_offset(0, 0, 0, stride_h));
1318         sub(aux_reg_dst, get_diff_dst_offset(filter_h_to_dst(1), 0));
1319 
1320         dec(reg_kj);
1321         cmp(reg_kj, 0);
1322         jg(kh_label, T_NEAR);
1323     }
1324 
1325     if (jcp.ndims == 5) {
1326         sub(aux_reg_dst_d, get_diff_dst_offset(filter_d_to_dst(1), 0));
1327         add(aux_reg_ker_d, get_kernel_offset(0, 0, 0, 0, jcp.stride_d));
1328 
1329         dec(reg_ki);
1330         cmp(reg_ki, 0);
1331         jg(kd_label, T_NEAR);
1332     }
1333 
1334     // End of OC Loop
1335     auto diff_dst_step = get_diff_dst_offset(0, 0, 1);
1336     auto ker_step = get_kernel_offset(0, jcp.oc_block, 0);
1337     add(reg_dst, diff_dst_step);
1338     add(reg_ker, ker_step);
1339 
1340     sub(reg_oc, jcp.oc_block);
1341     cmp(reg_oc, 0);
1342     jg(ocb_label, T_NEAR);
1343 
1344     sub(reg_dst, diff_dst_step * jcp.nb_oc);
1345     sub(reg_ker, ker_step * jcp.nb_oc);
1346 
1347     L(skip_compute_label);
1348     store_output(ur_w);
1349 }
1350 
1351 template <typename Vmm>
generate()1352 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::generate() {
1353     int iw = jcp.iw;
1354     int kw = jcp.kw;
1355     int ur_w = jcp.ur_w;
1356     int nb_iw = jcp.nb_iw;
1357     int iw_block = jcp.iw_block;
1358     int ur_w_tail = jcp.ur_w_tail;
1359     int dilate_w = jcp.dilate_w + 1;
1360     int stride_w = jcp.stride_w;
1361 
1362     const auto dst_shift = get_diff_dst_offset(ur_w / stride_w, 0);
1363     const auto src_shift = get_diff_src_offset(ur_w, 0);
1364 
1365     preamble();
1366 
1367     if (jcp.simd_w == 4) {
1368         Reg32 reg_tail_32 = reg_oc.cvt32();
1369         mov(reg_tail_32, (1 << jcp.simd_w) - 1);
1370         kmovb(k_ic_tail_mask, reg_tail_32);
1371     }
1372 
1373     if (jcp.ic_tail) {
1374         Label done;
1375         // dummy mask all 1's
1376         if (jcp.simd_w != 4)
1377             kxnord(k_ic_tail_mask, k_ic_tail_mask, k_ic_tail_mask);
1378         // To account for special store optimization, where two ic_blocks are
1379         // combined with one single write, extend the mask for 32bits (32 bf16s)
1380         const bool need_extended_mask
1381                 = isa_has_bf16(jcp.isa) && jcp.nb_ic_blocking > 1;
1382         if (need_extended_mask)
1383             kxnord(k_ic_tail_mask_extended, k_ic_tail_mask_extended,
1384                     k_ic_tail_mask_extended);
1385 
1386         test(byte[param1 + GET_OFF(load_work)], jcp.ic_block - 1);
1387         jz(done, T_NEAR);
1388         Reg32 reg_tail_32 = reg_ic.cvt32();
1389         mov(reg_tail_32, (1 << jcp.ic_tail) - 1);
1390         kmovd(k_ic_tail_mask, reg_tail_32);
1391         if (need_extended_mask) {
1392             mov(reg_tail_32, (1 << (jcp.ic_tail + jcp.simd_w)) - 1);
1393             kmovd(k_ic_tail_mask_extended, reg_tail_32);
1394         }
1395         L(done);
1396     }
1397 
1398     mov(reg_src, ptr[param + GET_OFF(src)]);
1399     mov(reg_dst, ptr[param + GET_OFF(dst)]);
1400     mov(reg_ker, ptr[param + GET_OFF(filt)]);
1401 
1402     mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
1403 
1404     int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
1405     int r_overflow = nstl::max(
1406             0, ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad)) / stride_w);
1407     int r_overflow1 = nstl::max(0,
1408             ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad + ur_w_tail))
1409                     / stride_w);
1410 
1411     int body_l_overflow = 0, body_r_overflow = 0;
1412     int n_oi = iw / ur_w;
1413     int head_n_oi = 0, body_n_oi = 0, pretail_n_oi = 0, tail_n_oi = 0;
1414     int head_thread = 0, pretail_thread = 0, tail_thread = 0;
1415     bool threaded = is_iw_threading_on(jcp);
1416     Label head_label, body_label, pretail_label, tail_label, end_label;
1417     assert(n_oi > 0);
1418 
1419     if (r_overflow1 > 0) n_oi--;
1420     if (l_overflow > 0) n_oi--;
1421     if (n_oi < 0) {
1422         // l_overflow and r_overflow1 are handled in the same compute_loop.
1423         // Perform one iteration of body handling l_overflow and r_overflow1.
1424         body_l_overflow = l_overflow;
1425         body_r_overflow = r_overflow1;
1426         n_oi = 1;
1427         l_overflow = 0;
1428         r_overflow1 = 0;
1429     }
1430 
1431     if (!threaded) {
1432         if (n_oi > 1) { mov(reg_oi, n_oi); }
1433     } else {
1434         // Setup for threaded code generation, and jump into the correct
1435         // portion of code for execution.
1436         head_thread = 0;
1437         tail_thread = nb_iw - 1;
1438         pretail_thread = tail_thread;
1439 
1440         int base_n_oi = iw_block / ur_w;
1441         head_n_oi = l_overflow > 0 ? base_n_oi - 1 : base_n_oi;
1442         tail_n_oi = (iw - iw_block * (nb_iw - 1)) / ur_w;
1443         pretail_n_oi = tail_n_oi;
1444         if (r_overflow1 > 0) {
1445             if (tail_n_oi > 0) {
1446                 pretail_n_oi--;
1447                 tail_n_oi = pretail_n_oi;
1448             } else {
1449                 // pretail_thread and tail_thread are different
1450                 pretail_n_oi = base_n_oi - 1;
1451                 pretail_thread = tail_thread - 1;
1452             }
1453             if (head_thread == pretail_thread) {
1454                 head_n_oi--;
1455                 pretail_n_oi = 0;
1456                 tail_n_oi = 0;
1457             }
1458         }
1459         body_n_oi = (head_thread < pretail_thread - 1) ? base_n_oi : 0;
1460 
1461         // n_oi is used to determine how much control flow in the body portion
1462         // of the code needs generated. As such, n_oi needs to be set to the
1463         // maximum number of iterations it will be used the body code section.
1464         n_oi = nstl::max(body_n_oi, head_n_oi);
1465         n_oi = nstl::max(n_oi, pretail_n_oi);
1466 
1467         assert(iw_block % ur_w == 0);
1468         mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]);
1469 
1470         if (head_n_oi != 0) mov(reg_oi, head_n_oi);
1471         cmp(reg_iwb, head_thread);
1472         je(head_label, T_NEAR);
1473 
1474         cmp(reg_iwb, pretail_thread);
1475         if (pretail_n_oi == 0) {
1476             je(pretail_label, T_NEAR);
1477         } else {
1478             mov(reg_oi, pretail_n_oi);
1479             je(body_label, T_NEAR);
1480         }
1481         if (pretail_thread != tail_thread) {
1482             cmp(reg_iwb, tail_thread);
1483             je(tail_label, T_NEAR);
1484         }
1485         if (body_n_oi != 0) {
1486             mov(reg_oi, body_n_oi);
1487             jmp(body_label, T_NEAR);
1488         } else {
1489             jmp(end_label, T_NEAR);
1490         }
1491     }
1492     L(head_label);
1493     if (l_overflow > 0) {
1494         compute_loop(ur_w, l_overflow, 0);
1495         if (threaded && head_n_oi == 0 && head_thread != pretail_thread)
1496             jmp(end_label, T_NEAR);
1497         add(reg_src, src_shift);
1498         add(reg_dst, dst_shift);
1499     }
1500     L(body_label);
1501     if (n_oi > 0) {
1502         Label ow_loop_label;
1503         L(ow_loop_label);
1504         {
1505             compute_loop(ur_w, body_l_overflow, body_r_overflow);
1506             if (n_oi > 1 || r_overflow1 > 0 || ur_w_tail != 0) {
1507                 add(reg_src, src_shift);
1508                 add(reg_dst, dst_shift);
1509             }
1510             if (n_oi > 1) {
1511                 sub(reg_oi, 1);
1512                 jg(ow_loop_label, T_NEAR);
1513             }
1514         }
1515     }
1516     if (threaded) {
1517         cmp(reg_iwb, pretail_thread);
1518         jne(end_label, T_NEAR);
1519     }
1520     L(pretail_label);
1521     if (r_overflow1 > 0) {
1522         compute_loop(ur_w, 0, r_overflow1);
1523         if (ur_w_tail != 0) {
1524             if (threaded && tail_thread != pretail_thread)
1525                 jmp(end_label, T_NEAR);
1526             else {
1527                 add(reg_src, src_shift);
1528                 add(reg_dst, dst_shift);
1529             }
1530         }
1531     }
1532     L(tail_label);
1533     if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_overflow); }
1534     L(end_label);
1535 
1536     postamble();
1537 }
1538 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & diff_src_md,memory_desc_t & weights_md,memory_desc_t & diff_dst_md,int nthreads)1539 status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp,
1540         const convolution_desc_t &cd, memory_desc_t &diff_src_md,
1541         memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) {
1542 
1543     const memory_desc_wrapper diff_src_d(&diff_src_md);
1544     const memory_desc_wrapper weights_d(&weights_md);
1545     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
1546 
1547     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1548     int ndims = diff_src_d.ndims();
1549 
1550     jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
1551                                         : bf16_emulation_t::get_isa();
1552     jcp.nthr = nthreads;
1553     jcp.ver = ver_vnni;
1554     jcp.ndims = ndims;
1555     jcp.prop_kind = cd.prop_kind;
1556 
1557     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1558     jcp.mb = diff_src_d.dims()[0];
1559 
1560     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1561     jcp.oc_without_padding = jcp.oc;
1562     jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
1563 
1564     jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
1565     jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2];
1566     jcp.iw = diff_src_d.dims()[ndims - 1];
1567     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1568     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1569     jcp.ow = diff_dst_d.dims()[ndims - 1];
1570 
1571     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1572     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1573     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1574 
1575     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1576     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1577     jcp.l_pad = cd.padding[0][ndims - 3];
1578 
1579     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1580     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1581     jcp.stride_w = cd.strides[ndims - 3];
1582 
1583     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1584     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1585     jcp.dilate_w = cd.dilates[ndims - 3];
1586     jcp.dst_dt = cd.diff_src_desc.data_type;
1587     jcp.nb_iw = 1;
1588     jcp.iw_block = jcp.iw;
1589 
1590     /* Dilated convolutions supported with unit strides only */
1591     if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
1592             || (jcp.dilate_d != 0 && jcp.stride_d != 1)
1593             || (jcp.dilate_h != 0 && jcp.stride_h != 1))
1594         return status::unimplemented;
1595 
1596     int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1597     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1598     int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1599     jcp.r_pad = calculate_end_padding(
1600             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1601     jcp.b_pad = calculate_end_padding(
1602             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1603     jcp.back_pad = calculate_end_padding(
1604             jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1605     bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1606             || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1607             || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1608     if (kernel_outside_src) return status::unimplemented;
1609 
1610     jcp.aligned_threads = 0;
1611 
1612     const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1613     const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
1614     const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1615     const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1616     auto curr_src_tag = diff_src_d.matches_one_of_tag(
1617             dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1618     auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
1619             dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1620     bool is_data_layout_nxc
1621             = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
1622 
1623     bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc;
1624 
1625     jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
1626 
1627     const bool ok_to_try_lower_zmm = true
1628             && IMPLICATION(is_data_layout_nxc,
1629                     jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w
1630                             && jcp.ngroups > 1)
1631             && !ok_to_pad_channels
1632             && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
1633 
1634     if (ok_to_try_lower_zmm) {
1635         for (auto simd : {8, 4}) {
1636             if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
1637                 jcp.simd_w = simd;
1638                 break;
1639             }
1640         }
1641     }
1642 
1643     jcp.oc_block = jcp.simd_w;
1644     jcp.ic_block = jcp.simd_w;
1645 
1646     if (ok_to_pad_channels) {
1647         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1648         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1649     }
1650 
1651     if (!IMPLICATION(!is_data_layout_nxc,
1652                 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
1653         return status::unimplemented;
1654     jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1655     jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0;
1656 
1657     format_tag_t wei_tag, dat_tag;
1658 
1659     if (jcp.simd_w == 8) {
1660         dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
1661         wei_tag = utils::pick(ndims - 3, gOIw4o8i2o, gOIhw4o8i2o, gOIdhw4o8i2o);
1662     } else if (jcp.simd_w == 4) {
1663         dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
1664         wei_tag = utils::pick(ndims - 3, gOIw2o4i2o, gOIhw2o4i2o, gOIdhw2o4i2o);
1665     } else {
1666         dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
1667         wei_tag = pick(2 * ndims - 6 + with_groups, OIw8o16i2o, gOIw8o16i2o,
1668                 OIhw8o16i2o, gOIhw8o16i2o, OIdhw8o16i2o, gOIdhw8o16i2o);
1669     }
1670 
1671     if (diff_src_md.format_kind == format_kind::any) {
1672         CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag));
1673     } else if (curr_src_tag != dat_tag)
1674         return status::unimplemented;
1675     jcp.src_tag = dat_tag;
1676 
1677     if (diff_dst_md.format_kind == format_kind::any) {
1678         CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag));
1679     } else if (curr_dst_tag != dat_tag)
1680         return status::unimplemented;
1681     jcp.dst_tag = dat_tag;
1682 
1683     if (weights_md.format_kind == format_kind::any) {
1684         CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
1685         jcp.wei_tag = wei_tag;
1686     } else {
1687         jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1688         if (jcp.wei_tag != wei_tag) return status::unimplemented;
1689     }
1690 
1691     bool args_ok = true && jcp.ic <= diff_src_d.padded_dims()[1]
1692             && jcp.oc <= diff_dst_d.padded_dims()[1]
1693             && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1694             && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
1695     if (!args_ok) return status::unimplemented;
1696 
1697     jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
1698     jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
1699 
1700     jcp.ur_w = jcp.stride_w;
1701 
1702     /* Maximum number of registers available for result accumulation and delta
1703        dst data. One additional register is reserved for weights data. */
1704     const int max_regs
1705             = isa_has_bf16(jcp.isa) ? 31 : 26; /* In case of cpx emulation
1706                                                   additional 5 registers are
1707                                                   reserved */
1708     int l_overflow = nstl::max(
1709             0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
1710     int r_overflow1 = nstl::max(0,
1711             ((jcp.kw - 1) * (jcp.dilate_w + 1)
1712                     - nstl::max(0, jcp.r_pad + jcp.iw % jcp.ur_w))
1713                     / jcp.stride_w);
1714     int n_oi = jcp.iw / jcp.ur_w;
1715     if (r_overflow1 > 0) n_oi--;
1716 
1717     jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
1718     jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
1719 
1720     /* Find the best blocking with maximum number of compute instructions
1721        per ur_w * nb_ic_blocking compute loops. Number of required registers
1722        is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1723        ur_w must be divisible by stride_w */
1724     if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
1725                                          distribution exceeds max_regs */
1726         return status::unimplemented;
1727 
1728     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1729     {
1730         jcp.kernel_kind = expl_bcast;
1731         int best_compute_pipeline_length = 0;
1732         const int max_ic_blocks = 4;
1733         for (int b = 1; b <= max_ic_blocks; b++) {
1734             if (jcp.nb_ic % b != 0) continue;
1735 
1736             for (int u = jcp.stride_w; u * b + u / jcp.stride_w <= max_regs
1737                     && u < jcp.iw + jcp.stride_w;
1738                     u += jcp.stride_w) {
1739                 int ur_w = nstl::min(u, jcp.iw);
1740                 /* maximum 1 step with l_overflow so far */
1741                 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1742                     continue;
1743                 int pipeline_length = utils::div_up(ur_w, jcp.stride_w) * b;
1744                 if (pipeline_length > best_compute_pipeline_length
1745                         || (pipeline_length == best_compute_pipeline_length
1746                                 && jcp.ur_w < ur_w)) {
1747                     jcp.ur_w = ur_w;
1748                     jcp.nb_ic_blocking = b;
1749                     best_compute_pipeline_length = pipeline_length;
1750                 }
1751             }
1752         }
1753         if (best_compute_pipeline_length == 0) /* can't find
1754                                                   appropriate blocking */
1755             return status::unimplemented;
1756     }
1757     jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1758 
1759     if (is_iw_threading_available(jcp)) {
1760         int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
1761         int work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
1762         float no_iw_block_eff
1763                 = (float)work_units / rnd_up(work_units, jcp.nthr);
1764 
1765         // current design of generate() requires iw_block >= 2 * ur_w
1766         const int min_iw_block = jcp.ur_w * 2;
1767         int iw_threads = jcp.nthr / math::gcd(work_units, jcp.nthr);
1768         int iw_block = nstl::max(min_iw_block,
1769                 rnd_up(jcp.iw, jcp.ur_w * iw_threads) / iw_threads);
1770         int nb_iw = div_up(jcp.iw, iw_block);
1771 
1772         float block_eff = (float)jcp.iw / rnd_up(jcp.iw, iw_block);
1773         work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih * nb_iw;
1774         float work_eff = (float)work_units / rnd_up(work_units, jcp.nthr);
1775         float iw_block_eff = block_eff * work_eff;
1776 
1777         const int iw_thread_min_size = 16 * 128;
1778         const float iw_block_cost = 20.0;
1779         float block_overhead = nstl::max(0.0f, 1.0f - iw_block_cost / iw_block);
1780 
1781         bool iw_thread_useful = no_iw_block_eff < block_overhead * iw_block_eff
1782                 && jcp.ic_block * jcp.iw > iw_thread_min_size;
1783 
1784         if (iw_thread_useful) {
1785             jcp.iw_block = iw_block;
1786             jcp.nb_iw = nb_iw;
1787         }
1788     }
1789 
1790     if (l_overflow * jcp.stride_w > jcp.ur_w) return status::unimplemented;
1791     int r_overflow_no_tail = nstl::max(0,
1792             ((jcp.kw - 1) * (jcp.dilate_w + 1)
1793                     - nstl::max(0, jcp.r_pad + jcp.ur_w_tail))
1794                     / jcp.stride_w);
1795     bool tails_not_ok = false
1796             /* maximum 1 ur_w block with r_overflow so far */
1797             || r_overflow_no_tail * jcp.stride_w > jcp.ur_w
1798             /* ur_w must be a multiple of stride */
1799             || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1800             /* r_pad must not extend beyond ur_w_tail */
1801             || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
1802     if (tails_not_ok) return status::unimplemented;
1803 
1804     /* adjust the thread decomposition
1805      *  to improve the perf for small problem size
1806      *  the threshold L1_cache_size/factor and the factor is empirical
1807      *  simply set the thread number to 4 now
1808      *  TODO: Add get_thr_eff function to compute optimal thread*/
1809     size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh
1810             * jcp.kw * jcp.kd;
1811     size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1812             * jcp.iw * jcp.id;
1813     size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1814             * jcp.ow * jcp.od;
1815     size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1816     const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1817 
1818     //The factor for 1d: 1, 2d: 2, 3d: 4;
1819     int factor = nstl::max(1, (2 * (ndims - 3)));
1820     if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) {
1821         jcp.nthr = nstl::min(jcp.nthr, 4);
1822     }
1823 
1824     pick_loop_order(jcp);
1825 
1826     return status::success;
1827 }
1828 
1829 const int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::max_ur_w = 28;
1830 
1831 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
od_step_comeback_pointers()1832         od_step_comeback_pointers() {
1833     Label kd_comeback_label;
1834     mov(kj, reg_kd_count);
1835     L(kd_comeback_label);
1836     {
1837         sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
1838         sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
1839         dec(kj);
1840         cmp(kj, 0);
1841         jg(kd_comeback_label, T_NEAR);
1842     }
1843 }
1844 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
oh_step_comeback_pointers()1845         oh_step_comeback_pointers() {
1846     Label kh_comeback_label;
1847     mov(kj, reg_kh);
1848     L(kh_comeback_label);
1849     {
1850         sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
1851         sub(reg_kernel, get_kernel_offset(0, jcp.kw));
1852         dec(kj);
1853         cmp(kj, 0);
1854         jg(kh_comeback_label, T_NEAR);
1855     }
1856 }
1857 
1858 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_extern(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)1859         compute_ic_block_step_extern(int ur_w, int pad_l, int pad_r,
1860                 int ic_block_step, int src_offset, int kernel_offset,
1861                 int ddst_offset, bool is_tail) {
1862     assert(!is_src_layout_nxc() && !is_ddst_layout_nxc());
1863     int kw = jcp.kw;
1864     bool no_src_pad = jcp.is_1stconv && !jcp.transpose_src;
1865     const int ddst_zmm_base_idx = 24;
1866     const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4;
1867     const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs;
1868 
1869     auto zmm_ker = [=](int i_kw, int i_ic) {
1870         return Zmm(i_kw * ic_block_step + i_ic);
1871     };
1872     auto zmm_ddst = [=](int i_iw) {
1873         // TODO: move reg calc to global member funcs
1874         return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs);
1875     };
1876 
1877     auto ker_addr = [=](int i_kw, int i_ic) {
1878         auto local_offset = get_kernel_offset(i_ic, i_kw);
1879         return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
1880     };
1881     auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
1882                             bool vnni_bcast = false) {
1883         auto local_offset = get_src_offset(i_ic, i_iw);
1884         return EVEX_compress_addr(
1885                 reg_src, local_offset + src_offset + extra_offset, vnni_bcast);
1886     };
1887     auto ddst_addr = [=](int i_ur) {
1888         auto ow_scale = 2;
1889         return EVEX_compress_addr(
1890                 reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset);
1891     };
1892 
1893     for (int i_kw = 0; i_kw < kw; i_kw++)
1894         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1895             auto zmm = zmm_ker(i_kw, i_ic);
1896             vpxord(zmm, zmm, zmm);
1897         }
1898     assert(ur_w % 2 == 0);
1899     auto steps = ur_w / 2;
1900 
1901     const int str_w = jcp.stride_w;
1902     const int underflow_boundary = -1;
1903     int i_iw_shift = jcp.tr_ow - ur_w - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0);
1904     const int overflow_boundary = jcp.iw - 1 - i_iw_shift;
1905 
1906     for (int s = 0; s < str_w; s++) {
1907         const int kw_start = s;
1908         assert(jcp.tr_iw % str_w == 0);
1909         const int src_stride_w_shift = jcp.tr_iw / str_w;
1910         for (int i_ur = 0; i_ur < steps; i_ur++) {
1911             auto zmm = zmm_ddst(i_ur);
1912             vmovdqu16(zmm, ddst_addr(i_ur));
1913 
1914             for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1915                 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1916                     int i_iw = 2 * i_ur + (i_kw * (jcp.dilate_w + 1)) / str_w
1917                             + s * src_stride_w_shift;
1918                     bool underflow = false;
1919                     bool overflow = false;
1920                     if (no_src_pad) {
1921                         i_iw = i_iw - pad_l;
1922                         underflow = i_iw <= underflow_boundary;
1923                         overflow = is_tail && i_iw >= overflow_boundary;
1924                     }
1925 
1926                     auto src = Zmm(zmm_src_reg);
1927                     auto acc = zmm_ker(i_kw, i_ic);
1928                     auto ddst = zmm_ddst(i_ur);
1929                     if (underflow || overflow || !isa_has_bf16(jcp.isa)) {
1930                         assert(ddst != src);
1931                         assert(acc != src);
1932                     }
1933                     assert(ddst != acc);
1934                     if (underflow || overflow) {
1935                         if (underflow && i_iw == underflow_boundary)
1936                             vpbroadcastw(src | everyother_shift_mask | T_z,
1937                                     src_addr(i_iw + 1, i_ic, 0));
1938                         else if (overflow && i_iw == overflow_boundary)
1939                             vpbroadcastw(src | everyother_mask | T_z,
1940                                     src_addr(i_iw, i_ic, 0));
1941                         else
1942                             continue;
1943 
1944                         if (!isa_has_bf16(jcp.isa))
1945                             bf16_emu_->vdpbf16ps(acc, ddst, src);
1946                         else
1947                             vdpbf16ps(acc, ddst, src);
1948                     } else if (!isa_has_bf16(jcp.isa)) {
1949                         vpbroadcastd(src, src_addr(i_iw, i_ic, 0));
1950                         bf16_emu_->vdpbf16ps(acc, ddst, src);
1951                     } else
1952                         vdpbf16ps(acc, ddst, src_addr(i_iw, i_ic, 0, true));
1953                 }
1954             }
1955         }
1956         for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1957             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1958                 auto addr = ker_addr(i_kw, i_ic);
1959                 auto zmm = zmm_ker(i_kw, i_ic);
1960                 vaddps(zmm, zmm, addr);
1961                 vmovups(addr, zmm);
1962             }
1963         }
1964     }
1965 }
1966 
interleave_w_reorder_size(int ur_w) const1967 int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_w_reorder_size(
1968         int ur_w) const {
1969     const int reorder_block = 16;
1970     return rnd_up(jcp.stride_w * (ur_w - 1) + jcp.kw, reorder_block);
1971 }
1972 int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
interleave_w_reorder_bytes(int ur_w)1973         interleave_w_reorder_bytes(int ur_w) {
1974     return 2 * jcp.typesize_in * interleave_w_reorder_size(ur_w);
1975 }
interleave_stack_size(int ur_w,int ic_block_step)1976 int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_stack_size(
1977         int ur_w, int ic_block_step) {
1978     return ic_block_step * interleave_w_reorder_bytes(ur_w);
1979 }
1980 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_interleave(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)1981         compute_ic_block_step_interleave(int ur_w, int pad_l, int pad_r,
1982                 int ic_block_step, int src_offset, int kernel_offset,
1983                 int ddst_offset, bool is_tail) {
1984     // Only supports nchw format src
1985     assert(jcp.is_1stconv && !jcp.transpose_src);
1986     int kw = jcp.kw;
1987     const int ddst_zmm_base_idx = 24;
1988     const int in_zmm_base_idx = 24;
1989     const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4;
1990     //const int num_in_zmm_regs = 8;
1991     const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs;
1992     const int reorder_block = 16;
1993     const int reorder_size = interleave_w_reorder_size(ur_w);
1994     const int reorder_bytes = interleave_w_reorder_bytes(ur_w);
1995     const int stack_size = interleave_stack_size(ur_w, ic_block_step);
1996     if (stack_size > ic_block_step_stack_size) {
1997         // This is a guard. Ideally it is never used, but is included to defend
1998         // against overlooked edge cases.
1999         assert(stack_size <= ic_block_step_stack_size);
2000         sub(rsp, stack_size - ic_block_step_stack_size);
2001     }
2002 
2003     auto zmm_ker = [=](int i_kw, int i_ic) {
2004         return Zmm(i_kw * ic_block_step + i_ic);
2005     };
2006     auto zmm_ddst = [=](int i_iw) {
2007         return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs);
2008     };
2009     auto zmm_in = [=](int i_iw, int i_ic, bool stride_reg) {
2010         int stride = stride_reg ? 1 : 0;
2011         return Zmm(in_zmm_base_idx + 4 * (i_ic % 2) + 2 * (i_iw % 2) + stride);
2012     };
2013 
2014     auto ker_addr = [=](int i_kw, int i_ic) {
2015         auto local_offset = get_kernel_offset(i_ic, i_kw);
2016         return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
2017     };
2018     auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
2019                             bool vnni_bcast = false) {
2020         int local_offset = i_ic * reorder_bytes + 2 * jcp.typesize_in * i_iw;
2021         return EVEX_compress_addr(rsp, local_offset, vnni_bcast);
2022     };
2023     auto ddst_addr = [=](int i_ur) {
2024         auto ow_scale = 2;
2025         return EVEX_compress_addr(
2026                 reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset);
2027     };
2028     auto load_src_to_stack = [=](int i_iw, int i_ic, Opmask mask,
2029                                      bool mask_empty, Opmask stride_mask,
2030                                      bool stride_mask_empty) {
2031         auto local_offset = get_src_offset(i_ic, i_iw);
2032         int stack_offset
2033                 = i_ic * reorder_bytes + 2 * jcp.typesize_in * (i_iw + pad_l);
2034 
2035         auto zmm = zmm_in(i_iw, i_ic, false);
2036         auto zmm_stride = zmm_in(i_iw, i_ic, true);
2037         auto base_addr
2038                 = EVEX_compress_addr(reg_src, local_offset + src_offset, false);
2039         auto stride_addr = EVEX_compress_addr(reg_src,
2040                 local_offset + src_offset + get_src_offset(0, jcp.stride_w));
2041         auto stack_addr = EVEX_compress_addr(rsp, stack_offset);
2042         assert(IMPLICATION(mask_empty, stride_mask_empty));
2043         if (mask_empty) {
2044             vpxord(zmm, zmm, zmm);
2045         } else {
2046             vpmovzxwd(zmm | mask | T_z, base_addr);
2047         }
2048         if (!stride_mask_empty) {
2049             vpmovzxwd(zmm_stride | stride_mask | T_z, stride_addr);
2050             vpslld(zmm_stride, zmm_stride, 16);
2051             vpord(zmm, zmm, zmm_stride);
2052         }
2053         vmovdqu16(stack_addr, zmm);
2054     };
2055 
2056     assert(ur_w % 2 == 0);
2057     auto steps = ur_w / 2;
2058 
2059     const int str_w = jcp.stride_w;
2060     int i_iw_shift = str_w * (jcp.tr_ow - ur_w)
2061             - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0);
2062     const int overflow_boundary
2063             = is_tail ? jcp.iw - i_iw_shift : str_w * (ur_w - 1) + kw - pad_l;
2064 
2065     // Calculate padding required by the data reorder using 32 byte loads
2066     int reorder_overflow = reorder_size - pad_l - overflow_boundary;
2067     int reorder_stride_overflow = reorder_overflow + str_w;
2068     reorder_overflow = nstl::max(0, reorder_overflow);
2069     reorder_stride_overflow = nstl::max(0, reorder_stride_overflow);
2070     int reorder_pad_r = reorder_overflow % reorder_block;
2071     int reorder_stride_pad_r = reorder_stride_overflow % reorder_block;
2072     if (reorder_stride_overflow >= reorder_size && reorder_stride_pad_r == 0) {
2073         assert(reorder_stride_overflow == reorder_size);
2074         reorder_stride_pad_r = reorder_block;
2075     }
2076     reorder_overflow -= reorder_pad_r;
2077     reorder_stride_overflow -= reorder_stride_pad_r;
2078 
2079     int pad_l_mask = (0xffff << pad_l) & 0xffff;
2080     int pad_l_mask_strided
2081             = (0xffff << (pad_l >= str_w ? (pad_l - str_w) : 0)) & 0xffff;
2082     int pad_r_mask = 0xffff >> reorder_pad_r;
2083     int pad_r_mask_strided = 0xffff >> (reorder_stride_pad_r);
2084     pad_r_mask = pad_r_mask & 0xffff;
2085 
2086     // Setup masks to load and reorder data
2087     if (reorder_size - reorder_stride_overflow > reorder_block) {
2088         // Overflow and underflow happen in different data reorder rounds
2089         kxnorw(overflow_stride_mask, overflow_stride_mask,
2090                 overflow_stride_mask);
2091         kshiftlw(underflow_mask, overflow_stride_mask, pad_l);
2092         kshiftlw(underflow_stride_mask, overflow_stride_mask,
2093                 pad_l >= str_w ? pad_l - str_w : 0);
2094         kshiftrw(overflow_mask, overflow_stride_mask, reorder_pad_r);
2095         kshiftrw(overflow_stride_mask, overflow_stride_mask,
2096                 reorder_stride_pad_r);
2097     } else if (reorder_size - reorder_overflow > reorder_block) {
2098         // Overflow and underflow happen in the same round for loading the data
2099         // at the stride offset.
2100         kxnorw(overflow_mask, overflow_mask, overflow_mask);
2101         kshiftlw(underflow_mask, overflow_mask, pad_l);
2102         kshiftrw(overflow_mask, overflow_mask, reorder_pad_r);
2103         mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided);
2104         kmovw(underflow_stride_mask, reg_tmp.cvt32());
2105     } else {
2106         // Overflow and underflow happen in the same round for all data loads
2107         mov(reg_tmp.cvt32(), pad_l_mask & pad_r_mask);
2108         kmovw(underflow_mask, reg_tmp.cvt32());
2109         mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided);
2110         kmovw(underflow_stride_mask, reg_tmp.cvt32());
2111     }
2112 
2113     // Load and reorder data to the stack
2114     int reorder_start = -pad_l;
2115     int reorder_end = reorder_size - pad_l;
2116     for (int i_iw = reorder_start; i_iw < reorder_end; i_iw += reorder_block) {
2117         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2118             Opmask mask, stride_mask;
2119             bool mask_empty, stride_mask_empty;
2120             // Performing this reorder on the stack may not be (always) optimal.
2121             // There are a couple of methods involving externally reordering the
2122             // data that were not considered due to time constraints. The first
2123             // is to transpose similar to the extern method. The other is to
2124             // perform the same interleave transform used here. The tradeoff
2125             // between these methods is the transpose method does not lend
2126             // itself to SIMD instructions (except possibly for some specific
2127             // strides) since the data is not blocked. The transform performed
2128             // here does, but uses twice as much data since
2129             // most data elements are duplicated.
2130 
2131             if (i_iw == reorder_start) {
2132                 mask = underflow_mask;
2133                 mask_empty = false;
2134                 if (pad_l_mask == 0) mask_empty = true;
2135             } else if (i_iw + reorder_overflow >= reorder_end) {
2136                 mask_empty = true;
2137             } else if (i_iw + reorder_block + reorder_overflow >= reorder_end) {
2138                 mask = overflow_mask;
2139                 mask_empty = false;
2140                 if (pad_r_mask == 0) mask_empty = true;
2141             } else {
2142                 mask = m_ffffffff;
2143                 mask_empty = false;
2144             }
2145             if (i_iw == reorder_start) {
2146                 stride_mask = underflow_stride_mask;
2147                 stride_mask_empty = false;
2148                 if (pad_l_mask_strided == 0) mask_empty = true;
2149             } else if (i_iw + reorder_stride_overflow >= reorder_end) {
2150                 stride_mask_empty = true;
2151             } else if (i_iw + reorder_block + reorder_stride_overflow
2152                     >= reorder_end) {
2153                 stride_mask = overflow_stride_mask;
2154                 stride_mask_empty = false;
2155                 if (pad_r_mask_strided == 0) mask_empty = true;
2156             } else {
2157                 stride_mask = m_ffffffff;
2158                 stride_mask_empty = false;
2159             }
2160             load_src_to_stack(i_iw, i_ic, mask, mask_empty, stride_mask,
2161                     stride_mask_empty);
2162         }
2163     }
2164 
2165     // Initialize kernel accumulators. It should sometimes be possible to skip
2166     // initializing and storing this data between calls to this function.
2167     for (int i_kw = 0; i_kw < kw; i_kw++)
2168         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2169             auto zmm = zmm_ker(i_kw, i_ic);
2170             vpxord(zmm, zmm, zmm);
2171         }
2172 
2173     // Calculate this blocks contribution
2174     for (int i_ur = 0; i_ur < steps; i_ur++) {
2175         auto zmm = zmm_ddst(i_ur);
2176         vmovdqu16(zmm, ddst_addr(i_ur));
2177 
2178         for (int i_kw = 0; i_kw < kw; i_kw++) {
2179             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2180                 int i_iw = 2 * i_ur * str_w + i_kw;
2181                 auto acc = zmm_ker(i_kw, i_ic);
2182                 auto ddst = zmm_ddst(i_ur);
2183 
2184                 const bool isa_supports_bf16 = isa_has_bf16(jcp.isa);
2185                 auto src_stack_addr
2186                         = src_addr(i_iw, i_ic, 0, isa_supports_bf16);
2187 
2188                 if (isa_supports_bf16)
2189                     vdpbf16ps(acc, ddst, src_stack_addr);
2190                 else {
2191                     auto src = Zmm(zmm_src_reg);
2192                     vpbroadcastd(src, src_stack_addr);
2193                     bf16_emu_->vdpbf16ps(acc, ddst, src);
2194                 }
2195             }
2196         }
2197     }
2198 
2199     // Store kernel accumulators
2200     for (int i_kw = 0; i_kw < kw; i_kw++) {
2201         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2202             auto addr = ker_addr(i_kw, i_ic);
2203             auto zmm = zmm_ker(i_kw, i_ic);
2204             vaddps(zmm, zmm, addr);
2205             vmovups(addr, zmm);
2206         }
2207     }
2208 
2209     if (stack_size > ic_block_step_stack_size) {
2210         // This is a guard. Ideally it is never used, but is included to defend
2211         // against overlooked edge cases.
2212         add(rsp, stack_size - ic_block_step_stack_size);
2213     }
2214 }
2215 
2216 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
convert_src_to_vnni_format(int ur_w,int pad_l,int pad_r,int src_offset)2217         convert_src_to_vnni_format(
2218                 int ur_w, int pad_l, int pad_r, int src_offset) {
2219     Reg64 reg_trans_tmp = r11;
2220     const int ic_tail = jcp.ic_tail;
2221     mov(EVEX_compress_addr(rsp, trans_tmp_offset), reg_trans_tmp);
2222 
2223     mov(reg_trans_tmp, dst_prm_table);
2224     vmovups(get_perm_reg(), ptr[reg_trans_tmp]);
2225 
2226     mov(reg_trans_tmp, EVEX_compress_addr(rsp, trans_tmp_offset));
2227     const int max_regs = 16;
2228     if (ic_tail) {
2229         Label skip_tail_mask;
2230         cmp(reg_icb, jcp.simd_w);
2231         jge(skip_tail_mask);
2232         kandd(m_0000ffff, m_0000ffff, m_0000_ic_tail);
2233         kandd(m_ffff0000, m_ffff0000, m_ic_tail_0000);
2234         L(skip_tail_mask);
2235     }
2236     for (int src_count = 0;
2237             sizeof_cacheline * src_count < permw_stack_size(ur_w);
2238             src_count++) {
2239         int i_ur = nstl::min(src_count, ur_w - 2);
2240         int i_kw = src_count - i_ur;
2241         int buffer_offset = permw_buffer_start + src_count * 64;
2242         auto bcast_values = Zmm(src_count % max_regs);
2243         bool check = check_borders(ur_w, pad_l, pad_r, i_ur, i_kw);
2244         if (check) {
2245             if (is_src_layout_nxc()) {
2246                 int iw_1, iw_2;
2247                 get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2);
2248                 if (iw_1 == -1)
2249                     vxorpd(bcast_values, bcast_values, bcast_values);
2250                 else {
2251                     dim_t local_src_offset = src_offset
2252                             + get_src_offset(
2253                                     0, filter_w_to_src(i_kw, i_ur, pad_l));
2254                     vmovdqu16(bcast_values | m_0000ffff | T_z,
2255                             ptr[reg_src + local_src_offset]);
2256                 }
2257                 if (iw_2 != -1) {
2258                     dim_t local_src_offset = src_offset - 32
2259                             + get_src_offset(
2260                                     0, filter_w_to_src(i_kw, i_ur + 1, pad_l));
2261                     vmovdqu16(bcast_values | m_ffff0000,
2262                             ptr[reg_src + local_src_offset]);
2263                 }
2264             } else {
2265                 Opmask load_mask;
2266                 get_load_mask(ur_w, pad_l, pad_r, i_ur, i_kw, load_mask);
2267 
2268                 dim_t local_src_offset = src_offset
2269                         + get_src_offset(0, filter_w_to_src(i_kw, i_ur, pad_l));
2270                 vmovdqu16(bcast_values | load_mask | T_z,
2271                         ptr[reg_src + local_src_offset]);
2272             }
2273             vpermw(bcast_values, get_perm_reg(), bcast_values);
2274         } else {
2275             vpxord(bcast_values, bcast_values, bcast_values);
2276         }
2277         vmovups(ptr[rsp + buffer_offset], bcast_values);
2278     }
2279     if (ic_tail) {
2280         // Reset-back the masks
2281         kxnorw(m_0000ffff, m_0000ffff, m_0000ffff);
2282         kshiftld(m_ffff0000, m_0000ffff, 16);
2283     }
2284 }
2285 
2286 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
may_be_set_oc_tail_mask()2287         may_be_set_oc_tail_mask() {
2288     if (jcp.oc_tail) {
2289         Label skip_tail_mask;
2290         cmp(dword[param + GET_OFF(load_work)], jcp.simd_w);
2291         jge(skip_tail_mask);
2292         kandd(m_0000ffff, m_0000ffff, m_0000_oc_tail);
2293         kandd(m_ffff0000, m_ffff0000, m_oc_tail_0000);
2294         L(skip_tail_mask);
2295     }
2296 }
2297 
2298 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
may_be_reset_oc_tail_mask()2299         may_be_reset_oc_tail_mask() {
2300     if (jcp.oc_tail) {
2301         // Reset-back the masks
2302         kxnorw(m_0000ffff, m_0000ffff, m_0000ffff);
2303         kshiftld(m_ffff0000, m_0000ffff, 16);
2304     }
2305 }
2306 
2307 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_vpermw_expl(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)2308         compute_ic_block_step_vpermw_expl(int ur_w, int pad_l, int pad_r,
2309                 int ic_block_step, int src_offset, int kernel_offset,
2310                 int ddst_offset, bool is_tail) {
2311     assert(!jcp.is_1stconv); // This method does not support nchw data
2312     int kw = jcp.kw;
2313     int src_count = 0;
2314     int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step);
2315     const int max_regs = (!isa_has_bf16(jcp.isa)) ? 26 : 31;
2316     int src_pl_len = kw;
2317     const int diff_dst_pl_start_reg_idx = ic_block_step * (kw + src_pl_len);
2318     const int diff_dst_pl_len = max_regs - diff_dst_pl_start_reg_idx;
2319 
2320     auto get_diff_wei_reg_idx
2321             = [=](int i_kw, int i_ic) { return i_kw * ic_block_step + i_ic; };
2322     auto get_src_reg_idx = [=](int i_iw, int i_ic) {
2323         return kw * ic_block_step + (i_iw % src_pl_len) * ic_block_step + i_ic;
2324     };
2325     auto get_diff_dst_reg_idx = [=](int i_ur) {
2326         return diff_dst_pl_start_reg_idx + (i_ur / 2) % diff_dst_pl_len;
2327     };
2328 
2329     may_be_set_oc_tail_mask();
2330     auto load_dst = [=](int c) {
2331         bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w;
2332         bool is_ddst_nxc = is_ddst_layout_nxc();
2333         auto offset = get_ddst_offset(c * 2) + ddst_offset;
2334 
2335         Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff;
2336         vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | load_mask | T_z,
2337                 EVEX_compress_addr(reg_ddst, offset));
2338 
2339         if (is_ddst_nxc && !is_tail) {
2340             offset += get_ddst_offset(1) - 32;
2341             vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | m_ffff0000,
2342                     EVEX_compress_addr(reg_ddst, offset));
2343         }
2344         vpermw(Zmm(get_diff_dst_reg_idx(2 * c)), get_perm_reg(),
2345                 Zmm(get_diff_dst_reg_idx(2 * c)));
2346     };
2347 
2348     for (int i_kw = 0; i_kw < kw; i_kw++)
2349         for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2350             vpxord(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2351                     Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2352                     Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2353 
2354     auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) {
2355         int scale = 2 * jcp.typesize_in;
2356         return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64
2357                 + jcp.typesize_in * 2
2358                 * (ic_block_step_idx * ic_block_step + ic);
2359     };
2360     int src_count_last = 0;
2361     for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
2362         if (i_ur == 0) {
2363             for (int dst_count = 0;
2364                     dst_count < nstl::min(diff_dst_pl_len, div_up(ur_w, 2));
2365                     dst_count++) {
2366                 load_dst(dst_count);
2367             }
2368             for (src_count = 0; src_count < src_pl_len; src_count++) {
2369                 int _i_ur = src_count / kw;
2370                 int _i_kw = src_count % kw;
2371                 if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw))
2372                     for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2373                         vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)),
2374                                 ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]);
2375                     }
2376             }
2377             src_count_last = src_count;
2378         } else {
2379             int diff_dst_load_idx = i_ur + 2 * (diff_dst_pl_len - 1);
2380             if (diff_dst_load_idx < ur_w) load_dst(diff_dst_load_idx / 2);
2381             for (src_count = i_ur; src_count < i_ur + src_pl_len; src_count++) {
2382                 if (src_count < src_count_last) continue;
2383                 int _i_ur = (src_count - i_ur) / kw + i_ur;
2384                 int _i_kw = (src_count - i_ur) % kw;
2385                 if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw))
2386                     for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2387                         vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)),
2388                                 ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]);
2389                     }
2390             }
2391             src_count_last = src_count;
2392         }
2393         for (int i_kw = 0; i_kw < kw; i_kw++) {
2394             int i_iw = i_ur + i_kw;
2395             if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) {
2396                 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2397                     if (!isa_has_bf16(jcp.isa)) {
2398                         bf16_emu_->vdpbf16ps(
2399                                 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2400                                 Zmm(get_diff_dst_reg_idx(i_ur)),
2401                                 Zmm(get_src_reg_idx(i_iw, i_ic)));
2402                     } else {
2403                         vdpbf16ps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2404                                 Zmm(get_diff_dst_reg_idx(i_ur)),
2405                                 Zmm(get_src_reg_idx(i_iw, i_ic)));
2406                     }
2407                 }
2408             }
2409         }
2410     }
2411 
2412     for (int i_kw = 0; i_kw < kw; i_kw++)
2413         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2414             auto l_offset = get_kernel_offset(i_ic, i_kw);
2415             vaddps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2416                     EVEX_compress_addr(reg_kernel, l_offset + kernel_offset));
2417         }
2418 
2419     for (int i_kw = 0; i_kw < kw; i_kw++) {
2420         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2421             auto l_offset = get_kernel_offset(i_ic, i_kw);
2422             vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
2423                     Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2424         }
2425     }
2426 
2427     may_be_reset_oc_tail_mask();
2428 }
2429 
2430 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_vpermw(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)2431         compute_ic_block_step_vpermw(int ur_w, int pad_l, int pad_r,
2432                 int ic_block_step, int src_offset, int kernel_offset,
2433                 int ddst_offset, bool is_tail) {
2434     assert(!jcp.is_1stconv); // This method does not support nchw data
2435     int kw = jcp.kw;
2436 
2437     int dst_count = 0;
2438 
2439     int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step);
2440 
2441     int pipeline_length = (isa_has_bf16(jcp.isa))
2442             ? nstl::max(1, nstl::min(4, ur_w / 2))
2443             : 1;
2444     may_be_set_oc_tail_mask();
2445 
2446     const int dst_off_reg = (!isa_has_bf16(jcp.isa)) ? 26 : 31;
2447     auto load_dst = [=](int c) {
2448         bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w;
2449         bool is_ddst_nxc = is_ddst_layout_nxc();
2450         auto offset = get_ddst_offset(2 * c) + ddst_offset;
2451 
2452         Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff;
2453         vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | load_mask | T_z,
2454                 EVEX_compress_addr(reg_ddst, offset));
2455 
2456         if (is_ddst_nxc && !is_tail) {
2457             offset += get_ddst_offset(1) - 32;
2458             vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | m_ffff0000,
2459                     EVEX_compress_addr(reg_ddst, offset));
2460         }
2461         vpermw(Zmm(dst_off_reg - c % pipeline_length), get_perm_reg(),
2462                 Zmm(dst_off_reg - c % pipeline_length));
2463     };
2464 
2465     for (int i_kw = 0; i_kw < kw; i_kw++)
2466         for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2467             vmovups(Zmm(i_kw * ic_block_step + i_ic),
2468                     EVEX_compress_addr(reg_kernel,
2469                             get_kernel_offset(i_ic, i_kw) + kernel_offset));
2470 
2471     for (dst_count = 0; dst_count < pipeline_length; dst_count++) {
2472         load_dst(dst_count);
2473     }
2474     auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) {
2475         int scale = 2 * jcp.typesize_in;
2476         return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64
2477                 + jcp.typesize_in * 2
2478                 * (ic_block_step_idx * ic_block_step + ic);
2479     };
2480 
2481     for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
2482         for (int i_kw = 0; i_kw < kw; i_kw++) {
2483             if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) {
2484                 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2485                     if (!isa_has_bf16(jcp.isa)) {
2486                         auto zmm_src = Zmm(28);
2487                         vpbroadcastd(
2488                                 zmm_src, ptr[get_bcast_ptr(i_ur, i_kw, i_ic)]);
2489                         bf16_emu_->vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
2490                                 Zmm(dst_off_reg - dst_count % pipeline_length),
2491                                 zmm_src);
2492                     } else {
2493                         vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
2494                                 Zmm(dst_off_reg - dst_count % pipeline_length),
2495                                 zword_b[get_bcast_ptr(i_ur, i_kw, i_ic)]);
2496                     }
2497                 }
2498             }
2499         }
2500         if (dst_count * 2 < ur_w) load_dst(dst_count);
2501         dst_count++;
2502     }
2503     for (int i_kw = 0; i_kw < kw; i_kw++) {
2504         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2505             auto l_offset = get_kernel_offset(i_ic, i_kw);
2506             vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
2507                     Zmm(i_kw * ic_block_step + i_ic));
2508         }
2509     }
2510 
2511     may_be_reset_oc_tail_mask();
2512 }
2513 
2514 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_diff_bias_init()2515         compute_diff_bias_init() {
2516     auto reg_unit_val = reg_tmp.cvt16();
2517     mov(reg_unit_val, 0x3f80); // bf16 value of 1.
2518     vpbroadcastw(vreg_bias_unit, reg_unit_val);
2519 
2520     mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2521     vmovups(vreg_bias_acc, ptr[reg_tmp]);
2522 
2523     if (jcp.uses_permw_transposition) {
2524         mov(reg_tmp, dst_prm_table);
2525         vmovups(get_perm_reg(), ptr[reg_tmp]);
2526     }
2527 }
2528 
compute_diff_bias_row(bool is_partial)2529 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_diff_bias_row(
2530         bool is_partial) {
2531     if (!jcp.with_bias) return;
2532     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2533     Label skip_label;
2534     test(reg_tmp, FLAG_IC_FIRST);
2535     jz(skip_label, T_NEAR);
2536 
2537     may_be_set_oc_tail_mask();
2538 
2539     if (is_partial) compute_diff_bias_init();
2540 
2541     auto compute_step = [&](bool is_tail) {
2542         if (jcp.transpose_dst) {
2543             UNUSED(is_tail);
2544             vmovups(vreg_bias_ddst, ptr[reg_ddst]);
2545         } else {
2546             auto vreg_ddst_load = is_ddst_layout_nxc() || is_tail
2547                     ? vreg_bias_ddst | m_0000ffff | T_z
2548                     : vreg_bias_ddst;
2549             vmovdqu16(vreg_ddst_load, ptr[reg_ddst]);
2550             if (is_ddst_layout_nxc() && !is_tail) {
2551                 const int shift_16_elems = 16 * jcp.typesize_in;
2552                 vmovdqu16(vreg_bias_ddst | m_ffff0000,
2553                         ptr[reg_ddst + get_ddst_offset(1) - shift_16_elems]);
2554             }
2555             vpermw(vreg_bias_ddst, get_perm_reg(), vreg_bias_ddst);
2556         }
2557         if (!isa_has_bf16(jcp.isa))
2558             bf16_emu_->vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
2559         else
2560             vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
2561     };
2562 
2563     Label ow_loop, ow_tail;
2564     int niters = jcp.tr_ow / 2;
2565     if (niters > 0) {
2566         mov(reg_tmp, jcp.tr_ow / 2);
2567         L(ow_loop);
2568         compute_step(false);
2569         add(reg_ddst, get_ddst_offset(2));
2570         sub(reg_tmp, 1);
2571         jnz(ow_loop, T_NEAR);
2572     }
2573     if (jcp.tr_ow % 2) compute_step(true);
2574 
2575     if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
2576 
2577     if (is_partial) {
2578         mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2579         vmovups(ptr[reg_tmp], vreg_bias_acc);
2580     }
2581 
2582     may_be_reset_oc_tail_mask();
2583 
2584     L(skip_label);
2585 }
2586 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
maybe_compute_diff_bias()2587         maybe_compute_diff_bias() {
2588     // In harness_3d_reduction case calculation of diff_bias is called
2589     // for every ow row separately to be aligned with od loop in
2590     // compute_od_loop_common()
2591     if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
2592     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2593 
2594     Label skip_label;
2595     test(reg_tmp, FLAG_IC_FIRST);
2596     jz(skip_label, T_NEAR);
2597 
2598     switch (jcp.harness) {
2599         case harness_2d_reduction:
2600             mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
2601             sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
2602             break;
2603         case harness_mb_reduction:
2604         case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
2605         case harness_3d_reduction:
2606         default: assert(!"Invalid harness type");
2607     }
2608 
2609     compute_diff_bias_init();
2610 
2611     cmp(reg_oj, 0);
2612     jle(skip_label, T_NEAR); // nothing to do
2613     Label bias_loop;
2614     L(bias_loop);
2615     {
2616         compute_diff_bias_row(false);
2617         add(reg_ddst, get_ddst_offset(0, 1));
2618 
2619         sub(reg_oj, 1);
2620         jnz(bias_loop, T_NEAR);
2621     }
2622 
2623     mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2624     vmovups(ptr[reg_tmp], vreg_bias_acc);
2625 
2626     // restore reg_ddst value
2627     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
2628 
2629     L(skip_label);
2630 }
2631 
compute_ic_block_step(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)2632 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
2633         int ur_w, int pad_l, int pad_r, int ic_block_step, int src_offset,
2634         int kernel_offset, int ddst_offset, bool is_tail) {
2635 
2636     if (jcp.uses_permw_transposition)
2637         if (jcp.kernel_kind == expl_bcast)
2638             compute_ic_block_step_vpermw_expl(ur_w, pad_l, pad_r, ic_block_step,
2639                     src_offset, kernel_offset, ddst_offset, is_tail);
2640         else
2641             compute_ic_block_step_vpermw(ur_w, pad_l, pad_r, ic_block_step,
2642                     src_offset, kernel_offset, ddst_offset, is_tail);
2643     else if (jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1)
2644         compute_ic_block_step_interleave(ur_w, pad_l, pad_r, ic_block_step,
2645                 src_offset, kernel_offset, ddst_offset, is_tail);
2646     else
2647         compute_ic_block_step_extern(ur_w, pad_l, pad_r, ic_block_step,
2648                 src_offset, kernel_offset, ddst_offset, is_tail);
2649 }
2650 
get_ur_w(int & ur_w,int & ur_w_tail,int & ur_w_trips)2651 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::get_ur_w(
2652         int &ur_w, int &ur_w_tail, int &ur_w_trips) {
2653     if (jcp.tr_ow <= max_ur_w) {
2654         ur_w = jcp.tr_ow;
2655         ur_w_tail = 0;
2656         ur_w_trips = 1;
2657         return;
2658     }
2659 
2660     int r_pad = 0;
2661     if (!jcp.transpose_src) {
2662         // If jcp.transpose_src, the buffer has physical padding
2663         int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2664         r_pad = nstl::max(0,
2665                 calculate_end_padding(
2666                         jcp.l_pad, jcp.tr_ow, jcp.tr_iw, jcp.stride_w, ext_kw));
2667     }
2668     int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2669     ur_w = max_ur_w;
2670     ur_w_trips = jcp.tr_ow / ur_w;
2671     ur_w_tail = jcp.tr_ow % ur_w;
2672     if ((ur_w_tail == 0 && jcp.r_pad != 0) || r_pad >= ur_w_tail) {
2673         if (ur_w_trips > 1) {
2674             ur_w_tail += ur_w;
2675             ur_w_trips--;
2676         } else {
2677             int ur_w_tail_total = ur_w + ur_w_tail;
2678             ur_w = (ur_w_tail_total % 4 == 0) ? ur_w_tail / 2
2679                                               : ur_w_tail / 2 + 1;
2680             ur_w_tail = ur_w_tail_total - ur_w;
2681             if (l_pad > ur_w / 2) {
2682                 ur_w = (l_pad % 2 == 0) ? l_pad : l_pad + 1;
2683                 ur_w_tail = ur_w_tail_total - ur_w;
2684             } else if (r_pad > ur_w_tail) {
2685                 ur_w_tail = (r_pad % 2 == 0) ? r_pad : r_pad + 1;
2686                 ur_w = ur_w_tail_total - ur_w_tail;
2687             }
2688         }
2689     }
2690 }
2691 
2692 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::
compute_oh_step_unroll_ow_icblock(int ic_block_step)2693         compute_oh_step_unroll_ow_icblock(int ic_block_step) {
2694     Label kh_label, kd_label;
2695 
2696     int ic_block = jcp.ic_block;
2697     int ic_tail = jcp.ic_tail;
2698     int ow = jcp.tr_ow;
2699     int r_pad = 0;
2700     int ur_w, ur_w_tail, ur_w_trips;
2701     get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2702     assert(ur_w_tail == 0 && ur_w_trips == 1);
2703 
2704     if (!jcp.transpose_src) {
2705         // If jcp.transpose_src, the buffer has physical padding
2706         int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2707         int iw = jcp.tr_iw;
2708         r_pad = nstl::max(0,
2709                 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2710     }
2711     int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2712 
2713     if (jcp.ndims == 5) {
2714         L(kd_label);
2715         mov(reg_src, aux_reg_src);
2716         mov(reg_kernel, aux_reg_kernel);
2717     }
2718 
2719     mov(kj, reg_kh);
2720     L(kh_label);
2721     {
2722         const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2723         // icb loop is supported for nxc layout only
2724         assert(IMPLICATION(generate_icb_loop,
2725                 is_src_layout_nxc() && is_ddst_layout_nxc()));
2726         Label icb_block_label, icb_block_label_end;
2727         if (generate_icb_loop || ic_tail) {
2728             mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
2729             mov(ptr[rsp + icb_loop_src_ptr], reg_src);
2730             mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2731             L(icb_block_label);
2732         }
2733 
2734         if (jcp.uses_permw_transposition) {
2735             convert_src_to_vnni_format(ur_w, l_pad, r_pad, 0);
2736             xor_(b_ic, b_ic);
2737         }
2738 
2739         const int ic_tail_loop_work = rnd_up(ic_tail, ic_block_step);
2740         for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
2741             const int src_offset = get_src_offset(i_b_ic, 0);
2742             compute_ic_block_step(ur_w, l_pad, r_pad, ic_block_step, src_offset,
2743                     get_kernel_offset(i_b_ic, 0), 0, true);
2744             if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2745             // We relax the boundary for reg_icb, as the src is already
2746             // converted to vnni_format with appropriate padding either through
2747             // transpose_src or convert_to_src_to_vnni_format. We can safely
2748             // allow compute_ic_block_step overstep as it operates on buffer
2749             // instead of src.
2750             if (ic_tail && i_b_ic + ic_block_step == ic_tail_loop_work) {
2751                 assert(jcp.transpose_src || jcp.uses_permw_transposition);
2752                 cmp(reg_icb, 0);
2753                 jle(icb_block_label_end, T_NEAR);
2754             }
2755         }
2756         L(icb_block_label_end);
2757 
2758         const auto src_icb_loop_shift_bytes = get_src_offset(ic_block, 0);
2759         const auto kernel_icb_loop_shift_bytes
2760                 = get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw);
2761         if (generate_icb_loop) {
2762             add(reg_src, src_icb_loop_shift_bytes);
2763             safe_add(reg_kernel, kernel_icb_loop_shift_bytes, reg_long_offt);
2764 
2765             assert(jcp.uses_permw_transposition);
2766             cmp(reg_icb, 0);
2767             jg(icb_block_label, T_NEAR);
2768         }
2769 
2770         if (generate_icb_loop || ic_tail) {
2771             // restore pointers
2772             mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2773             mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2774         }
2775 
2776         add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2777         add(reg_kernel, get_kernel_offset(0, jcp.kw));
2778         dec(kj);
2779         cmp(kj, 0);
2780         jg(kh_label, T_NEAR);
2781     }
2782 
2783     if (jcp.ndims == 5) {
2784         add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
2785         add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
2786         dec(ki);
2787         cmp(ki, 0);
2788         jg(kd_label, T_NEAR);
2789     }
2790 }
2791 
2792 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::
compute_oh_step_unroll_ow(int ic_block_step)2793         compute_oh_step_unroll_ow(int ic_block_step) {
2794     Label kh_label, ic_block_label, kd_label;
2795 
2796     int ic_block = jcp.ic_block;
2797     const int ic_tail = jcp.ic_tail;
2798     int ow = jcp.tr_ow;
2799 
2800     int r_pad = 0;
2801     int ur_w, ur_w_tail, ur_w_trips;
2802     get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2803     assert(ur_w_tail == 0 && ur_w_trips == 1);
2804 
2805     if (!jcp.transpose_src) {
2806         // If jcp.transpose_src, the buffer has physical padding
2807         int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2808         int iw = jcp.tr_iw;
2809         r_pad = nstl::max(0,
2810                 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2811     }
2812     int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2813 
2814     if (jcp.ndims == 5) {
2815         L(kd_label);
2816         mov(reg_src, aux_reg_src);
2817         mov(reg_kernel, aux_reg_kernel);
2818     }
2819 
2820     mov(kj, reg_kh);
2821     L(kh_label);
2822     {
2823         size_t src_offset = get_src_offset(ic_block_step, 0);
2824 
2825         const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2826         // icb loop is supported for nxc layout only
2827         assert(IMPLICATION(generate_icb_loop,
2828                 is_src_layout_nxc() && is_ddst_layout_nxc()));
2829         Label icb_block_label, icb_block_label_end;
2830         if (generate_icb_loop || ic_tail) {
2831             mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
2832             mov(ptr[rsp + icb_loop_src_ptr], reg_src);
2833             mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2834             L(icb_block_label);
2835         }
2836 
2837         xor_(b_ic, b_ic);
2838         if (jcp.uses_permw_transposition) {
2839             convert_src_to_vnni_format(ow, l_pad, r_pad, 0);
2840             xor_(b_ic, b_ic);
2841         }
2842 
2843         L(ic_block_label);
2844         {
2845             compute_ic_block_step(
2846                     ur_w, l_pad, r_pad, ic_block_step, 0, 0, 0, true);
2847             assert(jcp.ic_block % jcp.ic_block_step == 0);
2848             safe_add(reg_src, src_offset, reg_long_offt);
2849             add(reg_kernel, get_kernel_offset(ic_block_step, 0));
2850             add(b_ic, ic_block_step);
2851             if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2852             // We relax the boundary for reg_icb, as the src is already
2853             // converted to vnni_format with appropriate padding either through
2854             // transpose_src or convert_to_src_to_vnni_format. We can safely
2855             // allow compute_ic_block_step overstep as it operates on buffer
2856             // instead of src.
2857             if (ic_tail) {
2858                 assert(jcp.transpose_src || jcp.uses_permw_transposition);
2859                 cmp(reg_icb, 0);
2860                 jle(icb_block_label_end, T_NEAR);
2861             }
2862             cmp(b_ic, jcp.ic_block);
2863             jl(ic_block_label, T_NEAR);
2864         }
2865         L(icb_block_label_end);
2866 
2867         if (jcp.uses_permw_transposition) {
2868             if (generate_icb_loop || ic_tail) {
2869                 // substract pointer shift made within ic block loop
2870                 // and move to next ic block
2871                 safe_add(reg_kernel,
2872                         get_kernel_offset(-ic_block, jcp.kd * jcp.kh * jcp.kw),
2873                         reg_long_offt);
2874 
2875                 cmp(reg_icb, 0);
2876                 jg(icb_block_label, T_NEAR);
2877                 // restore pointers
2878                 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2879                 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2880 
2881                 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2882                 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2883             } else {
2884                 add(reg_src,
2885                         get_src_offset(0, 0, filter_h_to_src(1))
2886                                 - jcp.typesize_in * ic_block);
2887             }
2888         } else if (ic_tail) {
2889             // restore pointers
2890             mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2891             mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2892 
2893             add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2894             add(reg_kernel, get_kernel_offset(0, jcp.kw));
2895         } else if (jcp.is_1stconv && !jcp.transpose_src) {
2896             // Fixup reg_src to point to the correct location
2897             safe_add(reg_src,
2898                     get_src_offset(0, 0, filter_h_to_src(1))
2899                             - src_offset * (jcp.ic_block / ic_block_step),
2900                     reg_long_offt);
2901         } else {
2902             if (jcp.dilate_h > 0)
2903                 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
2904         }
2905         if (!generate_icb_loop && !ic_tail)
2906             // substract pointer shift made within ic block loop
2907             // and move to next kh index
2908             add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
2909         dec(kj);
2910         cmp(kj, 0);
2911         jg(kh_label, T_NEAR);
2912     }
2913     if (jcp.ndims == 5) {
2914         add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
2915         add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
2916         dec(ki);
2917         cmp(ki, 0);
2918         jg(kd_label, T_NEAR);
2919     }
2920 }
2921 
compute_oh_step_common(int ic_block_step)2922 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_common(
2923         int ic_block_step) {
2924     Label kh_label, ic_block_label, ow_block_label, kd_label;
2925 
2926     int ic_block = jcp.ic_block;
2927     int ic_tail = jcp.ic_tail;
2928     int ow = jcp.tr_ow;
2929     int r_pad = 0;
2930     if (!jcp.transpose_src) {
2931         // If jcp.transpose_src, the buffer has physical padding
2932         int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2933         int iw = jcp.tr_iw;
2934         r_pad = nstl::max(0,
2935                 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2936     }
2937     int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2938 
2939     int ur_w, ur_w_trips, ur_w_tail;
2940     get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2941     assert(l_pad <= ur_w);
2942     assert(r_pad <= ur_w_tail);
2943 
2944     auto src_comeback
2945             = get_src_offset(0, filter_w_to_src(0, ur_w_trips * ur_w, l_pad));
2946     auto ddst_comeback = get_ddst_offset(ur_w_trips * ur_w);
2947 
2948     if (jcp.ndims == 5) {
2949         L(kd_label);
2950         mov(reg_src, aux_reg_src);
2951         mov(reg_kernel, aux_reg_kernel);
2952     }
2953 
2954     bool use_kh_ic_ow_loop_order = !jcp.uses_permw_transposition;
2955     if (use_kh_ic_ow_loop_order) {
2956         assert(!jcp.uses_permw_transposition);
2957 
2958         auto ic_loop = [=](int ic_block_step) {
2959             Label ow_block_label;
2960             // create a local copy
2961             int ur_w_blocks = ur_w_trips;
2962             auto src_offset = get_src_offset(ic_block_step, 0);
2963             if (l_pad != 0) {
2964                 ur_w_blocks--;
2965                 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
2966                 add(reg_src,
2967                         get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)));
2968                 add(reg_ddst, get_ddst_offset(ur_w));
2969             }
2970 
2971             if (ur_w_blocks > 0) {
2972                 xor_(reg_ur_w_trips, reg_ur_w_trips);
2973                 L(ow_block_label);
2974                 {
2975                     compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
2976                     add(reg_src,
2977                             get_src_offset(0, filter_w_to_src(0, ur_w, 0)));
2978                     add(reg_ddst, get_ddst_offset(ur_w));
2979 
2980                     inc(reg_ur_w_trips);
2981                     cmp(reg_ur_w_trips, ur_w_blocks);
2982                     jl(ow_block_label, T_NEAR);
2983                 }
2984             }
2985 
2986             if (ur_w_tail > 0) {
2987                 compute_ic_block_step(
2988                         ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true);
2989             }
2990 
2991             sub(reg_src, src_comeback);
2992             sub(reg_ddst, ddst_comeback);
2993 
2994             safe_add(reg_src, src_offset, reg_long_offt);
2995             add(reg_kernel, get_kernel_offset(ic_block_step, 0));
2996         };
2997 
2998         mov(kj, reg_kh);
2999         L(kh_label);
3000         {
3001             Label ic_tail_label, skip_ic_tail_offset_compensation;
3002             if (ic_tail) {
3003                 // It appears currently, generate_icb_loop is not enabled here,
3004                 // implying at most one icb is processed.
3005                 assert(jcp.nb_ic_blocking_max == 1);
3006                 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3007             } else {
3008                 mov(reg_icb, ic_block);
3009             }
3010 
3011             L(ic_block_label);
3012             {
3013                 ic_loop(ic_block_step);
3014                 sub(reg_icb, ic_block_step);
3015                 // We relax the boundary for reg_icb, as the src is already
3016                 // converted to vnni_format with appropriate padding either
3017                 // through transpose_src or convert_to_src_to_vnni_format. We
3018                 // can safely allow compute_ic_block_step overstep as it
3019                 // operates on buffer instead of src.
3020                 if (ic_tail) {
3021                     assert(jcp.transpose_src || jcp.uses_permw_transposition);
3022                 }
3023                 cmp(reg_icb, 0);
3024                 jg(ic_block_label, T_NEAR);
3025             }
3026 
3027             if (ic_tail) {
3028                 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3029                 cmp(reg_icb, jcp.simd_w);
3030                 je(skip_ic_tail_offset_compensation);
3031                 add(reg_kernel,
3032                         get_kernel_offset(
3033                                 jcp.ic_block - rnd_up(ic_tail, ic_block_step),
3034                                 0));
3035                 safe_add(reg_src,
3036                         get_src_offset(0, 0, filter_h_to_src(1))
3037                                 - get_src_offset(
3038                                         rnd_up(ic_tail, ic_block_step), 0),
3039                         reg_long_offt);
3040                 L(skip_ic_tail_offset_compensation);
3041             }
3042             if (jcp.is_1stconv && !jcp.transpose_src) {
3043                 // Fixup reg_src to point to the correct location
3044                 auto src_offset = get_src_offset(ic_block_step, 0);
3045                 safe_add(reg_src,
3046                         get_src_offset(0, 0, filter_h_to_src(1))
3047                                 - src_offset * (jcp.ic_block / ic_block_step),
3048                         reg_long_offt);
3049             } else if (jcp.dilate_h > 0) {
3050                 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
3051             }
3052             // substract pointer shift made within ic block loop
3053             // and move to next kh index
3054             add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
3055             dec(kj);
3056             cmp(kj, 0);
3057             jg(kh_label, T_NEAR);
3058         }
3059     } else {
3060         assert(!jcp.is_1stconv);
3061         auto src_icbstep_shift = get_src_offset(1, 0);
3062 
3063         auto ic_loop = [=](int ic_block_step) {
3064             int ic_work = ic_block;
3065             Label ow_block_label, ic_block_label_padl, ic_block_label_general,
3066                     ic_block_label_tail;
3067             int ur_w_blocks = ur_w_trips;
3068             if (l_pad != 0) {
3069                 ur_w_blocks--;
3070                 xor_(b_ic, b_ic);
3071                 if (jcp.uses_permw_transposition) {
3072                     convert_src_to_vnni_format(ur_w, l_pad, 0, 0);
3073                 }
3074                 L(ic_block_label_padl);
3075                 {
3076                     compute_ic_block_step(
3077                             ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
3078                     safe_add(reg_src, src_icbstep_shift * ic_block_step,
3079                             reg_long_offt);
3080                     add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3081 
3082                     add(b_ic, ic_block_step);
3083                     cmp(b_ic, ic_work);
3084                     jl(ic_block_label_padl, T_NEAR);
3085                 }
3086                 safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt);
3087                 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3088                 add(reg_src,
3089                         get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)));
3090                 add(reg_ddst, get_ddst_offset(ur_w));
3091             }
3092 
3093             if (ur_w_blocks > 0) {
3094                 xor_(reg_ur_w_trips, reg_ur_w_trips);
3095                 L(ow_block_label);
3096                 {
3097                     if (jcp.uses_permw_transposition) {
3098                         convert_src_to_vnni_format(ur_w, 0, 0, 0);
3099                     }
3100                     xor_(b_ic, b_ic);
3101                     L(ic_block_label_general);
3102                     {
3103                         compute_ic_block_step(
3104                                 ur_w, 0, 0, ic_block_step, 0, 0, 0);
3105                         safe_add(reg_src, src_icbstep_shift * ic_block_step,
3106                                 reg_long_offt);
3107                         add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3108 
3109                         add(b_ic, ic_block_step);
3110                         cmp(b_ic, ic_work);
3111                         jl(ic_block_label_general, T_NEAR);
3112                     }
3113                     safe_sub(reg_src, src_icbstep_shift * ic_work,
3114                             reg_long_offt);
3115                     sub(reg_kernel, get_kernel_offset(ic_work, 0));
3116                     add(reg_src, get_src_offset(0, filter_w_to_src(0, ur_w)));
3117                     add(reg_ddst, get_ddst_offset(ur_w));
3118 
3119                     inc(reg_ur_w_trips);
3120                     cmp(reg_ur_w_trips, ur_w_blocks);
3121                     jl(ow_block_label, T_NEAR);
3122                 }
3123             }
3124 
3125             if (ur_w_tail > 0) {
3126                 if (jcp.uses_permw_transposition) {
3127                     convert_src_to_vnni_format(ur_w_tail, 0, r_pad, 0);
3128                 }
3129                 xor_(b_ic, b_ic);
3130                 L(ic_block_label_tail);
3131                 {
3132                     compute_ic_block_step(
3133                             ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true);
3134                     safe_add(reg_src, src_icbstep_shift * ic_block_step,
3135                             reg_long_offt);
3136                     add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3137 
3138                     add(b_ic, ic_block_step);
3139                     cmp(b_ic, ic_work);
3140                     jl(ic_block_label_tail, T_NEAR);
3141                 }
3142                 safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt);
3143                 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3144             }
3145 
3146             sub(reg_src, src_comeback);
3147             sub(reg_ddst, ddst_comeback);
3148         };
3149 
3150         mov(kj, reg_kh);
3151         L(kh_label);
3152         {
3153             const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3154             // icb loop is supported for nxc layout only
3155             assert(IMPLICATION(generate_icb_loop,
3156                     is_src_layout_nxc() && is_ddst_layout_nxc()));
3157             Label icb_block_label, icb_block_label_cb, ic_tail_loop_label;
3158 
3159             if (generate_icb_loop) {
3160                 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
3161                 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
3162             }
3163             if (ic_tail || generate_icb_loop)
3164                 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3165             L(icb_block_label);
3166 
3167             ic_loop(ic_block_step);
3168 
3169             if (generate_icb_loop) {
3170                 add(reg_src, get_src_offset(ic_block, 0));
3171                 safe_add(reg_kernel,
3172                         get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw),
3173                         reg_long_offt);
3174                 sub(reg_icb, ic_block);
3175                 cmp(reg_icb, 0);
3176                 jg(icb_block_label, T_NEAR);
3177             }
3178 
3179             if (generate_icb_loop) {
3180                 // restore pointers
3181                 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
3182                 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
3183             }
3184 
3185             add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3186             add(reg_kernel, get_kernel_offset(0, jcp.kw));
3187             dec(kj);
3188             cmp(kj, 0);
3189             jg(kh_label, T_NEAR);
3190         }
3191     }
3192     if (jcp.ndims == 5) {
3193         add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3194         add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3195         dec(ki);
3196         cmp(ki, 0);
3197         jg(kd_label, T_NEAR);
3198     }
3199 }
3200 
compute_oh_step_disp()3201 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_disp() {
3202     int ic_block_step = jcp.ic_block_step;
3203 
3204     bool too_large_to_unroll = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
3205             && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
3206 
3207     int ow = jcp.tr_ow;
3208     if (jcp.ndims == 5) {
3209         /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
3210          * 'movs' must be guaranteed. */
3211         mov(ki, reg_kd_count);
3212         mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
3213         mov(aux_reg_src, reg_src);
3214         mov(aux_reg_kernel, reg_kernel);
3215     }
3216     if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) {
3217         compute_oh_step_unroll_ow_icblock(ic_block_step);
3218     } else if (ow <= max_ur_w) {
3219         compute_oh_step_unroll_ow(ic_block_step);
3220     } else {
3221         compute_oh_step_common(ic_block_step);
3222     }
3223 
3224     // In harness_3d_reduction case calculation of diff_bias is called
3225     // for every ow row separately to be aligned with od loop in
3226     // compute_od_loop_common()
3227     if (jcp.harness == harness_3d_reduction) compute_diff_bias_row();
3228     if (jcp.ndims == 5) {
3229         mov(reg_src, aux_reg_src);
3230         mov(reg_kernel, aux_reg_kernel);
3231         mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
3232         od_step_comeback_pointers();
3233     } else {
3234         oh_step_comeback_pointers();
3235     }
3236 }
3237 
maybe_zero_kernel()3238 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::maybe_zero_kernel() {
3239     if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
3240     Label skip_zeroing, zeroing_loop;
3241 
3242     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3243     cmp(reg_tmp, 0);
3244     jz(skip_zeroing, T_NEAR);
3245 
3246     Zmm zero = Zmm(0);
3247     vpxord(zero, zero, zero);
3248     if (jcp.with_bias) {
3249         Label skip_bias_zeroing;
3250         mov(reg_tmp, ptr[param + GET_OFF(flags)]);
3251         test(reg_tmp, FLAG_IC_FIRST);
3252         jz(skip_bias_zeroing, T_NEAR);
3253 
3254         mov(reg_tmp, ptr[param + GET_OFF(bias)]);
3255         vmovups(ptr[reg_tmp], zero);
3256 
3257         L(skip_bias_zeroing);
3258         if (jcp.harness == harness_compute_full_spatial)
3259             jmp(skip_zeroing, T_NEAR);
3260     }
3261 
3262     const size_t kernel_block_bytes
3263             = get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
3264     Label icb_block_label, icb_block_label_cb;
3265 
3266     const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3267     // icb loop is supported for nxc layout only
3268     assert(IMPLICATION(
3269             generate_icb_loop, is_src_layout_nxc() && is_ddst_layout_nxc()));
3270     if (generate_icb_loop) {
3271         mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
3272         mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3273         L(icb_block_label);
3274     }
3275 
3276     xor_(reg_tmp, reg_tmp);
3277     L(zeroing_loop);
3278     {
3279         assert(get_kernel_offset(1, 0) == cpu_isa_traits<avx512_core>::vlen);
3280         for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3281             vmovups(ptr[reg_kernel + reg_tmp + get_kernel_offset(ic1, 0)],
3282                     zero);
3283         add(reg_tmp, get_kernel_offset(jcp.ic_block, 0));
3284         cmp(reg_tmp, kernel_block_bytes);
3285         jnz(zeroing_loop);
3286     }
3287 
3288     if (generate_icb_loop) {
3289         add(reg_kernel, kernel_block_bytes);
3290         sub(reg_icb, jcp.ic_block);
3291         cmp(reg_icb, 0);
3292         jg(icb_block_label, T_NEAR);
3293         // restore pointer
3294         mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
3295     }
3296 
3297     L(skip_zeroing);
3298 }
3299 
compute_oh_loop_common(bool is_partial)3300 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_oh_loop_common(
3301         bool is_partial) {
3302     int b_pad = jcp.b_pad;
3303     int t_pad = jcp.t_pad;
3304     bool is_dilated = jcp.dilate_h != 0;
3305     int dilate_h = jcp.dilate_h + 1;
3306     int stride_h = jcp.stride_h;
3307     auto filter_step_size = get_kernel_offset(0, jcp.kw);
3308     auto src_step_size = get_src_offset(0, 0, 1);
3309     auto ddst_step_size = get_ddst_offset(0, 1);
3310     Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
3311             oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
3312             oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
3313             oh_dilate_label_end, oh_dilate_setup_label_shift,
3314             oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
3315 
3316     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
3317     int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
3318     int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
3319     int oh_head_overflow_end = div_up(t_pad, stride_h);
3320     int oh_tail_end = jcp.oh;
3321 
3322     int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
3323     int ih_body_end
3324             = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
3325 
3326     if (is_partial)
3327         mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
3328     else
3329         xor_(reg_oj, reg_oj);
3330 
3331     /* Compute 'top' edge */
3332     if (t_pad > 0) {
3333         if (is_partial) {
3334             cmp(reg_oj, oh_head_overflow_end);
3335             jge(oh_tpad_tail_label_end, T_NEAR);
3336         }
3337         const int overflow
3338                 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
3339         const int underflow = div_up(t_pad, dilate_h);
3340         const int initial_kh = jcp.kh - overflow - underflow;
3341 
3342         // Setup reg_kh, reg_kernel, and reg_src
3343         mov(reg_kh, initial_kh);
3344         add(reg_kernel, filter_step_size * underflow);
3345         if (is_dilated) {
3346             const int tail = t_pad % dilate_h;
3347             const int shift = tail == 0 ? 0 : dilate_h - tail;
3348             mov(reg_ih_shift, shift);
3349             if (!is_partial) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3350             add(reg_src, src_step_size * shift);
3351         }
3352 
3353         if (is_partial) {
3354             Label head_setup, head_setup_finish;
3355             cmp(reg_oj, 0);
3356             je(head_setup_finish, T_NEAR);
3357             mov(reg_oj_setup, reg_oj);
3358 
3359             L(head_setup);
3360             if (is_dilated) {
3361                 inc(reg_ih_shift);
3362                 cmp(reg_ih_shift, dilate_h);
3363                 jl(oh_dilate_setup_label_shift, T_NEAR);
3364                 // unshift src as new kernel element enters
3365                 sub(reg_src, src_step_size * (dilate_h - 1));
3366                 xor_(reg_ih_shift, reg_ih_shift);
3367             }
3368             // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3369             add(reg_kh, stride_h);
3370             sub(reg_kernel, filter_step_size * stride_h);
3371             if (is_dilated) {
3372                 jmp(oh_dilate_setup_label_noshift, T_NEAR);
3373                 L(oh_dilate_setup_label_shift);
3374                 // shift src as old kernel element progresses
3375                 add(reg_src, src_step_size * stride_h);
3376                 L(oh_dilate_setup_label_noshift);
3377             }
3378             sub(reg_oj_setup, 1);
3379             jg(head_setup, T_NEAR);
3380             L(head_setup_finish);
3381 
3382             if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3383             if (oh_head_end < oh_head_overflow_end) {
3384                 cmp(reg_oj, oh_head_end);
3385                 jge(oh_tpad_label_end, T_NEAR);
3386             }
3387         }
3388 
3389         //Setup reg_kernel
3390         // If dilated, shift src ptr
3391         // Loop
3392         L(oh_tpad_label);
3393         compute_oh_step_disp();
3394         add(reg_ddst, ddst_step_size);
3395         if (is_dilated) {
3396             mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]);
3397             inc(reg_ih_shift);
3398             mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3399             cmp(reg_ih_shift, dilate_h);
3400             jl(oh_dilate_label_shift, T_NEAR);
3401             // unshift src as new kernel element enters
3402             sub(reg_src, src_step_size * (dilate_h - 1));
3403             xor_(reg_ih_shift, reg_ih_shift);
3404             mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3405         }
3406         // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3407         add(reg_kh, stride_h);
3408         sub(reg_kernel, filter_step_size * stride_h);
3409         if (is_dilated) {
3410             jmp(oh_dilate_label_noshift, T_NEAR);
3411             L(oh_dilate_label_shift);
3412             // shift src as old kernel element progresses
3413             add(reg_src, src_step_size * stride_h);
3414             L(oh_dilate_label_noshift);
3415         }
3416         inc(reg_oj);
3417 
3418         if (is_partial) {
3419             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3420             jge(oh_bpad_label_end, T_NEAR);
3421         }
3422         cmp(reg_oj, oh_head_end);
3423         jl(oh_tpad_label, T_NEAR);
3424 
3425         L(oh_tpad_label_end);
3426         // need second loop to process kernel if it is larger than the src
3427         // (does not apply to dilations as they must have unit stride)
3428         if (oh_head_end < oh_head_overflow_end) {
3429             assert(!is_dilated);
3430 
3431             cmp(reg_oj, oh_head_overflow_end);
3432             jge(oh_tpad_tail_label_end, T_NEAR);
3433 
3434             mov(reg_kh, jcp.ih);
3435             L(oh_tpad_tail_label);
3436             {
3437                 compute_oh_step_disp();
3438                 add(reg_ddst, ddst_step_size);
3439                 sub(reg_kernel, filter_step_size * stride_h);
3440 
3441                 inc(reg_oj);
3442 
3443                 if (is_partial) {
3444                     cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3445                     jge(oh_bpad_label_end, T_NEAR);
3446                 }
3447                 cmp(reg_oj, oh_head_overflow_end);
3448                 jl(oh_tpad_tail_label, T_NEAR);
3449             }
3450         }
3451         if (body_src_start_offset != 0) {
3452             add(reg_kernel, filter_step_size * body_src_start_offset);
3453             add(reg_src, src_step_size * body_src_start_offset);
3454         }
3455         L(oh_tpad_tail_label_end);
3456     }
3457 
3458     if (is_partial) {
3459         cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3460         jge(oh_bpad_label_end, T_NEAR);
3461     }
3462     cmp(reg_oj, oh_body_end);
3463     jge(oh_label_end, T_NEAR);
3464 
3465     /* Compute middle block(s) */
3466     mov(reg_kh, jcp.kh);
3467     L(oh_label);
3468     {
3469         compute_oh_step_disp();
3470         add(reg_src, src_step_size * stride_h);
3471         add(reg_ddst, ddst_step_size);
3472 
3473         inc(reg_oj);
3474 
3475         if (is_partial) {
3476             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3477             jge(oh_bpad_label_end, T_NEAR);
3478         }
3479 
3480         cmp(reg_oj, oh_body_end);
3481         jl(oh_label, T_NEAR);
3482     }
3483     L(oh_label_end);
3484 
3485     /* Compute bottom edge */
3486     if (b_pad > 0) {
3487         if (is_partial) {
3488             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3489             jge(oh_bpad_label_end, T_NEAR);
3490         }
3491         cmp(reg_oj, jcp.oh);
3492         jge(oh_bpad_label_end, T_NEAR);
3493 
3494         if (is_dilated) {
3495             // Assumes unit stride for dilations
3496             mov(reg_kh, jcp.kh - 1);
3497             xor_(reg_ih_shift, reg_ih_shift);
3498         } else {
3499             assert(jcp.dilate_h == 0);
3500             mov(reg_kh, jcp.ih - ih_body_end);
3501         }
3502         if (is_partial) {
3503             lea(reg_oj_setup,
3504                     ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
3505             if (stride_h == 1 && !is_dilated) {
3506                 sub(reg_kh, reg_oj_setup);
3507             } else {
3508                 Label body_setup, body_setup_finish, dilate_skip;
3509                 cmp(reg_oj_setup, 0);
3510                 je(body_setup_finish, T_NEAR);
3511 
3512                 L(body_setup);
3513                 if (is_dilated) {
3514                     inc(reg_ih_shift);
3515                     cmp(reg_ih_shift, dilate_h);
3516                     jl(dilate_skip, T_NEAR);
3517                     xor_(reg_ih_shift, reg_ih_shift);
3518                 }
3519                 sub(reg_kh, stride_h);
3520                 L(dilate_skip);
3521                 sub(reg_oj_setup, 1);
3522                 jg(body_setup, T_NEAR);
3523                 L(body_setup_finish);
3524             }
3525         }
3526 
3527         if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3528         L(oh_bpad_label);
3529         {
3530             compute_oh_step_disp();
3531             add(reg_src, src_step_size * stride_h);
3532             add(reg_ddst, ddst_step_size);
3533 
3534             if (is_dilated) {
3535                 mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]);
3536                 inc(reg_ih_shift);
3537                 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3538                 cmp(reg_ih_shift, dilate_h);
3539                 jl(oh_dilate_label_end, T_NEAR);
3540                 xor_(reg_ih_shift, reg_ih_shift);
3541                 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3542             }
3543             sub(reg_kh, stride_h);
3544             L(oh_dilate_label_end);
3545             inc(reg_oj);
3546             if (is_partial) {
3547                 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3548                 jge(oh_bpad_label_end, T_NEAR);
3549             }
3550             cmp(reg_oj, oh_tail_end);
3551             jl(oh_bpad_label, T_NEAR);
3552         }
3553     }
3554     L(oh_bpad_label_end);
3555 }
3556 
compute_od_loop_common(bool is_partial)3557 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_od_loop_common(
3558         bool is_partial) {
3559     assert(jcp.harness == harness_3d_reduction);
3560 
3561     const int src_backpad_overlap
3562             = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
3563 
3564     const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
3565     const auto src_shift = get_src_offset(0, 0, jcp.ih);
3566     const auto ddst_shift = get_ddst_offset(0, jcp.oh);
3567 
3568     const int kd_front_pad = nstl::max(0, jcp.f_pad);
3569     const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
3570 
3571     Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
3572             backpad_end_label, backpad_label;
3573 
3574     /* initially offset 'kd' by f_pad */
3575     mov(reg_src_d, ptr[param + GET_OFF(src)]);
3576     mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
3577 
3578     if (is_partial) {
3579         add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3580         mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
3581         mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3582     } else {
3583         const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
3584         const int kd_offset = get_kernel_offset(
3585                 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
3586         add(reg_kernel, kd_offset);
3587         xor_(reg_d_index, reg_d_index);
3588         mov(reg_kd_count, kd_padding);
3589     }
3590 
3591     cmp(reg_kd_count, 0);
3592     jle(loop_end_label, T_NEAR); // no iterations along kd
3593     if (is_partial)
3594         cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3595     else
3596         cmp(reg_d_index, jcp.od);
3597     jge(loop_end_label, T_NEAR); // no iterations along depth dimension
3598 
3599     L(d_loop_label);
3600 
3601     mov(reg_src, reg_src_d);
3602     mov(reg_ddst, reg_ddst_d);
3603 
3604     mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
3605     mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
3606     mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
3607 
3608     compute_oh_loop_common();
3609 
3610     mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
3611     mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
3612     mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
3613 
3614     /* Compute 'front' edge */
3615     if (jcp.f_pad > 0) {
3616         /* Check if within fpad region */
3617         cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
3618         jge(fpad_end_label, T_NEAR);
3619 
3620         /* Fpad steps */
3621         sub(reg_kernel, filter_shift * jcp.stride_d);
3622         add(reg_kd_count, jcp.stride_d);
3623 
3624         /* Final number of kernel elements that overlap with src */
3625         const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
3626         cmp(reg_kd_count, src_ker_overlap);
3627         jle(common_block_label, T_NEAR);
3628 
3629         /* Correct any excess shifts to kernel and src */
3630         if (jcp.f_pad <= jcp.od * jcp.stride_d) {
3631             /* Filter has moved beyond padding (adjust for stride effects) */
3632             if (jcp.f_pad % jcp.stride_d != 0) {
3633                 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
3634                 add(reg_kernel, filter_shift * src_corr);
3635                 add(reg_src_d, src_shift * src_corr);
3636             }
3637         } else {
3638             /* Filter still overlaps padding (complete reset) */
3639             sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
3640         }
3641 
3642         /* Apply correction */
3643         mov(reg_kd_count, src_ker_overlap);
3644         jmp(common_block_label);
3645 
3646         L(fpad_end_label);
3647     }
3648 
3649     /* Compute bottom edge */
3650     if (jcp.back_pad > 0) {
3651 
3652         /* Check if within back_pad region */
3653         cmp(reg_d_index, src_backpad_overlap - 1);
3654         jl(backpad_end_label, T_NEAR);
3655         jg(backpad_label, T_NEAR);
3656 
3657         /* Execute overlap correction between the filter and the initial
3658          * back_pad region. */
3659         mov(reg_kd_count,
3660                 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
3661         jmp(backpad_end_label, T_NEAR);
3662 
3663         L(backpad_label);
3664         sub(reg_kd_count, jcp.stride_d);
3665         cmp(reg_kd_count, 0);
3666         jle(loop_end_label, T_NEAR);
3667 
3668         L(backpad_end_label);
3669     }
3670 
3671     /* Compute middle block */
3672     add(reg_src_d, src_shift * jcp.stride_d);
3673 
3674     /* Execute common block and loop */
3675     L(common_block_label);
3676     add(reg_ddst_d, ddst_shift);
3677     inc(reg_d_index);
3678     if (is_partial)
3679         cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3680     else
3681         cmp(reg_d_index, jcp.od);
3682     jl(d_loop_label, T_NEAR);
3683 
3684     L(loop_end_label);
3685 }
3686 
3687 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_full_spat_loop()3688         compute_full_spat_loop() {
3689     // General code layout:
3690     //
3691     // Blocking over OH -- top level
3692     // (Reduces L2 pressure; not very useful right now)
3693     //  Loop over all KHxKW kernel -- emit_kh_kw_loop()
3694     //    Loop over OH block -- emit_h_loop()
3695     //      Loop over OW blocks -- emit_fma_block()
3696     //      (Supports both fully unrolled and partially unrolled
3697     //      versions to reduce code size)
3698     //          Loop over OW block -- emit_fma_step()
3699 
3700     auto src_row_size = get_src_offset(0, 0, 1);
3701     auto ddst_row_size = get_ddst_offset(0, 1);
3702     auto row_size = src_row_size + ddst_row_size;
3703 
3704     int h_block_size = jcp.oh;
3705     int h_last_block_size = h_block_size;
3706     int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3707     auto working_set_size = row_size * h_block_size;
3708 
3709     if (working_set_size > full_spat_max_working_set_size) {
3710         assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3711 
3712         while (working_set_size > full_spat_opt_working_set_size
3713                 && h_block_size >= min_h_block_size) {
3714             for (int i = 2; i <= h_block_size; i++)
3715                 if (i == h_block_size)
3716                     h_block_size = h_block_size / 2;
3717                 else if (h_block_size % i == 0) {
3718                     h_block_size = h_block_size / i;
3719                     break;
3720                 }
3721             working_set_size = row_size * h_block_size;
3722         }
3723         h_block_size = nstl::max(min_h_block_size, h_block_size);
3724         h_last_block_size = jcp.oh % h_block_size;
3725         if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3726     }
3727 
3728     Opmask reg_h_block = k1;
3729     Reg64 reg_kh = rax;
3730     Reg64 reg_kw = rbx;
3731     Reg64 reg_tmp = abi_not_param1;
3732     Reg32 reg_tmp_w = reg_tmp.cvt32();
3733     Reg64 reg_ohs = rdx;
3734     Reg64 reg_ihs = rsi;
3735     Reg64 reg_h = r8;
3736     Reg64 reg_i = r9;
3737     Reg64 reg_j = r10;
3738 
3739     Reg64 reg_src = r13;
3740     Reg64 reg_ddst = r14;
3741     Reg64 reg_ker = r15;
3742 
3743     Reg64 reg_src_save = abi_param1;
3744     Reg64 reg_ddst_save = reg_tmp;
3745 
3746     auto zmm_ddst = [&](int oi) { return Zmm(24 + oi % 8); };
3747     auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
3748     auto src_addr = [&](int oi, int ic1) {
3749         return zword_b[reg_src + get_src_offset(ic1, oi)];
3750     };
3751     auto ddst_addr = [&](int oi) {
3752         auto ow_per_oc = 2;
3753         return ptr[reg_ddst + get_ddst_offset(ow_per_oc * oi)];
3754     };
3755     auto ker_addr
3756             = [&](int ic1) { return ptr[reg_ker + get_kernel_offset(ic1, 0)]; };
3757 
3758     auto emit_block = [&]() {
3759         auto pad_ow = jcp.tr_ow;
3760         int ow_per_oc = 2;
3761         int def_step_size = 16;
3762         bool has_w_tail = pad_ow % def_step_size != 0;
3763         bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
3764 
3765         auto emit_step = [&](int ur_ow, bool is_w_tail) {
3766             int tail_size = pad_ow % ur_ow;
3767             int this_ur_ow = (is_w_tail && tail_size) ? tail_size : ur_ow;
3768             auto numloads = 1;
3769 
3770             assert(this_ur_ow % ow_per_oc == 0);
3771             int steps = this_ur_ow / ow_per_oc;
3772             for (int oi_base = 0; oi_base < steps; oi_base += numloads) {
3773                 for (int oi_offset = 0; oi_offset < numloads; oi_offset++) {
3774                     int oi = oi_base + oi_offset;
3775                     if (oi < steps) {
3776                         vmovups(zmm_ddst(oi), ddst_addr(oi));
3777                     } else {
3778                         auto zmm = zmm_ddst(oi);
3779                         vpxord(zmm, zmm, zmm);
3780                     }
3781                 }
3782 
3783                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3784                     vdpbf16ps(zmm_ker(ic1), zmm_ddst(oi_base),
3785                             src_addr(ow_per_oc * oi_base, ic1));
3786                 }
3787             }
3788         };
3789 
3790         if (full_w_unroll) {
3791             emit_step(pad_ow, true);
3792         } else {
3793             Label w_loop;
3794             int num_w_iters = pad_ow / def_step_size;
3795             mov(reg_i, num_w_iters);
3796             L(w_loop);
3797             {
3798                 emit_step(def_step_size, false);
3799                 add(reg_src, get_src_offset(0, def_step_size));
3800                 add(reg_ddst, get_ddst_offset(def_step_size));
3801                 sub(reg_i, 1);
3802                 jnz(w_loop);
3803             }
3804             if (has_w_tail) { emit_step(def_step_size, true); }
3805             // reset reg_src and reg_ddst because emit_h_loop expects
3806             // unmodified pointers
3807             int w_offset = num_w_iters * def_step_size;
3808             sub(reg_src, get_src_offset(0, w_offset));
3809             sub(reg_ddst, get_ddst_offset(w_offset));
3810         }
3811     };
3812 
3813     auto emit_h_loop = [&]() {
3814         Label h_loop, skip_h_loop;
3815         mov(reg_j, 1);
3816         cmp(reg_j, reg_h);
3817         je(skip_h_loop, T_NEAR);
3818         L(h_loop);
3819         {
3820             emit_block();
3821 
3822             add(reg_src, get_src_offset(0, 0, 1));
3823             add(reg_ddst, get_ddst_offset(0, 1));
3824             add(reg_j, 1);
3825             cmp(reg_j, reg_h);
3826             jb(h_loop);
3827         }
3828         L(skip_h_loop);
3829 
3830         emit_block();
3831     };
3832 
3833     auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3834         xor_(reg_kh, reg_kh);
3835         Label kh_loop, kh_loop_end;
3836 
3837         int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3838         // NB: this is correct because we only support t_pad = kh / 2 and thus
3839         // ih == oh
3840         int ih_block_size = oh_block_size
3841                 + (!is_first_block + !is_last_block) * jcp.t_pad;
3842 
3843         L(kh_loop);
3844         {
3845             if (is_first_block) {
3846                 xor_(reg_tmp, reg_tmp);
3847                 mov(reg_ohs, jcp.t_pad);
3848                 sub(reg_ohs, reg_kh);
3849                 cmovb(reg_ohs, reg_tmp);
3850 
3851                 mov(reg_ihs, reg_ohs);
3852                 sub(reg_ihs, jcp.t_pad);
3853                 add(reg_ihs, reg_kh);
3854             } else {
3855                 xor_(reg_ohs, reg_ohs);
3856                 mov(reg_ihs, reg_kh);
3857             }
3858 
3859             mov(reg_tmp, oh_block_size);
3860             sub(reg_tmp, reg_ohs);
3861             mov(reg_h, ih_block_size);
3862             sub(reg_h, reg_ihs);
3863             cmp(reg_tmp, reg_h);
3864             cmovb(reg_h, reg_tmp);
3865 
3866             Label kh_loop_work;
3867             cmp(reg_h, 0);
3868             jg(kh_loop_work, T_NEAR);
3869 
3870             // empty h loop for this jcp.kh:
3871             // - set the ddst to 0 if necessary
3872             // - move ker pt
3873             // - jump to the end
3874             sub(reg_h, 1);
3875             Label skip_ker_zeroing;
3876 
3877             // The reg_ker ptr has highest bit set if the ddst needs to be
3878             // zeroed. Those who have byte-aligned their data will suffer the
3879             // consequences :(
3880             // TODO: move the flag to a mask register? (Roma)
3881             test(reg_ker, 1);
3882             jz(skip_ker_zeroing, T_NEAR);
3883 
3884             Label zeroing_loop;
3885             vpxord(zmm0, zmm0, zmm0);
3886             and_(reg_ker, ~1); // temporarily clear the zeroing flag
3887             mov(reg_tmp, jcp.kw);
3888             L(zeroing_loop);
3889             {
3890                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3891                     vmovups(ker_addr(ic1), zmm0);
3892                 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
3893                 sub(reg_tmp, 1);
3894                 jnz(zeroing_loop, T_NEAR);
3895             }
3896             // restore the zeroing flag (it will be cleared after the end of
3897             // emit_kh_kw_loop, but we may need it until then)
3898             or_(reg_ker, 1);
3899             jmp(kh_loop_end, T_NEAR);
3900 
3901             L(skip_ker_zeroing);
3902             add(reg_ker, get_kernel_offset(0, jcp.kw));
3903             jmp(kh_loop_end, T_NEAR);
3904 
3905             L(kh_loop_work);
3906 
3907             mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
3908             mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
3909 
3910             add(reg_src, reg_ihs);
3911             add(reg_ddst, reg_ohs);
3912 
3913             Label kw_loop;
3914             xor_(reg_kw, reg_kw);
3915             L(kw_loop);
3916             {
3917                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3918                     auto zmm = zmm_ker(ic1);
3919                     vpxord(zmm, zmm, zmm);
3920                 }
3921 
3922                 mov(reg_ddst_save, reg_ddst);
3923                 mov(reg_src_save, reg_src);
3924                 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
3925 
3926                 emit_h_loop();
3927 
3928                 mov(reg_ddst, reg_ddst_save);
3929                 mov(reg_src, reg_src_save);
3930 
3931                 Label do_store;
3932                 // The reg_ker ptr has highest bit set if the ddst needs to
3933                 // be zeroed. Those who have byte-aligned their data will
3934                 // suffer the consiquences :(
3935                 mov(reg_tmp, reg_ker);
3936                 and_(reg_ker, ~1);
3937                 test(reg_tmp, 1);
3938                 jnz(do_store, T_NEAR);
3939 
3940                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3941                     auto zmm = zmm_ker(ic1);
3942                     vaddps(zmm, ker_addr(ic1));
3943                 }
3944 
3945                 L(do_store);
3946                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3947                     auto zmm = zmm_ker(ic1);
3948                     vmovups(ker_addr(ic1), zmm);
3949                 }
3950 
3951                 mov(reg_ker, reg_tmp);
3952                 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
3953                 add(reg_kw, 1);
3954                 cmp(reg_kw, jcp.kw);
3955                 jl(kw_loop);
3956             }
3957 
3958             sub(reg_src, reg_ihs);
3959             sub(reg_ddst, reg_ohs);
3960 
3961             L(kh_loop_end);
3962             add(reg_kh, 1);
3963             cmp(reg_kh, jcp.kh);
3964             jl(kh_loop);
3965         }
3966     };
3967 
3968     mov(reg_src, ptr[param + GET_OFF(src)]);
3969     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
3970     mov(reg_ker, ptr[param + GET_OFF(filt)]);
3971     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3972     or_(reg_ker, reg_tmp);
3973 
3974     bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
3975 
3976     auto src_row_step = get_src_offset(0, 0, 1);
3977     auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
3978     auto ddst_block_step = get_ddst_offset(0, h_block_size);
3979 
3980     emit_kh_kw_loop(true, single_kh_kw_loop);
3981 
3982     if (!single_kh_kw_loop) {
3983         auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
3984         sub(reg_ker, ker_reset_offset);
3985         and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
3986 
3987         add(reg_src, first_src_block_step);
3988         add(reg_ddst, ddst_block_step);
3989 
3990         int num_innermost_iters
3991                 = (jcp.oh - h_last_block_size) / h_block_size - 1;
3992         if (num_innermost_iters > 0) {
3993             Label h_block_loop;
3994 
3995             mov(reg_tmp_w, num_innermost_iters);
3996             kmovw(reg_h_block, reg_tmp_w);
3997             L(h_block_loop);
3998             {
3999                 emit_kh_kw_loop(false, false);
4000                 sub(reg_ker, ker_reset_offset);
4001                 add(reg_src, src_row_step * h_block_size);
4002                 add(reg_ddst, ddst_block_step);
4003 
4004                 kmovw(reg_tmp_w, reg_h_block);
4005                 sub(reg_tmp_w, 1);
4006                 kmovw(reg_h_block, reg_tmp_w);
4007                 jnz(h_block_loop);
4008             }
4009         }
4010 
4011         emit_kh_kw_loop(false, true);
4012     }
4013 }
4014 
compute_loop()4015 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_loop() {
4016     Reg64 reg_mask_load = r11;
4017     if (jcp.uses_permw_transposition) {
4018 
4019         mov(reg_mask_load.cvt32(), 0xffffffff);
4020         kmovd(m_ffffffff, reg_mask_load.cvt32());
4021 
4022         mov(reg_mask_load.cvt32(), 0x0000ffff);
4023         kmovd(m_0000ffff, reg_mask_load.cvt32());
4024 
4025         mov(reg_mask_load.cvt32(), 0xffff0000);
4026         kmovd(m_ffff0000, reg_mask_load.cvt32());
4027         const int oc_tail = jcp.oc_tail;
4028         if (oc_tail) {
4029             mov(reg_mask_load.cvt32(), (1 << oc_tail) - 1);
4030             kmovd(m_0000_oc_tail, reg_mask_load.cvt32());
4031             kshiftld(m_oc_tail_0000, m_0000_oc_tail, 16);
4032         }
4033         const int ic_tail = jcp.ic_tail;
4034         if (ic_tail) {
4035             mov(reg_mask_load.cvt32(), (1 << ic_tail) - 1);
4036             kmovd(m_0000_ic_tail, reg_mask_load.cvt32());
4037             kshiftld(m_ic_tail_0000, m_0000_ic_tail, 16);
4038         }
4039     } else if (jcp.is_1stconv && !jcp.transpose_src) {
4040         if (jcp.stride_w == 1) {
4041             int ieveryother_mask = 0x55555555;
4042             mov(reg_mask_load.cvt32(), ieveryother_mask);
4043             kmovd(everyother_mask, reg_mask_load.cvt32());
4044             kshiftld(everyother_shift_mask, everyother_mask, 1);
4045         } else {
4046             mov(reg_mask_load.cvt32(), 0xffffffff);
4047             kmovd(m_ffffffff, reg_mask_load.cvt32());
4048         }
4049     }
4050 
4051     mov(reg_src, ptr[param + GET_OFF(src)]);
4052     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4053     mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4054 
4055     maybe_zero_kernel();
4056     maybe_compute_diff_bias();
4057 
4058     switch (jcp.harness) {
4059         case harness_3d_reduction: compute_od_loop_common(true); break;
4060         case harness_2d_reduction: compute_oh_loop_common(true); break;
4061         case harness_mb_reduction: compute_oh_loop_common(); break;
4062         case harness_compute_full_spatial: compute_full_spat_loop(); break;
4063         default: assert(!"Invalid harness type");
4064     }
4065 }
4066 
setup_stack_space()4067 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::setup_stack_space() {
4068 
4069     if ((jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1)
4070             || jcp.uses_permw_transposition) {
4071         int ur_w, ur_w_tail, ur_w_trips;
4072         get_ur_w(ur_w, ur_w_tail, ur_w_trips);
4073         ur_w = nstl::max(ur_w, ur_w_tail);
4074         ic_block_step_stack_size = jcp.uses_permw_transposition
4075                 ? permw_stack_size(ur_w)
4076                 : interleave_stack_size(ur_w, jcp.ic_block_step);
4077     } else
4078         ic_block_step_stack_size = extern_ic_block_step_stack_size;
4079 
4080     permw_buffer_start = 0;
4081     kd_count_offset = ic_block_step_stack_size;
4082     src_d_offset = ic_block_step_stack_size + 8;
4083     ddst_d_offset = ic_block_step_stack_size + 16;
4084     d_index_offset = ic_block_step_stack_size + 24;
4085     trans_tmp_offset = ic_block_step_stack_size + 32;
4086     ih_dilate_shift = ic_block_step_stack_size + 40;
4087     icb_loop_ker_ptr = ic_block_step_stack_size + 48;
4088     icb_loop_src_ptr = ic_block_step_stack_size + 56;
4089     stack_space_needed = ic_block_step_stack_size + 64;
4090 }
4091 
generate()4092 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::generate() {
4093     preamble();
4094 
4095     setup_stack_space();
4096 
4097     sub(rsp, stack_space_needed);
4098 
4099     compute_loop();
4100 
4101     add(rsp, stack_space_needed);
4102 
4103     postamble();
4104 
4105     if (jcp.uses_permw_transposition) {
4106         align(64);
4107         L(dst_prm_table);
4108         const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20,
4109                 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
4110                 29, 14, 30, 15, 31};
4111 
4112         for (size_t i = 0; i < 32; ++i)
4113             dw(dst_prm_array[i]);
4114     }
4115 }
4116 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_bias_md,memory_desc_t & diff_dst_md,int nthreads)4117 status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
4118         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4119         memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4120         memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4121     const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
4122 
4123     const memory_desc_wrapper src_d(&src_md);
4124     const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4125     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4126     const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4127 
4128     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4129     int ndims = src_d.ndims();
4130 
4131     jcp = zero<decltype(jcp)>();
4132     jcp.nthr = nthreads;
4133     jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
4134                                         : bf16_emulation_t::get_isa();
4135     jcp.ver = ver_vnni;
4136     jcp.ndims = ndims;
4137     jcp.prop_kind = cd.prop_kind;
4138 
4139     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4140     jcp.mb = src_d.dims()[0];
4141 
4142     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4143     jcp.oc_without_padding = jcp.oc;
4144     jcp.ic = src_d.dims()[1] / jcp.ngroups;
4145 
4146     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4147     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4148     jcp.iw = src_d.dims()[ndims - 1];
4149     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4150     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4151     jcp.ow = diff_dst_d.dims()[ndims - 1];
4152 
4153     jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4154     jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4155     jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4156 
4157     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4158     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4159     jcp.l_pad = cd.padding[0][ndims - 3];
4160 
4161     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4162     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4163     jcp.stride_w = cd.strides[ndims - 3];
4164 
4165     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4166     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4167     jcp.dilate_w = cd.dilates[ndims - 3];
4168 
4169     int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4170     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4171     int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4172 
4173     bool ok = true
4174             // general condition to simplify dilations
4175             && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4176             && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4177             && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4178             // special condition to simplify dilations in compute_oh_loop_common
4179             && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4180     if (!ok) return status::unimplemented;
4181 
4182     jcp.r_pad = nstl::max(0,
4183             calculate_end_padding(
4184                     jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4185     jcp.b_pad = nstl::max(0,
4186             calculate_end_padding(
4187                     jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4188     jcp.back_pad = nstl::max(0,
4189             calculate_end_padding(
4190                     jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4191 
4192     /* XXX: no support for padding when dilation_d > 0 */
4193     if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4194         return status::unimplemented;
4195 
4196     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4197     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4198     jcp.ohp = jcp.oh;
4199     jcp.owp = jcp.ow;
4200     jcp.aligned_threads = 0;
4201 
4202     jcp.simd_w = simd_w;
4203     jcp.oc_block = simd_w;
4204     const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
4205     const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
4206     const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4207     auto curr_src_tag = src_d.matches_one_of_tag(
4208             dat_tag_nxc, dat_tag_nCx16c, dat_tag_ncx);
4209     auto curr_dst_tag
4210             = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
4211     bool is_data_layout_nxc
4212             = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
4213 
4214     jcp.is_1stconv = is_1stconv(jcp);
4215 
4216     bool ok_to_pad_channels
4217             = (jcp.ngroups == 1) && !jcp.is_1stconv && !is_data_layout_nxc;
4218 
4219     if (ok_to_pad_channels) {
4220         jcp.oc = rnd_up(jcp.oc, simd_w);
4221         jcp.ic = rnd_up(jcp.ic, simd_w);
4222     }
4223 
4224     auto src_tag = is_data_layout_nxc
4225             ? dat_tag_nxc
4226             : (jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx16c);
4227     auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
4228     auto wei_tag = jcp.is_1stconv
4229             ? pick(2 * ndims - 6 + with_groups, Owi16o, gOwi16o, Ohwi16o,
4230                     gOhwi16o, Odhwi16o, gOdhwi16o)
4231             : pick(2 * ndims - 6 + with_groups, OIw16i16o, gOIw16i16o,
4232                     OIhw16i16o, gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o);
4233 
4234     if (src_md.format_kind == format_kind::any) {
4235         CHECK(memory_desc_init_by_tag(src_md, src_tag));
4236     } else if (curr_src_tag != src_tag)
4237         return status::unimplemented;
4238     jcp.src_tag = src_tag;
4239 
4240     if (diff_dst_md.format_kind == format_kind::any) {
4241         CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
4242     } else if (curr_dst_tag != dst_tag)
4243         return status::unimplemented;
4244     jcp.dst_tag = dst_tag;
4245 
4246     if (diff_weights_md.format_kind == format_kind::any) {
4247         CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4248         jcp.wei_tag = wei_tag;
4249     } else {
4250         jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4251         if (jcp.wei_tag != wei_tag) return status::unimplemented;
4252     }
4253 
4254     /* conditions on bias memory */
4255     jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
4256     if (jcp.with_bias) {
4257         if (diff_bias_d.format_kind() == format_kind::any)
4258             CHECK(memory_desc_init_by_tag(diff_bias_md, x));
4259     }
4260     jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
4261     jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
4262 
4263     jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
4264 
4265     /* kernel applicability check wrt boundaries
4266      * the conditions are quite general across the kernels we have,
4267      * but ideally the check should belong to a specific kernel... */
4268     const int max_pad_h = ext_kh / 2;
4269     const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
4270             && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
4271             && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd
4272             && IMPLICATION(jcp.is_1stconv && jcp.ow > max_ur_w,
4273                     jcp.l_pad < max_ur_w && ext_kw <= jcp.ow);
4274     if (!boundaries_ok) return status::unimplemented;
4275 
4276     const int max_kw = jcp.is_1stconv ? 24 : 14;
4277     /* yet another common check */
4278     if (jcp.kw > max_kw) return status::unimplemented;
4279 
4280     jcp.wei_dt = diff_weights_d.data_type();
4281 
4282     jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
4283     if (ok_to_pad_channels) jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4284     jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
4285     ok = true && one_of(ndims, 3, 4, 5)
4286             && everyone_is(
4287                     data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4288             && one_of(diff_weights_d.data_type(), data_type::f32,
4289                     data_type::bf16);
4290     if (!ok) return status::unimplemented;
4291 
4292     jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.ic_block : 0;
4293     jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.oc_block : 0;
4294 
4295     if (jcp.is_1stconv) {
4296         jcp.ic_block_step = 24 / jcp.kw;
4297         while (jcp.ic_block % jcp.ic_block_step != 0)
4298             jcp.ic_block_step--;
4299     } else {
4300         jcp.ic_block_step
4301                 = jcp.kw <= 3 ? 8 : (jcp.kw < 7 ? 4 : (jcp.kw <= 12 ? 2 : 1));
4302     }
4303 
4304     // jcp.uses_permw_transposition = false shows better performance for
4305     // resnet50 v1.5 problems
4306     // jcp.uses_permw_transposition = true works better for 3d 1x1x1 problems
4307     const bool is_permw_applicable
4308             = !jcp.is_1stconv && jcp.stride_w == 1 && jcp.dilate_w == 0;
4309     const bool apply_permw_blocked = !is_data_layout_nxc && ndims == 5
4310             && jcp.kw == 1 && jcp.ic_block_step > 4;
4311     // Threshold is based on performance measurements
4312     const bool apply_permw_nxc = is_data_layout_nxc && ndims == 3
4313             && nstl::max(jcp.ic, jcp.oc) <= 32;
4314     jcp.uses_permw_transposition
4315             = is_permw_applicable && (apply_permw_blocked || apply_permw_nxc);
4316 
4317     jcp.kernel_kind = embd_bcast;
4318     if (jcp.uses_permw_transposition && jcp.kw <= 3)
4319         jcp.kernel_kind = expl_bcast;
4320     if (jcp.uses_permw_transposition && jcp.kernel_kind == expl_bcast)
4321         jcp.ic_block_step = jcp.kw <= 3 ? 4 : (jcp.kw < 7 ? 2 : 1);
4322 
4323     if (jcp.uses_permw_transposition) {
4324         jcp.transpose_src = false;
4325         jcp.transpose_dst = false;
4326     } else if (jcp.is_1stconv && IMPLICATION(is_data_layout_nxc, jcp.ic == 1)) {
4327         jcp.transpose_src = false;
4328         jcp.transpose_dst = true;
4329     } else {
4330         jcp.transpose_src = true;
4331         jcp.transpose_dst = true;
4332     }
4333 
4334     const bool is_2d = (ndims == 4);
4335     const bool is_3d = (ndims == 5);
4336     jcp.typesize_in = sizeof(bfloat16_t);
4337     jcp.typesize_out = sizeof(float);
4338     const dim_t cache_l2
4339             = platform::get_per_core_cache_size(2) / jcp.typesize_out;
4340 
4341     // Observation: Given large 3D shapes with large filter size, 1st nspc
4342     // bwd_w convolution benefits from non-temporal stores in diff_dst
4343     // transformation but not so much from blocking w.r.t. depth dimension
4344     // In particular, it's optimized for i3D 1st convolution
4345     const bool nt_stores_ok = is_data_layout_nxc
4346             && dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= 2 * cache_l2
4347             && jcp.kd >= 6 && jcp.kh >= 6 && jcp.kw >= 6;
4348 
4349     // Performancewise transposition of diff_dst tensor is one of the major
4350     // bottleneck in 1st convolution. Thus for large diff_dst size we can
4351     // potentially further split up transposition in smaller chunks to achieve
4352     // better cache reuse
4353     const bool large_diff_dst_size
4354             = dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= cache_l2;
4355 
4356     // For two dimensional diff_dst tensor blocking along height demands
4357     // non-trivial work along width dimension. Similarly, for three dimensional
4358     // diff_dst tensor enough work must be present in the joint width-height
4359     // dimension. Finally, there is no blocking along the width dimension
4360     const bool blocking_ok = large_diff_dst_size
4361             && IMPLICATION(is_2d, jcp.ow >= 124 && jcp.oh > 1)
4362             && IMPLICATION(is_3d, jcp.ow * jcp.oh >= 64 * 124 && jcp.od > 1)
4363             && (is_2d || is_3d);
4364 
4365     // TODO: Find more shapes (especially 3D with large spatials) for which
4366     // local transposition will be beneficial. Furthermore, for TBB threads
4367     // more shapes can potentially benefit from spatial blocking
4368     bool use_spatial_blocking = jcp.is_1stconv && !nt_stores_ok && blocking_ok;
4369     int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
4370     if (use_spatial_blocking) {
4371         // Default value, works best most of the times
4372         // TODO: For 3D shapes with intermediate sizes especially the ones not
4373         // belonging to the 1st convolution, we potentially have more scope
4374         // for optimization
4375         optimal_blk_size = 1;
4376 
4377         // Diff_weights computation can be roughly broken down into
4378         // the following three steps
4379         // = [Src transform*] + [Diff_dst transform] + [Weights computation]
4380         //
4381         // where the bottleneck lies with diff_dst transform that spatial
4382         // blocking tries to mitigate by avoiding cache thrashing.
4383         // *note: Src transform may not always be needed.
4384         //
4385         // In an idealistic scenario, optimal_blk_size will be an explicit
4386         // function of the following form
4387         // optimal_blk_size = f(od, oh, ow, oc)
4388         //
4389         // though owing to lack of data points w.r.t. 1st convolution shapes it
4390         // is approximated by one with few exceptional cases [found by manual
4391         // optimization] as written below
4392 
4393         if (is_2d && utils::one_of(jcp.oh, 149, 300, 224, 512, 608)) {
4394             switch (jcp.oh) {
4395                 case 149: optimal_blk_size = 10; break;
4396                 case 224: optimal_blk_size = 56; break;
4397                 case 300: optimal_blk_size = 30; break;
4398                 case 512: optimal_blk_size = 8; break;
4399                 case 608: optimal_blk_size = 10; break;
4400             }
4401         }
4402     }
4403 
4404     jcp.global_transpose = dnnl_thr_syncable() && !use_spatial_blocking;
4405     jcp.use_nt_stores_ddst = jcp.global_transpose && nt_stores_ok;
4406     jcp.spatial_blk_size = optimal_blk_size;
4407 
4408     const bool padding_ok = IMPLICATION(!jcp.transpose_src,
4409             jcp.l_pad < max_ur_w && jcp.r_pad < max_ur_w
4410                     && ext_kw <= jcp.iw + 1);
4411     if (!padding_ok) return status::unimplemented;
4412 
4413     const int tr_round = 2;
4414     // Logic for tr_pad calculation: transpose is used in the extern kernel.
4415     // There is a memory usage optimization where physical padding is shared
4416     // between transpose buffers. In calculating on a row, data is read from the
4417     // src 2 elements at a time due to the bf16 broadcast. Calculation starts
4418     // at the beginning of the left padding and ends at the end of the right
4419     // padding. Because elements are read two at a time, we may need r_pad + 1
4420     // padding on the right. As such, the shared padding is the max of l_pad and
4421     // r_pad + 1, rounded as necessary for the transpose data format.
4422     int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
4423     jcp.tr_iw = jcp.transpose_src
4424             ? rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
4425                     * jcp.stride_w
4426             : jcp.iw;
4427 
4428     jcp.tr_src_num_guard_elems = tr_pad; // upper bound
4429     jcp.tr_ow = jcp.transpose_dst ? rnd_up(jcp.ow, 2) : jcp.ow;
4430 
4431     bool args_ok = true
4432             && IMPLICATION(!is_data_layout_nxc,
4433                     jcp.ic % jcp.ic_block == 0 && jcp.oc % jcp.oc_block == 0)
4434             && jcp.ic <= src_d.padded_dims()[1]
4435             && jcp.oc <= diff_dst_d.padded_dims()[1]
4436             && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
4437             && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
4438     if (!args_ok) return status::unimplemented;
4439 
4440     int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
4441     int out_row_size = jcp.oc_block * jcp.tr_ow * jcp.typesize_in;
4442     int full_spat_min_h_block_size
4443             = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
4444     int full_spat_working_set_size
4445             = (inp_row_size + out_row_size) * full_spat_min_h_block_size;
4446     bool use_full_spat_loop = isa_has_bf16(jcp.isa) && jcp.ndims < 5
4447             && jcp.ih == jcp.oh && jcp.iw == jcp.ow
4448             && !one_of(1, jcp.kh, jcp.kw)
4449             && everyone_is(1, jcp.stride_h, jcp.stride_w)
4450             && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
4451             && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
4452             && !jcp.uses_permw_transposition && !jcp.is_1stconv
4453             && full_spat_working_set_size <= full_spat_opt_working_set_size
4454             && jcp.ic >= 128;
4455 
4456     jcp.harness = ndims == 5
4457             ? harness_3d_reduction
4458             : (use_full_spat_loop ? harness_compute_full_spatial
4459                                   : (ndims == 4) ? harness_2d_reduction
4460                                                  : harness_mb_reduction);
4461 
4462     switch (jcp.harness) {
4463         case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
4464         case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
4465         case harness_compute_full_spatial:
4466         case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
4467         default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
4468     }
4469     { // balancing
4470         int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4471         balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
4472         jcp.nthr = nthr;
4473         jcp.nthr_mb = nthr_mb;
4474         jcp.nthr_g = nthr_g;
4475         jcp.nthr_oc_b = nthr_oc_b;
4476         jcp.nthr_ic_b = nthr_ic_b;
4477 
4478         // TODO: Optimize memory allocation when threaded on height and depth
4479         if (jcp.transpose_src) {
4480             jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
4481             jcp.tr_src_buf_count = jcp.global_transpose
4482                     ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
4483                     : jcp.nthr;
4484         }
4485         if (jcp.transpose_dst) {
4486             jcp.tr_diff_dst_buf_size
4487                     = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
4488             jcp.tr_diff_dst_buf_count = jcp.global_transpose
4489                     ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
4490                     : jcp.nthr;
4491         }
4492     }
4493 
4494     jcp.nb_ic_blocking_max = 1;
4495     if (is_data_layout_nxc && jcp.uses_permw_transposition
4496             && (jcp.ow > max_ur_w || jcp.ndims == 5))
4497         jcp.nb_ic_blocking_max = nstl::min(8, div_up(jcp.nb_ic, jcp.nthr_ic_b));
4498     return status::success;
4499 }
4500 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp)4501 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
4502         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4503 
4504     if (!jcp.uses_permw_transposition) {
4505         // XXX: See the comment about tr_iw and guarding elements in
4506         // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
4507         const size_t tr_src_size = jcp.tr_src_buf_count * jcp.tr_src_buf_size
4508                 + jcp.tr_src_num_guard_elems;
4509         scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
4510 
4511         /* prepare synchronization contexts */
4512         if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
4513             const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
4514             scratchpad.book<simple_barrier::ctx_t>(
4515                     key_conv_tr_src_bctx, tr_src_bctx_size);
4516         }
4517 
4518         const size_t tr_diff_dst_size
4519                 = jcp.tr_diff_dst_buf_count * jcp.tr_diff_dst_buf_size;
4520 
4521         const size_t min_align = jcp.use_nt_stores_ddst ? 64 : jcp.typesize_in;
4522         scratchpad.book(key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in,
4523                 min_align);
4524 
4525         /* prepare synchronization contexts */
4526         if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
4527             const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
4528             scratchpad.book<simple_barrier::ctx_t>(
4529                     key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
4530         }
4531     }
4532 
4533     if (IMPLICATION(jcp.nthr_mb == 1,
4534                 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
4535                         || jcp.wei_dt == data_type::bf16)) {
4536         const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
4537                 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
4538         const size_t bia_size
4539                 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
4540 
4541         const int num_wei_buffers
4542                 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
4543         const int num_bia_buffers = jcp.with_bias
4544                 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
4545                                                  : jcp.nthr_mb - 1)
4546                 : 0;
4547 
4548         const size_t wei_bia_reduction_size
4549                 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
4550 
4551         scratchpad.book<float>(
4552                 key_conv_wei_bia_reduction, wei_bia_reduction_size);
4553 
4554         if (jcp.global_transpose)
4555             scratchpad.book<simple_barrier::ctx_t>(
4556                     key_conv_wei_bia_reduction_bctx, 1);
4557     }
4558 
4559     if (jcp.with_bias) {
4560         if ((jcp.oc_without_padding % jcp.oc_block != 0)
4561                 && jcp.bia_dt == data_type::f32)
4562             scratchpad.book(key_conv_padded_bias,
4563                     jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
4564     }
4565 }
4566 
balance(const jit_conv_conf_t & j,int & nthr_,int & nthr_mb_,int & nthr_g_,int & nthr_oc_b_,int & nthr_ic_b_)4567 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::balance(
4568         const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4569         int &nthr_oc_b_, int &nthr_ic_b_) {
4570     nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4571 
4572     const int max_threads = dnnl_get_max_threads();
4573 
4574     if (max_threads < j.ngroups) {
4575         /* simplification... fortunately it doesn't hurt much */
4576         nthr_ = nthr_g_ = max_threads;
4577         return;
4578     }
4579 
4580     nthr_g_ = j.ngroups;
4581     const int nthr = max_threads / nthr_g_;
4582 
4583     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4584         /* calculate per thread memory cost (read/write). high level optimizer
4585          * tries to minimize memory consumption. few notes:
4586          *  (n1) if weights tensor size is less than source and destination
4587          *       tensors we apply the ratio of the source and destination
4588          *       tensor sizes to weights one as compensation coefficient to
4589          *       avoid parallelization across batch size only, othervise we
4590          *       apply additional coefficient to source component based on
4591          *       performance measurements
4592          *  (n2) use scales based on output vs input channels ratio for source
4593          *       and destination componets to imporve threading balance across
4594          *       input and output channels */
4595 
4596         const dim_t src_type_size = 2;
4597         const dim_t wei_type_size = 4;
4598 
4599         dim_t src_size
4600                 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
4601         dim_t dst_size
4602                 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
4603         dim_t wei_size
4604                 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
4605 
4606         float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
4607         float oi_channels_ratio = (float)j.nb_oc / j.nb_ic;
4608         auto get_src_coef = [=]() {
4609             float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
4610             if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
4611 
4612             return src_coef;
4613         };
4614 
4615         auto get_dst_coef
4616                 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
4617 
4618         auto get_wei_coef
4619                 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
4620 
4621         const float src_coef = get_src_coef();
4622         const float dst_coef = get_dst_coef();
4623         const float wei_coef = get_wei_coef();
4624 
4625         float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
4626                 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_ic, nthr_ic_b) * j.mb
4627                 * j.ic_block * j.id * j.ih * j.tr_iw / j.nthr_mb_work
4628                 / j.stride_d / j.stride_h / j.stride_w;
4629         float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
4630                 * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) * j.kh
4631                 * j.kw * j.kd * j.ic_block * j.oc_block;
4632         float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
4633                 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b) * j.mb
4634                 * j.oc_block * j.od * j.oh * j.tr_ow / j.nthr_mb_work;
4635 
4636         return src_v + dst_v + wei_v;
4637     };
4638 
4639     float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4640 
4641     /* find the best thread distribution with lowest memory cost */
4642     const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
4643     for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4644         const int nthr_par = nthr / nthr_mb;
4645         const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4646         for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4647             int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4648 
4649             float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4650             if (mem_cost <= best_mem_cost) {
4651                 best_mem_cost = mem_cost;
4652                 nthr_mb_ = nthr_mb;
4653                 nthr_oc_b_ = nthr_oc_b;
4654                 nthr_ic_b_ = nthr_ic_b;
4655             }
4656         }
4657     }
4658 
4659     if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
4660         nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
4661     nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
4662 
4663     assert(nthr_ <= max_threads);
4664 }
4665 
4666 template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Zmm>;
4667 template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Ymm>;
4668 template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Xmm>;
4669 template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>;
4670 template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>;
4671 template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>;
4672 } // namespace x64
4673 } // namespace cpu
4674 } // namespace impl
4675 } // namespace dnnl
4676 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
4677