1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/bfloat16.hpp"
18 #include "common/c_types_map.hpp"
19 #include "common/math_utils.hpp"
20 #include "common/nstl.hpp"
21 #include "common/type_helpers.hpp"
22 #include "common/utils.hpp"
23
24 #include "cpu/platform.hpp"
25 #include "cpu/x64/cpu_barrier.hpp"
26
27 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28 #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
29 #include "cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp"
30
31 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
32
33 namespace dnnl {
34 namespace impl {
35 namespace cpu {
36 namespace x64 {
37
38 using namespace format_tag;
39 using namespace dnnl::impl::memory_tracking::names;
40 using namespace dnnl::impl::utils;
41 using namespace Xbyak;
42
43 namespace {
44
45 constexpr auto small_spatial = 14;
46
pick_loop_order(jit_conv_conf_t & jcp)47 inline void pick_loop_order(jit_conv_conf_t &jcp) {
48 using namespace prop_kind;
49 assert(one_of(
50 jcp.prop_kind, forward_training, forward_inference, backward_data));
51 auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
52 auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
53
54 if (utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
55 format_tag::nwc)
56 && jcp.ngroups > 1 && jcp.oc < 16) {
57 jcp.loop_order = loop_nhwcg;
58 } else if (jcp.prop_kind == backward_data) {
59 // ow-threading is currently implemented for forward only
60 // TODO: single code for fwd and bwd after ow-thr for bwd
61 // meaningless switch was removed
62 if (jcp.ndims < 5)
63 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
64 ? loop_cwgn
65 : loop_gncw;
66 else
67 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
68 ? loop_cgn
69 : loop_gnc;
70 } else {
71 jcp.loop_order = (w <= small_spatial && h <= small_spatial) ? loop_cwgn
72 : loop_gncw;
73 }
74 }
is_ow_threading_available(const jit_conv_conf_t & jcp)75 inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
76 /*is 1D conv */
77 return (jcp.id == 1 && jcp.ih == 1 && jcp.kd == 1 && jcp.kh == 1);
78 }
is_ow_threading_on(const jit_conv_conf_t & jcp)79 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
80 return (jcp.nb_ow > 1);
81 }
is_iw_threading_available(const jit_conv_conf_t & jcp)82 inline bool is_iw_threading_available(const jit_conv_conf_t &jcp) {
83 return one_of(jcp.ndims, 3, 4);
84 }
is_iw_threading_on(const jit_conv_conf_t & jcp)85 inline bool is_iw_threading_on(const jit_conv_conf_t &jcp) {
86 return (jcp.nb_iw > 1);
87 }
is_1stconv(const jit_conv_conf_t & jcp)88 inline bool is_1stconv(const jit_conv_conf_t &jcp) {
89 const bool no_big_offt = nstl::max<size_t>(jcp.ic, jcp.oc)
90 * nstl::max(jcp.typesize_in, jcp.typesize_out) * jcp.id
91 * jcp.ih * jcp.iw
92 < INT_MAX;
93 return jcp.ic < 16 && jcp.ngroups == 1 && no_big_offt;
94 }
95 } // namespace
96
97 template <typename Vmm>
_jit_avx512_core_bf16_fwd_kernel(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_t & dst_md)98 _jit_avx512_core_bf16_fwd_kernel<Vmm>::_jit_avx512_core_bf16_fwd_kernel(
99 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
100 const memory_desc_t &dst_md)
101 : jit_generator(nullptr, ker_code_size, true, avx512_core_bf16)
102 , jcp(ajcp)
103 , attr_(attr) {
104 if (jcp.with_eltwise || jcp.with_binary) {
105 using namespace binary_injector;
106 static constexpr bool preserve_gpr = true;
107 static constexpr bool preserve_vmm = false;
108 static constexpr size_t helper_vmm_idx = 31;
109 const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
110 const size_t tail_size = oc_block_tail
111 ? oc_block_tail
112 : jcp.oc_without_padding % isa_simd_width_;
113 static constexpr bool use_exact_tail_scalar_bcast = true;
114
115 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
116 r14, r15, preserve_gpr, preserve_vmm,
117 GET_OFF(post_ops_binary_rhs_arg_vec),
118 memory_desc_wrapper(dst_md), tail_size, postops_mask,
119 use_exact_tail_scalar_bcast};
120 const static_params_t static_params {
121 this->param1, rhs_arg_static_params};
122
123 postops_injector_ = utils::make_unique<
124 injector::jit_uni_postops_injector_t<avx512_core>>(
125 this, jcp.post_ops, static_params);
126 }
127 if (!isa_has_bf16(jcp.isa))
128 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
129 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
130 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5);
131 }
132
133 template <typename Vmm>
prepare_dst(int ur_w)134 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::prepare_dst(int ur_w) {
135 for (int k = 0; k < jcp.nb_oc_blocking; k++)
136 for (int j = 0; j < ur_w; j++) {
137 Vmm vmm = vmm_dst(j, k);
138 vpxord(vmm, vmm, vmm);
139 }
140 }
141
142 template <typename Vmm>
vmm_dst_idx(int i_ur,int i_oc) const143 int _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst_idx(
144 int i_ur, int i_oc) const {
145 const int idx = i_ur * jcp.nb_oc_blocking + i_oc;
146 assert(idx < ker_reg_base_idx);
147 return idx;
148 }
149
150 template <typename Vmm>
vmm_dst(int i_ur,int i_oc) const151 Vmm _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst(int i_ur, int i_oc) const {
152 return Vmm(vmm_dst_idx(i_ur, i_oc));
153 }
154
155 template <typename F>
iterate(const int nb_oc_block,const int ur_w,const bool mask_tail,const bool force_masking,const F & f)156 static void iterate(const int nb_oc_block, const int ur_w, const bool mask_tail,
157 const bool force_masking, const F &f) {
158 for (int k = 0; k < nb_oc_block; k++) {
159 const bool mask_flag
160 = force_masking || (mask_tail && k + 1 == nb_oc_block);
161 for (int j = 0; j < ur_w; j++)
162 f(mask_flag, k, j);
163 }
164 }
165 template <typename F>
iterate(const int nb_oc_block,const int ur_w,const F & f)166 static void iterate(const int nb_oc_block, const int ur_w, const F &f) {
167 iterate(nb_oc_block, ur_w, false, false, f);
168 }
169
170 template <typename Vmm>
apply_postops(int ur_w)171 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::apply_postops(int ur_w) {
172 if (jcp.with_eltwise || jcp.with_binary) {
173 injector_utils::vmm_index_set_t vmm_idxs;
174 if (jcp.with_binary) {
175 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
176 rhs_arg_params_tail;
177 const auto temp_offset_reg = this->r12;
178 const auto mask_tail = jcp.oc_without_padding % jcp.simd_w;
179 const bool oc_blk_is_smaller_than_vmm
180 = jcp.oc_block < isa_simd_width_;
181 iterate(jcp.nb_oc_blocking, ur_w, mask_tail,
182 oc_blk_is_smaller_than_vmm,
183 [&](const bool mask_flag, const int k, const int j) {
184 const int aux_output_l_off
185 = get_dst_offset(j, k) / jcp.typesize_out;
186 const auto vmm_idx = vmm_dst_idx(j, k);
187 vmm_idxs.emplace(vmm_idx);
188
189 rhs_arg_params_tail.vmm_idx_to_oc_elem_off_addr.emplace(
190 vmm_idx, ptr[param1 + GET_OFF(oc_l_off)]);
191 rhs_arg_params_tail.vmm_idx_to_oc_elem_off_val.emplace(
192 vmm_idx, k * jcp.oc_block);
193 rhs_arg_params_tail.vmm_idx_to_out_off_oprnd.emplace(
194 vmm_idx, temp_offset_reg);
195 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
196 vmm_idx, aux_output_l_off);
197 if (mask_flag)
198 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
199 });
200 rhs_arg_params = rhs_arg_params_tail;
201 rhs_arg_params.vmm_tail_idx_.clear();
202
203 const injector_utils::register_preserve_guard_t register_guard(
204 this, {temp_offset_reg});
205 mov(temp_offset_reg, reg_dst);
206 sub(temp_offset_reg, ptr[param1 + GET_OFF(dst_orig)]);
207 shr(temp_offset_reg, std::log2(sizeof(float)));
208
209 Label postops_done;
210 if (mask_tail || oc_blk_is_smaller_than_vmm) {
211 Label postops_no_tail;
212 if (mask_tail) {
213 test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1);
214 jz(postops_no_tail, T_NEAR);
215 }
216 postops_injector_->compute_vector_range(
217 vmm_idxs, rhs_arg_params_tail);
218 jmp(postops_done, T_NEAR);
219 L(postops_no_tail);
220 }
221 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
222 L(postops_done);
223
224 } else {
225 iterate(jcp.nb_oc_blocking, ur_w,
226 [&](const bool, const int k, const int j) {
227 vmm_idxs.emplace(vmm_dst_idx(j, k));
228 });
229 postops_injector_->compute_vector_range(vmm_idxs);
230 }
231 }
232 }
233
234 template <typename Vmm>
store_dst(int ur_w)235 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::store_dst(int ur_w) {
236 Label store_label;
237 const int oc_tail = jcp.oc_tail;
238 if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
239
240 if (jcp.with_sum) {
241 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
242 for (int j = 0; j < ur_w; j++) {
243 // mask only needed for last oc_block
244 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking;
245 Vmm vmm = vmm_dst(j, k);
246 size_t aux_dst_offset = get_dst_offset(j, k);
247 if (jcp.dst_dt == data_type::bf16) {
248 vpmovzxwd(may_be_mask_vmm(vmm_prev_dst, mask_flag, true),
249 make_safe_addr(
250 reg_dst, aux_dst_offset, reg_long_offt));
251 vpslld(vmm_prev_dst, vmm_prev_dst, 16);
252 vaddps(vmm, vmm_prev_dst);
253 } else {
254 vaddps(may_be_mask_vmm(vmm, mask_flag, true),
255 make_safe_addr(
256 reg_dst, aux_dst_offset, reg_long_offt));
257 }
258 }
259 }
260 }
261
262 if (jcp.with_bias) {
263 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
264 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
265 int bias_offset = jcp.typesize_bia * k * jcp.oc_block;
266 for (int j = 0; j < ur_w; j++) {
267 // mask only needed for last oc_block
268 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking;
269 Vmm vmm = vmm_dst(j, k);
270 if (jcp.bia_dt == data_type::bf16) {
271 vpmovzxwd(may_be_mask_vmm(vmm_bias, mask_flag, true),
272 EVEX_compress_addr(reg_bias, bias_offset));
273 vpslld(vmm_bias, vmm_bias, 16);
274 vaddps(vmm, vmm_bias);
275 } else
276 vaddps(may_be_mask_vmm(vmm, mask_flag, true),
277 EVEX_compress_addr(reg_bias, bias_offset));
278 }
279 }
280 }
281
282 apply_postops(ur_w);
283
284 L(store_label);
285 if (jcp.dst_dt == data_type::f32) {
286 for (int k = 0; k < jcp.nb_oc_blocking; k++)
287 for (int j = 0; j < ur_w; j++) {
288 Vmm vmm = vmm_dst(j, k);
289 size_t aux_dst_offset = get_dst_offset(j, k);
290 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
291 // mask only needed for last oc_block
292 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking
293 && is_dst_layout_nxc();
294 vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false));
295 }
296 } else if (jcp.dst_dt == data_type::bf16) {
297 if (isa_has_bf16(jcp.isa) && is_dst_layout_nxc()) {
298 // Optimization: use single store instruction for pair of the
299 // nearest vectors along OC dimension
300 for (int j = 0; j < ur_w; j++) {
301 int k = 0;
302 for (; k < rnd_dn(jcp.nb_oc_blocking, 2); k += 2) {
303 Vmm vmm = vmm_dst(j, k);
304 Vmm vmm_next = vmm_dst(j, k + 1);
305 size_t aux_dst_offset = get_dst_offset(j, k);
306 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
307 vcvtne2ps2bf16(vmm, vmm_next, vmm);
308 // mask only needed for last oc_block
309 bool mask_flag = oc_tail && k + 2 == jcp.nb_oc_blocking;
310 vmovdqu16(
311 addr, may_be_mask_vmm(vmm, mask_flag, false, true));
312 }
313 if (jcp.nb_oc_blocking % 2 != 0) {
314 Vmm vmm = vmm_dst(j, k);
315 auto vmm_down = Vmm_down_t(vmm.getIdx());
316 size_t aux_dst_offset = get_dst_offset(j, k);
317 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
318 vcvtneps2bf16(vmm_down, vmm);
319 // for xmm, upper half is zero after conversion to
320 // bf16, so mask always & mask for tails
321 bool mask_flag = jcp.simd_w == 4 || oc_tail;
322 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
323 }
324 }
325 } else if (isa_has_bf16(jcp.isa) /* !is_dst_layout_nxc() */) {
326 // Optimization: use single store instruction for pair of the
327 // nearest vectors along WIDTH dimension
328 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
329 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
330 for (j = 0; j < n_2bf2ps; j += 2) {
331 size_t aux_dst_offset = get_dst_offset(j, k);
332 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
333
334 auto vmm_str = vmm_src(j, jcp.nb_oc_blocking);
335 vcvtne2ps2bf16(vmm_str, vmm_dst(j + 1, k), vmm_dst(j, k));
336 vmovups(addr, vmm_str);
337 }
338 if (j < ur_w) {
339 size_t aux_dst_offset = get_dst_offset(j, k);
340
341 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
342 auto vmm_down_str = vmm_src_down(j, jcp.nb_oc_blocking);
343 vcvtneps2bf16(vmm_down_str, vmm_dst(j, k));
344 // for xmm, upper half is zero after conversion to
345 // bf16, so mask always.
346 const bool mask_flag = jcp.simd_w == 4;
347 vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag));
348 }
349 }
350 } else {
351 for (int k = 0; k < jcp.nb_oc_blocking; k++)
352 for (int j = 0; j < ur_w; j++) {
353 Vmm vmm = vmm_dst(j, k);
354 size_t aux_dst_offset = get_dst_offset(j, k);
355 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
356 auto vmm_down = vmm_src_down(0, jcp.nb_oc_blocking);
357 bf16_emu_->vcvtneps2bf16(
358 Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx()));
359 bool mask_flag = (oc_tail && k + 1 == jcp.nb_oc_blocking
360 && is_dst_layout_nxc())
361 // for xmm, upper half is zero after conversion to
362 // bf16, so mask always & mask for tails
363 || jcp.simd_w == 4;
364 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
365 }
366 }
367 } else
368 assert(!"unsupported destination type");
369 }
370
371 template <typename Vmm>
compute_loop(int ur_w,int pad_l,int pad_r)372 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::compute_loop(
373 int ur_w, int pad_l, int pad_r) {
374 Label kh_label, kd_label;
375 const int ic_tail = jcp.ic_tail;
376 const int ic_step = 2;
377
378 /* max_src_offset is explicitly used in the 1st convolution.
379 * Set its value so that accessing the double-word memory
380 * referenced by ptr[src_base + offset] is safe whenever
381 * 0 <= offset < max_src_offset
382 *
383 * Note: Since the arguments pad_l, pad_r might not exactly match
384 * with jcp.l_pad and jcp.r_pad respectively so this value needs to be
385 * computed separately for each invocation of the compute_loop.
386 */
387 dim_t max_src_offset = 0;
388 if (jcp.is_1stconv || ic_tail) {
389 for (int ki = 0; ki < jcp.kw; ki++) {
390 int ow_fst = get_ow_start(ki, pad_l);
391 int ow_last = get_ow_end(ur_w, ki, pad_r) - 1;
392 if (ow_fst > ow_last) continue;
393 int ic_last = rnd_up(nstl::min(jcp.ic_block,
394 nstl::max(jcp.ic, ic_tail)),
395 ic_step)
396 - ic_step;
397
398 dim_t src_offset = get_src_offset(
399 ic_last, filter_w_to_src(ki, ow_last, pad_l));
400 if (src_offset > max_src_offset) max_src_offset = src_offset;
401 }
402 }
403
404 prepare_dst(ur_w);
405
406 Label skip_compute_loop;
407 if (jcp.ndims == 5) {
408 mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
409 if ((jcp.dilate_d >= jcp.id)
410 || (jcp.kd - 1) * (jcp.dilate_d + 1)
411 < nstl::max(jcp.f_pad, jcp.back_pad)) {
412 cmp(reg_kj, 0);
413 je(skip_compute_loop, T_NEAR);
414 }
415 }
416 mov(reg_kj, reg_kh);
417 if ((jcp.dilate_h >= jcp.ih)
418 || (jcp.kh - 1) * (jcp.dilate_h + 1)
419 < nstl::max(jcp.t_pad, jcp.b_pad)) {
420 cmp(reg_kj, 0);
421 je(skip_compute_loop, T_NEAR);
422 }
423
424 // IC loop
425 Label icb_label;
426 mov(reg_ic, jcp.ic);
427 L(icb_label);
428
429 if (jcp.ndims == 5) {
430 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
431 mov(ptr[rsp + off_reg_ker_], reg_ker);
432 mov(ptr[rsp + off_reg_src_], reg_src);
433
434 L(kd_label);
435 }
436
437 mov(aux_reg_src, reg_src);
438 mov(aux_reg_ker, reg_ker);
439
440 mov(reg_kj, reg_kh);
441
442 std::vector<Label> ic_tail_jmp(jcp.kw);
443 L(kh_label);
444 {
445 for (int ki = 0; ki < jcp.kw; ki++) {
446 int ow_start = get_ow_start(ki, pad_l);
447 int ow_end = get_ow_end(ur_w, ki, pad_r);
448 for (int ic = 0;
449 ic < rnd_up(nstl::min(jcp.ic_block, jcp.ic), ic_step);
450 ic += ic_step) {
451 if (ic_tail && ic == rnd_up(ic_tail, ic_step)) {
452 // insert this check at most once per icb, no more.
453 cmp(reg_ic, ic_tail);
454 je(ic_tail_jmp[ki], T_NEAR);
455 }
456 for (int oi = ow_start; oi < ow_end; oi++) {
457 dim_t src_offset = get_src_offset(
458 ic, filter_w_to_src(ki, oi, pad_l));
459 auto vmm_in = vmm_src(oi, jcp.nb_oc_blocking);
460 const auto addr_base = EVEX_compress_addr_safe(
461 aux_reg_src, src_offset, reg_long_offt);
462 const bool tail_load
463 = ic_tail && ic == rnd_dn(ic_tail, ic_step);
464 if (jcp.is_1stconv || tail_load) {
465 const bool need_single_load
466 = (ic + 1 == jcp.ic || ic + 1 == ic_tail);
467 const bool safe_overstep = (src_offset < max_src_offset)
468 && !is_src_layout_nxc();
469
470 /* For the comment below, let us define three words
471 * x_b = ptr[addr_base] and x_s = ptr[addr_strided]
472 * x_g = ptr[addr_base + 2]
473 *
474 * For single load case:
475 * Without overstep zmm_in register is loaded as
476 * [0, x_b, ..., 0, x_b, 0, x_b]
477 * On the other hand, "with overstep" zmm_in register
478 * is loaded as
479 * [x_g, x_b, ..., x_g, x_b, x_g, x_b]
480 * where x_g is a garbage word.
481 *
482 * Note:
483 * 1. In single load case with safe_overstep enabled,
484 * it is implicitly assumed that the element in zmm_wei
485 * register corresponding to the "garbage value x_g" in
486 * zmm_in register is zero.
487 * 2. One can have potential problem when x_g is
488 * either Inf or NaN since it is multiplied by zero
489 * in accumulation. But as x_g is a "valid input"
490 * for different offset so one might assume that x_g is
491 * neither Inf nor Nan.
492 *
493 * For non single load case:
494 * zmm_in register is loaded as
495 * [x_s, x_b, ...., x_s, x_b, x_s, x_b]
496 */
497 if (tail_load) {
498 if (need_single_load) {
499 Label mask_load, load_done;
500 cmp(reg_ic, ic + ic_step);
501 jl(mask_load, T_NEAR);
502 vpbroadcastd(vmm_in, addr_base);
503 jmp(load_done, T_NEAR);
504 L(mask_load);
505 vpbroadcastw(vmm_in | odd_load_mask | T_z,
506 addr_base);
507 L(load_done);
508 } else {
509 vpbroadcastd(vmm_in, addr_base);
510 }
511 } else if (need_single_load && !safe_overstep)
512 vpbroadcastw(
513 vmm_in | odd_load_mask | T_z, addr_base);
514 else if (IMPLICATION(!is_src_layout_nxc(),
515 need_single_load && safe_overstep))
516 vpbroadcastd(vmm_in, addr_base);
517 else {
518 const auto addr_strided
519 = EVEX_compress_addr_safe(aux_reg_src,
520 src_offset + get_src_offset(1, 0),
521 reg_long_offt);
522 vpbroadcastd(vmm_in, addr_base);
523 vpbroadcastw(vmm_in | even_load_mask, addr_strided);
524 }
525 } else {
526 vpbroadcastd(vmm_in, addr_base);
527 }
528 }
529 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
530 auto wei_off = get_kernel_offset(kk, ic, ki);
531 vmovups(vmm_wei,
532 EVEX_compress_addr_safe(
533 aux_reg_ker, wei_off, reg_long_offt));
534 for (int oi = ow_start; oi < ow_end; oi++) {
535 auto acc = vmm_dst(oi, kk);
536 auto src = vmm_src(oi, jcp.nb_oc_blocking);
537 if (isa_has_bf16(jcp.isa)) {
538 vdpbf16ps(acc, vmm_wei, src);
539 } else
540 bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()),
541 Zmm(vmm_wei.getIdx()), Zmm(src.getIdx()));
542 }
543 }
544 }
545 L(ic_tail_jmp[ki]);
546 }
547 safe_add(aux_reg_ker, get_kernel_offset(0, 0, 0, 1), reg_long_offt);
548 safe_add(aux_reg_src, get_src_offset(0, filter_h_to_src(1)),
549 reg_long_offt);
550
551 dec(reg_kj);
552 cmp(reg_kj, 0);
553 jg(kh_label, T_NEAR);
554 }
555
556 if (jcp.ndims == 5) {
557 safe_add(reg_src, get_src_offset(0, filter_d_to_src(1)), reg_long_offt);
558 safe_add(reg_ker, get_kernel_offset(0, 0, 0, 0, 1), reg_long_offt);
559 dec(reg_ki);
560 cmp(reg_ki, 0);
561 jg(kd_label, T_NEAR);
562
563 mov(reg_ker, ptr[rsp + off_reg_ker_]);
564 mov(reg_src, ptr[rsp + off_reg_src_]);
565 }
566
567 // End of IC Loop
568 dim_t src_step = get_src_offset(jcp.ic_block, 0);
569 const size_t ker_step = get_kernel_offset(0, jcp.ic_block, 0);
570 safe_add(reg_src, src_step, reg_long_offt);
571 safe_add(reg_ker, ker_step, reg_long_offt);
572
573 sub(reg_ic, jcp.ic_block);
574 cmp(reg_ic, 0);
575 jg(icb_label, T_NEAR);
576
577 safe_sub(reg_src, src_step * jcp.nb_ic, reg_long_offt);
578 safe_sub(reg_ker, ker_step * jcp.nb_ic, reg_long_offt);
579
580 L(skip_compute_loop);
581 store_dst(ur_w);
582 }
583
584 template <typename Vmm>
generate()585 void _jit_avx512_core_bf16_fwd_kernel<Vmm>::generate() {
586 int iw = jcp.iw;
587 int ow = jcp.ow;
588 int ow_block = jcp.ow_block;
589 int nb_ow = jcp.nb_ow;
590 int kw = jcp.kw;
591 int l_pad = jcp.l_pad;
592 int ur_w = jcp.ur_w;
593 int ur_w_tail = jcp.ur_w_tail;
594 int stride_w = jcp.stride_w;
595
596 auto src_shift = get_src_offset(0, filter_w_to_src(0, ur_w));
597 auto dst_shift = get_dst_offset(ur_w, 0);
598
599 auto src_shift_pad = get_src_offset(0, filter_w_to_src(0, ur_w, l_pad));
600 auto src_shift_pad_second_block
601 = get_src_offset(0, filter_w_to_src(0, 0, l_pad));
602
603 preamble();
604 if (jcp.ndims == 5) sub(rsp, stack_space_needed_);
605
606 if (jcp.is_1stconv || jcp.ic_tail) {
607 Xbyak::Reg64 reg_alt_mask = r8;
608 const auto odd_mask = size_t {0x5555555555555555};
609 const auto even_mask = size_t {0xaaaaaaaaaaaaaaaa};
610 mov(reg_alt_mask, odd_mask);
611 kmovq(odd_load_mask, reg_alt_mask);
612 mov(reg_alt_mask, even_mask);
613 kmovq(even_load_mask, reg_alt_mask);
614 }
615
616 if (jcp.simd_w == 4) {
617 auto reg_tail_32 = reg_oc.cvt32();
618 mov(reg_tail_32, (1 << jcp.simd_w) - 1);
619 kmovb(k_oc_tail_mask, reg_tail_32);
620 }
621
622 if (jcp.oc_tail) {
623 Label done;
624 // dummy mask all 1's
625 if (jcp.simd_w != 4) { // simd_w == 4, has its dummy mask set already
626 kxnord(k_oc_tail_mask, k_oc_tail_mask, k_oc_tail_mask);
627 }
628 // To account for special store optimization, where two oc_blocks are
629 // combined with one single write, extend the mask for 32bits (32 bf16s)
630 const bool need_extended_mask = jcp.dst_dt == data_type::bf16
631 && isa_has_bf16(jcp.isa) && jcp.nb_oc_blocking > 1;
632 if (need_extended_mask)
633 kxnord(k_oc_tail_mask_extended, k_oc_tail_mask_extended,
634 k_oc_tail_mask_extended);
635
636 test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1);
637 jz(done, T_NEAR);
638 auto reg_tail_32 = reg_oc.cvt32();
639 mov(reg_tail_32, (1 << jcp.oc_tail) - 1);
640 kmovd(k_oc_tail_mask, reg_tail_32);
641 kmovd(postops_mask, reg_tail_32);
642 if (need_extended_mask) {
643 mov(reg_tail_32, (1 << (jcp.oc_tail + jcp.simd_w)) - 1);
644 kmovd(k_oc_tail_mask_extended, reg_tail_32);
645 }
646 L(done);
647 } else if (jcp.with_binary)
648 if (jcp.oc_block != isa_simd_width_) {
649 const int mask = (1 << jcp.oc_block) - 1;
650 const Reg32 regw_tmp = reg_oi.cvt32();
651 mov(regw_tmp, mask);
652 kmovd(postops_mask, regw_tmp);
653 }
654
655 mov(reg_src, ptr[param1 + GET_OFF(src)]);
656 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
657 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
658 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
659
660 int r_pad = nstl::max(0, jcp.r_pad);
661 int n_oi = ow / ur_w;
662 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w,
663 calculate_extended_filter_size(kw, jcp.dilate_w));
664
665 if (!is_ow_threading_on(jcp)) {
666 // ow is being processed as a whole - with left and right paddings
667 if (r_pad1 > 0) n_oi--;
668
669 xor_(reg_oi, reg_oi);
670 if (ow == ur_w) {
671 compute_loop(ur_w, l_pad, r_pad);
672 } else {
673 if (n_oi == 0) {
674 compute_loop(ur_w, l_pad, r_pad1);
675 add(reg_src, src_shift_pad);
676 add(reg_dst, dst_shift);
677 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
678 } else {
679 if (l_pad > 0) {
680 compute_loop(ur_w, l_pad, 0);
681 add(reg_src, src_shift_pad);
682 add(reg_dst, dst_shift);
683 inc(reg_oi);
684 }
685 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
686 Label ow_loop_label;
687 L(ow_loop_label);
688 {
689 compute_loop(ur_w, 0, 0);
690 add(reg_src, src_shift);
691 add(reg_dst, dst_shift);
692
693 inc(reg_oi);
694 cmp(reg_oi, n_oi);
695 jl(ow_loop_label, T_NEAR);
696 }
697 }
698 if (r_pad1 > 0) {
699 compute_loop(ur_w, 0, r_pad1);
700 add(reg_src, src_shift);
701 add(reg_dst, dst_shift);
702 }
703 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
704 }
705 }
706 } else {
707 // ow block is only processed.
708 // Number of block is passed as parameter owb,
709 // and padding processing depends on this number.
710
711 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
712 Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
713
714 assert(ow_block % ur_w == 0);
715 int n_oi_not_last_ow_block = ow_block / ur_w;
716 // to simplify code (and general regs usage),
717 // size of ow block must be >= 2 * ur_w
718 assert(n_oi_not_last_ow_block > 1);
719 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
720 int n_oi_first_ow_block = n_oi_not_last_ow_block;
721
722 int n_oi_last_ow_block = (ow - ow_block * (nb_ow - 1)) / ur_w;
723
724 // prepare right padding
725 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
726 bool first_ow_block_padded
727 = next_last_ow_block_padded && jcp.nb_ow == 2;
728 bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
729
730 if (last_ow_block_padded)
731 n_oi_last_ow_block--;
732 else if (first_ow_block_padded)
733 n_oi_first_ow_block--;
734 else if (next_last_ow_block_padded)
735 n_oi_next_last_ow_block--;
736
737 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
738 cmp(reg_owb, 0); // is that the first ow-block ?
739 jg(middle_ow_blocks_label, T_NEAR);
740
741 // the first ow block, compute left padding
742
743 mov(reg_oi, n_oi_first_ow_block);
744 if (l_pad > 0) {
745 compute_loop(ur_w, l_pad, 0);
746 add(reg_src, src_shift_pad);
747 add(reg_dst, dst_shift);
748 dec(reg_oi);
749 }
750 jmp(oi_loop_label, T_NEAR);
751
752 // middle or last ow block entry
753
754 L(middle_ow_blocks_label);
755
756 if (l_pad > 0) {
757 // just to consider left padding, not compute
758 add(reg_src, src_shift_pad_second_block);
759 }
760
761 // set number of iteration for oi-loop
762 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
763 mov(reg_oi, n_oi_last_ow_block);
764 je(oi_loop_label, T_NEAR);
765 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
766 mov(reg_oi, n_oi_next_last_ow_block);
767 je(oi_loop_label, T_NEAR);
768 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
769
770 // oi loop w/o padding
771 L(oi_loop_label);
772 L(oi_loop_start_label);
773 cmp(reg_oi, 0);
774 jle(oi_loop_end_label, T_NEAR);
775
776 compute_loop(ur_w, 0, 0);
777 add(reg_src, src_shift);
778 add(reg_dst, dst_shift);
779 dec(reg_oi);
780 jmp(oi_loop_start_label, T_NEAR);
781 L(oi_loop_end_label);
782
783 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
784
785 cmp(reg_owb, 0); // first ow-block ?
786 if (first_ow_block_padded) {
787 je(last_oi_label, T_NEAR);
788 } else {
789 je(end_label, T_NEAR);
790 }
791 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
792 jl(end_label, T_NEAR);
793 if (next_last_ow_block_padded) {
794 je(last_oi_label, T_NEAR);
795 } else {
796 je(end_label, T_NEAR);
797 }
798 // that is last block
799 if (!last_ow_block_padded) { jmp(tail_label, T_NEAR); }
800
801 // last oi block with right padding
802 L(last_oi_label);
803 compute_loop(ur_w, 0, r_pad1);
804 add(reg_src, src_shift);
805 add(reg_dst, dst_shift);
806
807 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
808 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
809 jl(end_label, T_NEAR);
810
811 L(tail_label);
812 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
813 L(end_label);
814 }
815
816 if (jcp.ndims == 5) add(rsp, stack_space_needed_);
817 postamble();
818
819 if (jcp.with_eltwise) postops_injector_->prepare_table();
820 }
821
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp)822 void jit_avx512_core_bf16_fwd_kernel::init_scratchpad(
823 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
824 using namespace memory_tracking::names;
825 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
826 assert(jcp.ngroups == 1);
827 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
828 }
829 }
830
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr,int nthreads)831 status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
832 const convolution_desc_t &cd, memory_desc_t &src_md,
833 memory_desc_t &weights_md, memory_desc_t &dst_md,
834 memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) {
835
836 using namespace prop_kind;
837
838 const memory_desc_wrapper src_d(&src_md);
839 const memory_desc_wrapper weights_d(&weights_md);
840 const memory_desc_wrapper dst_d(&dst_md);
841 const memory_desc_wrapper bias_d(&bias_md);
842
843 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
844 int ndims = src_d.ndims();
845
846 jcp = zero<decltype(jcp)>();
847 jcp.nthr = nthreads;
848 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
849 : bf16_emulation_t::get_isa();
850 jcp.ver = ver_vnni;
851 jcp.ndims = ndims;
852 jcp.prop_kind = cd.prop_kind;
853 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
854 jcp.mb = src_d.dims()[0];
855 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
856 jcp.oc_without_padding = jcp.oc;
857 jcp.ic = src_d.dims()[1] / jcp.ngroups;
858 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
859 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
860 jcp.iw = src_d.dims()[ndims - 1];
861 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
862 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
863 jcp.ow = dst_d.dims()[ndims - 1];
864 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
865 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
866 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
867 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
868 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
869 jcp.l_pad = cd.padding[0][ndims - 3];
870 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
871 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
872 jcp.stride_w = cd.strides[ndims - 3];
873 jcp.dst_dt = dst_d.data_type();
874
875 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
876 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
877 jcp.dilate_w = cd.dilates[ndims - 3];
878
879 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
880
881 jcp.typesize_in = types::data_type_size(src_d.data_type());
882 jcp.typesize_out = types::data_type_size(dst_d.data_type());
883
884 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
885 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
886
887 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
888 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
889 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
890 jcp.r_pad = calculate_end_padding(
891 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
892 jcp.b_pad = calculate_end_padding(
893 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
894 jcp.back_pad = calculate_end_padding(
895 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
896 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
897 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
898 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
899 if (kernel_outside_src) return status::unimplemented;
900
901 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
902 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
903 const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
904 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
905 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
906 auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c,
907 dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx);
908 auto curr_dst_tag = dst_d.matches_one_of_tag(
909 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
910 bool is_data_layout_nxc
911 = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
912 jcp.is_1stconv = is_1stconv(jcp);
913
914 const int regs = isa_has_bf16(jcp.isa) ? 31 /* expl_bcast case */ : 26;
915 const bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc;
916
917 jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
918
919 const bool ok_to_try_lower_zmm = true
920 && IMPLICATION(is_data_layout_nxc,
921 jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w
922 && jcp.ngroups > 1)
923 && !jcp.is_1stconv && !ok_to_pad_channels
924 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
925
926 if (ok_to_try_lower_zmm) {
927 for (auto simd : {8, 4}) {
928 if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
929 jcp.simd_w = simd;
930 break;
931 }
932 }
933 }
934
935 jcp.oc_block = jcp.simd_w;
936 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
937
938 if (ok_to_pad_channels) {
939 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
940 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
941 }
942
943 if (!IMPLICATION(!is_data_layout_nxc,
944 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
945 return status::unimplemented;
946
947 format_tag_t src_tag, dst_tag, wei_tag;
948
949 if (jcp.simd_w == 8) {
950 assert(with_groups);
951 dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
952 wei_tag = pick(ndims - 3, gOIw4i8o2i, gOIhw4i8o2i, gOIdhw4i8o2i);
953 } else if (jcp.simd_w == 4) {
954 assert(with_groups);
955 dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
956 wei_tag = pick(ndims - 3, gOIw2i4o2i, gOIhw2i4o2i, gOIdhw2i4o2i);
957 } else if (jcp.is_1stconv) {
958 dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
959 src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx;
960 wei_tag = pick(2 * ndims - 6 + with_groups, OwI16o2i, gOwI16o2i,
961 OhwI16o2i, gOhwI16o2i, OdhwI16o2i, gOdhwI16o2i);
962 } else {
963 dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
964 wei_tag = pick(2 * ndims - 6 + with_groups, OIw8i16o2i, gOIw8i16o2i,
965 OIhw8i16o2i, gOIhw8i16o2i, OIdhw8i16o2i, gOIdhw8i16o2i);
966 }
967
968 if (src_md.format_kind == format_kind::any)
969 CHECK(memory_desc_init_by_tag(src_md, src_tag));
970 else if (curr_src_tag != src_tag)
971 return status::unimplemented;
972 jcp.src_tag = src_tag;
973
974 if (dst_md.format_kind == format_kind::any)
975 CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
976 else if (curr_dst_tag != dst_tag)
977 return status::unimplemented;
978 jcp.dst_tag = dst_tag;
979
980 if (weights_md.format_kind == format_kind::any) {
981 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
982 jcp.wei_tag = wei_tag;
983 } else {
984 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
985 if (jcp.wei_tag != wei_tag) return status::unimplemented;
986 }
987
988 if (jcp.with_bias) {
989 if (bias_d.format_kind() == format_kind::any)
990 CHECK(memory_desc_init_by_tag(bias_md, x));
991 }
992
993 jcp.aligned_threads = 0;
994
995 bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
996 && jcp.oc <= dst_d.padded_dims()[1]
997 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
998 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
999 if (!args_ok) return status::unimplemented;
1000
1001 const auto &post_ops = attr.post_ops_;
1002 jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
1003 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
1004 jcp.with_eltwise = eltwise_ind != -1;
1005 if (jcp.with_eltwise) {
1006 jcp.eltwise = post_ops.entry_[eltwise_ind].eltwise;
1007 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1008 }
1009 const int binary_ind = post_ops.find(primitive_kind::binary);
1010 jcp.with_binary = binary_ind != -1;
1011
1012 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1013 if (is_data_layout_nxc)
1014 jcp.oc_tail = jcp.oc % jcp.simd_w;
1015 else
1016 jcp.oc_tail = jcp.with_binary ? jcp.oc_without_padding % jcp.simd_w : 0;
1017
1018 jcp.post_ops = post_ops;
1019
1020 using namespace injector;
1021 static constexpr bool sum_at_pos_0_only = true;
1022 static constexpr bool sum_requires_scale_one = true;
1023 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
1024 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
1025 {broadcasting_strategy_t::scalar,
1026 broadcasting_strategy_t::per_oc}});
1027 if (!post_ops_ok_) return status::unimplemented;
1028
1029 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
1030 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
1031 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1032
1033 jcp.kernel_kind = expl_bcast;
1034 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
1035 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
1036 int ur_w = regs / (jcp.nb_oc_blocking + 1);
1037 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
1038 && (jcp.l_pad <= ur_w
1039 && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
1040 break;
1041 }
1042
1043 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
1044 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1045 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1046
1047 jcp.ow_block = jcp.ow;
1048 if (is_ow_threading_available(jcp)) {
1049 const int L1_part = platform::get_per_core_cache_size(1) * 5 / 8;
1050 int size_src_chunk = jcp.typesize_in * jcp.ic_block * jcp.ur_w;
1051 int size_dst_chunk = jcp.typesize_out * jcp.oc_block
1052 * jcp.nb_oc_blocking * jcp.ur_w;
1053 int size_wei_chunk = jcp.typesize_in * jcp.oc_block * jcp.ic_block
1054 * jcp.nb_oc_blocking * jcp.kw;
1055 int nurw = (L1_part - size_wei_chunk)
1056 / (size_dst_chunk + size_src_chunk);
1057 // current design of generate() requires ow_block >= 2 * ur_w
1058 jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
1059 }
1060 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1061
1062 int r_pad_no_tail = nstl::max(0,
1063 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
1064 jcp.stride_w, ext_kw));
1065 if (jcp.l_pad > jcp.ur_w || r_pad_no_tail > jcp.ur_w)
1066 return status::unimplemented;
1067
1068 /* adjust the thread decomposition
1069 * to improve the perf for small problem size
1070 * the threshold L1_cache_size/factor and the factor is empirical
1071 * simply set the thread to 4 for now
1072 * TODO: Add get_thr_eff func to get optimal thread number */
1073
1074 size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh
1075 * jcp.kw * jcp.kd;
1076 size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1077 * jcp.iw * jcp.id;
1078 size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1079 * jcp.ow * jcp.od;
1080 size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1081 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1082
1083 // The factor for 1d=1, 2d=2, 3d=4;
1084 int factor = nstl::max(1, (2 * (ndims - 3)));
1085 if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) {
1086 jcp.nthr = nstl::min(jcp.nthr, 4);
1087 }
1088
1089 pick_loop_order(jcp);
1090
1091 return status::success;
1092 }
1093
1094 template <typename Vmm>
prepare_output(int ur_w)1095 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::prepare_output(int ur_w) {
1096 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1097 for (int j = 0; j < ur_w; j++) {
1098 Vmm vmm = vmm_dsrc(j, k);
1099 vpxord(vmm, vmm, vmm);
1100 }
1101 }
1102 }
1103
1104 template <typename Vmm>
store_output(int ur_w)1105 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::store_output(int ur_w) {
1106 if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
1107 const int ic_tail = jcp.ic_tail;
1108
1109 if (jcp.dst_dt == data_type::f32) {
1110 for (int k = 0; k < jcp.nb_ic_blocking; k++)
1111 for (int j = 0; j < ur_w; j++) {
1112 Vmm vmm = vmm_dsrc(j, k);
1113 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1114 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1115 // mask only needed for last ic_block
1116 bool mask_flag = ic_tail && k + 1 == jcp.nb_ic_blocking
1117 && is_dsrc_layout_nxc();
1118 vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false));
1119 }
1120 } else if (jcp.dst_dt == data_type::bf16) {
1121 if (isa_has_bf16(jcp.isa) && is_ddst_layout_nxc()) {
1122 // Optimization: use single store instruction for pair of the
1123 // nearest vectors along IC dimension
1124 for (int j = 0; j < ur_w; j++) {
1125 int k = 0;
1126 for (; k < rnd_dn(jcp.nb_ic_blocking, 2); k += 2) {
1127 Vmm vmm = vmm_dsrc(j, k);
1128 Vmm vmm_next = vmm_dsrc(j, k + 1);
1129 size_t aux_dsrc_offset = get_diff_src_offset(j, k);
1130 auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset);
1131 vcvtne2ps2bf16(vmm, vmm_next, vmm);
1132 bool mask_flag = ic_tail && k + 2 == jcp.nb_ic_blocking;
1133 vmovdqu16(
1134 addr, may_be_mask_vmm(vmm, mask_flag, false, true));
1135 }
1136 if (jcp.nb_ic_blocking % 2 != 0) {
1137 Vmm vmm = vmm_dsrc(j, k);
1138 auto vmm_down = Vmm_down_t(vmm.getIdx());
1139 size_t aux_dsrc_offset = get_diff_src_offset(j, k);
1140 auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset);
1141 vcvtneps2bf16(vmm_down, vmm);
1142 // for xmm, upper half is zero after conversion to
1143 // bf16, so mask always & mask for tails
1144 bool mask_flag = jcp.simd_w == 4 || ic_tail;
1145 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
1146 }
1147 }
1148 } else if (isa_has_bf16(jcp.isa) /* && !is_ddst_layout_nxc() */) {
1149 // Optimization: use single store instruction for pair of the
1150 // nearest vectors along WIDTH dimension
1151 int store_idx = 0;
1152 const int max_regs = 32;
1153 const int free_regs_start_idx = jcp.ur_w * jcp.nb_ic_blocking;
1154 const int num_regs_available = max_regs - free_regs_start_idx;
1155 int reg_idx = 0;
1156 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1157 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
1158 for (j = 0; j < n_2bf2ps; j += 2) {
1159 reg_idx = free_regs_start_idx
1160 + store_idx % num_regs_available;
1161 assert(reg_idx < max_regs);
1162 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1163 auto addr
1164 = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1165
1166 auto vmm_str = Vmm(reg_idx);
1167 vcvtne2ps2bf16(vmm_str, vmm_dsrc(j + 1, k), vmm_dsrc(j, k));
1168 vmovups(addr, vmm_str);
1169 store_idx++;
1170 }
1171 if (j < ur_w) {
1172 reg_idx = free_regs_start_idx
1173 + store_idx % num_regs_available;
1174 assert(reg_idx < max_regs);
1175
1176 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1177 auto addr
1178 = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1179 auto vmm_down_str = Vmm_down_t(reg_idx);
1180 vcvtneps2bf16(vmm_down_str, vmm_dsrc(j, k));
1181 // for xmm, upper half is zero after conversion to
1182 // bf16, so mask always.
1183 bool mask_flag = jcp.simd_w == 4;
1184 vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag));
1185 store_idx++;
1186 }
1187 }
1188 } else {
1189 for (int k = 0; k < jcp.nb_ic_blocking; k++)
1190 for (int j = 0; j < ur_w; j++) {
1191 Vmm vmm = vmm_dsrc(j, k);
1192 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1193 auto addr
1194 = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1195 auto vmm_down = vmm_ddst_down(0);
1196 bf16_emu_->vcvtneps2bf16(
1197 Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx()));
1198 bool mask_flag = (ic_tail && k + 1 == jcp.nb_ic_blocking
1199 && is_dsrc_layout_nxc())
1200 // for xmm, upper half is zero after conversion to
1201 // bf16, so mask always & mask for tails
1202 || jcp.simd_w == 4;
1203 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
1204 }
1205 }
1206 } else
1207 assert(!"unsupported diff_src type");
1208 }
1209
1210 template <typename Vmm>
compute_loop(int ur_w,int l_overflow,int r_overflow)1211 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::compute_loop(
1212 int ur_w, int l_overflow, int r_overflow) {
1213 int kw = jcp.kw;
1214 int dilate_w = jcp.dilate_w + 1;
1215 int stride_w = jcp.stride_w;
1216 int stride_h = jcp.stride_h;
1217 const int oc_tail = jcp.oc_tail;
1218 Label kh_label, skip_compute_label;
1219
1220 prepare_output(ur_w);
1221
1222 if (jcp.ndims == 5) {
1223 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1224 cmp(reg_ki, 0);
1225 jle(skip_compute_label, T_NEAR);
1226 }
1227
1228 cmp(reg_kh, 0);
1229 jle(skip_compute_label, T_NEAR);
1230
1231 // OC loop
1232 Label ocb_label;
1233 mov(reg_oc, jcp.oc);
1234 L(ocb_label);
1235
1236 if (jcp.ndims < 5) {
1237 mov(aux_reg_dst, reg_dst);
1238 mov(aux_reg_ker, reg_ker);
1239 }
1240 Label kd_label;
1241 if (jcp.ndims == 5) {
1242 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1243 mov(aux_reg_dst_d, reg_dst);
1244 mov(aux_reg_ker_d, reg_ker);
1245
1246 L(kd_label);
1247 mov(aux_reg_dst, aux_reg_dst_d);
1248 mov(aux_reg_ker, aux_reg_ker_d);
1249 }
1250
1251 std::vector<Label> oc_tail_jmp(jcp.kw);
1252 mov(reg_kj, reg_kh);
1253 L(kh_label);
1254 {
1255 for (int ki = 0; ki < kw; ki++) {
1256 int jj_start = get_iw_start(ki, l_overflow);
1257 int jj_end = get_iw_end(ur_w, ki, r_overflow);
1258 const int ref_jj_start
1259 = nstl::max(0, l_overflow - (kw - 1 - ki) * dilate_w);
1260 const int ref_jj_end
1261 = ur_w - nstl::max(0, r_overflow - ki * dilate_w);
1262 assert(IMPLICATION(stride_w == 1,
1263 jj_start == ref_jj_start && jj_end == ref_jj_end));
1264 UNUSED(ref_jj_start);
1265 UNUSED(ref_jj_end);
1266 const int oc_step = 2;
1267 for (int oc = 0;
1268 oc < rnd_up(nstl::min(jcp.oc_block, jcp.oc), oc_step);
1269 oc += oc_step) {
1270 if (oc_tail && oc == rnd_up(oc_tail, oc_step)) {
1271 cmp(reg_oc, oc_tail);
1272 je(oc_tail_jmp[ki], T_NEAR);
1273 }
1274 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1275 assert((jj + jcp.l_pad - ki * dilate_w) % stride_w == 0);
1276 int ow_idx = (jj + jcp.l_pad - ki * dilate_w) / stride_w;
1277 auto aux_ddst_offset = get_diff_dst_offset(ow_idx, oc);
1278 auto ddst = vmm_ddst(jj / stride_w);
1279 const bool tail_load = oc_tail && oc == rnd_dn(oc_tail, 2);
1280 const bool need_single_load = oc + 1 == oc_tail;
1281
1282 if (tail_load && need_single_load) {
1283 Label mask_load, load_done;
1284 cmp(reg_oc, oc + 2);
1285 jl(mask_load, T_NEAR);
1286 vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1287 jmp(load_done, T_NEAR);
1288 L(mask_load);
1289 // We broadcast w here. As the weights are zero-padded
1290 // at oc + 1, vdpbf16ps({0, w}, {dst, dst}) is okay.
1291 vpbroadcastw(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1292 L(load_done);
1293 } else {
1294 vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1295 }
1296 }
1297 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1298 size_t aux_kernel_offset = get_kernel_offset(kk, oc, ki);
1299 vmovups(vmm_wei,
1300 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1301
1302 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1303 auto ddst = vmm_ddst(jj / stride_w);
1304 auto acc = vmm_dsrc(jj, kk);
1305
1306 if (isa_has_bf16(jcp.isa)) {
1307 vdpbf16ps(acc, vmm_wei, ddst);
1308 } else
1309 bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()),
1310 Zmm(vmm_wei.getIdx()), Zmm(ddst.getIdx()));
1311 }
1312 }
1313 }
1314 L(oc_tail_jmp[ki]);
1315 }
1316
1317 add(aux_reg_ker, get_kernel_offset(0, 0, 0, stride_h));
1318 sub(aux_reg_dst, get_diff_dst_offset(filter_h_to_dst(1), 0));
1319
1320 dec(reg_kj);
1321 cmp(reg_kj, 0);
1322 jg(kh_label, T_NEAR);
1323 }
1324
1325 if (jcp.ndims == 5) {
1326 sub(aux_reg_dst_d, get_diff_dst_offset(filter_d_to_dst(1), 0));
1327 add(aux_reg_ker_d, get_kernel_offset(0, 0, 0, 0, jcp.stride_d));
1328
1329 dec(reg_ki);
1330 cmp(reg_ki, 0);
1331 jg(kd_label, T_NEAR);
1332 }
1333
1334 // End of OC Loop
1335 auto diff_dst_step = get_diff_dst_offset(0, 0, 1);
1336 auto ker_step = get_kernel_offset(0, jcp.oc_block, 0);
1337 add(reg_dst, diff_dst_step);
1338 add(reg_ker, ker_step);
1339
1340 sub(reg_oc, jcp.oc_block);
1341 cmp(reg_oc, 0);
1342 jg(ocb_label, T_NEAR);
1343
1344 sub(reg_dst, diff_dst_step * jcp.nb_oc);
1345 sub(reg_ker, ker_step * jcp.nb_oc);
1346
1347 L(skip_compute_label);
1348 store_output(ur_w);
1349 }
1350
1351 template <typename Vmm>
generate()1352 void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::generate() {
1353 int iw = jcp.iw;
1354 int kw = jcp.kw;
1355 int ur_w = jcp.ur_w;
1356 int nb_iw = jcp.nb_iw;
1357 int iw_block = jcp.iw_block;
1358 int ur_w_tail = jcp.ur_w_tail;
1359 int dilate_w = jcp.dilate_w + 1;
1360 int stride_w = jcp.stride_w;
1361
1362 const auto dst_shift = get_diff_dst_offset(ur_w / stride_w, 0);
1363 const auto src_shift = get_diff_src_offset(ur_w, 0);
1364
1365 preamble();
1366
1367 if (jcp.simd_w == 4) {
1368 Reg32 reg_tail_32 = reg_oc.cvt32();
1369 mov(reg_tail_32, (1 << jcp.simd_w) - 1);
1370 kmovb(k_ic_tail_mask, reg_tail_32);
1371 }
1372
1373 if (jcp.ic_tail) {
1374 Label done;
1375 // dummy mask all 1's
1376 if (jcp.simd_w != 4)
1377 kxnord(k_ic_tail_mask, k_ic_tail_mask, k_ic_tail_mask);
1378 // To account for special store optimization, where two ic_blocks are
1379 // combined with one single write, extend the mask for 32bits (32 bf16s)
1380 const bool need_extended_mask
1381 = isa_has_bf16(jcp.isa) && jcp.nb_ic_blocking > 1;
1382 if (need_extended_mask)
1383 kxnord(k_ic_tail_mask_extended, k_ic_tail_mask_extended,
1384 k_ic_tail_mask_extended);
1385
1386 test(byte[param1 + GET_OFF(load_work)], jcp.ic_block - 1);
1387 jz(done, T_NEAR);
1388 Reg32 reg_tail_32 = reg_ic.cvt32();
1389 mov(reg_tail_32, (1 << jcp.ic_tail) - 1);
1390 kmovd(k_ic_tail_mask, reg_tail_32);
1391 if (need_extended_mask) {
1392 mov(reg_tail_32, (1 << (jcp.ic_tail + jcp.simd_w)) - 1);
1393 kmovd(k_ic_tail_mask_extended, reg_tail_32);
1394 }
1395 L(done);
1396 }
1397
1398 mov(reg_src, ptr[param + GET_OFF(src)]);
1399 mov(reg_dst, ptr[param + GET_OFF(dst)]);
1400 mov(reg_ker, ptr[param + GET_OFF(filt)]);
1401
1402 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
1403
1404 int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
1405 int r_overflow = nstl::max(
1406 0, ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad)) / stride_w);
1407 int r_overflow1 = nstl::max(0,
1408 ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad + ur_w_tail))
1409 / stride_w);
1410
1411 int body_l_overflow = 0, body_r_overflow = 0;
1412 int n_oi = iw / ur_w;
1413 int head_n_oi = 0, body_n_oi = 0, pretail_n_oi = 0, tail_n_oi = 0;
1414 int head_thread = 0, pretail_thread = 0, tail_thread = 0;
1415 bool threaded = is_iw_threading_on(jcp);
1416 Label head_label, body_label, pretail_label, tail_label, end_label;
1417 assert(n_oi > 0);
1418
1419 if (r_overflow1 > 0) n_oi--;
1420 if (l_overflow > 0) n_oi--;
1421 if (n_oi < 0) {
1422 // l_overflow and r_overflow1 are handled in the same compute_loop.
1423 // Perform one iteration of body handling l_overflow and r_overflow1.
1424 body_l_overflow = l_overflow;
1425 body_r_overflow = r_overflow1;
1426 n_oi = 1;
1427 l_overflow = 0;
1428 r_overflow1 = 0;
1429 }
1430
1431 if (!threaded) {
1432 if (n_oi > 1) { mov(reg_oi, n_oi); }
1433 } else {
1434 // Setup for threaded code generation, and jump into the correct
1435 // portion of code for execution.
1436 head_thread = 0;
1437 tail_thread = nb_iw - 1;
1438 pretail_thread = tail_thread;
1439
1440 int base_n_oi = iw_block / ur_w;
1441 head_n_oi = l_overflow > 0 ? base_n_oi - 1 : base_n_oi;
1442 tail_n_oi = (iw - iw_block * (nb_iw - 1)) / ur_w;
1443 pretail_n_oi = tail_n_oi;
1444 if (r_overflow1 > 0) {
1445 if (tail_n_oi > 0) {
1446 pretail_n_oi--;
1447 tail_n_oi = pretail_n_oi;
1448 } else {
1449 // pretail_thread and tail_thread are different
1450 pretail_n_oi = base_n_oi - 1;
1451 pretail_thread = tail_thread - 1;
1452 }
1453 if (head_thread == pretail_thread) {
1454 head_n_oi--;
1455 pretail_n_oi = 0;
1456 tail_n_oi = 0;
1457 }
1458 }
1459 body_n_oi = (head_thread < pretail_thread - 1) ? base_n_oi : 0;
1460
1461 // n_oi is used to determine how much control flow in the body portion
1462 // of the code needs generated. As such, n_oi needs to be set to the
1463 // maximum number of iterations it will be used the body code section.
1464 n_oi = nstl::max(body_n_oi, head_n_oi);
1465 n_oi = nstl::max(n_oi, pretail_n_oi);
1466
1467 assert(iw_block % ur_w == 0);
1468 mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]);
1469
1470 if (head_n_oi != 0) mov(reg_oi, head_n_oi);
1471 cmp(reg_iwb, head_thread);
1472 je(head_label, T_NEAR);
1473
1474 cmp(reg_iwb, pretail_thread);
1475 if (pretail_n_oi == 0) {
1476 je(pretail_label, T_NEAR);
1477 } else {
1478 mov(reg_oi, pretail_n_oi);
1479 je(body_label, T_NEAR);
1480 }
1481 if (pretail_thread != tail_thread) {
1482 cmp(reg_iwb, tail_thread);
1483 je(tail_label, T_NEAR);
1484 }
1485 if (body_n_oi != 0) {
1486 mov(reg_oi, body_n_oi);
1487 jmp(body_label, T_NEAR);
1488 } else {
1489 jmp(end_label, T_NEAR);
1490 }
1491 }
1492 L(head_label);
1493 if (l_overflow > 0) {
1494 compute_loop(ur_w, l_overflow, 0);
1495 if (threaded && head_n_oi == 0 && head_thread != pretail_thread)
1496 jmp(end_label, T_NEAR);
1497 add(reg_src, src_shift);
1498 add(reg_dst, dst_shift);
1499 }
1500 L(body_label);
1501 if (n_oi > 0) {
1502 Label ow_loop_label;
1503 L(ow_loop_label);
1504 {
1505 compute_loop(ur_w, body_l_overflow, body_r_overflow);
1506 if (n_oi > 1 || r_overflow1 > 0 || ur_w_tail != 0) {
1507 add(reg_src, src_shift);
1508 add(reg_dst, dst_shift);
1509 }
1510 if (n_oi > 1) {
1511 sub(reg_oi, 1);
1512 jg(ow_loop_label, T_NEAR);
1513 }
1514 }
1515 }
1516 if (threaded) {
1517 cmp(reg_iwb, pretail_thread);
1518 jne(end_label, T_NEAR);
1519 }
1520 L(pretail_label);
1521 if (r_overflow1 > 0) {
1522 compute_loop(ur_w, 0, r_overflow1);
1523 if (ur_w_tail != 0) {
1524 if (threaded && tail_thread != pretail_thread)
1525 jmp(end_label, T_NEAR);
1526 else {
1527 add(reg_src, src_shift);
1528 add(reg_dst, dst_shift);
1529 }
1530 }
1531 }
1532 L(tail_label);
1533 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_overflow); }
1534 L(end_label);
1535
1536 postamble();
1537 }
1538
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & diff_src_md,memory_desc_t & weights_md,memory_desc_t & diff_dst_md,int nthreads)1539 status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp,
1540 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
1541 memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) {
1542
1543 const memory_desc_wrapper diff_src_d(&diff_src_md);
1544 const memory_desc_wrapper weights_d(&weights_md);
1545 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
1546
1547 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1548 int ndims = diff_src_d.ndims();
1549
1550 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
1551 : bf16_emulation_t::get_isa();
1552 jcp.nthr = nthreads;
1553 jcp.ver = ver_vnni;
1554 jcp.ndims = ndims;
1555 jcp.prop_kind = cd.prop_kind;
1556
1557 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1558 jcp.mb = diff_src_d.dims()[0];
1559
1560 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1561 jcp.oc_without_padding = jcp.oc;
1562 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
1563
1564 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
1565 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2];
1566 jcp.iw = diff_src_d.dims()[ndims - 1];
1567 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1568 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1569 jcp.ow = diff_dst_d.dims()[ndims - 1];
1570
1571 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1572 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1573 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1574
1575 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1576 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1577 jcp.l_pad = cd.padding[0][ndims - 3];
1578
1579 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1580 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1581 jcp.stride_w = cd.strides[ndims - 3];
1582
1583 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1584 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1585 jcp.dilate_w = cd.dilates[ndims - 3];
1586 jcp.dst_dt = cd.diff_src_desc.data_type;
1587 jcp.nb_iw = 1;
1588 jcp.iw_block = jcp.iw;
1589
1590 /* Dilated convolutions supported with unit strides only */
1591 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
1592 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
1593 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
1594 return status::unimplemented;
1595
1596 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1597 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1598 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1599 jcp.r_pad = calculate_end_padding(
1600 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1601 jcp.b_pad = calculate_end_padding(
1602 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1603 jcp.back_pad = calculate_end_padding(
1604 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1605 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1606 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1607 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1608 if (kernel_outside_src) return status::unimplemented;
1609
1610 jcp.aligned_threads = 0;
1611
1612 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1613 const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
1614 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1615 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1616 auto curr_src_tag = diff_src_d.matches_one_of_tag(
1617 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1618 auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
1619 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1620 bool is_data_layout_nxc
1621 = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
1622
1623 bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc;
1624
1625 jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
1626
1627 const bool ok_to_try_lower_zmm = true
1628 && IMPLICATION(is_data_layout_nxc,
1629 jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w
1630 && jcp.ngroups > 1)
1631 && !ok_to_pad_channels
1632 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
1633
1634 if (ok_to_try_lower_zmm) {
1635 for (auto simd : {8, 4}) {
1636 if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
1637 jcp.simd_w = simd;
1638 break;
1639 }
1640 }
1641 }
1642
1643 jcp.oc_block = jcp.simd_w;
1644 jcp.ic_block = jcp.simd_w;
1645
1646 if (ok_to_pad_channels) {
1647 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1648 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1649 }
1650
1651 if (!IMPLICATION(!is_data_layout_nxc,
1652 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
1653 return status::unimplemented;
1654 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1655 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0;
1656
1657 format_tag_t wei_tag, dat_tag;
1658
1659 if (jcp.simd_w == 8) {
1660 dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
1661 wei_tag = utils::pick(ndims - 3, gOIw4o8i2o, gOIhw4o8i2o, gOIdhw4o8i2o);
1662 } else if (jcp.simd_w == 4) {
1663 dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
1664 wei_tag = utils::pick(ndims - 3, gOIw2o4i2o, gOIhw2o4i2o, gOIdhw2o4i2o);
1665 } else {
1666 dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
1667 wei_tag = pick(2 * ndims - 6 + with_groups, OIw8o16i2o, gOIw8o16i2o,
1668 OIhw8o16i2o, gOIhw8o16i2o, OIdhw8o16i2o, gOIdhw8o16i2o);
1669 }
1670
1671 if (diff_src_md.format_kind == format_kind::any) {
1672 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag));
1673 } else if (curr_src_tag != dat_tag)
1674 return status::unimplemented;
1675 jcp.src_tag = dat_tag;
1676
1677 if (diff_dst_md.format_kind == format_kind::any) {
1678 CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag));
1679 } else if (curr_dst_tag != dat_tag)
1680 return status::unimplemented;
1681 jcp.dst_tag = dat_tag;
1682
1683 if (weights_md.format_kind == format_kind::any) {
1684 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
1685 jcp.wei_tag = wei_tag;
1686 } else {
1687 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1688 if (jcp.wei_tag != wei_tag) return status::unimplemented;
1689 }
1690
1691 bool args_ok = true && jcp.ic <= diff_src_d.padded_dims()[1]
1692 && jcp.oc <= diff_dst_d.padded_dims()[1]
1693 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1694 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
1695 if (!args_ok) return status::unimplemented;
1696
1697 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
1698 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
1699
1700 jcp.ur_w = jcp.stride_w;
1701
1702 /* Maximum number of registers available for result accumulation and delta
1703 dst data. One additional register is reserved for weights data. */
1704 const int max_regs
1705 = isa_has_bf16(jcp.isa) ? 31 : 26; /* In case of cpx emulation
1706 additional 5 registers are
1707 reserved */
1708 int l_overflow = nstl::max(
1709 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
1710 int r_overflow1 = nstl::max(0,
1711 ((jcp.kw - 1) * (jcp.dilate_w + 1)
1712 - nstl::max(0, jcp.r_pad + jcp.iw % jcp.ur_w))
1713 / jcp.stride_w);
1714 int n_oi = jcp.iw / jcp.ur_w;
1715 if (r_overflow1 > 0) n_oi--;
1716
1717 jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
1718 jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
1719
1720 /* Find the best blocking with maximum number of compute instructions
1721 per ur_w * nb_ic_blocking compute loops. Number of required registers
1722 is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1723 ur_w must be divisible by stride_w */
1724 if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
1725 distribution exceeds max_regs */
1726 return status::unimplemented;
1727
1728 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1729 {
1730 jcp.kernel_kind = expl_bcast;
1731 int best_compute_pipeline_length = 0;
1732 const int max_ic_blocks = 4;
1733 for (int b = 1; b <= max_ic_blocks; b++) {
1734 if (jcp.nb_ic % b != 0) continue;
1735
1736 for (int u = jcp.stride_w; u * b + u / jcp.stride_w <= max_regs
1737 && u < jcp.iw + jcp.stride_w;
1738 u += jcp.stride_w) {
1739 int ur_w = nstl::min(u, jcp.iw);
1740 /* maximum 1 step with l_overflow so far */
1741 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1742 continue;
1743 int pipeline_length = utils::div_up(ur_w, jcp.stride_w) * b;
1744 if (pipeline_length > best_compute_pipeline_length
1745 || (pipeline_length == best_compute_pipeline_length
1746 && jcp.ur_w < ur_w)) {
1747 jcp.ur_w = ur_w;
1748 jcp.nb_ic_blocking = b;
1749 best_compute_pipeline_length = pipeline_length;
1750 }
1751 }
1752 }
1753 if (best_compute_pipeline_length == 0) /* can't find
1754 appropriate blocking */
1755 return status::unimplemented;
1756 }
1757 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1758
1759 if (is_iw_threading_available(jcp)) {
1760 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
1761 int work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
1762 float no_iw_block_eff
1763 = (float)work_units / rnd_up(work_units, jcp.nthr);
1764
1765 // current design of generate() requires iw_block >= 2 * ur_w
1766 const int min_iw_block = jcp.ur_w * 2;
1767 int iw_threads = jcp.nthr / math::gcd(work_units, jcp.nthr);
1768 int iw_block = nstl::max(min_iw_block,
1769 rnd_up(jcp.iw, jcp.ur_w * iw_threads) / iw_threads);
1770 int nb_iw = div_up(jcp.iw, iw_block);
1771
1772 float block_eff = (float)jcp.iw / rnd_up(jcp.iw, iw_block);
1773 work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih * nb_iw;
1774 float work_eff = (float)work_units / rnd_up(work_units, jcp.nthr);
1775 float iw_block_eff = block_eff * work_eff;
1776
1777 const int iw_thread_min_size = 16 * 128;
1778 const float iw_block_cost = 20.0;
1779 float block_overhead = nstl::max(0.0f, 1.0f - iw_block_cost / iw_block);
1780
1781 bool iw_thread_useful = no_iw_block_eff < block_overhead * iw_block_eff
1782 && jcp.ic_block * jcp.iw > iw_thread_min_size;
1783
1784 if (iw_thread_useful) {
1785 jcp.iw_block = iw_block;
1786 jcp.nb_iw = nb_iw;
1787 }
1788 }
1789
1790 if (l_overflow * jcp.stride_w > jcp.ur_w) return status::unimplemented;
1791 int r_overflow_no_tail = nstl::max(0,
1792 ((jcp.kw - 1) * (jcp.dilate_w + 1)
1793 - nstl::max(0, jcp.r_pad + jcp.ur_w_tail))
1794 / jcp.stride_w);
1795 bool tails_not_ok = false
1796 /* maximum 1 ur_w block with r_overflow so far */
1797 || r_overflow_no_tail * jcp.stride_w > jcp.ur_w
1798 /* ur_w must be a multiple of stride */
1799 || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1800 /* r_pad must not extend beyond ur_w_tail */
1801 || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
1802 if (tails_not_ok) return status::unimplemented;
1803
1804 /* adjust the thread decomposition
1805 * to improve the perf for small problem size
1806 * the threshold L1_cache_size/factor and the factor is empirical
1807 * simply set the thread number to 4 now
1808 * TODO: Add get_thr_eff function to compute optimal thread*/
1809 size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh
1810 * jcp.kw * jcp.kd;
1811 size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1812 * jcp.iw * jcp.id;
1813 size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1814 * jcp.ow * jcp.od;
1815 size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1816 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1817
1818 //The factor for 1d: 1, 2d: 2, 3d: 4;
1819 int factor = nstl::max(1, (2 * (ndims - 3)));
1820 if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) {
1821 jcp.nthr = nstl::min(jcp.nthr, 4);
1822 }
1823
1824 pick_loop_order(jcp);
1825
1826 return status::success;
1827 }
1828
1829 const int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::max_ur_w = 28;
1830
1831 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
od_step_comeback_pointers()1832 od_step_comeback_pointers() {
1833 Label kd_comeback_label;
1834 mov(kj, reg_kd_count);
1835 L(kd_comeback_label);
1836 {
1837 sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
1838 sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
1839 dec(kj);
1840 cmp(kj, 0);
1841 jg(kd_comeback_label, T_NEAR);
1842 }
1843 }
1844 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
oh_step_comeback_pointers()1845 oh_step_comeback_pointers() {
1846 Label kh_comeback_label;
1847 mov(kj, reg_kh);
1848 L(kh_comeback_label);
1849 {
1850 sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
1851 sub(reg_kernel, get_kernel_offset(0, jcp.kw));
1852 dec(kj);
1853 cmp(kj, 0);
1854 jg(kh_comeback_label, T_NEAR);
1855 }
1856 }
1857
1858 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_extern(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)1859 compute_ic_block_step_extern(int ur_w, int pad_l, int pad_r,
1860 int ic_block_step, int src_offset, int kernel_offset,
1861 int ddst_offset, bool is_tail) {
1862 assert(!is_src_layout_nxc() && !is_ddst_layout_nxc());
1863 int kw = jcp.kw;
1864 bool no_src_pad = jcp.is_1stconv && !jcp.transpose_src;
1865 const int ddst_zmm_base_idx = 24;
1866 const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4;
1867 const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs;
1868
1869 auto zmm_ker = [=](int i_kw, int i_ic) {
1870 return Zmm(i_kw * ic_block_step + i_ic);
1871 };
1872 auto zmm_ddst = [=](int i_iw) {
1873 // TODO: move reg calc to global member funcs
1874 return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs);
1875 };
1876
1877 auto ker_addr = [=](int i_kw, int i_ic) {
1878 auto local_offset = get_kernel_offset(i_ic, i_kw);
1879 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
1880 };
1881 auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
1882 bool vnni_bcast = false) {
1883 auto local_offset = get_src_offset(i_ic, i_iw);
1884 return EVEX_compress_addr(
1885 reg_src, local_offset + src_offset + extra_offset, vnni_bcast);
1886 };
1887 auto ddst_addr = [=](int i_ur) {
1888 auto ow_scale = 2;
1889 return EVEX_compress_addr(
1890 reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset);
1891 };
1892
1893 for (int i_kw = 0; i_kw < kw; i_kw++)
1894 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1895 auto zmm = zmm_ker(i_kw, i_ic);
1896 vpxord(zmm, zmm, zmm);
1897 }
1898 assert(ur_w % 2 == 0);
1899 auto steps = ur_w / 2;
1900
1901 const int str_w = jcp.stride_w;
1902 const int underflow_boundary = -1;
1903 int i_iw_shift = jcp.tr_ow - ur_w - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0);
1904 const int overflow_boundary = jcp.iw - 1 - i_iw_shift;
1905
1906 for (int s = 0; s < str_w; s++) {
1907 const int kw_start = s;
1908 assert(jcp.tr_iw % str_w == 0);
1909 const int src_stride_w_shift = jcp.tr_iw / str_w;
1910 for (int i_ur = 0; i_ur < steps; i_ur++) {
1911 auto zmm = zmm_ddst(i_ur);
1912 vmovdqu16(zmm, ddst_addr(i_ur));
1913
1914 for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1915 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1916 int i_iw = 2 * i_ur + (i_kw * (jcp.dilate_w + 1)) / str_w
1917 + s * src_stride_w_shift;
1918 bool underflow = false;
1919 bool overflow = false;
1920 if (no_src_pad) {
1921 i_iw = i_iw - pad_l;
1922 underflow = i_iw <= underflow_boundary;
1923 overflow = is_tail && i_iw >= overflow_boundary;
1924 }
1925
1926 auto src = Zmm(zmm_src_reg);
1927 auto acc = zmm_ker(i_kw, i_ic);
1928 auto ddst = zmm_ddst(i_ur);
1929 if (underflow || overflow || !isa_has_bf16(jcp.isa)) {
1930 assert(ddst != src);
1931 assert(acc != src);
1932 }
1933 assert(ddst != acc);
1934 if (underflow || overflow) {
1935 if (underflow && i_iw == underflow_boundary)
1936 vpbroadcastw(src | everyother_shift_mask | T_z,
1937 src_addr(i_iw + 1, i_ic, 0));
1938 else if (overflow && i_iw == overflow_boundary)
1939 vpbroadcastw(src | everyother_mask | T_z,
1940 src_addr(i_iw, i_ic, 0));
1941 else
1942 continue;
1943
1944 if (!isa_has_bf16(jcp.isa))
1945 bf16_emu_->vdpbf16ps(acc, ddst, src);
1946 else
1947 vdpbf16ps(acc, ddst, src);
1948 } else if (!isa_has_bf16(jcp.isa)) {
1949 vpbroadcastd(src, src_addr(i_iw, i_ic, 0));
1950 bf16_emu_->vdpbf16ps(acc, ddst, src);
1951 } else
1952 vdpbf16ps(acc, ddst, src_addr(i_iw, i_ic, 0, true));
1953 }
1954 }
1955 }
1956 for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1957 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1958 auto addr = ker_addr(i_kw, i_ic);
1959 auto zmm = zmm_ker(i_kw, i_ic);
1960 vaddps(zmm, zmm, addr);
1961 vmovups(addr, zmm);
1962 }
1963 }
1964 }
1965 }
1966
interleave_w_reorder_size(int ur_w) const1967 int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_w_reorder_size(
1968 int ur_w) const {
1969 const int reorder_block = 16;
1970 return rnd_up(jcp.stride_w * (ur_w - 1) + jcp.kw, reorder_block);
1971 }
1972 int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
interleave_w_reorder_bytes(int ur_w)1973 interleave_w_reorder_bytes(int ur_w) {
1974 return 2 * jcp.typesize_in * interleave_w_reorder_size(ur_w);
1975 }
interleave_stack_size(int ur_w,int ic_block_step)1976 int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_stack_size(
1977 int ur_w, int ic_block_step) {
1978 return ic_block_step * interleave_w_reorder_bytes(ur_w);
1979 }
1980 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_interleave(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)1981 compute_ic_block_step_interleave(int ur_w, int pad_l, int pad_r,
1982 int ic_block_step, int src_offset, int kernel_offset,
1983 int ddst_offset, bool is_tail) {
1984 // Only supports nchw format src
1985 assert(jcp.is_1stconv && !jcp.transpose_src);
1986 int kw = jcp.kw;
1987 const int ddst_zmm_base_idx = 24;
1988 const int in_zmm_base_idx = 24;
1989 const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4;
1990 //const int num_in_zmm_regs = 8;
1991 const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs;
1992 const int reorder_block = 16;
1993 const int reorder_size = interleave_w_reorder_size(ur_w);
1994 const int reorder_bytes = interleave_w_reorder_bytes(ur_w);
1995 const int stack_size = interleave_stack_size(ur_w, ic_block_step);
1996 if (stack_size > ic_block_step_stack_size) {
1997 // This is a guard. Ideally it is never used, but is included to defend
1998 // against overlooked edge cases.
1999 assert(stack_size <= ic_block_step_stack_size);
2000 sub(rsp, stack_size - ic_block_step_stack_size);
2001 }
2002
2003 auto zmm_ker = [=](int i_kw, int i_ic) {
2004 return Zmm(i_kw * ic_block_step + i_ic);
2005 };
2006 auto zmm_ddst = [=](int i_iw) {
2007 return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs);
2008 };
2009 auto zmm_in = [=](int i_iw, int i_ic, bool stride_reg) {
2010 int stride = stride_reg ? 1 : 0;
2011 return Zmm(in_zmm_base_idx + 4 * (i_ic % 2) + 2 * (i_iw % 2) + stride);
2012 };
2013
2014 auto ker_addr = [=](int i_kw, int i_ic) {
2015 auto local_offset = get_kernel_offset(i_ic, i_kw);
2016 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
2017 };
2018 auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
2019 bool vnni_bcast = false) {
2020 int local_offset = i_ic * reorder_bytes + 2 * jcp.typesize_in * i_iw;
2021 return EVEX_compress_addr(rsp, local_offset, vnni_bcast);
2022 };
2023 auto ddst_addr = [=](int i_ur) {
2024 auto ow_scale = 2;
2025 return EVEX_compress_addr(
2026 reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset);
2027 };
2028 auto load_src_to_stack = [=](int i_iw, int i_ic, Opmask mask,
2029 bool mask_empty, Opmask stride_mask,
2030 bool stride_mask_empty) {
2031 auto local_offset = get_src_offset(i_ic, i_iw);
2032 int stack_offset
2033 = i_ic * reorder_bytes + 2 * jcp.typesize_in * (i_iw + pad_l);
2034
2035 auto zmm = zmm_in(i_iw, i_ic, false);
2036 auto zmm_stride = zmm_in(i_iw, i_ic, true);
2037 auto base_addr
2038 = EVEX_compress_addr(reg_src, local_offset + src_offset, false);
2039 auto stride_addr = EVEX_compress_addr(reg_src,
2040 local_offset + src_offset + get_src_offset(0, jcp.stride_w));
2041 auto stack_addr = EVEX_compress_addr(rsp, stack_offset);
2042 assert(IMPLICATION(mask_empty, stride_mask_empty));
2043 if (mask_empty) {
2044 vpxord(zmm, zmm, zmm);
2045 } else {
2046 vpmovzxwd(zmm | mask | T_z, base_addr);
2047 }
2048 if (!stride_mask_empty) {
2049 vpmovzxwd(zmm_stride | stride_mask | T_z, stride_addr);
2050 vpslld(zmm_stride, zmm_stride, 16);
2051 vpord(zmm, zmm, zmm_stride);
2052 }
2053 vmovdqu16(stack_addr, zmm);
2054 };
2055
2056 assert(ur_w % 2 == 0);
2057 auto steps = ur_w / 2;
2058
2059 const int str_w = jcp.stride_w;
2060 int i_iw_shift = str_w * (jcp.tr_ow - ur_w)
2061 - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0);
2062 const int overflow_boundary
2063 = is_tail ? jcp.iw - i_iw_shift : str_w * (ur_w - 1) + kw - pad_l;
2064
2065 // Calculate padding required by the data reorder using 32 byte loads
2066 int reorder_overflow = reorder_size - pad_l - overflow_boundary;
2067 int reorder_stride_overflow = reorder_overflow + str_w;
2068 reorder_overflow = nstl::max(0, reorder_overflow);
2069 reorder_stride_overflow = nstl::max(0, reorder_stride_overflow);
2070 int reorder_pad_r = reorder_overflow % reorder_block;
2071 int reorder_stride_pad_r = reorder_stride_overflow % reorder_block;
2072 if (reorder_stride_overflow >= reorder_size && reorder_stride_pad_r == 0) {
2073 assert(reorder_stride_overflow == reorder_size);
2074 reorder_stride_pad_r = reorder_block;
2075 }
2076 reorder_overflow -= reorder_pad_r;
2077 reorder_stride_overflow -= reorder_stride_pad_r;
2078
2079 int pad_l_mask = (0xffff << pad_l) & 0xffff;
2080 int pad_l_mask_strided
2081 = (0xffff << (pad_l >= str_w ? (pad_l - str_w) : 0)) & 0xffff;
2082 int pad_r_mask = 0xffff >> reorder_pad_r;
2083 int pad_r_mask_strided = 0xffff >> (reorder_stride_pad_r);
2084 pad_r_mask = pad_r_mask & 0xffff;
2085
2086 // Setup masks to load and reorder data
2087 if (reorder_size - reorder_stride_overflow > reorder_block) {
2088 // Overflow and underflow happen in different data reorder rounds
2089 kxnorw(overflow_stride_mask, overflow_stride_mask,
2090 overflow_stride_mask);
2091 kshiftlw(underflow_mask, overflow_stride_mask, pad_l);
2092 kshiftlw(underflow_stride_mask, overflow_stride_mask,
2093 pad_l >= str_w ? pad_l - str_w : 0);
2094 kshiftrw(overflow_mask, overflow_stride_mask, reorder_pad_r);
2095 kshiftrw(overflow_stride_mask, overflow_stride_mask,
2096 reorder_stride_pad_r);
2097 } else if (reorder_size - reorder_overflow > reorder_block) {
2098 // Overflow and underflow happen in the same round for loading the data
2099 // at the stride offset.
2100 kxnorw(overflow_mask, overflow_mask, overflow_mask);
2101 kshiftlw(underflow_mask, overflow_mask, pad_l);
2102 kshiftrw(overflow_mask, overflow_mask, reorder_pad_r);
2103 mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided);
2104 kmovw(underflow_stride_mask, reg_tmp.cvt32());
2105 } else {
2106 // Overflow and underflow happen in the same round for all data loads
2107 mov(reg_tmp.cvt32(), pad_l_mask & pad_r_mask);
2108 kmovw(underflow_mask, reg_tmp.cvt32());
2109 mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided);
2110 kmovw(underflow_stride_mask, reg_tmp.cvt32());
2111 }
2112
2113 // Load and reorder data to the stack
2114 int reorder_start = -pad_l;
2115 int reorder_end = reorder_size - pad_l;
2116 for (int i_iw = reorder_start; i_iw < reorder_end; i_iw += reorder_block) {
2117 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2118 Opmask mask, stride_mask;
2119 bool mask_empty, stride_mask_empty;
2120 // Performing this reorder on the stack may not be (always) optimal.
2121 // There are a couple of methods involving externally reordering the
2122 // data that were not considered due to time constraints. The first
2123 // is to transpose similar to the extern method. The other is to
2124 // perform the same interleave transform used here. The tradeoff
2125 // between these methods is the transpose method does not lend
2126 // itself to SIMD instructions (except possibly for some specific
2127 // strides) since the data is not blocked. The transform performed
2128 // here does, but uses twice as much data since
2129 // most data elements are duplicated.
2130
2131 if (i_iw == reorder_start) {
2132 mask = underflow_mask;
2133 mask_empty = false;
2134 if (pad_l_mask == 0) mask_empty = true;
2135 } else if (i_iw + reorder_overflow >= reorder_end) {
2136 mask_empty = true;
2137 } else if (i_iw + reorder_block + reorder_overflow >= reorder_end) {
2138 mask = overflow_mask;
2139 mask_empty = false;
2140 if (pad_r_mask == 0) mask_empty = true;
2141 } else {
2142 mask = m_ffffffff;
2143 mask_empty = false;
2144 }
2145 if (i_iw == reorder_start) {
2146 stride_mask = underflow_stride_mask;
2147 stride_mask_empty = false;
2148 if (pad_l_mask_strided == 0) mask_empty = true;
2149 } else if (i_iw + reorder_stride_overflow >= reorder_end) {
2150 stride_mask_empty = true;
2151 } else if (i_iw + reorder_block + reorder_stride_overflow
2152 >= reorder_end) {
2153 stride_mask = overflow_stride_mask;
2154 stride_mask_empty = false;
2155 if (pad_r_mask_strided == 0) mask_empty = true;
2156 } else {
2157 stride_mask = m_ffffffff;
2158 stride_mask_empty = false;
2159 }
2160 load_src_to_stack(i_iw, i_ic, mask, mask_empty, stride_mask,
2161 stride_mask_empty);
2162 }
2163 }
2164
2165 // Initialize kernel accumulators. It should sometimes be possible to skip
2166 // initializing and storing this data between calls to this function.
2167 for (int i_kw = 0; i_kw < kw; i_kw++)
2168 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2169 auto zmm = zmm_ker(i_kw, i_ic);
2170 vpxord(zmm, zmm, zmm);
2171 }
2172
2173 // Calculate this blocks contribution
2174 for (int i_ur = 0; i_ur < steps; i_ur++) {
2175 auto zmm = zmm_ddst(i_ur);
2176 vmovdqu16(zmm, ddst_addr(i_ur));
2177
2178 for (int i_kw = 0; i_kw < kw; i_kw++) {
2179 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2180 int i_iw = 2 * i_ur * str_w + i_kw;
2181 auto acc = zmm_ker(i_kw, i_ic);
2182 auto ddst = zmm_ddst(i_ur);
2183
2184 const bool isa_supports_bf16 = isa_has_bf16(jcp.isa);
2185 auto src_stack_addr
2186 = src_addr(i_iw, i_ic, 0, isa_supports_bf16);
2187
2188 if (isa_supports_bf16)
2189 vdpbf16ps(acc, ddst, src_stack_addr);
2190 else {
2191 auto src = Zmm(zmm_src_reg);
2192 vpbroadcastd(src, src_stack_addr);
2193 bf16_emu_->vdpbf16ps(acc, ddst, src);
2194 }
2195 }
2196 }
2197 }
2198
2199 // Store kernel accumulators
2200 for (int i_kw = 0; i_kw < kw; i_kw++) {
2201 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2202 auto addr = ker_addr(i_kw, i_ic);
2203 auto zmm = zmm_ker(i_kw, i_ic);
2204 vaddps(zmm, zmm, addr);
2205 vmovups(addr, zmm);
2206 }
2207 }
2208
2209 if (stack_size > ic_block_step_stack_size) {
2210 // This is a guard. Ideally it is never used, but is included to defend
2211 // against overlooked edge cases.
2212 add(rsp, stack_size - ic_block_step_stack_size);
2213 }
2214 }
2215
2216 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
convert_src_to_vnni_format(int ur_w,int pad_l,int pad_r,int src_offset)2217 convert_src_to_vnni_format(
2218 int ur_w, int pad_l, int pad_r, int src_offset) {
2219 Reg64 reg_trans_tmp = r11;
2220 const int ic_tail = jcp.ic_tail;
2221 mov(EVEX_compress_addr(rsp, trans_tmp_offset), reg_trans_tmp);
2222
2223 mov(reg_trans_tmp, dst_prm_table);
2224 vmovups(get_perm_reg(), ptr[reg_trans_tmp]);
2225
2226 mov(reg_trans_tmp, EVEX_compress_addr(rsp, trans_tmp_offset));
2227 const int max_regs = 16;
2228 if (ic_tail) {
2229 Label skip_tail_mask;
2230 cmp(reg_icb, jcp.simd_w);
2231 jge(skip_tail_mask);
2232 kandd(m_0000ffff, m_0000ffff, m_0000_ic_tail);
2233 kandd(m_ffff0000, m_ffff0000, m_ic_tail_0000);
2234 L(skip_tail_mask);
2235 }
2236 for (int src_count = 0;
2237 sizeof_cacheline * src_count < permw_stack_size(ur_w);
2238 src_count++) {
2239 int i_ur = nstl::min(src_count, ur_w - 2);
2240 int i_kw = src_count - i_ur;
2241 int buffer_offset = permw_buffer_start + src_count * 64;
2242 auto bcast_values = Zmm(src_count % max_regs);
2243 bool check = check_borders(ur_w, pad_l, pad_r, i_ur, i_kw);
2244 if (check) {
2245 if (is_src_layout_nxc()) {
2246 int iw_1, iw_2;
2247 get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2);
2248 if (iw_1 == -1)
2249 vxorpd(bcast_values, bcast_values, bcast_values);
2250 else {
2251 dim_t local_src_offset = src_offset
2252 + get_src_offset(
2253 0, filter_w_to_src(i_kw, i_ur, pad_l));
2254 vmovdqu16(bcast_values | m_0000ffff | T_z,
2255 ptr[reg_src + local_src_offset]);
2256 }
2257 if (iw_2 != -1) {
2258 dim_t local_src_offset = src_offset - 32
2259 + get_src_offset(
2260 0, filter_w_to_src(i_kw, i_ur + 1, pad_l));
2261 vmovdqu16(bcast_values | m_ffff0000,
2262 ptr[reg_src + local_src_offset]);
2263 }
2264 } else {
2265 Opmask load_mask;
2266 get_load_mask(ur_w, pad_l, pad_r, i_ur, i_kw, load_mask);
2267
2268 dim_t local_src_offset = src_offset
2269 + get_src_offset(0, filter_w_to_src(i_kw, i_ur, pad_l));
2270 vmovdqu16(bcast_values | load_mask | T_z,
2271 ptr[reg_src + local_src_offset]);
2272 }
2273 vpermw(bcast_values, get_perm_reg(), bcast_values);
2274 } else {
2275 vpxord(bcast_values, bcast_values, bcast_values);
2276 }
2277 vmovups(ptr[rsp + buffer_offset], bcast_values);
2278 }
2279 if (ic_tail) {
2280 // Reset-back the masks
2281 kxnorw(m_0000ffff, m_0000ffff, m_0000ffff);
2282 kshiftld(m_ffff0000, m_0000ffff, 16);
2283 }
2284 }
2285
2286 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
may_be_set_oc_tail_mask()2287 may_be_set_oc_tail_mask() {
2288 if (jcp.oc_tail) {
2289 Label skip_tail_mask;
2290 cmp(dword[param + GET_OFF(load_work)], jcp.simd_w);
2291 jge(skip_tail_mask);
2292 kandd(m_0000ffff, m_0000ffff, m_0000_oc_tail);
2293 kandd(m_ffff0000, m_ffff0000, m_oc_tail_0000);
2294 L(skip_tail_mask);
2295 }
2296 }
2297
2298 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
may_be_reset_oc_tail_mask()2299 may_be_reset_oc_tail_mask() {
2300 if (jcp.oc_tail) {
2301 // Reset-back the masks
2302 kxnorw(m_0000ffff, m_0000ffff, m_0000ffff);
2303 kshiftld(m_ffff0000, m_0000ffff, 16);
2304 }
2305 }
2306
2307 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_vpermw_expl(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)2308 compute_ic_block_step_vpermw_expl(int ur_w, int pad_l, int pad_r,
2309 int ic_block_step, int src_offset, int kernel_offset,
2310 int ddst_offset, bool is_tail) {
2311 assert(!jcp.is_1stconv); // This method does not support nchw data
2312 int kw = jcp.kw;
2313 int src_count = 0;
2314 int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step);
2315 const int max_regs = (!isa_has_bf16(jcp.isa)) ? 26 : 31;
2316 int src_pl_len = kw;
2317 const int diff_dst_pl_start_reg_idx = ic_block_step * (kw + src_pl_len);
2318 const int diff_dst_pl_len = max_regs - diff_dst_pl_start_reg_idx;
2319
2320 auto get_diff_wei_reg_idx
2321 = [=](int i_kw, int i_ic) { return i_kw * ic_block_step + i_ic; };
2322 auto get_src_reg_idx = [=](int i_iw, int i_ic) {
2323 return kw * ic_block_step + (i_iw % src_pl_len) * ic_block_step + i_ic;
2324 };
2325 auto get_diff_dst_reg_idx = [=](int i_ur) {
2326 return diff_dst_pl_start_reg_idx + (i_ur / 2) % diff_dst_pl_len;
2327 };
2328
2329 may_be_set_oc_tail_mask();
2330 auto load_dst = [=](int c) {
2331 bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w;
2332 bool is_ddst_nxc = is_ddst_layout_nxc();
2333 auto offset = get_ddst_offset(c * 2) + ddst_offset;
2334
2335 Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff;
2336 vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | load_mask | T_z,
2337 EVEX_compress_addr(reg_ddst, offset));
2338
2339 if (is_ddst_nxc && !is_tail) {
2340 offset += get_ddst_offset(1) - 32;
2341 vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | m_ffff0000,
2342 EVEX_compress_addr(reg_ddst, offset));
2343 }
2344 vpermw(Zmm(get_diff_dst_reg_idx(2 * c)), get_perm_reg(),
2345 Zmm(get_diff_dst_reg_idx(2 * c)));
2346 };
2347
2348 for (int i_kw = 0; i_kw < kw; i_kw++)
2349 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2350 vpxord(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2351 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2352 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2353
2354 auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) {
2355 int scale = 2 * jcp.typesize_in;
2356 return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64
2357 + jcp.typesize_in * 2
2358 * (ic_block_step_idx * ic_block_step + ic);
2359 };
2360 int src_count_last = 0;
2361 for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
2362 if (i_ur == 0) {
2363 for (int dst_count = 0;
2364 dst_count < nstl::min(diff_dst_pl_len, div_up(ur_w, 2));
2365 dst_count++) {
2366 load_dst(dst_count);
2367 }
2368 for (src_count = 0; src_count < src_pl_len; src_count++) {
2369 int _i_ur = src_count / kw;
2370 int _i_kw = src_count % kw;
2371 if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw))
2372 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2373 vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)),
2374 ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]);
2375 }
2376 }
2377 src_count_last = src_count;
2378 } else {
2379 int diff_dst_load_idx = i_ur + 2 * (diff_dst_pl_len - 1);
2380 if (diff_dst_load_idx < ur_w) load_dst(diff_dst_load_idx / 2);
2381 for (src_count = i_ur; src_count < i_ur + src_pl_len; src_count++) {
2382 if (src_count < src_count_last) continue;
2383 int _i_ur = (src_count - i_ur) / kw + i_ur;
2384 int _i_kw = (src_count - i_ur) % kw;
2385 if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw))
2386 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2387 vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)),
2388 ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]);
2389 }
2390 }
2391 src_count_last = src_count;
2392 }
2393 for (int i_kw = 0; i_kw < kw; i_kw++) {
2394 int i_iw = i_ur + i_kw;
2395 if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) {
2396 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2397 if (!isa_has_bf16(jcp.isa)) {
2398 bf16_emu_->vdpbf16ps(
2399 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2400 Zmm(get_diff_dst_reg_idx(i_ur)),
2401 Zmm(get_src_reg_idx(i_iw, i_ic)));
2402 } else {
2403 vdpbf16ps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2404 Zmm(get_diff_dst_reg_idx(i_ur)),
2405 Zmm(get_src_reg_idx(i_iw, i_ic)));
2406 }
2407 }
2408 }
2409 }
2410 }
2411
2412 for (int i_kw = 0; i_kw < kw; i_kw++)
2413 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2414 auto l_offset = get_kernel_offset(i_ic, i_kw);
2415 vaddps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2416 EVEX_compress_addr(reg_kernel, l_offset + kernel_offset));
2417 }
2418
2419 for (int i_kw = 0; i_kw < kw; i_kw++) {
2420 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2421 auto l_offset = get_kernel_offset(i_ic, i_kw);
2422 vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
2423 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2424 }
2425 }
2426
2427 may_be_reset_oc_tail_mask();
2428 }
2429
2430 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_ic_block_step_vpermw(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)2431 compute_ic_block_step_vpermw(int ur_w, int pad_l, int pad_r,
2432 int ic_block_step, int src_offset, int kernel_offset,
2433 int ddst_offset, bool is_tail) {
2434 assert(!jcp.is_1stconv); // This method does not support nchw data
2435 int kw = jcp.kw;
2436
2437 int dst_count = 0;
2438
2439 int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step);
2440
2441 int pipeline_length = (isa_has_bf16(jcp.isa))
2442 ? nstl::max(1, nstl::min(4, ur_w / 2))
2443 : 1;
2444 may_be_set_oc_tail_mask();
2445
2446 const int dst_off_reg = (!isa_has_bf16(jcp.isa)) ? 26 : 31;
2447 auto load_dst = [=](int c) {
2448 bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w;
2449 bool is_ddst_nxc = is_ddst_layout_nxc();
2450 auto offset = get_ddst_offset(2 * c) + ddst_offset;
2451
2452 Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff;
2453 vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | load_mask | T_z,
2454 EVEX_compress_addr(reg_ddst, offset));
2455
2456 if (is_ddst_nxc && !is_tail) {
2457 offset += get_ddst_offset(1) - 32;
2458 vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | m_ffff0000,
2459 EVEX_compress_addr(reg_ddst, offset));
2460 }
2461 vpermw(Zmm(dst_off_reg - c % pipeline_length), get_perm_reg(),
2462 Zmm(dst_off_reg - c % pipeline_length));
2463 };
2464
2465 for (int i_kw = 0; i_kw < kw; i_kw++)
2466 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2467 vmovups(Zmm(i_kw * ic_block_step + i_ic),
2468 EVEX_compress_addr(reg_kernel,
2469 get_kernel_offset(i_ic, i_kw) + kernel_offset));
2470
2471 for (dst_count = 0; dst_count < pipeline_length; dst_count++) {
2472 load_dst(dst_count);
2473 }
2474 auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) {
2475 int scale = 2 * jcp.typesize_in;
2476 return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64
2477 + jcp.typesize_in * 2
2478 * (ic_block_step_idx * ic_block_step + ic);
2479 };
2480
2481 for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
2482 for (int i_kw = 0; i_kw < kw; i_kw++) {
2483 if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) {
2484 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2485 if (!isa_has_bf16(jcp.isa)) {
2486 auto zmm_src = Zmm(28);
2487 vpbroadcastd(
2488 zmm_src, ptr[get_bcast_ptr(i_ur, i_kw, i_ic)]);
2489 bf16_emu_->vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
2490 Zmm(dst_off_reg - dst_count % pipeline_length),
2491 zmm_src);
2492 } else {
2493 vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
2494 Zmm(dst_off_reg - dst_count % pipeline_length),
2495 zword_b[get_bcast_ptr(i_ur, i_kw, i_ic)]);
2496 }
2497 }
2498 }
2499 }
2500 if (dst_count * 2 < ur_w) load_dst(dst_count);
2501 dst_count++;
2502 }
2503 for (int i_kw = 0; i_kw < kw; i_kw++) {
2504 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2505 auto l_offset = get_kernel_offset(i_ic, i_kw);
2506 vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
2507 Zmm(i_kw * ic_block_step + i_ic));
2508 }
2509 }
2510
2511 may_be_reset_oc_tail_mask();
2512 }
2513
2514 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_diff_bias_init()2515 compute_diff_bias_init() {
2516 auto reg_unit_val = reg_tmp.cvt16();
2517 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
2518 vpbroadcastw(vreg_bias_unit, reg_unit_val);
2519
2520 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2521 vmovups(vreg_bias_acc, ptr[reg_tmp]);
2522
2523 if (jcp.uses_permw_transposition) {
2524 mov(reg_tmp, dst_prm_table);
2525 vmovups(get_perm_reg(), ptr[reg_tmp]);
2526 }
2527 }
2528
compute_diff_bias_row(bool is_partial)2529 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_diff_bias_row(
2530 bool is_partial) {
2531 if (!jcp.with_bias) return;
2532 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2533 Label skip_label;
2534 test(reg_tmp, FLAG_IC_FIRST);
2535 jz(skip_label, T_NEAR);
2536
2537 may_be_set_oc_tail_mask();
2538
2539 if (is_partial) compute_diff_bias_init();
2540
2541 auto compute_step = [&](bool is_tail) {
2542 if (jcp.transpose_dst) {
2543 UNUSED(is_tail);
2544 vmovups(vreg_bias_ddst, ptr[reg_ddst]);
2545 } else {
2546 auto vreg_ddst_load = is_ddst_layout_nxc() || is_tail
2547 ? vreg_bias_ddst | m_0000ffff | T_z
2548 : vreg_bias_ddst;
2549 vmovdqu16(vreg_ddst_load, ptr[reg_ddst]);
2550 if (is_ddst_layout_nxc() && !is_tail) {
2551 const int shift_16_elems = 16 * jcp.typesize_in;
2552 vmovdqu16(vreg_bias_ddst | m_ffff0000,
2553 ptr[reg_ddst + get_ddst_offset(1) - shift_16_elems]);
2554 }
2555 vpermw(vreg_bias_ddst, get_perm_reg(), vreg_bias_ddst);
2556 }
2557 if (!isa_has_bf16(jcp.isa))
2558 bf16_emu_->vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
2559 else
2560 vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
2561 };
2562
2563 Label ow_loop, ow_tail;
2564 int niters = jcp.tr_ow / 2;
2565 if (niters > 0) {
2566 mov(reg_tmp, jcp.tr_ow / 2);
2567 L(ow_loop);
2568 compute_step(false);
2569 add(reg_ddst, get_ddst_offset(2));
2570 sub(reg_tmp, 1);
2571 jnz(ow_loop, T_NEAR);
2572 }
2573 if (jcp.tr_ow % 2) compute_step(true);
2574
2575 if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
2576
2577 if (is_partial) {
2578 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2579 vmovups(ptr[reg_tmp], vreg_bias_acc);
2580 }
2581
2582 may_be_reset_oc_tail_mask();
2583
2584 L(skip_label);
2585 }
2586 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
maybe_compute_diff_bias()2587 maybe_compute_diff_bias() {
2588 // In harness_3d_reduction case calculation of diff_bias is called
2589 // for every ow row separately to be aligned with od loop in
2590 // compute_od_loop_common()
2591 if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
2592 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2593
2594 Label skip_label;
2595 test(reg_tmp, FLAG_IC_FIRST);
2596 jz(skip_label, T_NEAR);
2597
2598 switch (jcp.harness) {
2599 case harness_2d_reduction:
2600 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
2601 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
2602 break;
2603 case harness_mb_reduction:
2604 case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
2605 case harness_3d_reduction:
2606 default: assert(!"Invalid harness type");
2607 }
2608
2609 compute_diff_bias_init();
2610
2611 cmp(reg_oj, 0);
2612 jle(skip_label, T_NEAR); // nothing to do
2613 Label bias_loop;
2614 L(bias_loop);
2615 {
2616 compute_diff_bias_row(false);
2617 add(reg_ddst, get_ddst_offset(0, 1));
2618
2619 sub(reg_oj, 1);
2620 jnz(bias_loop, T_NEAR);
2621 }
2622
2623 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2624 vmovups(ptr[reg_tmp], vreg_bias_acc);
2625
2626 // restore reg_ddst value
2627 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
2628
2629 L(skip_label);
2630 }
2631
compute_ic_block_step(int ur_w,int pad_l,int pad_r,int ic_block_step,int src_offset,int kernel_offset,int ddst_offset,bool is_tail)2632 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
2633 int ur_w, int pad_l, int pad_r, int ic_block_step, int src_offset,
2634 int kernel_offset, int ddst_offset, bool is_tail) {
2635
2636 if (jcp.uses_permw_transposition)
2637 if (jcp.kernel_kind == expl_bcast)
2638 compute_ic_block_step_vpermw_expl(ur_w, pad_l, pad_r, ic_block_step,
2639 src_offset, kernel_offset, ddst_offset, is_tail);
2640 else
2641 compute_ic_block_step_vpermw(ur_w, pad_l, pad_r, ic_block_step,
2642 src_offset, kernel_offset, ddst_offset, is_tail);
2643 else if (jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1)
2644 compute_ic_block_step_interleave(ur_w, pad_l, pad_r, ic_block_step,
2645 src_offset, kernel_offset, ddst_offset, is_tail);
2646 else
2647 compute_ic_block_step_extern(ur_w, pad_l, pad_r, ic_block_step,
2648 src_offset, kernel_offset, ddst_offset, is_tail);
2649 }
2650
get_ur_w(int & ur_w,int & ur_w_tail,int & ur_w_trips)2651 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::get_ur_w(
2652 int &ur_w, int &ur_w_tail, int &ur_w_trips) {
2653 if (jcp.tr_ow <= max_ur_w) {
2654 ur_w = jcp.tr_ow;
2655 ur_w_tail = 0;
2656 ur_w_trips = 1;
2657 return;
2658 }
2659
2660 int r_pad = 0;
2661 if (!jcp.transpose_src) {
2662 // If jcp.transpose_src, the buffer has physical padding
2663 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2664 r_pad = nstl::max(0,
2665 calculate_end_padding(
2666 jcp.l_pad, jcp.tr_ow, jcp.tr_iw, jcp.stride_w, ext_kw));
2667 }
2668 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2669 ur_w = max_ur_w;
2670 ur_w_trips = jcp.tr_ow / ur_w;
2671 ur_w_tail = jcp.tr_ow % ur_w;
2672 if ((ur_w_tail == 0 && jcp.r_pad != 0) || r_pad >= ur_w_tail) {
2673 if (ur_w_trips > 1) {
2674 ur_w_tail += ur_w;
2675 ur_w_trips--;
2676 } else {
2677 int ur_w_tail_total = ur_w + ur_w_tail;
2678 ur_w = (ur_w_tail_total % 4 == 0) ? ur_w_tail / 2
2679 : ur_w_tail / 2 + 1;
2680 ur_w_tail = ur_w_tail_total - ur_w;
2681 if (l_pad > ur_w / 2) {
2682 ur_w = (l_pad % 2 == 0) ? l_pad : l_pad + 1;
2683 ur_w_tail = ur_w_tail_total - ur_w;
2684 } else if (r_pad > ur_w_tail) {
2685 ur_w_tail = (r_pad % 2 == 0) ? r_pad : r_pad + 1;
2686 ur_w = ur_w_tail_total - ur_w_tail;
2687 }
2688 }
2689 }
2690 }
2691
2692 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::
compute_oh_step_unroll_ow_icblock(int ic_block_step)2693 compute_oh_step_unroll_ow_icblock(int ic_block_step) {
2694 Label kh_label, kd_label;
2695
2696 int ic_block = jcp.ic_block;
2697 int ic_tail = jcp.ic_tail;
2698 int ow = jcp.tr_ow;
2699 int r_pad = 0;
2700 int ur_w, ur_w_tail, ur_w_trips;
2701 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2702 assert(ur_w_tail == 0 && ur_w_trips == 1);
2703
2704 if (!jcp.transpose_src) {
2705 // If jcp.transpose_src, the buffer has physical padding
2706 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2707 int iw = jcp.tr_iw;
2708 r_pad = nstl::max(0,
2709 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2710 }
2711 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2712
2713 if (jcp.ndims == 5) {
2714 L(kd_label);
2715 mov(reg_src, aux_reg_src);
2716 mov(reg_kernel, aux_reg_kernel);
2717 }
2718
2719 mov(kj, reg_kh);
2720 L(kh_label);
2721 {
2722 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2723 // icb loop is supported for nxc layout only
2724 assert(IMPLICATION(generate_icb_loop,
2725 is_src_layout_nxc() && is_ddst_layout_nxc()));
2726 Label icb_block_label, icb_block_label_end;
2727 if (generate_icb_loop || ic_tail) {
2728 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
2729 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
2730 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2731 L(icb_block_label);
2732 }
2733
2734 if (jcp.uses_permw_transposition) {
2735 convert_src_to_vnni_format(ur_w, l_pad, r_pad, 0);
2736 xor_(b_ic, b_ic);
2737 }
2738
2739 const int ic_tail_loop_work = rnd_up(ic_tail, ic_block_step);
2740 for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
2741 const int src_offset = get_src_offset(i_b_ic, 0);
2742 compute_ic_block_step(ur_w, l_pad, r_pad, ic_block_step, src_offset,
2743 get_kernel_offset(i_b_ic, 0), 0, true);
2744 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2745 // We relax the boundary for reg_icb, as the src is already
2746 // converted to vnni_format with appropriate padding either through
2747 // transpose_src or convert_to_src_to_vnni_format. We can safely
2748 // allow compute_ic_block_step overstep as it operates on buffer
2749 // instead of src.
2750 if (ic_tail && i_b_ic + ic_block_step == ic_tail_loop_work) {
2751 assert(jcp.transpose_src || jcp.uses_permw_transposition);
2752 cmp(reg_icb, 0);
2753 jle(icb_block_label_end, T_NEAR);
2754 }
2755 }
2756 L(icb_block_label_end);
2757
2758 const auto src_icb_loop_shift_bytes = get_src_offset(ic_block, 0);
2759 const auto kernel_icb_loop_shift_bytes
2760 = get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw);
2761 if (generate_icb_loop) {
2762 add(reg_src, src_icb_loop_shift_bytes);
2763 safe_add(reg_kernel, kernel_icb_loop_shift_bytes, reg_long_offt);
2764
2765 assert(jcp.uses_permw_transposition);
2766 cmp(reg_icb, 0);
2767 jg(icb_block_label, T_NEAR);
2768 }
2769
2770 if (generate_icb_loop || ic_tail) {
2771 // restore pointers
2772 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2773 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2774 }
2775
2776 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2777 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2778 dec(kj);
2779 cmp(kj, 0);
2780 jg(kh_label, T_NEAR);
2781 }
2782
2783 if (jcp.ndims == 5) {
2784 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
2785 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
2786 dec(ki);
2787 cmp(ki, 0);
2788 jg(kd_label, T_NEAR);
2789 }
2790 }
2791
2792 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::
compute_oh_step_unroll_ow(int ic_block_step)2793 compute_oh_step_unroll_ow(int ic_block_step) {
2794 Label kh_label, ic_block_label, kd_label;
2795
2796 int ic_block = jcp.ic_block;
2797 const int ic_tail = jcp.ic_tail;
2798 int ow = jcp.tr_ow;
2799
2800 int r_pad = 0;
2801 int ur_w, ur_w_tail, ur_w_trips;
2802 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2803 assert(ur_w_tail == 0 && ur_w_trips == 1);
2804
2805 if (!jcp.transpose_src) {
2806 // If jcp.transpose_src, the buffer has physical padding
2807 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2808 int iw = jcp.tr_iw;
2809 r_pad = nstl::max(0,
2810 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2811 }
2812 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2813
2814 if (jcp.ndims == 5) {
2815 L(kd_label);
2816 mov(reg_src, aux_reg_src);
2817 mov(reg_kernel, aux_reg_kernel);
2818 }
2819
2820 mov(kj, reg_kh);
2821 L(kh_label);
2822 {
2823 size_t src_offset = get_src_offset(ic_block_step, 0);
2824
2825 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2826 // icb loop is supported for nxc layout only
2827 assert(IMPLICATION(generate_icb_loop,
2828 is_src_layout_nxc() && is_ddst_layout_nxc()));
2829 Label icb_block_label, icb_block_label_end;
2830 if (generate_icb_loop || ic_tail) {
2831 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
2832 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
2833 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2834 L(icb_block_label);
2835 }
2836
2837 xor_(b_ic, b_ic);
2838 if (jcp.uses_permw_transposition) {
2839 convert_src_to_vnni_format(ow, l_pad, r_pad, 0);
2840 xor_(b_ic, b_ic);
2841 }
2842
2843 L(ic_block_label);
2844 {
2845 compute_ic_block_step(
2846 ur_w, l_pad, r_pad, ic_block_step, 0, 0, 0, true);
2847 assert(jcp.ic_block % jcp.ic_block_step == 0);
2848 safe_add(reg_src, src_offset, reg_long_offt);
2849 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
2850 add(b_ic, ic_block_step);
2851 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2852 // We relax the boundary for reg_icb, as the src is already
2853 // converted to vnni_format with appropriate padding either through
2854 // transpose_src or convert_to_src_to_vnni_format. We can safely
2855 // allow compute_ic_block_step overstep as it operates on buffer
2856 // instead of src.
2857 if (ic_tail) {
2858 assert(jcp.transpose_src || jcp.uses_permw_transposition);
2859 cmp(reg_icb, 0);
2860 jle(icb_block_label_end, T_NEAR);
2861 }
2862 cmp(b_ic, jcp.ic_block);
2863 jl(ic_block_label, T_NEAR);
2864 }
2865 L(icb_block_label_end);
2866
2867 if (jcp.uses_permw_transposition) {
2868 if (generate_icb_loop || ic_tail) {
2869 // substract pointer shift made within ic block loop
2870 // and move to next ic block
2871 safe_add(reg_kernel,
2872 get_kernel_offset(-ic_block, jcp.kd * jcp.kh * jcp.kw),
2873 reg_long_offt);
2874
2875 cmp(reg_icb, 0);
2876 jg(icb_block_label, T_NEAR);
2877 // restore pointers
2878 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2879 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2880
2881 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2882 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2883 } else {
2884 add(reg_src,
2885 get_src_offset(0, 0, filter_h_to_src(1))
2886 - jcp.typesize_in * ic_block);
2887 }
2888 } else if (ic_tail) {
2889 // restore pointers
2890 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2891 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2892
2893 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2894 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2895 } else if (jcp.is_1stconv && !jcp.transpose_src) {
2896 // Fixup reg_src to point to the correct location
2897 safe_add(reg_src,
2898 get_src_offset(0, 0, filter_h_to_src(1))
2899 - src_offset * (jcp.ic_block / ic_block_step),
2900 reg_long_offt);
2901 } else {
2902 if (jcp.dilate_h > 0)
2903 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
2904 }
2905 if (!generate_icb_loop && !ic_tail)
2906 // substract pointer shift made within ic block loop
2907 // and move to next kh index
2908 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
2909 dec(kj);
2910 cmp(kj, 0);
2911 jg(kh_label, T_NEAR);
2912 }
2913 if (jcp.ndims == 5) {
2914 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
2915 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
2916 dec(ki);
2917 cmp(ki, 0);
2918 jg(kd_label, T_NEAR);
2919 }
2920 }
2921
compute_oh_step_common(int ic_block_step)2922 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_common(
2923 int ic_block_step) {
2924 Label kh_label, ic_block_label, ow_block_label, kd_label;
2925
2926 int ic_block = jcp.ic_block;
2927 int ic_tail = jcp.ic_tail;
2928 int ow = jcp.tr_ow;
2929 int r_pad = 0;
2930 if (!jcp.transpose_src) {
2931 // If jcp.transpose_src, the buffer has physical padding
2932 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2933 int iw = jcp.tr_iw;
2934 r_pad = nstl::max(0,
2935 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2936 }
2937 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2938
2939 int ur_w, ur_w_trips, ur_w_tail;
2940 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2941 assert(l_pad <= ur_w);
2942 assert(r_pad <= ur_w_tail);
2943
2944 auto src_comeback
2945 = get_src_offset(0, filter_w_to_src(0, ur_w_trips * ur_w, l_pad));
2946 auto ddst_comeback = get_ddst_offset(ur_w_trips * ur_w);
2947
2948 if (jcp.ndims == 5) {
2949 L(kd_label);
2950 mov(reg_src, aux_reg_src);
2951 mov(reg_kernel, aux_reg_kernel);
2952 }
2953
2954 bool use_kh_ic_ow_loop_order = !jcp.uses_permw_transposition;
2955 if (use_kh_ic_ow_loop_order) {
2956 assert(!jcp.uses_permw_transposition);
2957
2958 auto ic_loop = [=](int ic_block_step) {
2959 Label ow_block_label;
2960 // create a local copy
2961 int ur_w_blocks = ur_w_trips;
2962 auto src_offset = get_src_offset(ic_block_step, 0);
2963 if (l_pad != 0) {
2964 ur_w_blocks--;
2965 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
2966 add(reg_src,
2967 get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)));
2968 add(reg_ddst, get_ddst_offset(ur_w));
2969 }
2970
2971 if (ur_w_blocks > 0) {
2972 xor_(reg_ur_w_trips, reg_ur_w_trips);
2973 L(ow_block_label);
2974 {
2975 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
2976 add(reg_src,
2977 get_src_offset(0, filter_w_to_src(0, ur_w, 0)));
2978 add(reg_ddst, get_ddst_offset(ur_w));
2979
2980 inc(reg_ur_w_trips);
2981 cmp(reg_ur_w_trips, ur_w_blocks);
2982 jl(ow_block_label, T_NEAR);
2983 }
2984 }
2985
2986 if (ur_w_tail > 0) {
2987 compute_ic_block_step(
2988 ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true);
2989 }
2990
2991 sub(reg_src, src_comeback);
2992 sub(reg_ddst, ddst_comeback);
2993
2994 safe_add(reg_src, src_offset, reg_long_offt);
2995 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
2996 };
2997
2998 mov(kj, reg_kh);
2999 L(kh_label);
3000 {
3001 Label ic_tail_label, skip_ic_tail_offset_compensation;
3002 if (ic_tail) {
3003 // It appears currently, generate_icb_loop is not enabled here,
3004 // implying at most one icb is processed.
3005 assert(jcp.nb_ic_blocking_max == 1);
3006 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3007 } else {
3008 mov(reg_icb, ic_block);
3009 }
3010
3011 L(ic_block_label);
3012 {
3013 ic_loop(ic_block_step);
3014 sub(reg_icb, ic_block_step);
3015 // We relax the boundary for reg_icb, as the src is already
3016 // converted to vnni_format with appropriate padding either
3017 // through transpose_src or convert_to_src_to_vnni_format. We
3018 // can safely allow compute_ic_block_step overstep as it
3019 // operates on buffer instead of src.
3020 if (ic_tail) {
3021 assert(jcp.transpose_src || jcp.uses_permw_transposition);
3022 }
3023 cmp(reg_icb, 0);
3024 jg(ic_block_label, T_NEAR);
3025 }
3026
3027 if (ic_tail) {
3028 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3029 cmp(reg_icb, jcp.simd_w);
3030 je(skip_ic_tail_offset_compensation);
3031 add(reg_kernel,
3032 get_kernel_offset(
3033 jcp.ic_block - rnd_up(ic_tail, ic_block_step),
3034 0));
3035 safe_add(reg_src,
3036 get_src_offset(0, 0, filter_h_to_src(1))
3037 - get_src_offset(
3038 rnd_up(ic_tail, ic_block_step), 0),
3039 reg_long_offt);
3040 L(skip_ic_tail_offset_compensation);
3041 }
3042 if (jcp.is_1stconv && !jcp.transpose_src) {
3043 // Fixup reg_src to point to the correct location
3044 auto src_offset = get_src_offset(ic_block_step, 0);
3045 safe_add(reg_src,
3046 get_src_offset(0, 0, filter_h_to_src(1))
3047 - src_offset * (jcp.ic_block / ic_block_step),
3048 reg_long_offt);
3049 } else if (jcp.dilate_h > 0) {
3050 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
3051 }
3052 // substract pointer shift made within ic block loop
3053 // and move to next kh index
3054 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
3055 dec(kj);
3056 cmp(kj, 0);
3057 jg(kh_label, T_NEAR);
3058 }
3059 } else {
3060 assert(!jcp.is_1stconv);
3061 auto src_icbstep_shift = get_src_offset(1, 0);
3062
3063 auto ic_loop = [=](int ic_block_step) {
3064 int ic_work = ic_block;
3065 Label ow_block_label, ic_block_label_padl, ic_block_label_general,
3066 ic_block_label_tail;
3067 int ur_w_blocks = ur_w_trips;
3068 if (l_pad != 0) {
3069 ur_w_blocks--;
3070 xor_(b_ic, b_ic);
3071 if (jcp.uses_permw_transposition) {
3072 convert_src_to_vnni_format(ur_w, l_pad, 0, 0);
3073 }
3074 L(ic_block_label_padl);
3075 {
3076 compute_ic_block_step(
3077 ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
3078 safe_add(reg_src, src_icbstep_shift * ic_block_step,
3079 reg_long_offt);
3080 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3081
3082 add(b_ic, ic_block_step);
3083 cmp(b_ic, ic_work);
3084 jl(ic_block_label_padl, T_NEAR);
3085 }
3086 safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt);
3087 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3088 add(reg_src,
3089 get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)));
3090 add(reg_ddst, get_ddst_offset(ur_w));
3091 }
3092
3093 if (ur_w_blocks > 0) {
3094 xor_(reg_ur_w_trips, reg_ur_w_trips);
3095 L(ow_block_label);
3096 {
3097 if (jcp.uses_permw_transposition) {
3098 convert_src_to_vnni_format(ur_w, 0, 0, 0);
3099 }
3100 xor_(b_ic, b_ic);
3101 L(ic_block_label_general);
3102 {
3103 compute_ic_block_step(
3104 ur_w, 0, 0, ic_block_step, 0, 0, 0);
3105 safe_add(reg_src, src_icbstep_shift * ic_block_step,
3106 reg_long_offt);
3107 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3108
3109 add(b_ic, ic_block_step);
3110 cmp(b_ic, ic_work);
3111 jl(ic_block_label_general, T_NEAR);
3112 }
3113 safe_sub(reg_src, src_icbstep_shift * ic_work,
3114 reg_long_offt);
3115 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3116 add(reg_src, get_src_offset(0, filter_w_to_src(0, ur_w)));
3117 add(reg_ddst, get_ddst_offset(ur_w));
3118
3119 inc(reg_ur_w_trips);
3120 cmp(reg_ur_w_trips, ur_w_blocks);
3121 jl(ow_block_label, T_NEAR);
3122 }
3123 }
3124
3125 if (ur_w_tail > 0) {
3126 if (jcp.uses_permw_transposition) {
3127 convert_src_to_vnni_format(ur_w_tail, 0, r_pad, 0);
3128 }
3129 xor_(b_ic, b_ic);
3130 L(ic_block_label_tail);
3131 {
3132 compute_ic_block_step(
3133 ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true);
3134 safe_add(reg_src, src_icbstep_shift * ic_block_step,
3135 reg_long_offt);
3136 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3137
3138 add(b_ic, ic_block_step);
3139 cmp(b_ic, ic_work);
3140 jl(ic_block_label_tail, T_NEAR);
3141 }
3142 safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt);
3143 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3144 }
3145
3146 sub(reg_src, src_comeback);
3147 sub(reg_ddst, ddst_comeback);
3148 };
3149
3150 mov(kj, reg_kh);
3151 L(kh_label);
3152 {
3153 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3154 // icb loop is supported for nxc layout only
3155 assert(IMPLICATION(generate_icb_loop,
3156 is_src_layout_nxc() && is_ddst_layout_nxc()));
3157 Label icb_block_label, icb_block_label_cb, ic_tail_loop_label;
3158
3159 if (generate_icb_loop) {
3160 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
3161 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
3162 }
3163 if (ic_tail || generate_icb_loop)
3164 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3165 L(icb_block_label);
3166
3167 ic_loop(ic_block_step);
3168
3169 if (generate_icb_loop) {
3170 add(reg_src, get_src_offset(ic_block, 0));
3171 safe_add(reg_kernel,
3172 get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw),
3173 reg_long_offt);
3174 sub(reg_icb, ic_block);
3175 cmp(reg_icb, 0);
3176 jg(icb_block_label, T_NEAR);
3177 }
3178
3179 if (generate_icb_loop) {
3180 // restore pointers
3181 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
3182 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
3183 }
3184
3185 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3186 add(reg_kernel, get_kernel_offset(0, jcp.kw));
3187 dec(kj);
3188 cmp(kj, 0);
3189 jg(kh_label, T_NEAR);
3190 }
3191 }
3192 if (jcp.ndims == 5) {
3193 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3194 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3195 dec(ki);
3196 cmp(ki, 0);
3197 jg(kd_label, T_NEAR);
3198 }
3199 }
3200
compute_oh_step_disp()3201 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_disp() {
3202 int ic_block_step = jcp.ic_block_step;
3203
3204 bool too_large_to_unroll = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
3205 && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
3206
3207 int ow = jcp.tr_ow;
3208 if (jcp.ndims == 5) {
3209 /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
3210 * 'movs' must be guaranteed. */
3211 mov(ki, reg_kd_count);
3212 mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
3213 mov(aux_reg_src, reg_src);
3214 mov(aux_reg_kernel, reg_kernel);
3215 }
3216 if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) {
3217 compute_oh_step_unroll_ow_icblock(ic_block_step);
3218 } else if (ow <= max_ur_w) {
3219 compute_oh_step_unroll_ow(ic_block_step);
3220 } else {
3221 compute_oh_step_common(ic_block_step);
3222 }
3223
3224 // In harness_3d_reduction case calculation of diff_bias is called
3225 // for every ow row separately to be aligned with od loop in
3226 // compute_od_loop_common()
3227 if (jcp.harness == harness_3d_reduction) compute_diff_bias_row();
3228 if (jcp.ndims == 5) {
3229 mov(reg_src, aux_reg_src);
3230 mov(reg_kernel, aux_reg_kernel);
3231 mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
3232 od_step_comeback_pointers();
3233 } else {
3234 oh_step_comeback_pointers();
3235 }
3236 }
3237
maybe_zero_kernel()3238 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::maybe_zero_kernel() {
3239 if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
3240 Label skip_zeroing, zeroing_loop;
3241
3242 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3243 cmp(reg_tmp, 0);
3244 jz(skip_zeroing, T_NEAR);
3245
3246 Zmm zero = Zmm(0);
3247 vpxord(zero, zero, zero);
3248 if (jcp.with_bias) {
3249 Label skip_bias_zeroing;
3250 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
3251 test(reg_tmp, FLAG_IC_FIRST);
3252 jz(skip_bias_zeroing, T_NEAR);
3253
3254 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
3255 vmovups(ptr[reg_tmp], zero);
3256
3257 L(skip_bias_zeroing);
3258 if (jcp.harness == harness_compute_full_spatial)
3259 jmp(skip_zeroing, T_NEAR);
3260 }
3261
3262 const size_t kernel_block_bytes
3263 = get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
3264 Label icb_block_label, icb_block_label_cb;
3265
3266 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3267 // icb loop is supported for nxc layout only
3268 assert(IMPLICATION(
3269 generate_icb_loop, is_src_layout_nxc() && is_ddst_layout_nxc()));
3270 if (generate_icb_loop) {
3271 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
3272 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3273 L(icb_block_label);
3274 }
3275
3276 xor_(reg_tmp, reg_tmp);
3277 L(zeroing_loop);
3278 {
3279 assert(get_kernel_offset(1, 0) == cpu_isa_traits<avx512_core>::vlen);
3280 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3281 vmovups(ptr[reg_kernel + reg_tmp + get_kernel_offset(ic1, 0)],
3282 zero);
3283 add(reg_tmp, get_kernel_offset(jcp.ic_block, 0));
3284 cmp(reg_tmp, kernel_block_bytes);
3285 jnz(zeroing_loop);
3286 }
3287
3288 if (generate_icb_loop) {
3289 add(reg_kernel, kernel_block_bytes);
3290 sub(reg_icb, jcp.ic_block);
3291 cmp(reg_icb, 0);
3292 jg(icb_block_label, T_NEAR);
3293 // restore pointer
3294 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
3295 }
3296
3297 L(skip_zeroing);
3298 }
3299
compute_oh_loop_common(bool is_partial)3300 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_oh_loop_common(
3301 bool is_partial) {
3302 int b_pad = jcp.b_pad;
3303 int t_pad = jcp.t_pad;
3304 bool is_dilated = jcp.dilate_h != 0;
3305 int dilate_h = jcp.dilate_h + 1;
3306 int stride_h = jcp.stride_h;
3307 auto filter_step_size = get_kernel_offset(0, jcp.kw);
3308 auto src_step_size = get_src_offset(0, 0, 1);
3309 auto ddst_step_size = get_ddst_offset(0, 1);
3310 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
3311 oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
3312 oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
3313 oh_dilate_label_end, oh_dilate_setup_label_shift,
3314 oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
3315
3316 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
3317 int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
3318 int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
3319 int oh_head_overflow_end = div_up(t_pad, stride_h);
3320 int oh_tail_end = jcp.oh;
3321
3322 int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
3323 int ih_body_end
3324 = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
3325
3326 if (is_partial)
3327 mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
3328 else
3329 xor_(reg_oj, reg_oj);
3330
3331 /* Compute 'top' edge */
3332 if (t_pad > 0) {
3333 if (is_partial) {
3334 cmp(reg_oj, oh_head_overflow_end);
3335 jge(oh_tpad_tail_label_end, T_NEAR);
3336 }
3337 const int overflow
3338 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
3339 const int underflow = div_up(t_pad, dilate_h);
3340 const int initial_kh = jcp.kh - overflow - underflow;
3341
3342 // Setup reg_kh, reg_kernel, and reg_src
3343 mov(reg_kh, initial_kh);
3344 add(reg_kernel, filter_step_size * underflow);
3345 if (is_dilated) {
3346 const int tail = t_pad % dilate_h;
3347 const int shift = tail == 0 ? 0 : dilate_h - tail;
3348 mov(reg_ih_shift, shift);
3349 if (!is_partial) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3350 add(reg_src, src_step_size * shift);
3351 }
3352
3353 if (is_partial) {
3354 Label head_setup, head_setup_finish;
3355 cmp(reg_oj, 0);
3356 je(head_setup_finish, T_NEAR);
3357 mov(reg_oj_setup, reg_oj);
3358
3359 L(head_setup);
3360 if (is_dilated) {
3361 inc(reg_ih_shift);
3362 cmp(reg_ih_shift, dilate_h);
3363 jl(oh_dilate_setup_label_shift, T_NEAR);
3364 // unshift src as new kernel element enters
3365 sub(reg_src, src_step_size * (dilate_h - 1));
3366 xor_(reg_ih_shift, reg_ih_shift);
3367 }
3368 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3369 add(reg_kh, stride_h);
3370 sub(reg_kernel, filter_step_size * stride_h);
3371 if (is_dilated) {
3372 jmp(oh_dilate_setup_label_noshift, T_NEAR);
3373 L(oh_dilate_setup_label_shift);
3374 // shift src as old kernel element progresses
3375 add(reg_src, src_step_size * stride_h);
3376 L(oh_dilate_setup_label_noshift);
3377 }
3378 sub(reg_oj_setup, 1);
3379 jg(head_setup, T_NEAR);
3380 L(head_setup_finish);
3381
3382 if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3383 if (oh_head_end < oh_head_overflow_end) {
3384 cmp(reg_oj, oh_head_end);
3385 jge(oh_tpad_label_end, T_NEAR);
3386 }
3387 }
3388
3389 //Setup reg_kernel
3390 // If dilated, shift src ptr
3391 // Loop
3392 L(oh_tpad_label);
3393 compute_oh_step_disp();
3394 add(reg_ddst, ddst_step_size);
3395 if (is_dilated) {
3396 mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]);
3397 inc(reg_ih_shift);
3398 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3399 cmp(reg_ih_shift, dilate_h);
3400 jl(oh_dilate_label_shift, T_NEAR);
3401 // unshift src as new kernel element enters
3402 sub(reg_src, src_step_size * (dilate_h - 1));
3403 xor_(reg_ih_shift, reg_ih_shift);
3404 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3405 }
3406 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3407 add(reg_kh, stride_h);
3408 sub(reg_kernel, filter_step_size * stride_h);
3409 if (is_dilated) {
3410 jmp(oh_dilate_label_noshift, T_NEAR);
3411 L(oh_dilate_label_shift);
3412 // shift src as old kernel element progresses
3413 add(reg_src, src_step_size * stride_h);
3414 L(oh_dilate_label_noshift);
3415 }
3416 inc(reg_oj);
3417
3418 if (is_partial) {
3419 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3420 jge(oh_bpad_label_end, T_NEAR);
3421 }
3422 cmp(reg_oj, oh_head_end);
3423 jl(oh_tpad_label, T_NEAR);
3424
3425 L(oh_tpad_label_end);
3426 // need second loop to process kernel if it is larger than the src
3427 // (does not apply to dilations as they must have unit stride)
3428 if (oh_head_end < oh_head_overflow_end) {
3429 assert(!is_dilated);
3430
3431 cmp(reg_oj, oh_head_overflow_end);
3432 jge(oh_tpad_tail_label_end, T_NEAR);
3433
3434 mov(reg_kh, jcp.ih);
3435 L(oh_tpad_tail_label);
3436 {
3437 compute_oh_step_disp();
3438 add(reg_ddst, ddst_step_size);
3439 sub(reg_kernel, filter_step_size * stride_h);
3440
3441 inc(reg_oj);
3442
3443 if (is_partial) {
3444 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3445 jge(oh_bpad_label_end, T_NEAR);
3446 }
3447 cmp(reg_oj, oh_head_overflow_end);
3448 jl(oh_tpad_tail_label, T_NEAR);
3449 }
3450 }
3451 if (body_src_start_offset != 0) {
3452 add(reg_kernel, filter_step_size * body_src_start_offset);
3453 add(reg_src, src_step_size * body_src_start_offset);
3454 }
3455 L(oh_tpad_tail_label_end);
3456 }
3457
3458 if (is_partial) {
3459 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3460 jge(oh_bpad_label_end, T_NEAR);
3461 }
3462 cmp(reg_oj, oh_body_end);
3463 jge(oh_label_end, T_NEAR);
3464
3465 /* Compute middle block(s) */
3466 mov(reg_kh, jcp.kh);
3467 L(oh_label);
3468 {
3469 compute_oh_step_disp();
3470 add(reg_src, src_step_size * stride_h);
3471 add(reg_ddst, ddst_step_size);
3472
3473 inc(reg_oj);
3474
3475 if (is_partial) {
3476 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3477 jge(oh_bpad_label_end, T_NEAR);
3478 }
3479
3480 cmp(reg_oj, oh_body_end);
3481 jl(oh_label, T_NEAR);
3482 }
3483 L(oh_label_end);
3484
3485 /* Compute bottom edge */
3486 if (b_pad > 0) {
3487 if (is_partial) {
3488 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3489 jge(oh_bpad_label_end, T_NEAR);
3490 }
3491 cmp(reg_oj, jcp.oh);
3492 jge(oh_bpad_label_end, T_NEAR);
3493
3494 if (is_dilated) {
3495 // Assumes unit stride for dilations
3496 mov(reg_kh, jcp.kh - 1);
3497 xor_(reg_ih_shift, reg_ih_shift);
3498 } else {
3499 assert(jcp.dilate_h == 0);
3500 mov(reg_kh, jcp.ih - ih_body_end);
3501 }
3502 if (is_partial) {
3503 lea(reg_oj_setup,
3504 ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
3505 if (stride_h == 1 && !is_dilated) {
3506 sub(reg_kh, reg_oj_setup);
3507 } else {
3508 Label body_setup, body_setup_finish, dilate_skip;
3509 cmp(reg_oj_setup, 0);
3510 je(body_setup_finish, T_NEAR);
3511
3512 L(body_setup);
3513 if (is_dilated) {
3514 inc(reg_ih_shift);
3515 cmp(reg_ih_shift, dilate_h);
3516 jl(dilate_skip, T_NEAR);
3517 xor_(reg_ih_shift, reg_ih_shift);
3518 }
3519 sub(reg_kh, stride_h);
3520 L(dilate_skip);
3521 sub(reg_oj_setup, 1);
3522 jg(body_setup, T_NEAR);
3523 L(body_setup_finish);
3524 }
3525 }
3526
3527 if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3528 L(oh_bpad_label);
3529 {
3530 compute_oh_step_disp();
3531 add(reg_src, src_step_size * stride_h);
3532 add(reg_ddst, ddst_step_size);
3533
3534 if (is_dilated) {
3535 mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]);
3536 inc(reg_ih_shift);
3537 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3538 cmp(reg_ih_shift, dilate_h);
3539 jl(oh_dilate_label_end, T_NEAR);
3540 xor_(reg_ih_shift, reg_ih_shift);
3541 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3542 }
3543 sub(reg_kh, stride_h);
3544 L(oh_dilate_label_end);
3545 inc(reg_oj);
3546 if (is_partial) {
3547 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3548 jge(oh_bpad_label_end, T_NEAR);
3549 }
3550 cmp(reg_oj, oh_tail_end);
3551 jl(oh_bpad_label, T_NEAR);
3552 }
3553 }
3554 L(oh_bpad_label_end);
3555 }
3556
compute_od_loop_common(bool is_partial)3557 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_od_loop_common(
3558 bool is_partial) {
3559 assert(jcp.harness == harness_3d_reduction);
3560
3561 const int src_backpad_overlap
3562 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
3563
3564 const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
3565 const auto src_shift = get_src_offset(0, 0, jcp.ih);
3566 const auto ddst_shift = get_ddst_offset(0, jcp.oh);
3567
3568 const int kd_front_pad = nstl::max(0, jcp.f_pad);
3569 const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
3570
3571 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
3572 backpad_end_label, backpad_label;
3573
3574 /* initially offset 'kd' by f_pad */
3575 mov(reg_src_d, ptr[param + GET_OFF(src)]);
3576 mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
3577
3578 if (is_partial) {
3579 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3580 mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
3581 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3582 } else {
3583 const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
3584 const int kd_offset = get_kernel_offset(
3585 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
3586 add(reg_kernel, kd_offset);
3587 xor_(reg_d_index, reg_d_index);
3588 mov(reg_kd_count, kd_padding);
3589 }
3590
3591 cmp(reg_kd_count, 0);
3592 jle(loop_end_label, T_NEAR); // no iterations along kd
3593 if (is_partial)
3594 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3595 else
3596 cmp(reg_d_index, jcp.od);
3597 jge(loop_end_label, T_NEAR); // no iterations along depth dimension
3598
3599 L(d_loop_label);
3600
3601 mov(reg_src, reg_src_d);
3602 mov(reg_ddst, reg_ddst_d);
3603
3604 mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
3605 mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
3606 mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
3607
3608 compute_oh_loop_common();
3609
3610 mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
3611 mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
3612 mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
3613
3614 /* Compute 'front' edge */
3615 if (jcp.f_pad > 0) {
3616 /* Check if within fpad region */
3617 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
3618 jge(fpad_end_label, T_NEAR);
3619
3620 /* Fpad steps */
3621 sub(reg_kernel, filter_shift * jcp.stride_d);
3622 add(reg_kd_count, jcp.stride_d);
3623
3624 /* Final number of kernel elements that overlap with src */
3625 const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
3626 cmp(reg_kd_count, src_ker_overlap);
3627 jle(common_block_label, T_NEAR);
3628
3629 /* Correct any excess shifts to kernel and src */
3630 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
3631 /* Filter has moved beyond padding (adjust for stride effects) */
3632 if (jcp.f_pad % jcp.stride_d != 0) {
3633 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
3634 add(reg_kernel, filter_shift * src_corr);
3635 add(reg_src_d, src_shift * src_corr);
3636 }
3637 } else {
3638 /* Filter still overlaps padding (complete reset) */
3639 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
3640 }
3641
3642 /* Apply correction */
3643 mov(reg_kd_count, src_ker_overlap);
3644 jmp(common_block_label);
3645
3646 L(fpad_end_label);
3647 }
3648
3649 /* Compute bottom edge */
3650 if (jcp.back_pad > 0) {
3651
3652 /* Check if within back_pad region */
3653 cmp(reg_d_index, src_backpad_overlap - 1);
3654 jl(backpad_end_label, T_NEAR);
3655 jg(backpad_label, T_NEAR);
3656
3657 /* Execute overlap correction between the filter and the initial
3658 * back_pad region. */
3659 mov(reg_kd_count,
3660 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
3661 jmp(backpad_end_label, T_NEAR);
3662
3663 L(backpad_label);
3664 sub(reg_kd_count, jcp.stride_d);
3665 cmp(reg_kd_count, 0);
3666 jle(loop_end_label, T_NEAR);
3667
3668 L(backpad_end_label);
3669 }
3670
3671 /* Compute middle block */
3672 add(reg_src_d, src_shift * jcp.stride_d);
3673
3674 /* Execute common block and loop */
3675 L(common_block_label);
3676 add(reg_ddst_d, ddst_shift);
3677 inc(reg_d_index);
3678 if (is_partial)
3679 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3680 else
3681 cmp(reg_d_index, jcp.od);
3682 jl(d_loop_label, T_NEAR);
3683
3684 L(loop_end_label);
3685 }
3686
3687 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
compute_full_spat_loop()3688 compute_full_spat_loop() {
3689 // General code layout:
3690 //
3691 // Blocking over OH -- top level
3692 // (Reduces L2 pressure; not very useful right now)
3693 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
3694 // Loop over OH block -- emit_h_loop()
3695 // Loop over OW blocks -- emit_fma_block()
3696 // (Supports both fully unrolled and partially unrolled
3697 // versions to reduce code size)
3698 // Loop over OW block -- emit_fma_step()
3699
3700 auto src_row_size = get_src_offset(0, 0, 1);
3701 auto ddst_row_size = get_ddst_offset(0, 1);
3702 auto row_size = src_row_size + ddst_row_size;
3703
3704 int h_block_size = jcp.oh;
3705 int h_last_block_size = h_block_size;
3706 int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3707 auto working_set_size = row_size * h_block_size;
3708
3709 if (working_set_size > full_spat_max_working_set_size) {
3710 assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3711
3712 while (working_set_size > full_spat_opt_working_set_size
3713 && h_block_size >= min_h_block_size) {
3714 for (int i = 2; i <= h_block_size; i++)
3715 if (i == h_block_size)
3716 h_block_size = h_block_size / 2;
3717 else if (h_block_size % i == 0) {
3718 h_block_size = h_block_size / i;
3719 break;
3720 }
3721 working_set_size = row_size * h_block_size;
3722 }
3723 h_block_size = nstl::max(min_h_block_size, h_block_size);
3724 h_last_block_size = jcp.oh % h_block_size;
3725 if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3726 }
3727
3728 Opmask reg_h_block = k1;
3729 Reg64 reg_kh = rax;
3730 Reg64 reg_kw = rbx;
3731 Reg64 reg_tmp = abi_not_param1;
3732 Reg32 reg_tmp_w = reg_tmp.cvt32();
3733 Reg64 reg_ohs = rdx;
3734 Reg64 reg_ihs = rsi;
3735 Reg64 reg_h = r8;
3736 Reg64 reg_i = r9;
3737 Reg64 reg_j = r10;
3738
3739 Reg64 reg_src = r13;
3740 Reg64 reg_ddst = r14;
3741 Reg64 reg_ker = r15;
3742
3743 Reg64 reg_src_save = abi_param1;
3744 Reg64 reg_ddst_save = reg_tmp;
3745
3746 auto zmm_ddst = [&](int oi) { return Zmm(24 + oi % 8); };
3747 auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
3748 auto src_addr = [&](int oi, int ic1) {
3749 return zword_b[reg_src + get_src_offset(ic1, oi)];
3750 };
3751 auto ddst_addr = [&](int oi) {
3752 auto ow_per_oc = 2;
3753 return ptr[reg_ddst + get_ddst_offset(ow_per_oc * oi)];
3754 };
3755 auto ker_addr
3756 = [&](int ic1) { return ptr[reg_ker + get_kernel_offset(ic1, 0)]; };
3757
3758 auto emit_block = [&]() {
3759 auto pad_ow = jcp.tr_ow;
3760 int ow_per_oc = 2;
3761 int def_step_size = 16;
3762 bool has_w_tail = pad_ow % def_step_size != 0;
3763 bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
3764
3765 auto emit_step = [&](int ur_ow, bool is_w_tail) {
3766 int tail_size = pad_ow % ur_ow;
3767 int this_ur_ow = (is_w_tail && tail_size) ? tail_size : ur_ow;
3768 auto numloads = 1;
3769
3770 assert(this_ur_ow % ow_per_oc == 0);
3771 int steps = this_ur_ow / ow_per_oc;
3772 for (int oi_base = 0; oi_base < steps; oi_base += numloads) {
3773 for (int oi_offset = 0; oi_offset < numloads; oi_offset++) {
3774 int oi = oi_base + oi_offset;
3775 if (oi < steps) {
3776 vmovups(zmm_ddst(oi), ddst_addr(oi));
3777 } else {
3778 auto zmm = zmm_ddst(oi);
3779 vpxord(zmm, zmm, zmm);
3780 }
3781 }
3782
3783 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3784 vdpbf16ps(zmm_ker(ic1), zmm_ddst(oi_base),
3785 src_addr(ow_per_oc * oi_base, ic1));
3786 }
3787 }
3788 };
3789
3790 if (full_w_unroll) {
3791 emit_step(pad_ow, true);
3792 } else {
3793 Label w_loop;
3794 int num_w_iters = pad_ow / def_step_size;
3795 mov(reg_i, num_w_iters);
3796 L(w_loop);
3797 {
3798 emit_step(def_step_size, false);
3799 add(reg_src, get_src_offset(0, def_step_size));
3800 add(reg_ddst, get_ddst_offset(def_step_size));
3801 sub(reg_i, 1);
3802 jnz(w_loop);
3803 }
3804 if (has_w_tail) { emit_step(def_step_size, true); }
3805 // reset reg_src and reg_ddst because emit_h_loop expects
3806 // unmodified pointers
3807 int w_offset = num_w_iters * def_step_size;
3808 sub(reg_src, get_src_offset(0, w_offset));
3809 sub(reg_ddst, get_ddst_offset(w_offset));
3810 }
3811 };
3812
3813 auto emit_h_loop = [&]() {
3814 Label h_loop, skip_h_loop;
3815 mov(reg_j, 1);
3816 cmp(reg_j, reg_h);
3817 je(skip_h_loop, T_NEAR);
3818 L(h_loop);
3819 {
3820 emit_block();
3821
3822 add(reg_src, get_src_offset(0, 0, 1));
3823 add(reg_ddst, get_ddst_offset(0, 1));
3824 add(reg_j, 1);
3825 cmp(reg_j, reg_h);
3826 jb(h_loop);
3827 }
3828 L(skip_h_loop);
3829
3830 emit_block();
3831 };
3832
3833 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3834 xor_(reg_kh, reg_kh);
3835 Label kh_loop, kh_loop_end;
3836
3837 int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3838 // NB: this is correct because we only support t_pad = kh / 2 and thus
3839 // ih == oh
3840 int ih_block_size = oh_block_size
3841 + (!is_first_block + !is_last_block) * jcp.t_pad;
3842
3843 L(kh_loop);
3844 {
3845 if (is_first_block) {
3846 xor_(reg_tmp, reg_tmp);
3847 mov(reg_ohs, jcp.t_pad);
3848 sub(reg_ohs, reg_kh);
3849 cmovb(reg_ohs, reg_tmp);
3850
3851 mov(reg_ihs, reg_ohs);
3852 sub(reg_ihs, jcp.t_pad);
3853 add(reg_ihs, reg_kh);
3854 } else {
3855 xor_(reg_ohs, reg_ohs);
3856 mov(reg_ihs, reg_kh);
3857 }
3858
3859 mov(reg_tmp, oh_block_size);
3860 sub(reg_tmp, reg_ohs);
3861 mov(reg_h, ih_block_size);
3862 sub(reg_h, reg_ihs);
3863 cmp(reg_tmp, reg_h);
3864 cmovb(reg_h, reg_tmp);
3865
3866 Label kh_loop_work;
3867 cmp(reg_h, 0);
3868 jg(kh_loop_work, T_NEAR);
3869
3870 // empty h loop for this jcp.kh:
3871 // - set the ddst to 0 if necessary
3872 // - move ker pt
3873 // - jump to the end
3874 sub(reg_h, 1);
3875 Label skip_ker_zeroing;
3876
3877 // The reg_ker ptr has highest bit set if the ddst needs to be
3878 // zeroed. Those who have byte-aligned their data will suffer the
3879 // consequences :(
3880 // TODO: move the flag to a mask register? (Roma)
3881 test(reg_ker, 1);
3882 jz(skip_ker_zeroing, T_NEAR);
3883
3884 Label zeroing_loop;
3885 vpxord(zmm0, zmm0, zmm0);
3886 and_(reg_ker, ~1); // temporarily clear the zeroing flag
3887 mov(reg_tmp, jcp.kw);
3888 L(zeroing_loop);
3889 {
3890 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3891 vmovups(ker_addr(ic1), zmm0);
3892 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
3893 sub(reg_tmp, 1);
3894 jnz(zeroing_loop, T_NEAR);
3895 }
3896 // restore the zeroing flag (it will be cleared after the end of
3897 // emit_kh_kw_loop, but we may need it until then)
3898 or_(reg_ker, 1);
3899 jmp(kh_loop_end, T_NEAR);
3900
3901 L(skip_ker_zeroing);
3902 add(reg_ker, get_kernel_offset(0, jcp.kw));
3903 jmp(kh_loop_end, T_NEAR);
3904
3905 L(kh_loop_work);
3906
3907 mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
3908 mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
3909
3910 add(reg_src, reg_ihs);
3911 add(reg_ddst, reg_ohs);
3912
3913 Label kw_loop;
3914 xor_(reg_kw, reg_kw);
3915 L(kw_loop);
3916 {
3917 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3918 auto zmm = zmm_ker(ic1);
3919 vpxord(zmm, zmm, zmm);
3920 }
3921
3922 mov(reg_ddst_save, reg_ddst);
3923 mov(reg_src_save, reg_src);
3924 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
3925
3926 emit_h_loop();
3927
3928 mov(reg_ddst, reg_ddst_save);
3929 mov(reg_src, reg_src_save);
3930
3931 Label do_store;
3932 // The reg_ker ptr has highest bit set if the ddst needs to
3933 // be zeroed. Those who have byte-aligned their data will
3934 // suffer the consiquences :(
3935 mov(reg_tmp, reg_ker);
3936 and_(reg_ker, ~1);
3937 test(reg_tmp, 1);
3938 jnz(do_store, T_NEAR);
3939
3940 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3941 auto zmm = zmm_ker(ic1);
3942 vaddps(zmm, ker_addr(ic1));
3943 }
3944
3945 L(do_store);
3946 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3947 auto zmm = zmm_ker(ic1);
3948 vmovups(ker_addr(ic1), zmm);
3949 }
3950
3951 mov(reg_ker, reg_tmp);
3952 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
3953 add(reg_kw, 1);
3954 cmp(reg_kw, jcp.kw);
3955 jl(kw_loop);
3956 }
3957
3958 sub(reg_src, reg_ihs);
3959 sub(reg_ddst, reg_ohs);
3960
3961 L(kh_loop_end);
3962 add(reg_kh, 1);
3963 cmp(reg_kh, jcp.kh);
3964 jl(kh_loop);
3965 }
3966 };
3967
3968 mov(reg_src, ptr[param + GET_OFF(src)]);
3969 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
3970 mov(reg_ker, ptr[param + GET_OFF(filt)]);
3971 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3972 or_(reg_ker, reg_tmp);
3973
3974 bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
3975
3976 auto src_row_step = get_src_offset(0, 0, 1);
3977 auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
3978 auto ddst_block_step = get_ddst_offset(0, h_block_size);
3979
3980 emit_kh_kw_loop(true, single_kh_kw_loop);
3981
3982 if (!single_kh_kw_loop) {
3983 auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
3984 sub(reg_ker, ker_reset_offset);
3985 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
3986
3987 add(reg_src, first_src_block_step);
3988 add(reg_ddst, ddst_block_step);
3989
3990 int num_innermost_iters
3991 = (jcp.oh - h_last_block_size) / h_block_size - 1;
3992 if (num_innermost_iters > 0) {
3993 Label h_block_loop;
3994
3995 mov(reg_tmp_w, num_innermost_iters);
3996 kmovw(reg_h_block, reg_tmp_w);
3997 L(h_block_loop);
3998 {
3999 emit_kh_kw_loop(false, false);
4000 sub(reg_ker, ker_reset_offset);
4001 add(reg_src, src_row_step * h_block_size);
4002 add(reg_ddst, ddst_block_step);
4003
4004 kmovw(reg_tmp_w, reg_h_block);
4005 sub(reg_tmp_w, 1);
4006 kmovw(reg_h_block, reg_tmp_w);
4007 jnz(h_block_loop);
4008 }
4009 }
4010
4011 emit_kh_kw_loop(false, true);
4012 }
4013 }
4014
compute_loop()4015 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_loop() {
4016 Reg64 reg_mask_load = r11;
4017 if (jcp.uses_permw_transposition) {
4018
4019 mov(reg_mask_load.cvt32(), 0xffffffff);
4020 kmovd(m_ffffffff, reg_mask_load.cvt32());
4021
4022 mov(reg_mask_load.cvt32(), 0x0000ffff);
4023 kmovd(m_0000ffff, reg_mask_load.cvt32());
4024
4025 mov(reg_mask_load.cvt32(), 0xffff0000);
4026 kmovd(m_ffff0000, reg_mask_load.cvt32());
4027 const int oc_tail = jcp.oc_tail;
4028 if (oc_tail) {
4029 mov(reg_mask_load.cvt32(), (1 << oc_tail) - 1);
4030 kmovd(m_0000_oc_tail, reg_mask_load.cvt32());
4031 kshiftld(m_oc_tail_0000, m_0000_oc_tail, 16);
4032 }
4033 const int ic_tail = jcp.ic_tail;
4034 if (ic_tail) {
4035 mov(reg_mask_load.cvt32(), (1 << ic_tail) - 1);
4036 kmovd(m_0000_ic_tail, reg_mask_load.cvt32());
4037 kshiftld(m_ic_tail_0000, m_0000_ic_tail, 16);
4038 }
4039 } else if (jcp.is_1stconv && !jcp.transpose_src) {
4040 if (jcp.stride_w == 1) {
4041 int ieveryother_mask = 0x55555555;
4042 mov(reg_mask_load.cvt32(), ieveryother_mask);
4043 kmovd(everyother_mask, reg_mask_load.cvt32());
4044 kshiftld(everyother_shift_mask, everyother_mask, 1);
4045 } else {
4046 mov(reg_mask_load.cvt32(), 0xffffffff);
4047 kmovd(m_ffffffff, reg_mask_load.cvt32());
4048 }
4049 }
4050
4051 mov(reg_src, ptr[param + GET_OFF(src)]);
4052 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4053 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4054
4055 maybe_zero_kernel();
4056 maybe_compute_diff_bias();
4057
4058 switch (jcp.harness) {
4059 case harness_3d_reduction: compute_od_loop_common(true); break;
4060 case harness_2d_reduction: compute_oh_loop_common(true); break;
4061 case harness_mb_reduction: compute_oh_loop_common(); break;
4062 case harness_compute_full_spatial: compute_full_spat_loop(); break;
4063 default: assert(!"Invalid harness type");
4064 }
4065 }
4066
setup_stack_space()4067 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::setup_stack_space() {
4068
4069 if ((jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1)
4070 || jcp.uses_permw_transposition) {
4071 int ur_w, ur_w_tail, ur_w_trips;
4072 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
4073 ur_w = nstl::max(ur_w, ur_w_tail);
4074 ic_block_step_stack_size = jcp.uses_permw_transposition
4075 ? permw_stack_size(ur_w)
4076 : interleave_stack_size(ur_w, jcp.ic_block_step);
4077 } else
4078 ic_block_step_stack_size = extern_ic_block_step_stack_size;
4079
4080 permw_buffer_start = 0;
4081 kd_count_offset = ic_block_step_stack_size;
4082 src_d_offset = ic_block_step_stack_size + 8;
4083 ddst_d_offset = ic_block_step_stack_size + 16;
4084 d_index_offset = ic_block_step_stack_size + 24;
4085 trans_tmp_offset = ic_block_step_stack_size + 32;
4086 ih_dilate_shift = ic_block_step_stack_size + 40;
4087 icb_loop_ker_ptr = ic_block_step_stack_size + 48;
4088 icb_loop_src_ptr = ic_block_step_stack_size + 56;
4089 stack_space_needed = ic_block_step_stack_size + 64;
4090 }
4091
generate()4092 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::generate() {
4093 preamble();
4094
4095 setup_stack_space();
4096
4097 sub(rsp, stack_space_needed);
4098
4099 compute_loop();
4100
4101 add(rsp, stack_space_needed);
4102
4103 postamble();
4104
4105 if (jcp.uses_permw_transposition) {
4106 align(64);
4107 L(dst_prm_table);
4108 const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20,
4109 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
4110 29, 14, 30, 15, 31};
4111
4112 for (size_t i = 0; i < 32; ++i)
4113 dw(dst_prm_array[i]);
4114 }
4115 }
4116
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_bias_md,memory_desc_t & diff_dst_md,int nthreads)4117 status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
4118 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4119 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4120 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4121 const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
4122
4123 const memory_desc_wrapper src_d(&src_md);
4124 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4125 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4126 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4127
4128 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4129 int ndims = src_d.ndims();
4130
4131 jcp = zero<decltype(jcp)>();
4132 jcp.nthr = nthreads;
4133 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
4134 : bf16_emulation_t::get_isa();
4135 jcp.ver = ver_vnni;
4136 jcp.ndims = ndims;
4137 jcp.prop_kind = cd.prop_kind;
4138
4139 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4140 jcp.mb = src_d.dims()[0];
4141
4142 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4143 jcp.oc_without_padding = jcp.oc;
4144 jcp.ic = src_d.dims()[1] / jcp.ngroups;
4145
4146 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4147 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4148 jcp.iw = src_d.dims()[ndims - 1];
4149 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4150 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4151 jcp.ow = diff_dst_d.dims()[ndims - 1];
4152
4153 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4154 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4155 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4156
4157 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4158 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4159 jcp.l_pad = cd.padding[0][ndims - 3];
4160
4161 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4162 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4163 jcp.stride_w = cd.strides[ndims - 3];
4164
4165 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4166 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4167 jcp.dilate_w = cd.dilates[ndims - 3];
4168
4169 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4170 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4171 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4172
4173 bool ok = true
4174 // general condition to simplify dilations
4175 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4176 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4177 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4178 // special condition to simplify dilations in compute_oh_loop_common
4179 && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4180 if (!ok) return status::unimplemented;
4181
4182 jcp.r_pad = nstl::max(0,
4183 calculate_end_padding(
4184 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4185 jcp.b_pad = nstl::max(0,
4186 calculate_end_padding(
4187 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4188 jcp.back_pad = nstl::max(0,
4189 calculate_end_padding(
4190 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4191
4192 /* XXX: no support for padding when dilation_d > 0 */
4193 if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4194 return status::unimplemented;
4195
4196 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4197 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4198 jcp.ohp = jcp.oh;
4199 jcp.owp = jcp.ow;
4200 jcp.aligned_threads = 0;
4201
4202 jcp.simd_w = simd_w;
4203 jcp.oc_block = simd_w;
4204 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
4205 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
4206 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4207 auto curr_src_tag = src_d.matches_one_of_tag(
4208 dat_tag_nxc, dat_tag_nCx16c, dat_tag_ncx);
4209 auto curr_dst_tag
4210 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
4211 bool is_data_layout_nxc
4212 = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag);
4213
4214 jcp.is_1stconv = is_1stconv(jcp);
4215
4216 bool ok_to_pad_channels
4217 = (jcp.ngroups == 1) && !jcp.is_1stconv && !is_data_layout_nxc;
4218
4219 if (ok_to_pad_channels) {
4220 jcp.oc = rnd_up(jcp.oc, simd_w);
4221 jcp.ic = rnd_up(jcp.ic, simd_w);
4222 }
4223
4224 auto src_tag = is_data_layout_nxc
4225 ? dat_tag_nxc
4226 : (jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx16c);
4227 auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
4228 auto wei_tag = jcp.is_1stconv
4229 ? pick(2 * ndims - 6 + with_groups, Owi16o, gOwi16o, Ohwi16o,
4230 gOhwi16o, Odhwi16o, gOdhwi16o)
4231 : pick(2 * ndims - 6 + with_groups, OIw16i16o, gOIw16i16o,
4232 OIhw16i16o, gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o);
4233
4234 if (src_md.format_kind == format_kind::any) {
4235 CHECK(memory_desc_init_by_tag(src_md, src_tag));
4236 } else if (curr_src_tag != src_tag)
4237 return status::unimplemented;
4238 jcp.src_tag = src_tag;
4239
4240 if (diff_dst_md.format_kind == format_kind::any) {
4241 CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
4242 } else if (curr_dst_tag != dst_tag)
4243 return status::unimplemented;
4244 jcp.dst_tag = dst_tag;
4245
4246 if (diff_weights_md.format_kind == format_kind::any) {
4247 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4248 jcp.wei_tag = wei_tag;
4249 } else {
4250 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4251 if (jcp.wei_tag != wei_tag) return status::unimplemented;
4252 }
4253
4254 /* conditions on bias memory */
4255 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
4256 if (jcp.with_bias) {
4257 if (diff_bias_d.format_kind() == format_kind::any)
4258 CHECK(memory_desc_init_by_tag(diff_bias_md, x));
4259 }
4260 jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
4261 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
4262
4263 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
4264
4265 /* kernel applicability check wrt boundaries
4266 * the conditions are quite general across the kernels we have,
4267 * but ideally the check should belong to a specific kernel... */
4268 const int max_pad_h = ext_kh / 2;
4269 const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
4270 && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
4271 && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd
4272 && IMPLICATION(jcp.is_1stconv && jcp.ow > max_ur_w,
4273 jcp.l_pad < max_ur_w && ext_kw <= jcp.ow);
4274 if (!boundaries_ok) return status::unimplemented;
4275
4276 const int max_kw = jcp.is_1stconv ? 24 : 14;
4277 /* yet another common check */
4278 if (jcp.kw > max_kw) return status::unimplemented;
4279
4280 jcp.wei_dt = diff_weights_d.data_type();
4281
4282 jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
4283 if (ok_to_pad_channels) jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4284 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
4285 ok = true && one_of(ndims, 3, 4, 5)
4286 && everyone_is(
4287 data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4288 && one_of(diff_weights_d.data_type(), data_type::f32,
4289 data_type::bf16);
4290 if (!ok) return status::unimplemented;
4291
4292 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.ic_block : 0;
4293 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.oc_block : 0;
4294
4295 if (jcp.is_1stconv) {
4296 jcp.ic_block_step = 24 / jcp.kw;
4297 while (jcp.ic_block % jcp.ic_block_step != 0)
4298 jcp.ic_block_step--;
4299 } else {
4300 jcp.ic_block_step
4301 = jcp.kw <= 3 ? 8 : (jcp.kw < 7 ? 4 : (jcp.kw <= 12 ? 2 : 1));
4302 }
4303
4304 // jcp.uses_permw_transposition = false shows better performance for
4305 // resnet50 v1.5 problems
4306 // jcp.uses_permw_transposition = true works better for 3d 1x1x1 problems
4307 const bool is_permw_applicable
4308 = !jcp.is_1stconv && jcp.stride_w == 1 && jcp.dilate_w == 0;
4309 const bool apply_permw_blocked = !is_data_layout_nxc && ndims == 5
4310 && jcp.kw == 1 && jcp.ic_block_step > 4;
4311 // Threshold is based on performance measurements
4312 const bool apply_permw_nxc = is_data_layout_nxc && ndims == 3
4313 && nstl::max(jcp.ic, jcp.oc) <= 32;
4314 jcp.uses_permw_transposition
4315 = is_permw_applicable && (apply_permw_blocked || apply_permw_nxc);
4316
4317 jcp.kernel_kind = embd_bcast;
4318 if (jcp.uses_permw_transposition && jcp.kw <= 3)
4319 jcp.kernel_kind = expl_bcast;
4320 if (jcp.uses_permw_transposition && jcp.kernel_kind == expl_bcast)
4321 jcp.ic_block_step = jcp.kw <= 3 ? 4 : (jcp.kw < 7 ? 2 : 1);
4322
4323 if (jcp.uses_permw_transposition) {
4324 jcp.transpose_src = false;
4325 jcp.transpose_dst = false;
4326 } else if (jcp.is_1stconv && IMPLICATION(is_data_layout_nxc, jcp.ic == 1)) {
4327 jcp.transpose_src = false;
4328 jcp.transpose_dst = true;
4329 } else {
4330 jcp.transpose_src = true;
4331 jcp.transpose_dst = true;
4332 }
4333
4334 const bool is_2d = (ndims == 4);
4335 const bool is_3d = (ndims == 5);
4336 jcp.typesize_in = sizeof(bfloat16_t);
4337 jcp.typesize_out = sizeof(float);
4338 const dim_t cache_l2
4339 = platform::get_per_core_cache_size(2) / jcp.typesize_out;
4340
4341 // Observation: Given large 3D shapes with large filter size, 1st nspc
4342 // bwd_w convolution benefits from non-temporal stores in diff_dst
4343 // transformation but not so much from blocking w.r.t. depth dimension
4344 // In particular, it's optimized for i3D 1st convolution
4345 const bool nt_stores_ok = is_data_layout_nxc
4346 && dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= 2 * cache_l2
4347 && jcp.kd >= 6 && jcp.kh >= 6 && jcp.kw >= 6;
4348
4349 // Performancewise transposition of diff_dst tensor is one of the major
4350 // bottleneck in 1st convolution. Thus for large diff_dst size we can
4351 // potentially further split up transposition in smaller chunks to achieve
4352 // better cache reuse
4353 const bool large_diff_dst_size
4354 = dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= cache_l2;
4355
4356 // For two dimensional diff_dst tensor blocking along height demands
4357 // non-trivial work along width dimension. Similarly, for three dimensional
4358 // diff_dst tensor enough work must be present in the joint width-height
4359 // dimension. Finally, there is no blocking along the width dimension
4360 const bool blocking_ok = large_diff_dst_size
4361 && IMPLICATION(is_2d, jcp.ow >= 124 && jcp.oh > 1)
4362 && IMPLICATION(is_3d, jcp.ow * jcp.oh >= 64 * 124 && jcp.od > 1)
4363 && (is_2d || is_3d);
4364
4365 // TODO: Find more shapes (especially 3D with large spatials) for which
4366 // local transposition will be beneficial. Furthermore, for TBB threads
4367 // more shapes can potentially benefit from spatial blocking
4368 bool use_spatial_blocking = jcp.is_1stconv && !nt_stores_ok && blocking_ok;
4369 int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
4370 if (use_spatial_blocking) {
4371 // Default value, works best most of the times
4372 // TODO: For 3D shapes with intermediate sizes especially the ones not
4373 // belonging to the 1st convolution, we potentially have more scope
4374 // for optimization
4375 optimal_blk_size = 1;
4376
4377 // Diff_weights computation can be roughly broken down into
4378 // the following three steps
4379 // = [Src transform*] + [Diff_dst transform] + [Weights computation]
4380 //
4381 // where the bottleneck lies with diff_dst transform that spatial
4382 // blocking tries to mitigate by avoiding cache thrashing.
4383 // *note: Src transform may not always be needed.
4384 //
4385 // In an idealistic scenario, optimal_blk_size will be an explicit
4386 // function of the following form
4387 // optimal_blk_size = f(od, oh, ow, oc)
4388 //
4389 // though owing to lack of data points w.r.t. 1st convolution shapes it
4390 // is approximated by one with few exceptional cases [found by manual
4391 // optimization] as written below
4392
4393 if (is_2d && utils::one_of(jcp.oh, 149, 300, 224, 512, 608)) {
4394 switch (jcp.oh) {
4395 case 149: optimal_blk_size = 10; break;
4396 case 224: optimal_blk_size = 56; break;
4397 case 300: optimal_blk_size = 30; break;
4398 case 512: optimal_blk_size = 8; break;
4399 case 608: optimal_blk_size = 10; break;
4400 }
4401 }
4402 }
4403
4404 jcp.global_transpose = dnnl_thr_syncable() && !use_spatial_blocking;
4405 jcp.use_nt_stores_ddst = jcp.global_transpose && nt_stores_ok;
4406 jcp.spatial_blk_size = optimal_blk_size;
4407
4408 const bool padding_ok = IMPLICATION(!jcp.transpose_src,
4409 jcp.l_pad < max_ur_w && jcp.r_pad < max_ur_w
4410 && ext_kw <= jcp.iw + 1);
4411 if (!padding_ok) return status::unimplemented;
4412
4413 const int tr_round = 2;
4414 // Logic for tr_pad calculation: transpose is used in the extern kernel.
4415 // There is a memory usage optimization where physical padding is shared
4416 // between transpose buffers. In calculating on a row, data is read from the
4417 // src 2 elements at a time due to the bf16 broadcast. Calculation starts
4418 // at the beginning of the left padding and ends at the end of the right
4419 // padding. Because elements are read two at a time, we may need r_pad + 1
4420 // padding on the right. As such, the shared padding is the max of l_pad and
4421 // r_pad + 1, rounded as necessary for the transpose data format.
4422 int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
4423 jcp.tr_iw = jcp.transpose_src
4424 ? rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
4425 * jcp.stride_w
4426 : jcp.iw;
4427
4428 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
4429 jcp.tr_ow = jcp.transpose_dst ? rnd_up(jcp.ow, 2) : jcp.ow;
4430
4431 bool args_ok = true
4432 && IMPLICATION(!is_data_layout_nxc,
4433 jcp.ic % jcp.ic_block == 0 && jcp.oc % jcp.oc_block == 0)
4434 && jcp.ic <= src_d.padded_dims()[1]
4435 && jcp.oc <= diff_dst_d.padded_dims()[1]
4436 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
4437 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
4438 if (!args_ok) return status::unimplemented;
4439
4440 int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
4441 int out_row_size = jcp.oc_block * jcp.tr_ow * jcp.typesize_in;
4442 int full_spat_min_h_block_size
4443 = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
4444 int full_spat_working_set_size
4445 = (inp_row_size + out_row_size) * full_spat_min_h_block_size;
4446 bool use_full_spat_loop = isa_has_bf16(jcp.isa) && jcp.ndims < 5
4447 && jcp.ih == jcp.oh && jcp.iw == jcp.ow
4448 && !one_of(1, jcp.kh, jcp.kw)
4449 && everyone_is(1, jcp.stride_h, jcp.stride_w)
4450 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
4451 && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
4452 && !jcp.uses_permw_transposition && !jcp.is_1stconv
4453 && full_spat_working_set_size <= full_spat_opt_working_set_size
4454 && jcp.ic >= 128;
4455
4456 jcp.harness = ndims == 5
4457 ? harness_3d_reduction
4458 : (use_full_spat_loop ? harness_compute_full_spatial
4459 : (ndims == 4) ? harness_2d_reduction
4460 : harness_mb_reduction);
4461
4462 switch (jcp.harness) {
4463 case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
4464 case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
4465 case harness_compute_full_spatial:
4466 case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
4467 default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
4468 }
4469 { // balancing
4470 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4471 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
4472 jcp.nthr = nthr;
4473 jcp.nthr_mb = nthr_mb;
4474 jcp.nthr_g = nthr_g;
4475 jcp.nthr_oc_b = nthr_oc_b;
4476 jcp.nthr_ic_b = nthr_ic_b;
4477
4478 // TODO: Optimize memory allocation when threaded on height and depth
4479 if (jcp.transpose_src) {
4480 jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
4481 jcp.tr_src_buf_count = jcp.global_transpose
4482 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
4483 : jcp.nthr;
4484 }
4485 if (jcp.transpose_dst) {
4486 jcp.tr_diff_dst_buf_size
4487 = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
4488 jcp.tr_diff_dst_buf_count = jcp.global_transpose
4489 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
4490 : jcp.nthr;
4491 }
4492 }
4493
4494 jcp.nb_ic_blocking_max = 1;
4495 if (is_data_layout_nxc && jcp.uses_permw_transposition
4496 && (jcp.ow > max_ur_w || jcp.ndims == 5))
4497 jcp.nb_ic_blocking_max = nstl::min(8, div_up(jcp.nb_ic, jcp.nthr_ic_b));
4498 return status::success;
4499 }
4500
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp)4501 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
4502 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4503
4504 if (!jcp.uses_permw_transposition) {
4505 // XXX: See the comment about tr_iw and guarding elements in
4506 // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
4507 const size_t tr_src_size = jcp.tr_src_buf_count * jcp.tr_src_buf_size
4508 + jcp.tr_src_num_guard_elems;
4509 scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
4510
4511 /* prepare synchronization contexts */
4512 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
4513 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
4514 scratchpad.book<simple_barrier::ctx_t>(
4515 key_conv_tr_src_bctx, tr_src_bctx_size);
4516 }
4517
4518 const size_t tr_diff_dst_size
4519 = jcp.tr_diff_dst_buf_count * jcp.tr_diff_dst_buf_size;
4520
4521 const size_t min_align = jcp.use_nt_stores_ddst ? 64 : jcp.typesize_in;
4522 scratchpad.book(key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in,
4523 min_align);
4524
4525 /* prepare synchronization contexts */
4526 if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
4527 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
4528 scratchpad.book<simple_barrier::ctx_t>(
4529 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
4530 }
4531 }
4532
4533 if (IMPLICATION(jcp.nthr_mb == 1,
4534 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
4535 || jcp.wei_dt == data_type::bf16)) {
4536 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
4537 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
4538 const size_t bia_size
4539 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
4540
4541 const int num_wei_buffers
4542 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
4543 const int num_bia_buffers = jcp.with_bias
4544 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
4545 : jcp.nthr_mb - 1)
4546 : 0;
4547
4548 const size_t wei_bia_reduction_size
4549 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
4550
4551 scratchpad.book<float>(
4552 key_conv_wei_bia_reduction, wei_bia_reduction_size);
4553
4554 if (jcp.global_transpose)
4555 scratchpad.book<simple_barrier::ctx_t>(
4556 key_conv_wei_bia_reduction_bctx, 1);
4557 }
4558
4559 if (jcp.with_bias) {
4560 if ((jcp.oc_without_padding % jcp.oc_block != 0)
4561 && jcp.bia_dt == data_type::f32)
4562 scratchpad.book(key_conv_padded_bias,
4563 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
4564 }
4565 }
4566
balance(const jit_conv_conf_t & j,int & nthr_,int & nthr_mb_,int & nthr_g_,int & nthr_oc_b_,int & nthr_ic_b_)4567 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::balance(
4568 const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4569 int &nthr_oc_b_, int &nthr_ic_b_) {
4570 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4571
4572 const int max_threads = dnnl_get_max_threads();
4573
4574 if (max_threads < j.ngroups) {
4575 /* simplification... fortunately it doesn't hurt much */
4576 nthr_ = nthr_g_ = max_threads;
4577 return;
4578 }
4579
4580 nthr_g_ = j.ngroups;
4581 const int nthr = max_threads / nthr_g_;
4582
4583 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4584 /* calculate per thread memory cost (read/write). high level optimizer
4585 * tries to minimize memory consumption. few notes:
4586 * (n1) if weights tensor size is less than source and destination
4587 * tensors we apply the ratio of the source and destination
4588 * tensor sizes to weights one as compensation coefficient to
4589 * avoid parallelization across batch size only, othervise we
4590 * apply additional coefficient to source component based on
4591 * performance measurements
4592 * (n2) use scales based on output vs input channels ratio for source
4593 * and destination componets to imporve threading balance across
4594 * input and output channels */
4595
4596 const dim_t src_type_size = 2;
4597 const dim_t wei_type_size = 4;
4598
4599 dim_t src_size
4600 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
4601 dim_t dst_size
4602 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
4603 dim_t wei_size
4604 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
4605
4606 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
4607 float oi_channels_ratio = (float)j.nb_oc / j.nb_ic;
4608 auto get_src_coef = [=]() {
4609 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
4610 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
4611
4612 return src_coef;
4613 };
4614
4615 auto get_dst_coef
4616 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
4617
4618 auto get_wei_coef
4619 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
4620
4621 const float src_coef = get_src_coef();
4622 const float dst_coef = get_dst_coef();
4623 const float wei_coef = get_wei_coef();
4624
4625 float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
4626 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_ic, nthr_ic_b) * j.mb
4627 * j.ic_block * j.id * j.ih * j.tr_iw / j.nthr_mb_work
4628 / j.stride_d / j.stride_h / j.stride_w;
4629 float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
4630 * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) * j.kh
4631 * j.kw * j.kd * j.ic_block * j.oc_block;
4632 float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
4633 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b) * j.mb
4634 * j.oc_block * j.od * j.oh * j.tr_ow / j.nthr_mb_work;
4635
4636 return src_v + dst_v + wei_v;
4637 };
4638
4639 float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4640
4641 /* find the best thread distribution with lowest memory cost */
4642 const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
4643 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4644 const int nthr_par = nthr / nthr_mb;
4645 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4646 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4647 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4648
4649 float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4650 if (mem_cost <= best_mem_cost) {
4651 best_mem_cost = mem_cost;
4652 nthr_mb_ = nthr_mb;
4653 nthr_oc_b_ = nthr_oc_b;
4654 nthr_ic_b_ = nthr_ic_b;
4655 }
4656 }
4657 }
4658
4659 if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
4660 nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
4661 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
4662
4663 assert(nthr_ <= max_threads);
4664 }
4665
4666 template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Zmm>;
4667 template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Ymm>;
4668 template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Xmm>;
4669 template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>;
4670 template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>;
4671 template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>;
4672 } // namespace x64
4673 } // namespace cpu
4674 } // namespace impl
4675 } // namespace dnnl
4676 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
4677