1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/memory_tracking.hpp"
19 #include "common/nstl.hpp"
20 #include "common/type_helpers.hpp"
21 #include "common/utils.hpp"
22 
23 #include "cpu/platform.hpp"
24 #include "cpu/x64/cpu_barrier.hpp"
25 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
26 #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
27 #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp"
28 
29 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace x64 {
35 
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::data_type;
38 using namespace dnnl::impl::utils;
39 using namespace Xbyak;
40 
prepare_output(int ur_w)41 void jit_avx512_core_amx_compute_zp_pbuff_t::prepare_output(int ur_w) {
42     for (int oc = 0; oc < jcp.nb_oc_blocking; oc++)
43         for (int ur = 0; ur < ur_w; ur++) {
44             const Zmm zmm = zmm_out(ur, oc);
45             vpxord(zmm, zmm, zmm);
46         }
47 }
48 
store_output(int ur_w,bool last_oc_block_flag)49 void jit_avx512_core_amx_compute_zp_pbuff_t::store_output(
50         int ur_w, bool last_oc_block_flag) {
51     assert(jcp.is_nspc);
52 
53     const int nb_oc_block = jcp.nb_oc_blocking;
54     const int oc_block = jcp.oc_block;
55 
56     const auto src_zp_addr = EVEX_compress_addr(reg_src_zero_point, 0, true);
57 
58     /* write out register to output_addr */
59     for (int oc = 0; oc < nb_oc_block; oc++) {
60         const bool mask_flag = last_oc_block_flag && oc == nb_oc_block - 1;
61         for (int ur = 0; ur < ur_w; ur++) {
62             const int output_offset = sizeof(int32_t)
63                     * (oc * oc_block
64                             + ur * jcp.oc_without_padding * jcp.ngroups);
65             const Zmm zmm_dst = zmm_out(ur, oc);
66             const Zmm m_zmm_dst = mask_flag ? zmm_dst | ktail_mask : zmm_dst;
67             // multiply dst by src_zero_point
68             vpmulld(m_zmm_dst, zmm_dst, src_zp_addr);
69             vmovups(EVEX_compress_addr(reg_zp_pbuff, output_offset), m_zmm_dst);
70         }
71     }
72 }
73 
compute_ker(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool padded)74 void jit_avx512_core_amx_compute_zp_pbuff_t::compute_ker(int ur_w, int pad_l,
75         int pad_r, ic_block_t last_ic_block_flag, bool padded) {
76 
77     const int kw = jcp.kw;
78     const int ic_block = jcp.ic_block_int_np;
79     const int oc_block = jcp.oc_block;
80     const int nb_oc_block = jcp.nb_oc_blocking;
81 
82     const bool ic_tail
83             = (jcp.ic_without_padding % (jcp.ic_block / ic_inner_block)) > 0;
84     const bool masked_write = ic_tail && last_ic_block_flag == last_ic_block;
85 
86     /* Skip the last loads of input
87             if (ic%16)/ic_sub_step < ic_block/ic_sub_step */
88     const int icb = (last_ic_block_flag == last_ic_block)
89             ? div_up(
90                     (jcp.ic_without_padding % jcp.ic_block_int), ic_inner_block)
91             : ic_block / ic_inner_block;
92 
93     auto get_filter_offset = [=](int ocb, int ic, int ki) {
94         size_t w_step = jcp.is_relo ? jcp.kh : 1;
95         size_t kw_offset = static_cast<size_t>(ki) * w_step
96                 * jcp.ic_block_int_np * jcp.oc_block;
97         size_t oc_subblock_step = static_cast<size_t>(jcp.kd) * jcp.kh * jcp.kw
98                 * jcp.ic_block_int_np * jcp.oc_block;
99         size_t offset = kw_offset
100                 + static_cast<size_t>(ocb) * jcp.nb_ic_int * oc_subblock_step
101                 + static_cast<size_t>(ic) * oc_block * ic_inner_block;
102         return sizeof(char) * offset;
103     };
104     auto compute_fma = [=](const Zmm zmm_accum, const int ic,
105                                const Address addr) {
106         if (jcp.is_relo) {
107             vmovups(zmm_permb, ptr[reg_scratch]); // get permute index table
108             const Zmm r_zmm = masked_write && ic == icb - 1
109                     ? zmm_permb | kmask_ic_block | T_z
110                     : zmm_permb;
111             // only values from 'src2' are used to write dst
112             vpermi2b(r_zmm, zmm_permb, addr);
113             vpdpbusd(zmm_accum, zmm_one,
114                     zmm_permb); // XXX - using the same register for all ur_w
115         } else {
116             vpdpbusd(zmm_accum, zmm_one, addr);
117         }
118     };
119 
120     if (jcp.is_relo && last_ic_block_flag == last_ic_block && ic_tail) {
121         const Reg64 reg_tmp = reg_scratch;
122         mov(reg_tmp, ic_mask_label);
123         kmovq(kmask_ic_block, qword[reg_tmp]);
124     }
125     if (jcp.is_relo) mov(reg_scratch, permb_idx_label);
126 
127     for (int ki = 0; ki < kw; ki++) {
128         const int ur_start = get_ow_start(ki, pad_l);
129         const int ur_end = get_ow_end(ur_w, ki, pad_r);
130         for (int ur = 0; ur < ur_w; ur++) {
131             // Calculate zero_point padding as:
132             // accum = is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0)
133             if (ur < ur_start || ur >= ur_end || padded) {
134                 for (int oc = 0; oc < nb_oc_block; oc++) {
135                     const Zmm zmm_accum = zmm_out(ur, oc);
136                     for (int ic = 0; ic < icb; ic++) {
137                         const auto addr_filt = EVEX_compress_addr(
138                                 aux_reg_filt, get_filter_offset(oc, ic, ki));
139                         compute_fma(zmm_accum, ic, addr_filt);
140                     }
141                 }
142             }
143         }
144     }
145 }
146 
kh_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)147 void jit_avx512_core_amx_compute_zp_pbuff_t::kh_loop(int ur_w, int pad_l,
148         int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
149 
150     Label kh_label, skip_kh_loop;
151     const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
152     const size_t shift_wei_h_step = sizeof(char)
153             * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
154             * jcp.oc_block;
155 
156     // Compute zero_point compensation for the padded region. Total compute
157     // area is 'overflow * kw' where 'overflow' indicates the overlap
158     // between the filter and either top_pad or bottom_pad region.
159     auto compute_kh_loop = [=](size_t param_overflow) {
160         Label overflow_label, no_overflow_label;
161 
162         mov(reg_overflow, ptr[param1 + param_overflow]);
163         cmp(reg_overflow, 0);
164         je(no_overflow_label, T_NEAR);
165         L(overflow_label);
166         {
167             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
168             add(aux_reg_filt, shift_wei_h_step);
169             dec(reg_overflow);
170             jne(overflow_label, T_NEAR);
171         }
172         L(no_overflow_label);
173     };
174 
175     if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(t_overflow));
176 
177     // check for holes and skip computation due to dilation
178     mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
179     if ((jcp.dilate_h >= jcp.ih)) {
180         cmp(reg_kj, 0);
181         je(skip_kh_loop, T_NEAR);
182     }
183 
184     L(kh_label);
185     {
186         compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
187 
188         add(aux_reg_filt, shift_wei_h_step);
189         dec(reg_kj);
190         jne(kh_label, T_NEAR);
191     }
192 
193     L(skip_kh_loop);
194 
195     if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(b_overflow));
196 }
197 
kd_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)198 void jit_avx512_core_amx_compute_zp_pbuff_t::kd_loop(int ur_w, int pad_l,
199         int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
200 
201     Label kd_label, skip_kd_loop;
202     const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
203     const size_t shift_wei_h_step = sizeof(char)
204             * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
205             * jcp.oc_block;
206 
207     // Compute zero_point compensation for the padded region. Total compute
208     // area is 'overflow * kh * kw' where 'overflow' indicates the overlap
209     // between the filter and either front_pad or back_pad region.
210     auto compute_kd_loop = [=](size_t param_overflow) {
211         Label kh_loop_label;
212         Label no_overflow_label, overflow_label;
213 
214         mov(reg_ki, ptr[param1 + param_overflow]);
215         cmp(reg_ki, 0);
216         je(no_overflow_label, T_NEAR);
217         L(overflow_label);
218         {
219             mov(aux_reg_filt, aux_reg_filt_d);
220             mov(reg_kj, jcp.kh);
221             L(kh_loop_label);
222             {
223                 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
224                 add(aux_reg_filt, shift_wei_h_step);
225                 dec(reg_kj);
226                 jne(kh_loop_label, T_NEAR);
227             }
228             add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
229             dec(reg_ki);
230             jne(overflow_label, T_NEAR);
231         }
232         L(no_overflow_label);
233     };
234 
235     const bool zp_d_padding
236             = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
237     if (zp_d_padding) {
238         mov(aux_reg_filt_d, reg_filt);
239         compute_kd_loop(GET_OFF(f_overflow));
240 
241         // check for holes and skip computation due to dilation
242         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
243         if (jcp.dilate_d >= jcp.id) {
244             cmp(reg_ki, 0);
245             je(skip_kd_loop, T_NEAR);
246         }
247         L(kd_label);
248         mov(aux_reg_filt, aux_reg_filt_d);
249 
250     } else {
251         mov(aux_reg_filt, reg_filt);
252     }
253 
254     kh_loop(ur_w, pad_l, pad_r, last_ic_block_flag, handle_h_pad);
255 
256     if (zp_d_padding) {
257         add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
258         dec(reg_ki);
259         jne(kd_label, T_NEAR);
260 
261         L(skip_kd_loop);
262 
263         compute_kd_loop(GET_OFF(back_overflow));
264     }
265 }
266 
icb_loop(int ur_w,int pad_l,int pad_r,bool handle_h_pad)267 void jit_avx512_core_amx_compute_zp_pbuff_t::icb_loop(
268         int ur_w, int pad_l, int pad_r, bool handle_h_pad) {
269 
270     Label icb_label;
271     const size_t nb_ic = jcp.nb_ic_int;
272     const bool do_icb_loop = nb_ic > 1;
273 
274     /* Initialize zmm_one for weight accumulation */
275     xor_(reg_scratch, reg_scratch);
276     const Reg8 _t8 = reg_scratch.cvt8();
277     mov(_t8, 0x1);
278     vpbroadcastb(zmm_one, _t8);
279 
280     prepare_output(ur_w);
281 
282     mov(reg_icb, nb_ic);
283 
284     L(icb_label);
285     if (jcp.ic_without_padding != jcp.ic) {
286         Label common_ker, end_ker;
287         if (do_icb_loop) {
288             cmp(reg_icb, 1); // The last ic block
289             jne(common_ker, T_NEAR);
290         }
291         kd_loop(ur_w, pad_l, pad_r, last_ic_block, handle_h_pad);
292         if (do_icb_loop) {
293             jmp(end_ker, T_NEAR);
294 
295             L(common_ker);
296             kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
297 
298             L(end_ker);
299         }
300     } else {
301         kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
302     }
303     // End of IC Loop
304     if (do_icb_loop) {
305         const size_t shift_wei_icb_step = static_cast<size_t>(jcp.kd) * jcp.kh
306                 * jcp.kw * jcp.oc_block * jcp.ic_block_int_np;
307         add(reg_filt, sizeof(char) * shift_wei_icb_step);
308 
309         dec(reg_icb);
310         cmp(reg_icb, 0);
311         jg(icb_label, T_NEAR);
312 
313         sub(reg_filt, sizeof(char) * shift_wei_icb_step * nb_ic);
314     }
315 
316     if (jcp.oc_without_padding != jcp.oc) {
317         Label common_store, end_store;
318 
319         cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
320         jne(common_store, T_NEAR);
321 
322         store_output(ur_w, true); // last oc block
323         jmp(end_store, T_NEAR);
324 
325         L(common_store);
326         store_output(ur_w, false);
327 
328         L(end_store);
329     } else {
330         store_output(ur_w, false);
331     }
332 }
333 
unroll_width(const bool h_padding)334 void jit_avx512_core_amx_compute_zp_pbuff_t::unroll_width(
335         const bool h_padding) {
336 
337     auto ur_w_shift = [&](const int ur_w) {
338         return sizeof(int32_t) * (ur_w * jcp.oc_without_padding * jcp.ngroups);
339     };
340 
341     const int max_ur_w = jit_avx512_core_amx_compute_zp_pbuff_t::max_regs_ur
342             / (jcp.nb_oc_blocking);
343     const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
344     int l_pad = jcp.l_pad;
345 
346     const int l_pad_output = jcp.l_pad_output;
347     const int r_pad_output = jcp.r_pad_output;
348 
349     // a single middle element (if required) containing only height padding
350     const int no_pad = nstl::max(0, jcp.ow - l_pad_output - r_pad_output);
351 
352     const int ow_start = nstl::max(jcp.ow - r_pad_output, l_pad_output);
353     const int r_pad_start = nstl::min(jcp.ow_pad - l_pad_output, r_pad_output);
354 
355     int ow = 0;
356     int cur_l_pad_output = l_pad_output;
357     while (cur_l_pad_output > 0) {
358         const int ur_w = nstl::min(cur_l_pad_output, max_ur_w);
359         ow += ur_w;
360         const int cur_r_pad = calculate_end_padding(
361                 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
362         icb_loop(ur_w, l_pad, cur_r_pad, h_padding);
363         add(reg_zp_pbuff, ur_w_shift(ur_w));
364 
365         l_pad = nstl::max(l_pad - ur_w * jcp.stride_w, 0);
366         cur_l_pad_output = nstl::max(cur_l_pad_output - ur_w, 0);
367     }
368 
369     if (no_pad > 0) {
370         const int ur_w = 1;
371         if (h_padding) icb_loop(ur_w, 0, 0, true);
372         if (h_padding || jcp.ow_mid) add(reg_zp_pbuff, ur_w_shift(ur_w));
373     }
374     assert(ow + no_pad == ow_start);
375 
376     ow = ow_start;
377     int cur_r_pad_output = r_pad_start;
378     while (cur_r_pad_output > 0 && ow < jcp.ow) {
379         const int ur_w = nstl::min(cur_r_pad_output, max_ur_w);
380         ow += ur_w;
381         const int cur_r_pad = calculate_end_padding(
382                 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
383         icb_loop(ur_w, 0, cur_r_pad, h_padding);
384         add(reg_zp_pbuff, ur_w_shift(ur_w));
385 
386         cur_r_pad_output = nstl::max(cur_r_pad_output - ur_w, 0);
387     }
388 }
389 
generate()390 void jit_avx512_core_amx_compute_zp_pbuff_t::generate() {
391     Label h_pad_label, end_label;
392 
393     assert(jcp.req_zero_point_buffer);
394     assert(jcp.typesize_in == sizeof(char));
395 
396     preamble();
397 
398     mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
399     mov(reg_zp_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
400     mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
401 
402     if (jcp.oc_without_padding != jcp.oc) {
403         const Reg32 reg_tmp = reg_scratch.cvt32();
404         const int tail_size = jcp.oc_without_padding % jcp.oc_block;
405         const int mask = (1 << tail_size) - 1;
406         mov(reg_tmp, mask);
407         kmovw(ktail_mask, reg_tmp);
408         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
409     }
410 
411     mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
412     cmp(reg_overflow, 0);
413     jne(h_pad_label, T_NEAR);
414     mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
415     cmp(reg_overflow, 0);
416     jne(h_pad_label, T_NEAR);
417     if (jcp.ndims == 5 && (jcp.f_pad_output > 0 || jcp.back_pad_output > 0)) {
418         mov(reg_overflow, ptr[param1 + GET_OFF(kd_padding)]);
419         cmp(reg_overflow, jcp.kd);
420         jne(h_pad_label, T_NEAR);
421     }
422 
423     // Handle width padding region
424     unroll_width(false);
425     jmp(end_label, T_NEAR);
426 
427     // handle height padding region
428     L(h_pad_label);
429     unroll_width(true);
430 
431     L(end_label);
432 
433     postamble();
434 
435     // reduced-lowering ('is_relo' == true) weights format is '..i16o', so
436     // permute elements through permb into the VNNI layout '...16i4i'.
437     if (jcp.is_relo) {
438         align(64);
439         L(permb_idx_label);
440         // permb: id-bit for table selection is bit[6]
441         const uint8_t select_src2_bit = 0x40;
442         // permb: bits [5:0] select the element within each input table
443         const uint8_t permb_idx_table[64] = {0, 16, 32, 48, 1, 17, 33, 49, 2,
444                 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22,
445                 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42,
446                 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46,
447                 62, 15, 31, 47, 63};
448         for (size_t i = 0; i < 64; ++i)
449             db(select_src2_bit | permb_idx_table[i]);
450 
451         // write zero-mask (permb) for ic_tail in VNNI format '..16o4i'
452         const int ic_tail_size
453                 = jcp.ic_without_padding % (jcp.ic_block / ic_inner_block);
454         if (jcp.ic_without_padding != jcp.ic && ic_tail_size > 0) {
455             align(64);
456             L(ic_mask_label);
457 
458             assert(4 > ic_tail_size);
459             // mask is on a 4-bit basis from the 4 ic elements in a zmm
460             const int nibble = (1 << ic_tail_size) - 1;
461             for (int i = 0; i < 16; ++i) {
462                 db(nibble | (nibble << 4));
463             }
464         }
465     }
466 }
467 
generate()468 void jit_avx512_core_amx_copy_to_wbuffer_t::generate() {
469 
470     const bool is_bf16 = jcp.src_dt == data_type::bf16;
471 
472     // required for use of VPERMB instruction
473     assert(IMPLICATION(!is_bf16, cpu().has(Xbyak::util::Cpu::tAVX512_VBMI)));
474     assert(jcp.ic_block_int * jcp.typesize_in == 64);
475 
476     preamble();
477 
478     mov(reg_src, ptr[param1 + GET_OFF(src)]);
479     mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
480 
481     // load permute indices from data section
482     Label permute_index_table;
483     mov(reg_tmp, permute_index_table);
484     if (is_bf16)
485         vmovdqu16(zmm_idx, ptr[reg_tmp]);
486     else
487         vmovdqu8(zmm_idx, ptr[reg_tmp]);
488 
489     const int vnni_width = is_bf16 ? 2 : 4;
490     const int r = jcp.kh * jcp.kw * jcp.ic_without_padding;
491     const int nb_r = div_up(r, vnni_width);
492     const int rtail = (r % vnni_width) * jcp.oc_block;
493     if (rtail > 0) {
494         uint64_t mask = (UINT64_C(1) << rtail) - 1;
495         mov(reg_tmp, mask);
496         kmovq(kmask_load, reg_tmp);
497     }
498     const int nb_z = rnd_up(nb_r, jcp.ic_block);
499     if (nb_r < nb_z) vpxord(zmm_zero, zmm_zero, zmm_zero);
500 
501     const int tile_size = jcp.ic_block_int * jcp.oc_block * jcp.typesize_in;
502     const int ocb_src_step = r * jcp.oc_block * jcp.typesize_in;
503     const int ocb_dst_step = rnd_up(ocb_src_step, tile_size);
504 
505     // reorder from ~Owhi16o -> ~OR16oVr with r := whi and V := vnni_width
506     for (int g = 0; g < jcp.ngroups; g++) {
507         for (int ocb = 0; ocb < jcp.nb_oc; ocb++) {
508             int offset = 0;
509             int rb = 0;
510             for (; rb < nb_r; offset += 64, rb++) {
511                 auto zmm_src_tmp = (rtail > 0 && rb == nb_r - 1)
512                         ? zmm_src | kmask_load | T_z
513                         : zmm_src;
514                 if (is_bf16) {
515                     vmovdqu16(zmm_src_tmp, ptr[reg_src + offset]);
516                     vpermw(zmm_dst, zmm_idx, zmm_src);
517                     vmovdqu16(ptr[reg_dst + offset], zmm_dst);
518                 } else {
519                     vmovdqu8(zmm_src_tmp, ptr[reg_src + offset]);
520                     vpermb(zmm_dst, zmm_idx, zmm_src);
521                     vmovdqu8(ptr[reg_dst + offset], zmm_dst);
522                 }
523             }
524             for (; rb < nb_z; offset += 64, rb++) {
525                 if (is_bf16)
526                     vmovdqu16(ptr[reg_dst + offset], zmm_zero);
527                 else
528                     vmovdqu8(ptr[reg_dst + offset], zmm_zero);
529             }
530             add(reg_src, ocb_src_step);
531             add(reg_dst, ocb_dst_step);
532         }
533     }
534 
535     postamble();
536 
537     align(64);
538     L(permute_index_table);
539     const uint8_t no = 16; // 16o
540     const uint8_t nr = is_bf16 ? 2 : 4; // 2r or 4r
541     for (uint8_t o = 0; o < no; ++o) {
542         for (uint8_t r = 0; r < nr; r++) {
543             const uint8_t index = o + r * no;
544             if (is_bf16)
545                 dw(index);
546             else
547                 db(index);
548         }
549     }
550 }
551 
copy_row_body(int lpad,int iw_len,int icb)552 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_body(
553         int lpad, int iw_len, int icb) {
554 
555     const bool is_bf16 = jcp.src_dt == data_type::bf16;
556     int iwp_idx = 0;
557     // there are min(gen_kw, jcp.stride_w) continuous sets of input
558     // data (for each stride idx), they are placed one by one
559     // without additional padding
560     const bool are_sets_interleaved
561             = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
562     const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
563     const int num_sets = are_sets_interleaved ? jcp.n_stride_sets : jcp.kw;
564     for (int set_idx = 0; set_idx < num_sets; set_idx++) {
565         int set_width_padded = !jcp.is_pbuffer_strided
566                 ? (jcp.ow_block - 1) * jcp.stride_w + gen_kw
567                 : are_sets_interleaved ? jcp.ow_block - 1 + gen_kw / num_sets
568                                 + (set_idx < gen_kw % num_sets ? 1 : 0)
569                                        : jcp.ow_block;
570         for (int set_shift = 0; set_shift < set_width_padded;
571                 set_shift++, iwp_idx++) {
572             int iw_idx = set_idx * (jcp.dilate_w + 1)
573                     + set_shift * (jcp.is_pbuffer_strided ? jcp.stride_w : 1)
574                     - lpad;
575             size_t out_base_offset
576                     = (size_t)jcp.typesize_in * iwp_idx * jcp.ic_block_int_np;
577             if (iw_idx < 0 || iw_idx >= iw_len) {
578                 // left or right padding
579                 vmovups(ptr[reg_aux_out_ptr + out_base_offset], zmm_zero);
580             } else if (jcp.is_nspc) {
581                 size_t inp_w_offset = (size_t)jcp.typesize_in * iw_idx
582                         * jcp.ngroups * jcp.ic_without_padding;
583                 int ic = icb * jcp.ic_block_int_np;
584                 // TODO: use Xmm or Ymm moves for better small ic efficiency
585                 auto zmm_tmp_mask
586                         = ic + jcp.ic_block_int <= jcp.ic_without_padding
587                         ? zmm_tmp
588                         : zmm_tmp | ktail_mask | T_z;
589                 if (is_bf16) {
590                     vmovdqu16(
591                             zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
592                     vmovdqu16(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
593                 } else {
594                     vmovdqu8(zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
595                     vmovdqu8(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
596                 }
597             } else {
598                 assert(is_bf16);
599                 size_t inp_w_offset
600                         = (size_t)jcp.typesize_in * iw_idx * jcp.ic_block;
601                 for (int j = 0; j < jcp.ic_block_int_np / jcp.ic_block; j++) {
602                     int ic = icb * jcp.ic_block_int_np + j * jcp.ic_block;
603                     size_t inp_c_w_offset = (size_t)jcp.typesize_in * j * jcp.ih
604                                     * jcp.iw * jcp.ic_block
605                             + inp_w_offset;
606                     if (ic + jcp.ic_block <= jcp.ic) {
607                         vmovdqu16(
608                                 ymm_tmp, ptr[reg_aux_inp_ptr + inp_c_w_offset]);
609                     } else {
610                         vpxord(ymm_tmp, ymm_tmp, ymm_tmp);
611                     }
612                     size_t out_offset = out_base_offset
613                             + (size_t)jcp.typesize_in * j * jcp.ic_block;
614                     vmovdqu16(ptr[reg_aux_out_ptr + out_offset], ymm_tmp);
615                 }
616             }
617         }
618     }
619 }
620 
copy_row(int icb)621 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) {
622     if (jcp.nb_ow == 1) {
623         copy_row_body(jcp.l_pad, jcp.iw, icb);
624     } else {
625         auto get_iw_len_required = [&](int cur_ow_block, int cur_lpad) {
626             return (cur_ow_block - 1) * jcp.stride_w
627                     + (jcp.kw - 1) * (jcp.dilate_w + 1) + 1 - cur_lpad;
628         };
629 
630         auto get_iw_len_limited = [&](int owb, int cur_ow_block, int cur_lpad) {
631             auto len_req = get_iw_len_required(cur_ow_block, cur_lpad);
632             if (owb < 0) return len_req;
633             int ow_block_start = nstl::max(
634                     0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad);
635             return nstl::min(jcp.iw - ow_block_start, len_req);
636         };
637 
638         int general_owb_cases = jcp.nb_ow;
639         Xbyak::Label copy_row_done_label;
640         bool special_first_block_case = jcp.l_pad > 0;
641         if (special_first_block_case) {
642             general_owb_cases--;
643             Xbyak::Label skip_first_block_case_label;
644             cmp(reg_owb, 0);
645             jne(skip_first_block_case_label, T_NEAR);
646             copy_row_body(jcp.l_pad,
647                     get_iw_len_limited(0, jcp.ow_block, jcp.l_pad), icb);
648             jmp(copy_row_done_label, T_NEAR);
649             L(skip_first_block_case_label);
650         }
651         bool special_last_block_case = false
652                 // has ow_block_tail
653                 || jcp.ow % jcp.ow_block != 0
654                 // there is no ow_block_tail but right padding exists
655                 || get_iw_len_limited(jcp.nb_ow - 1, jcp.ow_block, 0)
656                         != get_iw_len_required(jcp.ow_block, 0);
657         if (special_last_block_case) {
658             general_owb_cases--;
659             Xbyak::Label skip_last_block_case_label;
660             cmp(reg_owb, jcp.nb_ow - 1);
661             jne(skip_last_block_case_label, T_NEAR);
662             int ow_block_tail = jcp.ow % jcp.ow_block;
663             int cur_ow_block = ow_block_tail > 0 ? ow_block_tail : jcp.ow_block;
664             copy_row_body(
665                     0, get_iw_len_limited(jcp.nb_ow - 1, cur_ow_block, 0), icb);
666             jmp(copy_row_done_label, T_NEAR);
667             L(skip_last_block_case_label);
668         }
669 
670         bool special_penult_block_case = true
671                 // if nb_ow = 2 and l_pad > 0 it's the same as
672                 // special_first_block_case
673                 && jcp.nb_ow >= (special_first_block_case ? 3 : 2)
674                 // right padding exists in penult block
675                 && get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0)
676                         != get_iw_len_required(jcp.ow_block, 0);
677         if (special_penult_block_case) {
678             general_owb_cases--;
679             Xbyak::Label skip_penult_block_case_label;
680             cmp(reg_owb, jcp.nb_ow - 2);
681             jne(skip_penult_block_case_label, T_NEAR);
682             copy_row_body(
683                     0, get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0), icb);
684             jmp(copy_row_done_label, T_NEAR);
685             L(skip_penult_block_case_label);
686         }
687 
688         if (general_owb_cases > 0) // general case
689             copy_row_body(0, get_iw_len_required(jcp.ow_block, 0), icb);
690 
691         L(copy_row_done_label);
692     }
693 }
694 
copy_row_reduced_lowering()695 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() {
696     assert(jcp.nb_ic_int == 1);
697     assert(jcp.ic_block_int * jcp.typesize_in == 64);
698     assert(jcp.is_nspc);
699 
700     auto load_mask = [=](int tail, Opmask kmask) {
701         uint64_t mask = (UINT64_C(1) << tail) - 1;
702         mov(reg_tmp, mask);
703         kmovq(kmask, reg_tmp);
704     };
705 
706     const bool is_bf16 = jcp.src_dt == data_type::bf16;
707     const int inp_w_step
708             = jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in;
709     const int inp_h_step = jcp.iw * inp_w_step;
710     const int out_h_step = jcp.ic_without_padding * jcp.typesize_in;
711     const int out_w_step = jcp.kh * out_h_step;
712     const int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
713     if (tail_size > 0) load_mask(tail_size, ktail_mask);
714 
715     auto zero_it = [=](reg64_t tmp_out_ptr) {
716         for (int ic = 0; ic < jcp.ic_without_padding; ic += jcp.ic_block_int) {
717             const int offset = ic * jcp.typesize_in;
718             const bool masked = ic + jcp.ic_block_int > jcp.ic_without_padding;
719             Zmm zmm = masked ? zmm_zero | ktail_mask : zmm_zero;
720             if (is_bf16)
721                 vmovdqu16(ptr[tmp_out_ptr + offset], zmm);
722             else
723                 vmovdqu8(ptr[tmp_out_ptr + offset], zmm);
724         }
725     };
726 
727     // pointer to 1st needed element in src buffer
728     mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
729     // pointer to 1st needed element in dst buffer
730     mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
731 
732     // total number of rows to copy
733     mov(reg_kht, ptr[param1 + GET_OFF(kh_offset)]);
734 
735     // number of rows of src buffer to copy
736     mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
737     // number of zero-padded rows above src buffer to copy
738     mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
739     // number of zero-padded rows below src buffer to copy
740     mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
741 
742     // number of columns of src buffer to copy
743     mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
744     // number of zero-padded columns before src buffer to copy
745     mov(reg_lov, ptr[param1 + GET_OFF(f_overflow)]);
746     // number of zero-padded columns before src buffer to copy
747     mov(reg_rov, ptr[param1 + GET_OFF(back_overflow)]);
748 
749     vpxord(zmm_zero, zmm_zero, zmm_zero);
750 
751     { // Handle Left Overflow
752         Label label_lov, label_lov_skip;
753         test(reg_lov, reg_lov);
754         jz(label_lov_skip, T_NEAR);
755         L(label_lov); // handle left or right overflow
756         {
757             Label label_lov_inner;
758             mov(reg_aux_out_ptr, reg_out_ptr);
759             mov(reg_cnt, reg_kht);
760             L(label_lov_inner);
761             {
762                 zero_it(reg_aux_out_ptr);
763                 add(reg_aux_out_ptr, out_h_step);
764                 dec(reg_cnt);
765                 jnz(label_lov_inner, T_NEAR);
766             }
767             add(reg_out_ptr, out_w_step);
768             dec(reg_lov);
769             jnz(label_lov, T_NEAR);
770         }
771         L(label_lov_skip);
772     }
773 
774     // save output pointer for later use
775     mov(reg_save_out_ptr, reg_out_ptr);
776 
777     // just in case there is no meat...
778     Label label_kwp_end;
779     test(reg_kwp, reg_kwp);
780     jz(label_kwp_end, T_NEAR);
781 
782     // Unroll over W-dimension in powers of 2
783     Label label_tov;
784     Label label_khp, label_no_khp;
785     Label label_bov;
786     test(reg_tov, reg_tov);
787     jnz(label_tov, T_NEAR);
788     test(reg_khp, reg_khp);
789     jnz(label_khp, T_NEAR);
790     test(reg_bov, reg_bov);
791     jnz(label_bov, T_NEAR);
792     jmp(label_kwp_end, T_NEAR); // safe exit in case of bad parameters
793 
794     L(label_tov); // handle top overflow
795     {
796         Label label_tov_inner;
797         mov(reg_aux_out_ptr, reg_out_ptr);
798         mov(reg_cnt, reg_kwp);
799         L(label_tov_inner);
800         {
801             zero_it(reg_aux_out_ptr);
802             add(reg_aux_out_ptr, out_w_step);
803             dec(reg_cnt);
804             jnz(label_tov_inner, T_NEAR);
805         }
806         add(reg_out_ptr, out_h_step);
807         dec(reg_tov);
808         jnz(label_tov, T_NEAR);
809     }
810     test(reg_khp, reg_khp);
811     jz(label_no_khp, T_NEAR);
812     L(label_khp); // handle kh padding (not fully unrolled)
813     {
814         Label label_khp_inner;
815         mov(reg_aux_inp_ptr, reg_inp_ptr);
816         mov(reg_aux_out_ptr, reg_out_ptr);
817         mov(reg_cnt, reg_kwp);
818         L(label_khp_inner);
819         {
820             for (int ic = 0; ic < jcp.ic_without_padding;
821                     ic += jcp.ic_block_int) {
822                 const int offset = ic * jcp.typesize_in;
823                 const bool masked
824                         = ic + jcp.ic_block_int > jcp.ic_without_padding;
825                 // zero masking is needed to avoid dependency on destination
826                 Zmm zmm_load = masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
827                 Zmm zmm_store = masked ? zmm_tmp | ktail_mask : zmm_tmp;
828                 if (is_bf16) {
829                     vmovdqu16(zmm_load, ptr[reg_aux_inp_ptr + offset]);
830                     vmovdqu16(ptr[reg_aux_out_ptr + offset], zmm_store);
831                 } else {
832                     vmovdqu8(zmm_load, ptr[reg_aux_inp_ptr + offset]);
833                     vmovdqu8(ptr[reg_aux_out_ptr + offset], zmm_store);
834                 }
835             }
836             add(reg_aux_inp_ptr, inp_w_step);
837             add(reg_aux_out_ptr, out_w_step);
838             dec(reg_cnt);
839             jnz(label_khp_inner, T_NEAR);
840         }
841         add(reg_inp_ptr, inp_h_step);
842         add(reg_out_ptr, out_h_step);
843         dec(reg_khp);
844         jnz(label_khp, T_NEAR);
845     }
846     L(label_no_khp);
847     test(reg_bov, reg_bov);
848     jz(label_kwp_end, T_NEAR);
849     L(label_bov); // handle bottom overflow
850     {
851         Label label_bov_inner;
852         mov(reg_aux_out_ptr, reg_out_ptr);
853         mov(reg_cnt, reg_kwp);
854         L(label_bov_inner);
855         {
856             zero_it(reg_aux_out_ptr);
857             add(reg_aux_out_ptr, out_w_step);
858             dec(reg_cnt);
859             jnz(label_bov_inner, T_NEAR);
860         }
861         add(reg_out_ptr, out_h_step);
862         dec(reg_bov);
863         jnz(label_bov, T_NEAR);
864     }
865     L(label_kwp_end);
866 
867     { // Handle Right Overflow
868         Label label_rov, label_rov_skip;
869         // retrieve output pointer
870         mov(reg_out_ptr, reg_save_out_ptr);
871         // calculate the shift
872         imul(reg_tmp, reg_kwp, out_w_step);
873         // shift past the body
874         add(reg_out_ptr, reg_tmp);
875         // skip if no right overflow
876         test(reg_rov, reg_rov);
877         jz(label_rov_skip, T_NEAR);
878 
879         L(label_rov); // handle left or right overflow
880         {
881             Label label_rov_inner;
882             mov(reg_aux_out_ptr, reg_out_ptr);
883             mov(reg_cnt, reg_kht);
884             L(label_rov_inner);
885             {
886                 zero_it(reg_aux_out_ptr);
887                 add(reg_aux_out_ptr, out_h_step);
888                 dec(reg_cnt);
889                 jnz(label_rov_inner, T_NEAR);
890             }
891             add(reg_out_ptr, out_w_step);
892             dec(reg_rov);
893             jnz(label_rov, T_NEAR);
894         }
895         L(label_rov_skip);
896     }
897 
898     // For bf16, zero-pad an extra cacheline to avoid NaNs
899     // For int8, it is sufficient to zero-pad the weights only
900     if (is_bf16) {
901         // shift forward to align h index to end of needed buffer
902         imul(reg_tmp, reg_kht, out_h_step);
903         add(reg_out_ptr, reg_tmp);
904         // shift backward to align w index to end of needed buffer
905         sub(reg_out_ptr, out_w_step);
906         vmovdqu16(ptr[reg_out_ptr], zmm_zero);
907     }
908 }
909 
generate()910 void jit_avx512_core_amx_copy_to_pbuffer_t::generate() {
911 
912     // Special copy kernel for reduced lowering
913     if (jcp.is_relo) {
914         assert(jcp.nb_ic_int == 1);
915         preamble();
916         copy_row_reduced_lowering();
917         postamble();
918         return;
919     }
920 
921     preamble();
922 
923     const bool is_3d = jcp.ndims == 5;
924     mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
925     mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
926     if (is_3d) mov(reg_kdp, ptr[param1 + GET_OFF(kd_padding)]);
927     mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
928     mov(reg_tover, ptr[param1 + GET_OFF(t_overflow)]);
929     mov(reg_bover, ptr[param1 + GET_OFF(b_overflow)]);
930     mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
931 
932     vpxord(zmm_zero, zmm_zero, zmm_zero);
933 
934     if (jcp.is_nspc && jcp.ic_without_padding % jcp.ic_block_int) {
935         int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
936         uint64_t mask = (UINT64_C(1) << tail_size) - 1;
937         mov(reg_tmp, mask);
938         kmovq(ktail_mask, reg_tmp);
939     }
940 
941     for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
942         Xbyak::Label kd_label, no_kd_label;
943         Xbyak::Label kh_label, no_kh_label, icb_label;
944         Xbyak::Label kh_tover_label, kh_bover_label;
945         Xbyak::Label no_kh_tover_label, no_kh_bover_label;
946 
947         mov(reg_aux_inp_ptr, reg_inp_ptr);
948         mov(reg_aux_out_ptr, reg_out_ptr);
949         if (is_3d) {
950             cmp(reg_kdp, 0);
951             jle(no_kd_label, T_NEAR);
952             mov(reg_kdc, reg_kdp);
953             L(kd_label);
954             push(reg_aux_inp_ptr);
955             push(reg_aux_out_ptr);
956         }
957         cmp(reg_khp, 0);
958         jle(no_kh_bover_label, T_NEAR); // nothing to do
959         mov(reg_khc, reg_khp);
960 
961         cmp(reg_tover, 0);
962         jle(no_kh_tover_label, T_NEAR);
963 
964         mov(reg_kh_over, reg_tover);
965         L(kh_tover_label);
966         {
967             // TODO: adjust step to improve zeroing efficiency for small ic
968             for (int iw = 0; iw < jcp.iwp; iw++)
969                 vmovups(ptr[reg_aux_out_ptr
970                                 + jcp.typesize_in * iw * jcp.ic_block_int_np],
971                         zmm_zero);
972             int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
973             add(reg_aux_out_ptr, out_h_offset);
974 
975             dec(reg_kh_over);
976             jnz(kh_tover_label, T_NEAR);
977         }
978         sub(reg_khc, reg_tover);
979         L(no_kh_tover_label);
980 
981         cmp(reg_khc, reg_bover);
982         jle(no_kh_label, T_NEAR);
983 
984         L(kh_label);
985         {
986             copy_row(icb);
987             size_t inp_h_offset = !jcp.is_nspc
988                     ? (size_t)jcp.typesize_in * jcp.iw * jcp.ic_block
989                     : (size_t)jcp.typesize_in * jcp.iw * jcp.ngroups
990                             * jcp.ic_without_padding;
991             size_t out_h_offset
992                     = (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
993 
994             add(reg_aux_inp_ptr, inp_h_offset);
995             add(reg_aux_out_ptr, out_h_offset);
996 
997             dec(reg_khc);
998             cmp(reg_khc, reg_bover);
999             jg(kh_label, T_NEAR);
1000         }
1001         L(no_kh_label);
1002 
1003         cmp(reg_khc, 0);
1004         jle(no_kh_bover_label, T_NEAR);
1005 
1006         L(kh_bover_label);
1007         {
1008             // TODO: adjust step to improve zeroing efficiency for small ic
1009             for (int iw = 0; iw < jcp.iwp; iw++)
1010                 vmovups(ptr[reg_aux_out_ptr
1011                                 + jcp.typesize_in * iw * jcp.ic_block_int_np],
1012                         zmm_zero);
1013             int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
1014             add(reg_aux_out_ptr, out_h_offset);
1015 
1016             dec(reg_khc);
1017             jnz(kh_bover_label, T_NEAR);
1018         }
1019         size_t out_d_offset = (size_t)jcp.typesize_in
1020                 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1021         L(no_kh_bover_label);
1022         if (is_3d) {
1023             size_t inp_d_offset = !jcp.is_nspc
1024                     ? (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block
1025                             * (jcp.dilate_d + 1)
1026                     : (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ngroups
1027                             * jcp.ic_without_padding * (jcp.dilate_d + 1);
1028             pop(reg_aux_out_ptr);
1029             pop(reg_aux_inp_ptr);
1030             add(reg_aux_inp_ptr, inp_d_offset);
1031             add(reg_aux_out_ptr, out_d_offset);
1032             dec(reg_kdc);
1033             jnz(kd_label, T_NEAR);
1034             L(no_kd_label);
1035         }
1036         // End IC Loop
1037         size_t inp_cb_offset = !jcp.is_nspc
1038                 ? (size_t)jcp.typesize_in * (jcp.ic_block_int_np / jcp.ic_block)
1039                         * jcp.id * jcp.ih * jcp.iw * jcp.ic_block
1040                 : (size_t)jcp.typesize_in * jcp.ic_block_int_np;
1041         size_t out_cb_offset = (size_t)jcp.kd * out_d_offset;
1042 
1043         add(reg_inp_ptr, inp_cb_offset);
1044         add(reg_out_ptr, out_cb_offset);
1045     }
1046 
1047     postamble();
1048 }
1049 
jit_avx512_core_amx_fwd_kernel_t(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_t & dst_md)1050 jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t(
1051         const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
1052         const memory_desc_t &dst_md)
1053     : jit_generator(nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
1054     , jcp(ajcp)
1055     , attr_(attr) {
1056     if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
1057         using namespace binary_injector;
1058         const auto &rhs_addr_reg = bin_injector_helper_reg_1;
1059         const auto &rhs_helper_reg = bin_injector_helper_reg_2;
1060         static constexpr bool preserve_gpr = false;
1061         static constexpr bool preserve_vmm = false;
1062         const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
1063         static constexpr bool use_exact_tail_scalar_bcast = true;
1064 
1065         const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
1066                 31, rhs_addr_reg, rhs_helper_reg, preserve_gpr, preserve_vmm,
1067                 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
1068                 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
1069                 use_exact_tail_scalar_bcast};
1070         const binary_injector::static_params_t static_params {
1071                 this->param1, rhs_arg_static_params};
1072 
1073         postops_injector_ = utils::make_unique<
1074                 injector::jit_uni_postops_injector_t<avx512_core>>(
1075                 this, jcp.post_ops, static_params);
1076     }
1077     copy_to_pbuffer_
1078             = utils::make_unique<jit_avx512_core_amx_copy_to_pbuffer_t>(jcp);
1079     if (jcp.is_relo)
1080         copy_to_wbuffer_
1081                 = utils::make_unique<jit_avx512_core_amx_copy_to_wbuffer_t>(
1082                         jcp);
1083 }
1084 
create_kernel()1085 status_t jit_avx512_core_amx_fwd_kernel_t::create_kernel() {
1086     CHECK(jit_generator::create_kernel());
1087     CHECK(copy_to_pbuffer_->create_kernel());
1088     if (jcp.is_relo) CHECK(copy_to_wbuffer_->create_kernel());
1089     if (jcp.req_zero_point_buffer) {
1090         zp_pbuff_kernel_
1091                 = utils::make_unique<jit_avx512_core_amx_compute_zp_pbuff_t>(
1092                         jcp);
1093         if (zp_pbuff_kernel_ == nullptr) return status::out_of_memory;
1094         CHECK(zp_pbuff_kernel_->create_kernel());
1095     }
1096     return status::success;
1097 }
1098 
1099 // Tile register decomposition
1100 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i,bool is_h_tail) const1101 int jit_avx512_core_amx_fwd_kernel_t::get_out_tensor(
1102         int h, int i, bool is_h_tail) const {
1103     const int C_BASE = 0;
1104     const int C_LAST = 4;
1105     assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
1106     MAYBE_UNUSED(C_LAST);
1107     const int tile = C_BASE
1108             + (jcp.nb_oh_blocking > 1
1109                             ? h * jcp.nb_oh_blocking + i
1110                             : (int)is_h_tail * jcp.nb_oc_blocking + i);
1111     assert(C_BASE <= tile && tile < C_LAST);
1112     return tile;
1113 }
get_inp_tensor(int h,bool is_h_tail) const1114 int jit_avx512_core_amx_fwd_kernel_t::get_inp_tensor(
1115         int h, bool is_h_tail) const {
1116     const int I_BASE = 4;
1117     const int I_LAST = 6;
1118     assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
1119     MAYBE_UNUSED(I_LAST);
1120     const int tile = I_BASE + (jcp.nb_oh_blocking > 1 ? h : (int)is_h_tail);
1121     assert(I_BASE <= tile && tile < I_LAST);
1122     return tile;
1123 }
get_wei_tensor(int i) const1124 int jit_avx512_core_amx_fwd_kernel_t::get_wei_tensor(int i) const {
1125     const int W_BASE = 6;
1126     const int W_LAST = 8;
1127     assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
1128     MAYBE_UNUSED(W_LAST);
1129     const int tile = W_BASE + i;
1130     assert(W_BASE <= tile && tile < W_LAST);
1131     return tile;
1132 }
1133 
1134 // Shifts and offsets
get_inp_icb_step() const1135 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_icb_step() const {
1136     return (size_t)jcp.kd * get_inp_d_step();
1137 }
get_wei_icb_step() const1138 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_icb_step() const {
1139     return (size_t)jcp.typesize_in * jcp.kd * jcp.kh * jcp.kw
1140             * jcp.ic_block_int_np * jcp.oc_block;
1141 }
get_inp_d_step() const1142 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_d_step() const {
1143     return (size_t)jcp.typesize_in
1144             * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1145 }
get_inp_h_step() const1146 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_h_step() const {
1147     return (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np
1148             * (jcp.dilate_h + 1);
1149 }
get_wei_d_step() const1150 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_d_step() const {
1151     return (size_t)jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block_int_np
1152             * jcp.oc_block;
1153 }
get_wei_h_step() const1154 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_h_step() const {
1155     return (size_t)jcp.typesize_in * jcp.kw * jcp.ic_block_int_np
1156             * jcp.oc_block;
1157 }
get_out_ocb_offset(int ohb,int ocb,size_t typesize) const1158 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_ocb_offset(
1159         int ohb, int ocb, size_t typesize) const {
1160     size_t el_offset = jcp.is_nspc
1161             ? (size_t)ocb * jcp.oc_block
1162                     + (size_t)ohb * jcp.ow * jcp.ngroups
1163                             * jcp.oc_without_padding
1164             : (size_t)ocb * jcp.oh * jcp.ow * jcp.oc_block
1165                     + (size_t)ohb * jcp.ow * jcp.oc_block;
1166     return (size_t)typesize * el_offset;
1167 }
get_out_row_offset(int ohb,int ocb,int j,size_t typesize) const1168 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_row_offset(
1169         int ohb, int ocb, int j, size_t typesize) const {
1170     size_t offset_w = jcp.is_nspc
1171             ? (size_t)typesize * j * jcp.ngroups * jcp.oc_without_padding
1172             : (size_t)typesize * j * jcp.oc_block;
1173     return get_out_ocb_offset(ohb, ocb, typesize) + offset_w;
1174 }
get_out_shift(int width,size_t typesize) const1175 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_shift(
1176         int width, size_t typesize) const {
1177     return jcp.is_nspc
1178             ? (size_t)typesize * width * jcp.ngroups * jcp.oc_without_padding
1179             : (size_t)typesize * width * jcp.oc_block;
1180 }
get_wsp_ocb_offset(int ohb,int ocb) const1181 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_ocb_offset(
1182         int ohb, int ocb) const {
1183     size_t el_offset = (size_t)ocb * prv_width_ * jcp.oc_block
1184             + (size_t)ohb * jcp.nb_oc_blocking * jcp.full_tile_width
1185                     * jcp.oc_block;
1186     return jcp.typesize_acc * el_offset;
1187 }
get_wsp_row_offset(int ohb,int ocb,int j) const1188 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_row_offset(
1189         int ohb, int ocb, int j) const {
1190     return get_wsp_ocb_offset(ohb, ocb)
1191             + (size_t)jcp.typesize_acc * j * jcp.oc_block;
1192 }
get_wsp_shift() const1193 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_shift() const {
1194     return (size_t)jcp.typesize_acc * jcp.nb_oh_blocking * jcp.full_tile_width
1195             * jcp.oc_block * jcp.nb_oc_blocking;
1196 }
get_wei_offset(int ocb,int kw) const1197 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_offset(int ocb, int kw) const {
1198     size_t el_offset = (size_t)kw * jcp.ic_block_int_np * jcp.oc_block;
1199     size_t raw_oc_subblock_step
1200             = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
1201     size_t oc_subblock_step = jcp.is_relo
1202             ? rnd_up(raw_oc_subblock_step, jcp.ic_block_int * jcp.oc_block)
1203             : raw_oc_subblock_step;
1204     el_offset += (size_t)ocb * jcp.nb_ic_int * oc_subblock_step;
1205     return jcp.typesize_in * el_offset;
1206 }
get_inp_shift() const1207 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_shift() const {
1208     size_t w_step = (jcp.is_relo ? jcp.stride_w * jcp.kh
1209                                  : jcp.is_pbuffer_strided ? 1 : jcp.stride_w)
1210             * jcp.ic_block_int_np;
1211     return (size_t)jcp.typesize_in * jcp.tile_width * w_step;
1212 }
get_inp_offset(int ohb,int kw) const1213 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_offset(int ohb, int kw) const {
1214     if (jcp.is_relo)
1215         return ohb * jcp.iwp * jcp.kh * jcp.ic_block_int_np * jcp.typesize_in;
1216     // calculate offset by height dimension
1217     const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
1218     const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh);
1219     size_t el_offset = (size_t)ohb * jcp.oh_per_tile * gen_stride_h * jcp.iwp
1220             * jcp.ic_block_int_np;
1221 
1222     // add offset by width dimension
1223     if (IMPLICATION(jcp.is_pbuffer_strided, jcp.stride_w == 1)) {
1224         el_offset += (size_t)kw * (jcp.dilate_w + 1) * jcp.ic_block_int_np;
1225     } else if (jcp.dilate_w > 0) {
1226         el_offset += (size_t)kw * jcp.ow_block * jcp.ic_block_int_np;
1227     } else {
1228         // dilate_w == 0 && stride_w > 1
1229         // there are min(jcp.kw, jcp.stride_w) continuous sets of input data
1230         // (foreach stride idx), they are placed one by one without additional
1231         // padding
1232 
1233         // calculate set idx for current kw value
1234         int set_idx = kw % jcp.stride_w;
1235         // calculate shift within set for current kw value
1236         int set_shift = kw / jcp.stride_w;
1237 
1238         // calculate the beginning of the current set along width, each set
1239         // with index set_i contains number of elements along width equal to
1240         // jcp.ow - 1 + jcp.kw / jcp.stride_w
1241         //     + (set_i < jcp.kw % jcp.stride_w)
1242         size_t set_start = (jcp.ow_block - 1 + jcp.kw / jcp.stride_w) * set_idx
1243                 + nstl::min(set_idx, jcp.kw % jcp.stride_w);
1244         el_offset += (set_start + set_shift) * jcp.ic_block_int_np;
1245     }
1246     return jcp.typesize_in * el_offset;
1247 }
1248 
get_zp_comp_offset(int ocb,int zp_h,int zp_w) const1249 size_t jit_avx512_core_amx_fwd_kernel_t::get_zp_comp_offset(
1250         int ocb, int zp_h, int zp_w) const {
1251     const size_t ocb_offset = (size_t)ocb * jcp.oc_block;
1252     const size_t sp_offset = (size_t)(zp_h * jcp.ow_pad + zp_w) * jcp.ngroups
1253             * jcp.oc_without_padding;
1254     return (ocb_offset + sp_offset) * sizeof(int32_t);
1255 }
1256 
get_zp_index_offset(int index,int mid,int s_pad_output,int e_pad_output)1257 int jit_avx512_core_amx_fwd_kernel_t::get_zp_index_offset(
1258         int index, int mid, int s_pad_output, int e_pad_output) {
1259     using namespace nstl;
1260     const int mid_end = e_pad_output - 1;
1261     int zp_mid = min(mid, max(0, index - mid_end));
1262     int zp_pad_offset
1263             = accum_with_upper_bound(index, s_pad_output, e_pad_output);
1264     return zp_pad_offset + zp_mid;
1265 }
1266 
1267 // Code generation
prepare_output(int tail)1268 void jit_avx512_core_amx_fwd_kernel_t::prepare_output(int tail) {
1269     for (int h = 0; h < jcp.nb_oh_blocking; h++)
1270         for (int i = 0; i < jcp.nb_oc_blocking; i++)
1271             tilezero(Tmm(get_out_tensor(h, i, tail)));
1272 }
1273 
init_runtime_counters(bool start_with_last_tile_block)1274 void jit_avx512_core_amx_fwd_kernel_t::init_runtime_counters(
1275         bool start_with_last_tile_block) {
1276     prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
1277             ? jcp.tile_tail
1278             : jcp.tile_width;
1279 
1280     row_count_ = 0;
1281     is_store_done_ = false;
1282     is_buffer_empty_ = true;
1283 }
1284 
reduce_to_block(const int block_size,const int pad_output)1285 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_block(
1286         const int block_size, const int pad_output) {
1287     return (size_t)(pad_output >= block_size ? block_size : 0)
1288             + (pad_output % block_size);
1289 }
1290 
reduce_to_blocked_dims(const int dim_size,const int block_size,const int s_pad_output,const int e_pad_output)1291 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_blocked_dims(
1292         const int dim_size, const int block_size, const int s_pad_output,
1293         const int e_pad_output) {
1294     using namespace nstl;
1295 
1296     // start padding (s_pad)
1297     int s_pad_limit = reduce_to_block(block_size, s_pad_output);
1298     int s_pad_area_blk = rnd_up(s_pad_limit, block_size);
1299 
1300     // middle (no padding)
1301     int no_pad_area = max(
1302             0, dim_size - rnd_up(s_pad_output, block_size) - e_pad_output);
1303     int no_pad_limit = (no_pad_area >= block_size ? block_size : 0);
1304 
1305     // end padding (e_pad)
1306     int no_pad_area_shift = no_pad_area % block_size;
1307     int e_pad_area_overlap
1308             = no_pad_area_shift == 0 ? 0 : block_size - no_pad_area_shift;
1309     // middle and end padding shift
1310     int e_pad_shift_limit
1311             = no_pad_area_shift + min(e_pad_output, e_pad_area_overlap);
1312     int e_pad_area_blk = max(0, e_pad_output - e_pad_area_overlap);
1313     // full end padding block
1314     int e_pad_limit = reduce_to_block(block_size, e_pad_area_blk);
1315 
1316     // calculate reduced size of s_pad, middle and e_pad blocks.
1317     return min((size_t)dim_size,
1318             (size_t)s_pad_area_blk + no_pad_limit + e_pad_shift_limit
1319                     + e_pad_limit);
1320 }
1321 
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)1322 Ymm jit_avx512_core_amx_fwd_kernel_t::ymm_mask(
1323         const Ymm &ymm_in, bool mask_flag, bool store) {
1324     return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
1325                      : ymm_in;
1326 }
1327 
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)1328 Zmm jit_avx512_core_amx_fwd_kernel_t::zmm_mask(
1329         const Zmm &zmm_in, bool mask_flag, bool store) {
1330     return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
1331                      : zmm_in;
1332 }
1333 
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)1334 void jit_avx512_core_amx_fwd_kernel_t::cvt2ps(data_type_t type_in,
1335         const Zmm &zmm_in, const Operand &op, bool mask_flag) {
1336     const Zmm zmm = zmm_mask(zmm_in, mask_flag);
1337     switch (type_in) {
1338         case data_type::f32:
1339         case data_type::s32: vmovups(zmm, op); break;
1340         case data_type::s8: vpmovsxbd(zmm, op); break;
1341         case data_type::u8: vpmovzxbd(zmm, op); break;
1342         default: assert(!"unsupported data type");
1343     }
1344     if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
1345 }
1346 
apply_sum(const Zmm & zmm_out,const float * p_sum_scale,const int32_t * p_sum_zp,const Xbyak::Address & addr,const bool mask_flag)1347 void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out,
1348         const float *p_sum_scale, const int32_t *p_sum_zp,
1349         const Xbyak::Address &addr, const bool mask_flag) {
1350     if (p_sum_scale) {
1351         const float p_sum_scale_val = *p_sum_scale;
1352         const int32_t p_sum_zp_val = *p_sum_zp;
1353         const auto sum_injector = [&, p_sum_scale_val, p_sum_zp_val,
1354                                           mask_flag]() {
1355             cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag);
1356             if (p_sum_zp_val != 0) {
1357                 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
1358                 vsubps(zmm_prev_dst, zmm_sum_zp);
1359             }
1360             if (p_sum_scale_val == 1.f)
1361                 vaddps(zmm_out, zmm_prev_dst);
1362             else
1363                 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
1364         };
1365         postops_injector_->set_lambda_injector(
1366                 primitive_kind::sum, sum_injector);
1367     }
1368 }
1369 
apply_postops(const Zmm & zmm_out,const float * p_sum_scale,const int32_t * p_sum_zp,const Xbyak::Address & addr,const size_t off,const bool mask_flag)1370 void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out,
1371         const float *p_sum_scale, const int32_t *p_sum_zp,
1372         const Xbyak::Address &addr, const size_t off, const bool mask_flag) {
1373     if (jcp.with_eltwise || jcp.with_binary
1374             || (jcp.with_sum && p_sum_scale != nullptr)) {
1375         apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag);
1376 
1377         const auto vmm_idx = zmm_out.getIdx();
1378         if (jcp.with_binary) {
1379             binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1380             rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out_ptr);
1381             rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off);
1382             if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1383 
1384             postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
1385         } else {
1386             postops_injector_->compute_vector(vmm_idx);
1387         }
1388     }
1389 }
1390 
store_output_vector_bf16(const Zmm & zmm_out,int ocb,int h,int w)1391 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16(
1392         const Zmm &zmm_out, int ocb, int h, int w) {
1393     const bool mask_flag = jcp.is_nspc && jcp.oc_without_padding != jcp.oc
1394             && ocb == (jcp.nb_oc_blocking - 1);
1395 
1396     const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1397     auto addr = EVEX_compress_addr(reg_out_ptr, off);
1398 
1399     const auto &p = attr_.post_ops_;
1400 
1401     const int sum_idx = p.find(primitive_kind::sum);
1402     if (sum_idx != -1) {
1403         if (jcp.dst_dt == data_type::bf16) {
1404             vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
1405             vpslld(zmm_prev_dst, zmm_prev_dst, 16);
1406             vaddps(zmm_out, zmm_prev_dst);
1407         } else {
1408             vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
1409             vaddps(zmm_out, zmm_prev_dst);
1410         }
1411     }
1412     if (jcp.with_bias) {
1413         int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
1414         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1415         if (jcp.bia_dt == data_type::bf16) {
1416             vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
1417             vpslld(zmm_bias, zmm_bias, 16);
1418             vaddps(zmm_out, zmm_bias);
1419         } else
1420             vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
1421     }
1422 
1423     static constexpr auto skip_sum_injection = nullptr;
1424     apply_postops(zmm_out, skip_sum_injection, skip_sum_injection, addr, off,
1425             mask_flag);
1426 
1427     if (jcp.dst_dt == data_type::bf16) {
1428         Ymm ymm_out = Ymm(zmm_out.getIdx());
1429         vcvtneps2bf16(ymm_out, zmm_out);
1430         vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
1431     } else {
1432         vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
1433     }
1434 }
1435 
store_output_vector_int8(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1436 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8(
1437         const Zmm &zmm_out, int ocb, int h, int w, const bool compute_zp,
1438         const int zp_h, const int zp_w) {
1439     const int nb_oc_block = jcp.nb_oc_blocking;
1440     const int oc_block = jcp.oc_block;
1441     const bool mask_flag = true && jcp.oc_without_padding != jcp.oc
1442             && ocb == (nb_oc_block - 1);
1443 
1444     const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1445     auto addr = EVEX_compress_addr(reg_out_ptr, off);
1446 
1447     const auto &p = attr_.post_ops_;
1448     const int sum_idx = p.find(primitive_kind::sum);
1449     const float *p_sum_scale = nullptr;
1450     const int32_t *p_sum_zp = nullptr;
1451     if (sum_idx != -1) {
1452         const auto &p_entry = p.entry_[sum_idx];
1453         p_sum_scale = &p_entry.sum.scale;
1454         p_sum_zp = &p_entry.sum.zero_point;
1455     }
1456 
1457     if (p_sum_scale) {
1458         if (*p_sum_scale != 1.f)
1459             mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
1460         if (*p_sum_zp != 0)
1461             mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
1462     }
1463 
1464     int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * oc_block);
1465     if (jcp.with_bias) {
1466         int bias_offset = jcp.typesize_bia * ocb * oc_block;
1467         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1468         cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
1469     }
1470     if (compute_zp) {
1471         assert(jcp.req_zero_point_buffer);
1472         // add zero-point padding compensation when accum data is S32
1473         const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1474         vmovups(m_zmm_zp,
1475                 EVEX_compress_addr(reg_zero_point_pbuff,
1476                         get_zp_comp_offset(ocb, zp_h, zp_w)));
1477         const Zmm m_zmm_out = zmm_mask(zmm_out, mask_flag);
1478         vpaddd(m_zmm_out, zmm_out, zmm_zp);
1479     }
1480     if (jcp.src_zero_point) {
1481         // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
1482         int zp_offset = sizeof(int32_t) * ocb * oc_block;
1483         const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1484         vpmulld(m_zmm_zp, zmm_src_zp,
1485                 EVEX_compress_addr(reg_zp_compensation, zp_offset));
1486         vpaddd(zmm_out, zmm_out, zmm_zp);
1487     }
1488 
1489     /* add bias and zero-point to zmm_accum */
1490     vcvtdq2ps(zmm_out, zmm_out);
1491     if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
1492     const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
1493     vmulps(zmm_out_msk, zmm_out,
1494             EVEX_compress_addr(reg_ptr_scales, scale_offset));
1495 
1496     apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag);
1497 
1498     if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); }
1499 
1500     // Properly saturate the accumulators for integer datatypes
1501     if (one_of(jcp.dst_dt, u8, s8, s32)) {
1502         init_saturate_f32(
1503                 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dst_dt);
1504         saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt);
1505         vcvtps2dq(zmm_out, zmm_out);
1506     }
1507 
1508     const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
1509 
1510     switch (jcp.dst_dt) {
1511         case data_type::f32:
1512         case data_type::s32: vmovups(addr, zmm_out_store); break;
1513         case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
1514         case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
1515         default: assert(!"unknown dst_dt");
1516     }
1517 }
1518 
store_output_vector(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1519 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector(const Zmm &zmm_out,
1520         int ocb, int h, int w, const bool compute_zp, const int zp_h,
1521         const int zp_w) {
1522     /*
1523     Output:
1524               jcp.is_nspc              !jcp.is_nspc
1525               ---------------------    ---------------------
1526         INT8: [N][H][W][NBOC][16OC]
1527         BF16: [N][H][W][NBOC][16OC] or [N][NBOC][H][W][16OC]
1528     */
1529     if (jcp.src_dt == data_type::bf16) {
1530         store_output_vector_bf16(zmm_out, ocb, h, w);
1531     } else {
1532         store_output_vector_int8(zmm_out, ocb, h, w, compute_zp, zp_h, zp_w);
1533     }
1534 }
1535 
store_output(int width,int tail,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool is_last_oh_block,const bool zp_3d_pad)1536 void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail,
1537         bool do_store, const bool handle_h_blk, const int t_pad_output,
1538         const int b_pad_output, const int l_pad_output, const int r_pad_output,
1539         const bool is_last_oh_block, const bool zp_3d_pad) {
1540     auto store_output_block = [=](int width, int tail, bool do_store,
1541                                       bool is_last_h = false) {
1542         // Calculate the number of oh blocks; it may differ on last call
1543         const int last_h_blks
1544                 = div_up(jcp.oh, jcp.oh_per_tile) % jcp.nb_oh_blocking;
1545         const int h_blks = is_last_h && last_h_blks != 0 ? last_h_blks
1546                                                          : jcp.nb_oh_blocking;
1547         // Calculate the number of oh rows per tile; it may differ on last call
1548         const int h_tail = is_last_h && jcp.oh % jcp.oh_per_tile != 0
1549                 ? (h_blks - 1) * jcp.oh_per_tile + jcp.oh % jcp.oh_per_tile
1550                 : h_blks * jcp.oh_per_tile;
1551         const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
1552         const int owp = gen_kw + jcp.ow - 1;
1553 
1554         if (jcp.src_zero_point) {
1555             mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1556             mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1557             vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1558         }
1559         if (jcp.dst_zero_point) {
1560             mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1561             vcvtdq2ps(zmm_dst_zp,
1562                     EVEX_compress_addr(reg_dst_zero_point, 0, true));
1563         }
1564 
1565         for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
1566         for (int ohb = 0; ohb < h_blks; ohb++) {
1567             /* Formats: Workspace: [NBOC][W][16OC] */
1568             tilestored(ptr[reg_wsp_ptr + reg_wei_stride
1569                                + get_wsp_ocb_offset(ohb, ocb)],
1570                     Tmm(get_out_tensor(ohb, ocb, tail)));
1571             is_buffer_empty_ = false;
1572             is_store_done_ = false;
1573 
1574             // preserve registers used by binary post_ops injector
1575             const injector_utils::conditional_register_preserve_guard_t
1576                     cond_register_guard(jcp.with_binary, this,
1577                             {bin_injector_helper_reg_1,
1578                                     bin_injector_helper_reg_2});
1579 
1580             for (int tw = 0; tw < width && do_store; tw++) {
1581                 // height
1582                 const int oh_index = ohb * jcp.oh_per_tile + tw / owp;
1583                 const bool zp_h_pad
1584                         = oh_index < t_pad_output || oh_index >= b_pad_output;
1585                 const int zp_h = get_zp_index_offset(
1586                         oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1587                 // width
1588                 const int ow_index = tw % owp;
1589                 const bool zp_w_pad
1590                         = ow_index < l_pad_output || ow_index >= r_pad_output;
1591                 const int zp_w = get_zp_index_offset(
1592                         ow_index, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1593 
1594                 const bool compute_zp = jcp.req_zero_point_buffer
1595                         && (zp_3d_pad || zp_w_pad || zp_h_pad);
1596 
1597                 assert(IMPLICATION(jcp.oh_per_tile == 1,
1598                         ohb == oh_index && tw == ow_index));
1599                 if (oh_index < h_tail && ow_index < jcp.ow) {
1600                     Zmm zmm_r = zmm_out(tw);
1601                     vmovups(zmm_r,
1602                             ptr[reg_wsp_ptr
1603                                     + get_wsp_row_offset(ohb, ocb, tw)]);
1604                     store_output_vector(zmm_r, ocb, oh_index, ow_index,
1605                             compute_zp, zp_h, zp_w);
1606                 }
1607             }
1608         }
1609     };
1610 
1611     // adjustment in case interleave store is turned off
1612     do_store = do_store || jcp.per_one_pstore == 0;
1613     if (!do_store) { w_padding.emplace(l_pad_output, r_pad_output); }
1614     if (!handle_h_blk) {
1615         store_output_block(width, tail, do_store, is_last_oh_block);
1616     } else {
1617         if (jcp.oh % (jcp.oh_per_tile * jcp.nb_oh_blocking) == 0) {
1618             store_output_block(width, tail, do_store);
1619         } else {
1620             Label label_oh_oc_store, label_done;
1621             mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
1622             cmp(reg_last_h, 0);
1623             jne(label_oh_oc_store, T_NEAR);
1624             store_output_block(width, tail, do_store, true); // last h
1625             jmp(label_done, T_NEAR);
1626             L(label_oh_oc_store);
1627             store_output_block(width, tail, do_store, false);
1628             L(label_done);
1629         }
1630     }
1631     if (do_store) {
1632         add(reg_out_ptr, get_out_shift(width, jcp.typesize_out));
1633         if (jcp.req_zero_point_buffer) {
1634             const size_t sp_shift
1635                     = accum_with_upper_bound(width, l_pad_output, r_pad_output);
1636             add(reg_zero_point_pbuff, get_out_shift(sp_shift, sizeof(int32_t)));
1637         }
1638     }
1639 }
1640 
interleave_store(int width,int const t_pad_output,int const b_pad_output,const bool zp_3d_pad)1641 void jit_avx512_core_amx_fwd_kernel_t::interleave_store(int width,
1642         int const t_pad_output, int const b_pad_output, const bool zp_3d_pad) {
1643     for (int c = 0;
1644             c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
1645             c++) {
1646         // row_count = ohb * OCB * TW + ocb * TW + tw
1647         int tw = row_count_ % prv_width_;
1648         int ocb = (row_count_ / prv_width_) % jcp.nb_oc_blocking;
1649         int ohb = (row_count_ / prv_width_) / jcp.nb_oc_blocking;
1650 
1651         // preserve registers used by binary post_ops injector
1652         const injector_utils::conditional_register_preserve_guard_t
1653                 cond_register_guard(jcp.with_binary, this,
1654                         {bin_injector_helper_reg_1, bin_injector_helper_reg_2});
1655 
1656         // height
1657         const int oh_index = ohb;
1658         const bool zp_h_pad
1659                 = oh_index < t_pad_output || oh_index >= b_pad_output;
1660         const int zp_h = get_zp_index_offset(
1661                 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1662         // width
1663         const int l_pad_output
1664                 = w_padding.empty() ? 0 : w_padding.front().l_pad_output;
1665         const int r_pad_output
1666                 = w_padding.empty() ? jcp.ow : w_padding.front().r_pad_output;
1667 
1668         const bool zp_w_pad = tw < l_pad_output || tw >= r_pad_output;
1669         const int zp_w = get_zp_index_offset(
1670                 tw, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1671 
1672         const bool compute_zp = jcp.req_zero_point_buffer
1673                 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1674 
1675         Zmm zmm_r = zmm_out(tw);
1676         vmovups(zmm_r, ptr[reg_wsp_ptr + get_wsp_row_offset(ohb, ocb, tw)]);
1677         store_output_vector(zmm_r, ocb, ohb, tw, compute_zp, zp_h, zp_w);
1678         row_count_++;
1679 
1680         if (row_count_
1681                 == prv_width_ * jcp.nb_oc_blocking * jcp.nb_oh_blocking) {
1682             add(reg_out_ptr, get_out_shift(prv_width_, jcp.typesize_out));
1683             if (jcp.req_zero_point_buffer) {
1684                 const size_t sp_shift = accum_with_upper_bound(
1685                         prv_width_, l_pad_output, r_pad_output);
1686                 add(reg_zero_point_pbuff,
1687                         get_out_shift(sp_shift, sizeof(int32_t)));
1688                 if (!w_padding.empty()) w_padding.pop();
1689             }
1690             row_count_ = 0;
1691             is_store_done_ = true;
1692             prv_width_ = width;
1693         }
1694     }
1695 }
1696 
compute_icb_loop(int width,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad,const bool is_last_oh_block)1697 void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width,
1698         bool do_store, const bool handle_h_blk, const int t_pad_output,
1699         const int b_pad_output, const int l_pad_output, const int r_pad_output,
1700         const bool zp_3d_pad, const bool is_last_oh_block) {
1701     const bool tail = width == jcp.tile_tail;
1702 
1703     auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
1704         if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) {
1705             tdpbf16ps(x1, x2, x3);
1706         } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) {
1707             tdpbuud(x1, x2, x3);
1708         } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) {
1709             tdpbusd(x1, x2, x3);
1710         } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) {
1711             tdpbsud(x1, x2, x3);
1712         } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) {
1713             tdpbssd(x1, x2, x3);
1714         } else {
1715             assert(!"unsupported combination");
1716         }
1717     };
1718 
1719     prepare_output(tail);
1720 
1721     // prepare registers for when 'interleave_store()' is computed
1722     if (jcp.src_zero_point) {
1723         mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1724         mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1725         vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1726     }
1727     if (jcp.dst_zero_point) {
1728         mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1729         vcvtdq2ps(zmm_dst_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
1730     }
1731 
1732     // reduced lowering path
1733     if (jcp.is_relo) {
1734         const int nreduce = jcp.nreduce;
1735         const int stride = jcp.ic_block_int; // ie 64 (32) for int8 (bf16)
1736 
1737         push(reg_inp_ptr);
1738         push(reg_wei_ptr);
1739 
1740         for (int ireduce = 0; ireduce < nreduce; ireduce += stride) {
1741             for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1742                 tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1743                         ptr[reg_inp_ptr + get_inp_offset(ohb, 0)
1744                                 + reg_inp_stride]);
1745             }
1746             for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1747                 tileloadd(Tmm(get_wei_tensor(ocb)),
1748                         ptr[reg_wei_ptr + get_wei_offset(ocb, 0)
1749                                 + reg_wei_stride]);
1750                 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1751                     tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1752                             Tmm(get_inp_tensor(ohb, tail)),
1753                             Tmm(get_wei_tensor(ocb)));
1754                     interleave_store(width, t_pad_output, b_pad_output);
1755                 }
1756             }
1757             if (ireduce + stride < nreduce) {
1758                 add(reg_inp_ptr, stride * jcp.typesize_in);
1759                 add(reg_wei_ptr, stride * jcp.oc_block * jcp.typesize_in);
1760             }
1761         }
1762         pop(reg_wei_ptr);
1763         pop(reg_inp_ptr);
1764 
1765         store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1766                 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block);
1767 
1768         add(reg_inp_ptr, get_inp_shift());
1769 
1770         return;
1771     }
1772 
1773     auto wei_offset = [&](int icb, int ocb, int kd, int kh, int kw) {
1774         return (size_t)icb * get_wei_icb_step() + kd * get_wei_d_step()
1775                 + kh * get_wei_h_step() + get_wei_offset(ocb, kw);
1776     };
1777 
1778     auto inp_offset = [&](int icb, int ohb, int kd, int kh, int kw) {
1779         return (size_t)icb * get_inp_icb_step() + kd * get_inp_d_step()
1780                 + kh * get_inp_h_step() + get_inp_offset(ohb, kw);
1781     };
1782 
1783     auto safe_tileloadd
1784             = [=](const Tmm &t1, const Xbyak::Reg64 &reg_ptr, size_t offset,
1785                       const Xbyak::Reg64 &reg_stride) {
1786                   if (offset <= INT32_MAX) {
1787                       tileloadd(t1, ptr[reg_ptr + offset + reg_stride]);
1788                   } else {
1789                       safe_add(reg_ptr, offset, reg_tmp);
1790                       tileloadd(t1, ptr[reg_ptr + reg_stride]);
1791                       safe_sub(reg_ptr, offset, reg_tmp);
1792                   }
1793               };
1794 
1795     // normal and k-remainders path
1796     const bool check_kd_padding
1797             = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
1798     for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
1799         Label kd_skip_compute;
1800         if (check_kd_padding) mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1801 
1802         for (int kd = 0; kd < jcp.kd; kd++) {
1803             if (check_kd_padding) {
1804                 dec(reg_kd);
1805                 jl(kd_skip_compute, T_NEAR);
1806                 push(reg_kd);
1807             }
1808             for (int kh = 0; kh < jcp.kh; kh++) {
1809                 for (int set_idx = 0; set_idx < jcp.n_stride_sets;
1810                         set_idx++) { // used to optimize input memory reuse in L1$
1811                     for (int kw = set_idx; kw < jcp.kw; kw += jcp.kw_step) {
1812                         for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1813                             const size_t inp_off
1814                                     = inp_offset(icb, ohb, kd, kh, kw);
1815                             safe_tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1816                                     reg_inp_ptr, inp_off, reg_inp_stride);
1817                         }
1818                         for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1819                             const size_t wei_off
1820                                     = wei_offset(icb, ocb, kd, kh, kw);
1821                             safe_tileloadd(Tmm(get_wei_tensor(ocb)),
1822                                     reg_wei_ptr, wei_off, reg_wei_stride);
1823                             for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1824                                 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1825                                         Tmm(get_inp_tensor(ohb, tail)),
1826                                         Tmm(get_wei_tensor(ocb)));
1827                                 interleave_store(width, t_pad_output,
1828                                         b_pad_output, zp_3d_pad);
1829                             }
1830                         }
1831                     }
1832                 }
1833             }
1834             if (check_kd_padding) pop(reg_kd);
1835         }
1836         L(kd_skip_compute);
1837     }
1838 
1839     store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1840             b_pad_output, l_pad_output, r_pad_output, is_last_oh_block,
1841             zp_3d_pad);
1842 
1843     add(reg_inp_ptr, get_inp_shift());
1844 }
1845 
dispatch_icb_loop(int width,bool do_store,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad)1846 void jit_avx512_core_amx_fwd_kernel_t::dispatch_icb_loop(int width,
1847         bool do_store, const int l_pad_output, const int r_pad_output,
1848         const bool zp_3d_pad) {
1849     if (jcp.req_zero_point_buffer
1850             && (jcp.t_pad_output > 0 || jcp.b_pad_output > 0)) {
1851         const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
1852         const size_t height_limit = reduce_to_blocked_dims(
1853                 jcp.oh, oh_step_size, jcp.t_pad_output, jcp.b_pad_output);
1854         const int ur_h = div_up(height_limit, oh_step_size);
1855         assert(6 >= ur_h);
1856 
1857         // Use a jump-table to execute the corresponding block
1858         Label h_blk_label[6], h_blk_end_label, jmp_table_label;
1859         mov(reg_jmp_blk, ptr[param1 + GET_OFF(ohb)]);
1860         mov(reg_tmp, jmp_table_label);
1861         jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1862         jmp(h_blk_end_label, T_NEAR); // error, shouldn't happen
1863 
1864         align(8);
1865         L(jmp_table_label);
1866         for (int u = 0; u < ur_h; ++u) {
1867             putL(h_blk_label[u]);
1868         }
1869 
1870         // Save value of global variables for the next 'h_blk' iteration
1871         const int local_prv_width = prv_width_;
1872         const int local_row_count = row_count_;
1873         const bool local_is_store_done = is_store_done_;
1874         const bool local_is_buffer_empty = is_buffer_empty_;
1875 
1876         // Unroll ow_block with regards to l_pad_output and r_pad_output
1877         int cur_t_pad = reduce_to_block(oh_step_size, jcp.t_pad_output);
1878         int cur_b_pad = height_limit
1879                 - reduce_to_block(oh_step_size, jcp.b_pad_output);
1880         for (int u = 0; u < ur_h; u++) {
1881             bool last = u == ur_h - 1;
1882             L(h_blk_label[u]);
1883 
1884             // restore to previous 'h_blk' state of variables
1885             prv_width_ = local_prv_width;
1886             row_count_ = local_row_count;
1887             is_store_done_ = local_is_store_done;
1888             is_buffer_empty_ = local_is_buffer_empty;
1889             compute_icb_loop(width, do_store, false, cur_t_pad, cur_b_pad,
1890                     l_pad_output, r_pad_output, zp_3d_pad, last);
1891             cur_t_pad = nstl::max(0, cur_t_pad - oh_step_size);
1892             cur_b_pad = nstl::max(0, cur_b_pad - oh_step_size);
1893             if (!last) jmp(h_blk_end_label, T_NEAR);
1894         }
1895         L(h_blk_end_label);
1896     } else {
1897         compute_icb_loop(width, do_store, true, 0, jcp.oh, l_pad_output,
1898                 r_pad_output, zp_3d_pad);
1899     }
1900 }
1901 
dispatch_zp_3d_compute(int width,bool do_store,const int l_pad_output,const int r_pad_output)1902 void jit_avx512_core_amx_fwd_kernel_t::dispatch_zp_3d_compute(int width,
1903         bool do_store, const int l_pad_output, const int r_pad_output) {
1904     if (jcp.req_zero_point_buffer && (jcp.f_pad > 0 || jcp.back_pad > 0)) {
1905         Label compute_3d_zp_label, zp_d_end_label;
1906         mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1907         cmp(reg_kd, jcp.kd);
1908         jne(compute_3d_zp_label, T_NEAR);
1909 
1910         // Save value of global variables for next 'dispatch_icb_loop'
1911         const int local_prv_width = prv_width_;
1912         const int local_row_count = row_count_;
1913         const bool local_is_store_done = is_store_done_;
1914         const bool local_is_buffer_empty = is_buffer_empty_;
1915         dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1916 
1917         jmp(zp_d_end_label, T_NEAR);
1918         L(compute_3d_zp_label);
1919 
1920         prv_width_ = local_prv_width;
1921         row_count_ = local_row_count;
1922         is_store_done_ = local_is_store_done;
1923         is_buffer_empty_ = local_is_buffer_empty;
1924         dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, true);
1925 
1926         L(zp_d_end_label);
1927     } else
1928         dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1929 }
1930 
compute_ow_loop()1931 void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() {
1932     auto compute_ow_loop_body = [=](bool last_owb, int num_tile_blocks,
1933                                         const int l_pad_output,
1934                                         const int r_pad_output) {
1935         int cur_l_pad_output = l_pad_output;
1936         int cur_r_pad_output = r_pad_output;
1937         int gen_tile_tail = last_owb && jcp.tile_tail > 0 ? jcp.tile_tail
1938                                                           : jcp.tile_width;
1939         init_runtime_counters(last_owb && num_tile_blocks == 1);
1940         for (int owb = 0; owb < num_tile_blocks - 1; owb++) {
1941             dispatch_zp_3d_compute(
1942                     jcp.tile_width, false, cur_l_pad_output, cur_r_pad_output);
1943             cur_l_pad_output = nstl::max(0, cur_l_pad_output - jcp.tile_width);
1944             cur_r_pad_output = nstl::max(0, cur_r_pad_output - jcp.tile_width);
1945         }
1946         dispatch_zp_3d_compute(
1947                 gen_tile_tail, true, cur_l_pad_output, cur_r_pad_output);
1948     };
1949 
1950     assert(jcp.nb_ow > 0);
1951     if (jcp.nb_ow == 1) {
1952         const int ow_r_pad_start
1953                 = nstl::max(jcp.ow - jcp.r_pad_output, jcp.l_pad_output);
1954         compute_ow_loop_body(
1955                 true, jcp.ow_blocks, jcp.l_pad_output, ow_r_pad_start);
1956     } else if (jcp.req_zero_point_buffer
1957             && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0)) {
1958 
1959         const size_t zp_addr_shift
1960                 = jcp.ngroups * jcp.oc_without_padding * sizeof(int32_t);
1961         const int ow_step_size = jcp.ow_block;
1962         const int ow_blocks_per_call = div_up(ow_step_size, jcp.tile_width);
1963         const int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call == 0
1964                 ? ow_blocks_per_call
1965                 : jcp.ow_blocks % ow_blocks_per_call;
1966         const int width_limit = reduce_to_blocked_dims(
1967                 jcp.ow, ow_step_size, jcp.l_pad_output, jcp.r_pad_output);
1968         const int ur_w = div_up(width_limit, ow_step_size);
1969         assert(6 >= ur_w);
1970         // Use a jump-table to execute the corresponding block
1971         Label w_blk_label[6], w_blk_end_label, jmp_table_label;
1972         mov(reg_jmp_blk, ptr[param1 + GET_OFF(owb)]);
1973         mov(reg_tmp, jmp_table_label);
1974         jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1975         jmp(w_blk_end_label, T_NEAR); // error, shouldn't happen
1976 
1977         align(8);
1978         L(jmp_table_label);
1979         for (int u = 0; u < ur_w; ++u) {
1980             putL(w_blk_label[u]);
1981         }
1982 
1983         // Unroll ow_block with regards to l_pad_output and r_pad_output
1984         int cur_l_pad = reduce_to_block(ow_step_size, jcp.l_pad_output);
1985         int cur_r_pad
1986                 = width_limit - reduce_to_block(ow_step_size, jcp.r_pad_output);
1987         int zp_offset = 0;
1988         for (int u = 0; u < ur_w; u++) {
1989             const bool last = u == ur_w - 1;
1990             L(w_blk_label[u]);
1991             if (u > 0) add(reg_zero_point_pbuff, zp_offset * zp_addr_shift);
1992             compute_ow_loop_body(last,
1993                     last ? last_owb_tile_blocks : ow_blocks_per_call, cur_l_pad,
1994                     cur_r_pad);
1995             zp_offset += accum_with_upper_bound(
1996                     ow_step_size, cur_l_pad, cur_r_pad);
1997             cur_l_pad = nstl::max(0, cur_l_pad - ow_step_size);
1998             cur_r_pad = nstl::max(0, cur_r_pad - ow_step_size);
1999             if (!last) jmp(w_blk_end_label, T_NEAR);
2000         }
2001         L(w_blk_end_label);
2002 
2003     } else {
2004         assert(jcp.oh_per_tile == 1);
2005         Label label_done;
2006         int ow_blocks_per_call = utils::div_up(jcp.ow_block, jcp.tile_width);
2007         int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call;
2008         if (last_owb_tile_blocks == 0 && jcp.tile_tail > 0)
2009             last_owb_tile_blocks = ow_blocks_per_call;
2010         if (last_owb_tile_blocks > 0) {
2011             Label label_not_last_owb;
2012             mov(reg_tmp, ptr[param1 + GET_OFF(owb)]);
2013             cmp(reg_tmp, jcp.nb_ow - 1);
2014             jne(label_not_last_owb, T_NEAR);
2015 
2016             compute_ow_loop_body(true, last_owb_tile_blocks, 0, jcp.ow);
2017 
2018             jmp(label_done, T_NEAR);
2019 
2020             L(label_not_last_owb);
2021         }
2022         compute_ow_loop_body(false, ow_blocks_per_call, 0, jcp.ow);
2023 
2024         L(label_done);
2025     }
2026 }
2027 
generate()2028 void jit_avx512_core_amx_fwd_kernel_t::generate() {
2029     preamble();
2030 
2031     mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
2032     mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]);
2033     mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
2034     mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
2035     if (jcp.req_zero_point_buffer)
2036         mov(reg_zero_point_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
2037 
2038     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
2039     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
2040 
2041     const int fac = jcp.is_relo ? jcp.stride_w * jcp.kh
2042                                 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w;
2043     const int inp_stride = fac * jcp.ic_block_int_np * jcp.typesize_in;
2044     const int wei_stride = jcp.oc_block * jcp.typesize_acc;
2045     mov(reg_inp_stride, inp_stride);
2046     mov(reg_wei_stride, wei_stride);
2047 
2048     if (jcp.is_nspc && jcp.oc_without_padding != jcp.oc) {
2049         // Use mask 0xF by default for all output data and post-ops
2050         // loads / stores with block index
2051         // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1)
2052         // TODO: use masked loads / stores for the last occ only
2053         int current_block_size = jcp.oc_block;
2054         int mask = (1 << current_block_size) - 1;
2055         Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
2056         mov(regw_tmp, mask);
2057         kmovw(ktail_mask, regw_tmp);
2058         Xbyak::Label mask_is_set;
2059         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
2060         cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
2061         jne(mask_is_set, T_NEAR);
2062         // Reset the mask
2063         current_block_size = jcp.oc_without_padding % jcp.oc_block;
2064         mask = (1 << current_block_size) - 1;
2065         mov(regw_tmp, mask);
2066         kmovw(ktail_mask, regw_tmp);
2067 
2068         L(mask_is_set);
2069     }
2070     compute_ow_loop();
2071 
2072     postamble();
2073 
2074     if (jcp.with_eltwise) postops_injector_->prepare_table();
2075 }
2076 
tile_configure(char * tcfg_buff)2077 void jit_avx512_core_amx_fwd_kernel_t::tile_configure(char *tcfg_buff) {
2078     const int vnni_width = jcp.src_dt == data_type::bf16 ? 2 : 4;
2079     // Input tile dimensions
2080     const int a_col = jcp.is_relo ? jcp.ic_block_int
2081                                   : jcp.ic_block_int_np * jcp.kw_per_tile;
2082     // Weights tile dimensions
2083     const int b_col = jcp.oc_block * vnni_width;
2084     const int b_row = a_col / vnni_width;
2085     // Accumulator tile dimensions
2086     const int c_col = 16;
2087 
2088     for (size_t i = 0; i < 64; i++)
2089         tcfg_buff[i] = 0;
2090 
2091     // Weights (W_BASE) Tensor Tiles
2092     for (int i = 0; i < jcp.nb_oc_blocking; i++)
2093         tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
2094                 b_row, b_col * jcp.typesize_in);
2095 
2096     // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
2097     for (int h = 0; h < jcp.nb_oh_blocking; h++) {
2098         tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
2099                 jcp.tile_width, a_col * jcp.typesize_in);
2100         for (int i = 0; i < jcp.nb_oc_blocking; i++)
2101             tc_configure_tile((palette_config_t *)tcfg_buff,
2102                     get_out_tensor(h, i), jcp.tile_width,
2103                     c_col * jcp.typesize_acc);
2104     }
2105     if (jcp.tile_tail != 0) {
2106         assert(jcp.nb_oh_blocking == 1);
2107         assert(jcp.oh_per_tile == 1);
2108         assert(jcp.ow > jcp.tile_width);
2109         tc_configure_tile((palette_config_t *)tcfg_buff,
2110                 get_inp_tensor(0, true), jcp.tile_tail,
2111                 a_col * jcp.typesize_in);
2112         for (int i = 0; i < jcp.nb_oc_blocking; i++)
2113             tc_configure_tile((palette_config_t *)tcfg_buff,
2114                     get_out_tensor(0, i, true), jcp.tile_tail,
2115                     c_col * jcp.typesize_acc);
2116     }
2117 
2118     ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
2119 }
2120 
set_oh_blk_limits(jit_conv_conf_t & jcp)2121 void jit_avx512_core_amx_fwd_kernel_t::set_oh_blk_limits(jit_conv_conf_t &jcp) {
2122 
2123     constexpr int size = sizeof(jcp.h_blk_limits) / sizeof(jcp.h_blk_limits[0]);
2124     // set default values
2125     for (int i = 0; i < size; i++)
2126         jcp.h_blk_limits[i] = jcp.oh;
2127 
2128     const bool calculate_oh_limits
2129             = jcp.t_pad_output > 0 || jcp.b_pad_output > 0;
2130     if (jcp.req_zero_point_buffer && calculate_oh_limits) {
2131 
2132         int limit_idx = 0;
2133         const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2134 
2135         // full t_pad output block
2136         const int t_pad_blk_end = rnd_dn(jcp.t_pad_output, oh_step_size);
2137         if (jcp.t_pad_output >= oh_step_size) {
2138             jcp.h_blk_limits[limit_idx++] = t_pad_blk_end;
2139         }
2140         // t_pad output overlap with no padding
2141         const int t_pad_shift = jcp.t_pad_output % oh_step_size;
2142         if (t_pad_shift != 0) {
2143             jcp.h_blk_limits[limit_idx++] = t_pad_blk_end + t_pad_shift;
2144         }
2145         const int t_pad_next_blk = rnd_up(jcp.t_pad_output, oh_step_size);
2146         const int oh_blk_tail = jcp.oh % oh_step_size;
2147         const int b_pad_no_tail = nstl::max(0, jcp.b_pad_output - oh_blk_tail);
2148         const int b_pad_start
2149                 = nstl::max(jcp.t_pad_output, jcp.oh - jcp.b_pad_output);
2150         const int b_pad_blk_start = rnd_dn(b_pad_start, oh_step_size);
2151         // middle block without padding
2152         const int mid_blk = nstl::max(0, b_pad_blk_start - t_pad_next_blk);
2153         if (mid_blk >= oh_step_size) {
2154             jcp.h_blk_limits[limit_idx++] = b_pad_blk_start;
2155         }
2156         // no padding with b_pad overlap
2157         const int b_pad_shift = b_pad_no_tail % oh_step_size;
2158         if (b_pad_shift != 0) {
2159             jcp.h_blk_limits[limit_idx++] = rnd_up(b_pad_start, oh_step_size);
2160         }
2161         // full b_pad output block
2162         if (b_pad_no_tail >= oh_step_size) {
2163             jcp.h_blk_limits[limit_idx++] = jcp.oh - oh_blk_tail;
2164         }
2165         // b_pad tail block does not require a limit
2166     }
2167 }
2168 
set_ow_blk_limits(jit_conv_conf_t & jcp)2169 void jit_avx512_core_amx_fwd_kernel_t::set_ow_blk_limits(jit_conv_conf_t &jcp) {
2170 
2171     jcp.l_pad_blk = 0;
2172     jcp.no_pad_w_blk = 0;
2173     jcp.r_pad_blk = 0;
2174 
2175     const bool calculate_ow_limits
2176             = jcp.nb_ow > 1 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0);
2177     if (jcp.req_zero_point_buffer && calculate_ow_limits) {
2178         const int ow_step_size = jcp.ow_block;
2179 
2180         // l_pad
2181         const int l_pad_limit
2182                 = (jcp.l_pad_output >= ow_step_size ? ow_step_size : 0)
2183                 + (jcp.l_pad_output % ow_step_size);
2184         const int l_pad_area_blk = rnd_up(l_pad_limit, ow_step_size);
2185         jcp.l_pad_blk = div_up(l_pad_limit, ow_step_size);
2186 
2187         // middle (area without padding)
2188         const int no_pad_area
2189                 = nstl::max(0, jcp.ow - l_pad_area_blk - jcp.r_pad_output);
2190         jcp.no_pad_w_blk = no_pad_area >= ow_step_size ? 1 : 0;
2191 
2192         // r_pad
2193         const int no_pad_area_shift = no_pad_area % ow_step_size;
2194         const int r_pad_area_overlap
2195                 = no_pad_area_shift == 0 ? 0 : ow_step_size - no_pad_area_shift;
2196         const int r_pad_area
2197                 = nstl::max(0, jcp.r_pad_output - r_pad_area_overlap);
2198         const int r_pad_limit = (r_pad_area >= ow_step_size ? ow_step_size : 0)
2199                 + (r_pad_area % ow_step_size);
2200         jcp.r_pad_blk = (r_pad_area_overlap > 0 ? 1 : 0)
2201                 + div_up(r_pad_limit, ow_step_size);
2202     }
2203 }
2204 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)2205 status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
2206         const convolution_desc_t &cd, memory_desc_t &src_md,
2207         memory_desc_t &weights_md, memory_desc_t &dst_md,
2208         memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
2209     using namespace prop_kind;
2210 
2211     const memory_desc_wrapper src_d(&src_md);
2212     const memory_desc_wrapper weights_d(&weights_md);
2213     const memory_desc_wrapper dst_d(&dst_md);
2214     const memory_desc_wrapper bias_d(&bias_md);
2215 
2216     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
2217     int ndims = src_d.ndims();
2218     bool is_1d = ndims == 3;
2219     bool is_3d = ndims == 5;
2220 
2221     const bool is_bf16_convolution
2222             = everyone_is(true, src_d.data_type() == data_type::bf16,
2223                     weights_d.data_type() == data_type::bf16,
2224                     one_of(dst_d.data_type(), data_type::bf16, data_type::f32));
2225     const bool is_int8_convolution = everyone_is(true,
2226             (src_d.data_type() == data_type::u8
2227                     || src_d.data_type() == data_type::s8),
2228             weights_d.data_type() == data_type::s8,
2229             one_of(dst_d.data_type(), data_type::f32, data_type::s32,
2230                     data_type::s8, data_type::u8));
2231 
2232     bool supported = false
2233             || (is_bf16_convolution && mayiuse(avx512_core_bf16_amx_bf16))
2234             || (is_int8_convolution && mayiuse(avx512_core_bf16_amx_int8));
2235     if (!supported) return status::unimplemented;
2236 
2237     jcp = zero<decltype(jcp)>();
2238     jcp.isa = is_bf16_convolution ? avx512_core_bf16_amx_bf16
2239                                   : avx512_core_bf16_amx_int8;
2240     jcp.ndims = ndims;
2241     jcp.prop_kind = cd.prop_kind;
2242     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2243 
2244     jcp.mb = src_d.dims()[0];
2245     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
2246     jcp.oc_without_padding = jcp.oc;
2247     jcp.ic = src_d.dims()[1] / jcp.ngroups;
2248     jcp.ic_without_padding = jcp.ic;
2249     jcp.id = is_3d ? src_d.dims()[2] : 1;
2250     jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1;
2251     jcp.iw = src_d.dims()[ndims - 1];
2252     jcp.od = is_3d ? dst_d.dims()[2] : 1;
2253     jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1;
2254     jcp.ow = dst_d.dims()[ndims - 1];
2255     jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
2256     jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
2257     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2258     jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
2259     jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
2260     jcp.l_pad = cd.padding[0][ndims - 3];
2261     jcp.stride_d = is_3d ? cd.strides[0] : 1;
2262     jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
2263     jcp.stride_w = cd.strides[ndims - 3];
2264     jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
2265 
2266     jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0;
2267     jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
2268     jcp.dilate_w = cd.dilates[ndims - 3];
2269 
2270     const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
2271     const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
2272     const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
2273     jcp.back_pad = calculate_end_padding(
2274             jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd);
2275     jcp.b_pad = calculate_end_padding(
2276             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
2277     jcp.r_pad = calculate_end_padding(
2278             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
2279     if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
2280             || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd
2281             || jcp.back_pad >= gen_kd)
2282         return status::unimplemented;
2283 
2284     const int max_pad = 28; // akin to maximum jcp.ur_w value in other jits
2285     if (jcp.l_pad > max_pad || jcp.r_pad > max_pad)
2286         return status::unimplemented; // TODO: relax this restriction
2287 
2288     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
2289     jcp.dst_dt = cd.dst_desc.data_type;
2290     jcp.src_dt = cd.src_desc.data_type;
2291     jcp.wei_dt = cd.weights_desc.data_type;
2292 
2293     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
2294 
2295     if (jcp.is_depthwise)
2296         return status::unimplemented; // TODO: add support of DW convolution
2297 
2298     const auto zp = attr.zero_points_;
2299     jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
2300     jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
2301     jcp.zp_src_is_common = zp.common(
2302             DNNL_ARG_SRC); // otherwise, it's per-channel (not supported)
2303 
2304     if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)
2305             || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point,
2306                     is_int8_convolution))
2307         return status::unimplemented;
2308 
2309     // Calculate zero-point padding values outside of the main JIT-kernel
2310     // and store the results in an auxiliary buffer.
2311     jcp.req_zero_point_buffer = jcp.src_zero_point
2312             && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 || jcp.t_pad > 0
2313                     || jcp.f_pad > 0 || jcp.back_pad > 0);
2314 
2315     format_tag_t dat_tag_ncsp = utils::pick(ndims - 3, format_tag::nCw16c,
2316             format_tag::nChw16c, format_tag::nCdhw16c);
2317     format_tag_t dat_tag_nspc = utils::pick(
2318             ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
2319     // To toggle the default data layout for BF16 between nChw16c and nhwc,
2320     // swap the following two variable definitions. Current choice: nhwc.
2321 
2322     // Clang-tidy change - if it was intentional please revert it and
2323     // put `NOLINTNEXTLINE` to suppress the warning.
2324     format_tag_t dat_tag_opt = dat_tag_nspc;
2325     format_tag_t dat_tag_alt
2326             = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
2327 
2328     if (src_d.format_kind() == format_kind::any) {
2329         CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
2330         jcp.src_tag = dat_tag_opt;
2331     } else
2332         jcp.src_tag = src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
2333 
2334     if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
2335         return status::unimplemented;
2336 
2337     jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
2338     assert(IMPLICATION(is_int8_convolution, jcp.is_nspc));
2339 
2340     // TODO: remove all support for nChw16c from this implementation
2341     if (!jcp.is_nspc) return status::unimplemented;
2342 
2343     if (dst_d.format_kind() == format_kind::any) {
2344         CHECK(memory_desc_init_by_tag(dst_md, jcp.src_tag));
2345         jcp.dst_tag = jcp.src_tag;
2346     } else
2347         jcp.dst_tag = dst_d.matches_one_of_tag(jcp.src_tag);
2348 
2349     if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
2350 
2351     if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
2352         CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
2353 
2354     jcp.nthr = nthreads;
2355 
2356     jcp.ic_block = 16;
2357     jcp.oc_block = 16;
2358 
2359     if (jcp.ngroups == 1) {
2360         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2361         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2362     }
2363     bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
2364     if (!args_ok) return status::unimplemented;
2365 
2366     const int vnni_width = is_bf16_convolution ? 2 : 4;
2367     jcp.ic_block_int = jcp.ic_block * vnni_width; // 32 for bf16, 64 for int8
2368 
2369     // fallback to non-amx impl when accumulation is too small
2370     const dim_t total_k = jcp.ic_without_padding * jcp.kd * jcp.kh * jcp.kw;
2371     const bool is_tiny_k = total_k < jcp.ic_block_int / 2;
2372     if (is_tiny_k) return status::unimplemented;
2373 
2374     // small-ic parameters
2375     jcp.ic_block_int_np = jcp.is_nspc
2376             ? nstl::min(jcp.ic_block_int, jcp.ic_without_padding)
2377             : jcp.ic_block_int;
2378     bool is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2379 
2380     // reduced lowering
2381     jcp.is_relo = (!is_3d)
2382             && is_small_ic
2383             // no trivial cases
2384             && 1 < jcp.kh * jcp.kw
2385             // required for use of VPERMB instruction in weights copy kernel
2386             && IMPLICATION(is_int8_convolution,
2387                     cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))
2388             // no dilation or excessive stride along w-direction
2389             && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
2390             // no dilation or excessive stride along h-direction
2391             && jcp.stride_h <= jcp.kh && jcp.stride_w <= jcp.kw;
2392     jcp.nreduce = jcp.kh * jcp.kw * jcp.ic_block_int_np;
2393 
2394     if (!jcp.is_relo) {
2395         jcp.ic_block_int_np = is_bf16_convolution
2396                 ? jcp.ic_block_int
2397                 : rnd_up(jcp.ic_block_int_np, vnni_width);
2398         is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2399     }
2400 
2401     // k-remainders
2402     jcp.kw_per_tile = is_small_ic && !jcp.is_relo && jcp.dilate_w == 0
2403                     && jcp.stride_w <= jcp.kw // TODO: relax this restriction
2404                     && jcp.kw * jcp.ic_block_int_np <= jcp.ic_block_int
2405             ? jcp.kw
2406             : 1;
2407     jcp.is_pbuffer_strided = (1 == jcp.kw_per_tile);
2408     jcp.n_stride_sets
2409             = jcp.is_pbuffer_strided ? nstl::min(jcp.stride_w, jcp.kw) : 1;
2410     jcp.kw_step = jcp.is_pbuffer_strided ? jcp.stride_w : jcp.kw_per_tile;
2411 
2412     if (attr.set_default_formats(&dst_md) != status::success)
2413         return status::unimplemented;
2414 
2415     const auto &p = attr.post_ops_;
2416 
2417     const int sum_ind = p.find(primitive_kind::sum);
2418     jcp.with_sum = sum_ind != -1;
2419     const int eltwise_ind = p.find(primitive_kind::eltwise);
2420     jcp.with_eltwise = eltwise_ind != -1;
2421     const int binary_ind = p.find(primitive_kind::binary);
2422     jcp.with_binary = binary_ind != -1;
2423     jcp.sum_dt = p.get_sum_dt(jcp.dst_dt);
2424 
2425     jcp.post_ops = p;
2426 
2427     using namespace injector;
2428     const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16);
2429     const bool sum_requires_scale_one = sum_at_pos_0_only;
2430     const bool sum_requires_zp_zero = sum_at_pos_0_only;
2431     const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
2432             jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
2433             sum_requires_zp_zero});
2434     if (!post_ops_ok_) return status::unimplemented;
2435 
2436     auto set_or_check_wei_format = [&]() {
2437         using namespace format_tag;
2438         using namespace memory_extra_flags;
2439         format_tag_t wei_tag;
2440         wei_tag = jcp.is_relo ? pick(with_groups + 2 * (ndims - 3), Owi16o,
2441                           gOwi16o, Owhi16o, gOwhi16o) // no 3d support
2442                               : is_bf16_convolution
2443                         ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
2444                                 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i,
2445                                 OIdhw16i16o2i, gOIdhw16i16o2i)
2446                         : is_small_ic ? pick(with_groups + 2 * (ndims - 3),
2447                                   OwI16o4i, gOwI16o4i, OhwI16o4i, gOhwI16o4i,
2448                                   OdhwI16o4i, gOdhwI16o4i)
2449                                       : pick(with_groups + 2 * (ndims - 3),
2450                                               OIw16i16o4i, gOIw16i16o4i,
2451                                               OIhw16i16o4i, gOIhw16i16o4i,
2452                                               OIdhw16i16o4i, gOIdhw16i16o4i);
2453 
2454         memory_desc_t want_wei_md = weights_md;
2455         memory_desc_init_by_tag(want_wei_md, wei_tag);
2456 
2457         if (jcp.src_zero_point) {
2458             want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
2459             want_wei_md.extra.asymm_compensation_mask = (1 << 0)
2460                     + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
2461         }
2462         if (weights_md.format_kind == format_kind::any) {
2463             weights_md = want_wei_md;
2464             return true;
2465         }
2466         return weights_md == want_wei_md;
2467     };
2468 
2469     if (!set_or_check_wei_format()) return status::unimplemented;
2470 
2471     jcp.typesize_in = types::data_type_size(src_d.data_type());
2472     jcp.typesize_out = types::data_type_size(dst_d.data_type());
2473     jcp.typesize_bia
2474             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
2475     jcp.typesize_acc = sizeof(int32_t);
2476 
2477     jcp.nb_ic = jcp.ic / jcp.ic_block;
2478     jcp.nb_oc = jcp.oc / jcp.oc_block;
2479     jcp.nb_ic_int = div_up(jcp.ic, jcp.ic_block_int);
2480 
2481     jcp.nb_oc_blocking_thr_chunk = 1;
2482 
2483     const int max_palette = amx::get_max_palette();
2484     jcp.max_tiles = amx::get_max_tiles(max_palette);
2485     jcp.full_tile_width = amx::get_max_rows(max_palette);
2486     if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
2487         return status::unimplemented;
2488 
2489     // Pack n rows per tile, such that:
2490     // ow + (ow + gen_kw - 1) * (n - 1) <= jcp.full_tile_width
2491     auto calculate_tile_width = [&](int n) {
2492         assert(n > 0);
2493         return jcp.ow + (gen_kw + jcp.ow - 1) * (n - 1);
2494     };
2495     const bool ok_to_pack_tile = !jcp.is_relo
2496             && (utils::everyone_is(1, jcp.kh, jcp.kw)
2497                     || utils::everyone_is(1, jcp.stride_h, jcp.stride_w));
2498     const int max_oh_per_tile
2499             = 1 + (jcp.full_tile_width - jcp.ow) / (jcp.ow + gen_kw - 1);
2500     jcp.oh_per_tile = ok_to_pack_tile
2501             ? nstl::min(jcp.oh, nstl::max(1, max_oh_per_tile))
2502             : 1;
2503     jcp.tile_width = nstl::min<int>(
2504             jcp.full_tile_width, calculate_tile_width(jcp.oh_per_tile));
2505     jcp.ow_blocks = utils::div_up(jcp.ow, jcp.tile_width);
2506 
2507     // Prefer to use a single tile width when possible
2508     // (eg ow28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
2509     if (jcp.oh_per_tile == 1 && jcp.ow % jcp.ow_blocks == 0)
2510         jcp.tile_width = jcp.ow / jcp.ow_blocks;
2511     jcp.tile_tail = jcp.oh_per_tile == 1 ? jcp.ow % jcp.tile_width : 0;
2512 
2513     jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1;
2514     jcp.nb_ic_blocking = 1;
2515     jcp.nb_oh_blocking
2516             = utils::everyone_is(true, jcp.tile_tail == 0,
2517                       // requirement for interleave stores
2518                       IMPLICATION(jcp.ow_blocks > 1, jcp.oh % 2 == 0),
2519                       // requirement for small spatial
2520                       utils::div_up(jcp.oh, jcp.oh_per_tile) > 1,
2521                       // choose maximal pbuffer overlap for reduced lowering
2522                       !jcp.is_relo)
2523             ? 2
2524             : 1;
2525 
2526     // TODO: tune oh blocking
2527     const int oh_blk_size_param = jcp.is_relo ? 1 : 10;
2528     const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2529     const int oh_blk_size = rnd_up(oh_blk_size_param, oh_step_size);
2530     jcp.oh_blk_size = rnd_up(nstl::min(jcp.oh, oh_blk_size), oh_step_size);
2531     // Here ihp means the input buffer height including padding (ie the number
2532     // of input rows required for computation of jcp.oh_blk_size output rows.
2533     // If an input row doesn't participate in the computation of any output row,
2534     // it isn't copied to the buffer at all (eg jcp.stride_h > gen_kh).
2535     jcp.ihp = jcp.is_relo
2536             ? jcp.oh_blk_size
2537             : (jcp.oh_blk_size - 1) * nstl::min(jcp.stride_h, gen_kh) + gen_kh;
2538 
2539     // TODO: tune ow blocking
2540     const int ow_blocks_per_call = jcp.is_relo ? 10 : 2;
2541     jcp.ow_block = nstl::min(jcp.ow, jcp.tile_width * ow_blocks_per_call);
2542     jcp.nb_ow = utils::div_up(jcp.ow, jcp.ow_block);
2543     // iwp includes all width elements that are really used in calculation
2544     // including left and right zero padding
2545     const bool are_sets_interleaved
2546             = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
2547     jcp.iwp = are_sets_interleaved
2548             ? (jcp.ow_block - 1) * nstl::min(jcp.stride_w, jcp.kw) + gen_kw
2549             : jcp.ow_block * jcp.kw;
2550 
2551     // Number of ops per tile store
2552     int ops_tile_store = jcp.tile_width;
2553     // Number of ops per accumulation tile
2554     int avaliable_ops = jcp.is_relo
2555             ? utils::div_up(jcp.nreduce, jcp.ic_block_int)
2556             : jcp.nb_ic_int * jcp.kh * (jcp.kw / jcp.kw_per_tile);
2557     // Number of vectors to store per tile operation
2558     // NOTE: set to zero to turn off interleave store (mostly for debugging)
2559     jcp.per_one_pstore = utils::div_up(ops_tile_store, avaliable_ops);
2560 
2561     if (jcp.is_relo) {
2562         jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.ihp * jcp.iwp * jcp.kh
2563                         * jcp.ic_block_int_np
2564                 // pbuffer pointer shifts each oh step for reduced-lowering
2565                 + (jcp.oh - 1) * jcp.stride_h * jcp.ic_block_int_np
2566                 // extra $line due to pbuffer writing full Zmm
2567                 + jcp.ic_block_int;
2568     } else {
2569         jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.kd
2570                 * ((size_t)jcp.ihp * jcp.iwp * jcp.ic_block_int_np
2571                         // extra $line due to pbuffer writing full Zmm
2572                         + jcp.ic_block_int);
2573     }
2574     jcp.wei_buffer_size = (size_t)jcp.ngroups * jcp.nb_oc
2575             * rnd_up(jcp.kh * jcp.kw * jcp.ic * jcp.oc_block, 1024);
2576     jcp.wsp_buffer_size = (size_t)jcp.nb_oh_blocking * jcp.nb_oc_blocking
2577             * jcp.full_tile_width * jcp.oc_block;
2578 
2579     const auto &oscales = attr.output_scales_;
2580     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
2581 
2582     // Note: currently unsupported, results in seg-fault
2583     const int l_pad_output = nstl::min(jcp.ow, div_up(jcp.l_pad, jcp.stride_w));
2584     if (!jcp.is_relo && (l_pad_output > jcp.ow_block))
2585         return status::unimplemented;
2586 
2587     // Relevant to 'zero_point padding buffer' (pbuff) jit kernel
2588     if (jcp.req_zero_point_buffer) {
2589         auto calculate_output_padding_dims = [=](int o_dim, int s_pad,
2590                                                      int e_pad,
2591                                                      int &s_pad_output,
2592                                                      int &e_pad_output,
2593                                                      bool &o_mid, int &o_pad,
2594                                                      int stride,
2595                                                      bool req_mid_area) {
2596             s_pad_output = nstl::min(o_dim, div_up(s_pad, stride));
2597             e_pad_output = nstl::min(o_dim, div_up(e_pad, stride));
2598             o_mid = (o_dim - s_pad_output - e_pad_output > 0) && req_mid_area;
2599             o_pad = nstl::min(o_dim,
2600                     nstl::max(1, s_pad_output + e_pad_output + (int)o_mid));
2601         };
2602 
2603         const bool mid_w_area = (jcp.l_pad > 0 || jcp.r_pad > 0)
2604                 && (jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
2605                         || jcp.back_pad > 0);
2606         const bool mid_h_area = (jcp.t_pad > 0 || jcp.b_pad > 0)
2607                 && (jcp.l_pad > 0 || jcp.r_pad > 0 || jcp.f_pad > 0
2608                         || jcp.back_pad > 0);
2609         const bool mid_d_area = (jcp.f_pad > 0 || jcp.back_pad > 0)
2610                 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0
2611                         || jcp.t_pad > 0);
2612         calculate_output_padding_dims(jcp.ow, jcp.l_pad, jcp.r_pad,
2613                 jcp.l_pad_output, jcp.r_pad_output, jcp.ow_mid, jcp.ow_pad,
2614                 jcp.stride_w, mid_w_area);
2615         calculate_output_padding_dims(jcp.oh, jcp.t_pad, jcp.b_pad,
2616                 jcp.t_pad_output, jcp.b_pad_output, jcp.oh_mid, jcp.oh_pad,
2617                 jcp.stride_h, mid_h_area);
2618         calculate_output_padding_dims(jcp.od, jcp.f_pad, jcp.back_pad,
2619                 jcp.f_pad_output, jcp.back_pad_output, jcp.od_mid, jcp.od_pad,
2620                 jcp.stride_d, mid_d_area);
2621         jcp.zp_pbuff_size
2622                 = jcp.od_pad * jcp.oh_pad * jcp.ow_pad * jcp.oc * jcp.ngroups;
2623 
2624         // compute zero-point padding kernel outside of the main parallel
2625         // region when threads are more likely to parallelize work across mb
2626         // within the convolution compute block.
2627         jcp.zp_pbuff_outer_compute = jcp.mb > 1 || is_3d;
2628 
2629         const bool params_ok = ((jcp.ow_pad - (int)jcp.ow_mid) <= max_pad * 2);
2630         if (!params_ok) { return status::unimplemented; }
2631     }
2632 
2633     // Set default parameters for driver code, but mostly required for
2634     // 'zero_point padding buffer' (pbuff) accumulation over output tensor
2635     set_oh_blk_limits(jcp);
2636     set_ow_blk_limits(jcp);
2637 
2638     return status::success;
2639 }
2640 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)2641 status_t jit_avx512_core_amx_fwd_kernel_t::init_scratchpad(
2642         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
2643         const primitive_attr_t &attr) {
2644 
2645     size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
2646     scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
2647     if (jcp.is_relo) {
2648         scratchpad.book(
2649                 key_conv_amx_wei_buffer, jcp.wei_buffer_size, jcp.typesize_in);
2650     }
2651 
2652     size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
2653     scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
2654     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
2655         assert(jcp.ngroups == 1);
2656         scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
2657     }
2658     scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
2659     if (jcp.req_zero_point_buffer) {
2660         const int nthr = jcp.zp_pbuff_outer_compute ? 1 : jcp.nthr;
2661         scratchpad.book(key_conv_zero_point_pad,
2662                 (size_t)nthr * jcp.zp_pbuff_size, sizeof(int32_t));
2663         if (!jcp.zp_pbuff_outer_compute) {
2664             const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
2665             scratchpad.book<bool>(key_conv_zero_point_flag,
2666                     (size_t)jcp.nthr * oc_chunks * jcp.ngroups);
2667         }
2668     }
2669 
2670     // Keep scratchpad memory footprint under control
2671     const size_t L2_size_per_core = platform::get_per_core_cache_size(2);
2672     const size_t L3_size_per_core = platform::get_per_core_cache_size(3);
2673     const size_t max_scratchpad_size
2674             = jcp.nthr * (L2_size_per_core + L3_size_per_core);
2675     // TODO: tune this relationship as needed
2676     if (scratchpad.size() > max_scratchpad_size) return status::unimplemented;
2677     return status::success;
2678 }
2679 
copy_row(const bool is_masked)2680 void jit_avx512_core_amx_bwd_data_copy_kernel_t::copy_row(
2681         const bool is_masked) {
2682     assert(jcp.is_nspc && "no support for nChw16c in this copy kernel");
2683 
2684     const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
2685     const int inp_w_step
2686             = jcp.ngroups * jcp.oc_without_padding * jcp.typesize_in;
2687     const int inp_h_step = jcp.ow * inp_w_step;
2688     const int out_w_step = jcp.oc_block_int * jcp.typesize_in;
2689     const int out_h_step = jcp.owp * out_w_step;
2690 
2691     auto zero_it = [=](reg64_t tmp_out_ptr, int offset) {
2692         // no mask as output is a padded buffer
2693         if (is_bf16)
2694             vmovdqu16(ptr[tmp_out_ptr + offset], zmm_zero);
2695         else
2696             vmovdqu8(ptr[tmp_out_ptr + offset], zmm_zero);
2697     };
2698 
2699     auto copy_it = [=](reg64_t tmp_inp_ptr, int inp_off, reg64_t tmp_out_ptr,
2700                            int out_off) {
2701         Zmm zmm_load = is_masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
2702         Zmm zmm_stor = zmm_tmp; // no mask as output is padded buffer
2703         if (is_bf16) {
2704             vmovdqu16(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2705             vmovdqu16(ptr[tmp_out_ptr + out_off], zmm_stor);
2706         } else {
2707             vmovdqu8(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2708             vmovdqu8(ptr[tmp_out_ptr + out_off], zmm_stor);
2709         }
2710     };
2711 
2712     mov(reg_ptr_aux_out, reg_ptr_out);
2713 
2714     { // Handle Top Overflow
2715         Label label_tov_loop, label_tov_skip;
2716         test(reg_tov, reg_tov);
2717         jz(label_tov_skip, T_NEAR);
2718         mov(reg_cnt_tmp, reg_tov);
2719         L(label_tov_loop);
2720         {
2721             for (int ow = 0; ow < jcp.owp; ow++) {
2722                 const int offset = ow * out_w_step;
2723                 zero_it(reg_ptr_aux_out, offset);
2724             }
2725             add(reg_ptr_aux_out, out_h_step);
2726             dec(reg_cnt_tmp);
2727             jnz(label_tov_loop, T_NEAR);
2728         }
2729         L(label_tov_skip);
2730     }
2731 
2732     mov(reg_ptr_aux_inp_h, reg_ptr_inp);
2733 
2734     // Handle Middle Loop
2735     Label label_khp_loop, label_khp_skip;
2736     test(reg_khp, reg_khp);
2737     jz(label_khp_skip, T_NEAR);
2738     mov(reg_cnt_khp, reg_khp);
2739     L(label_khp_loop);
2740     {
2741         Label label_lov, label_lov_skip;
2742         Label label_kwp, label_kwp_skip;
2743         Label label_rov, label_rov_skip;
2744         test(reg_lov, reg_lov);
2745         jnz(label_lov, T_NEAR);
2746         test(reg_kwp, reg_kwp);
2747         jnz(label_kwp, T_NEAR);
2748         test(reg_rov, reg_rov);
2749         jnz(label_rov, T_NEAR);
2750 
2751         test(reg_lov, reg_lov);
2752         jz(label_lov_skip, T_NEAR); // not really needed, but just to be safe
2753         L(label_lov); // Handle Left Overflow
2754         {
2755             Label label_lov_loop;
2756             mov(reg_cnt_tmp, reg_lov);
2757             L(label_lov_loop);
2758             {
2759                 zero_it(reg_ptr_aux_out, 0);
2760                 add(reg_ptr_aux_out, out_w_step);
2761                 dec(reg_cnt_tmp);
2762                 jnz(label_lov_loop, T_NEAR);
2763             }
2764         }
2765         L(label_lov_skip);
2766 
2767         test(reg_kwp, reg_kwp);
2768         jz(label_kwp_skip, T_NEAR);
2769         L(label_kwp); // Handle Center Loop
2770         {
2771             Label label_kwp_loop;
2772             mov(reg_ptr_aux_inp_w, reg_ptr_aux_inp_h);
2773             mov(reg_cnt_tmp, reg_kwp);
2774             L(label_kwp_loop);
2775             {
2776                 copy_it(reg_ptr_aux_inp_w, 0, reg_ptr_aux_out, 0);
2777                 add(reg_ptr_aux_out, out_w_step);
2778                 add(reg_ptr_aux_inp_w, inp_w_step);
2779                 dec(reg_cnt_tmp);
2780 
2781                 if (jcp.stride_w > 1) {
2782                     jz(label_kwp_skip, T_NEAR);
2783                     // Handle Dilation-by-Stride
2784                     for (int sw = 0; sw < jcp.stride_w - 1; sw++) {
2785                         const int offset = sw * out_w_step;
2786                         zero_it(reg_ptr_aux_out, offset);
2787                     }
2788                     add(reg_ptr_aux_out, (jcp.stride_w - 1) * out_w_step);
2789                     if (jcp.stride_w == 2)
2790                         dec(reg_cnt_tmp);
2791                     else
2792                         sub(reg_cnt_tmp, jcp.stride_w - 1);
2793                     jmp(label_kwp_loop, T_NEAR);
2794                 } else {
2795                     jnz(label_kwp_loop, T_NEAR);
2796                 }
2797             }
2798         }
2799         L(label_kwp_skip);
2800 
2801         test(reg_rov, reg_rov);
2802         jz(label_rov_skip, T_NEAR);
2803         L(label_rov); // Handle Right Overflow
2804         {
2805             Label label_rov_loop;
2806             mov(reg_cnt_tmp, reg_rov);
2807             L(label_rov_loop);
2808             {
2809                 zero_it(reg_ptr_aux_out, 0);
2810                 add(reg_ptr_aux_out, out_w_step);
2811                 dec(reg_cnt_tmp);
2812                 jnz(label_rov_loop, T_NEAR);
2813             }
2814         }
2815         L(label_rov_skip);
2816 
2817         add(reg_ptr_aux_inp_h, inp_h_step);
2818         dec(reg_cnt_khp);
2819 
2820         if (jcp.stride_h > 1) {
2821             jz(label_khp_skip, T_NEAR);
2822             // Handle Dilation-by-Stride
2823             for (int sh = 0; sh < jcp.stride_h - 1; sh++) {
2824                 for (int ow = 0; ow < jcp.owp; ow++) {
2825                     const int offset = sh * out_h_step + ow * out_w_step;
2826                     zero_it(reg_ptr_aux_out, offset);
2827                 }
2828             }
2829             add(reg_ptr_aux_out, (jcp.stride_h - 1) * out_h_step);
2830             if (jcp.stride_h == 2)
2831                 dec(reg_cnt_khp);
2832             else
2833                 sub(reg_cnt_khp, jcp.stride_h - 1);
2834             jmp(label_khp_loop, T_NEAR);
2835         } else {
2836             jnz(label_khp_loop, T_NEAR);
2837         }
2838     }
2839     L(label_khp_skip);
2840 
2841     { // Handle Bottom Overflow
2842         Label label_bov_loop, label_bov_skip;
2843         test(reg_bov, reg_bov);
2844         jz(label_bov_skip, T_NEAR);
2845         mov(reg_cnt_tmp, reg_bov);
2846         L(label_bov_loop);
2847         {
2848             for (int ow = 0; ow < jcp.owp; ow++) {
2849                 const int offset = ow * out_w_step;
2850                 zero_it(reg_ptr_aux_out, offset);
2851             }
2852             add(reg_ptr_aux_out, out_h_step);
2853             dec(reg_cnt_tmp);
2854             jnz(label_bov_loop, T_NEAR);
2855         }
2856         L(label_bov_skip);
2857     }
2858 }
2859 
generate()2860 void jit_avx512_core_amx_bwd_data_copy_kernel_t::generate() {
2861 
2862     const int inp_c_step = jcp.oc_block_int * jcp.typesize_in;
2863     const int out_c_step = jcp.ohp * jcp.owp * inp_c_step;
2864     const int nb_oc_int_no_tail = jcp.oc_without_padding / jcp.oc_block_int;
2865     const int oc_block_int_tail = jcp.oc_without_padding % jcp.oc_block_int;
2866 
2867     preamble();
2868 
2869     // pointer to 1st needed element in src buffer
2870     mov(reg_ptr_inp, ptr[param1 + GET_OFF(src)]);
2871     // pointer to 1st needed element in dst buffer
2872     mov(reg_ptr_out, ptr[param1 + GET_OFF(dst)]);
2873 
2874     // number of rows of src buffer to copy
2875     mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
2876     // number of zero-padded rows above src buffer to copy
2877     mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
2878     // number of zero-padded rows below src buffer to copy
2879     mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
2880 
2881     // number of columns of src buffer to copy
2882     mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
2883     // number of zero-padded columns before src buffer to copy
2884     mov(reg_lov, ptr[param1 + GET_OFF(l_overflow)]);
2885     // number of zero-padded columns before src buffer to copy
2886     mov(reg_rov, ptr[param1 + GET_OFF(r_overflow)]);
2887 
2888     vpxord(zmm_zero, zmm_zero, zmm_zero);
2889 
2890     if (oc_block_int_tail > 0) {
2891         uint64_t mask = (UINT64_C(1) << oc_block_int_tail) - 1;
2892         mov(reg_tmp, mask);
2893         kmovq(ktail_mask, reg_tmp);
2894     }
2895 
2896     if (nb_oc_int_no_tail == 0) {
2897         copy_row(true); // masked
2898     } else if (nb_oc_int_no_tail == 1) {
2899         copy_row(false); // unmasked!
2900         if (oc_block_int_tail > 0) {
2901             add(reg_ptr_inp, inp_c_step);
2902             add(reg_ptr_out, out_c_step);
2903             copy_row(true); // masked
2904         }
2905     } else if (nb_oc_int_no_tail > 1) {
2906         mov(reg_cnt_ocb, nb_oc_int_no_tail);
2907         Label label_ocb_loop;
2908         L(label_ocb_loop);
2909         {
2910             copy_row(false); // unmasked!
2911             add(reg_ptr_inp, inp_c_step);
2912             add(reg_ptr_out, out_c_step);
2913             dec(reg_cnt_ocb);
2914             jnz(label_ocb_loop);
2915         }
2916         if (oc_block_int_tail > 0) copy_row(true); // masked
2917     }
2918 
2919     postamble();
2920 }
2921 
2922 // Tile register decomposition
2923 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i) const2924 int jit_avx512_core_amx_bwd_data_kernel_t::get_out_tensor(int h, int i) const {
2925     const int C_BASE = 0;
2926     const int C_LAST = 4;
2927     assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
2928     MAYBE_UNUSED(C_LAST);
2929     const int tile = C_BASE + h * jcp.nb_ih_blocking + i;
2930     assert(C_BASE <= tile && tile < C_LAST);
2931     return tile;
2932 }
get_inp_tensor(int h) const2933 int jit_avx512_core_amx_bwd_data_kernel_t::get_inp_tensor(int h) const {
2934     const int I_BASE = 4;
2935     const int I_LAST = 6;
2936     assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
2937     MAYBE_UNUSED(I_LAST);
2938     const int tile = I_BASE + h;
2939     assert(I_BASE <= tile && tile < I_LAST);
2940     return tile;
2941 }
get_wei_tensor(int i) const2942 int jit_avx512_core_amx_bwd_data_kernel_t::get_wei_tensor(int i) const {
2943     const int W_BASE = 6;
2944     const int W_LAST = 8;
2945     assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
2946     MAYBE_UNUSED(W_LAST);
2947     const int tile = W_BASE + i;
2948     assert(W_BASE <= tile && tile < W_LAST);
2949     return tile;
2950 }
2951 
2952 // Strides, shifts and offsets
2953 // - inp is a padded buffer ~ [nb_oc_int][ohp][owp]{32c,64c}
2954 // - weights is user buffer ~ OIhw16o16i{2o,4o}
2955 // - output is tiled buffer ~ [NBIH][NBIC][tile_width][16c]
get_inp_kh_step() const2956 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_kh_step() const {
2957     return (size_t)jcp.typesize_in * (jcp.dilate_h + 1) * jcp.owp
2958             * jcp.oc_block_int;
2959 }
get_inp_ocb_step() const2960 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_ocb_step() const {
2961     return (size_t)jcp.typesize_in * jcp.ohp * jcp.owp * jcp.oc_block_int;
2962 }
get_inp_shift() const2963 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_shift() const {
2964     return (size_t)jcp.typesize_in * jcp.tile_width * jcp.oc_block_int;
2965 }
get_inp_offset(int ihb,int kh,int kw) const2966 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_offset(
2967         int ihb, int kh, int kw) const {
2968     // calculate offset by src height dimension
2969     size_t sp_offset = (size_t)ihb * jcp.owp;
2970     // add offset by kernel height dimension
2971     sp_offset += (size_t)(jcp.kh - 1 - kh) * (jcp.dilate_h + 1) * jcp.owp;
2972     // add offset by kernel width dimension
2973     sp_offset += (size_t)(jcp.kw - 1 - kw) * (jcp.dilate_w + 1);
2974     return jcp.typesize_in * sp_offset * jcp.oc_block_int;
2975 }
get_wei_kh_step() const2976 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_kh_step() const {
2977     return (size_t)jcp.typesize_in * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2978 }
get_wei_ocb_step() const2979 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_ocb_step() const {
2980     const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2981     return (size_t)jcp.typesize_in * (is_deconv ? 1 : jcp.nb_ic) * jcp.kh
2982             * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2983 }
get_wei_offset(int icb,int kh,int kw) const2984 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_offset(
2985         int icb, int kh, int kw) const {
2986     const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2987     const size_t wei_kw_stride = jcp.oc_block_int * jcp.ic_block;
2988     const size_t wei_kh_stride = jcp.kw * wei_kw_stride;
2989     const size_t wei_icb_stride
2990             = (is_deconv ? jcp.nb_oc_int : 1) * jcp.kh * wei_kh_stride;
2991     return jcp.typesize_in
2992             * (icb * wei_icb_stride + kh * wei_kh_stride + kw * wei_kw_stride);
2993 }
get_out_icb_offset(int ihb,int icb) const2994 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_icb_offset(
2995         int ihb, int icb) const {
2996     size_t el_offset = jcp.is_nspc
2997             ? (size_t)icb * jcp.ic_block
2998                     + (size_t)ihb * jcp.iw * jcp.ngroups
2999                             * jcp.ic_without_padding
3000             : (size_t)icb * jcp.ih * jcp.iw * jcp.ic_block
3001                     + (size_t)ihb * jcp.iw * jcp.ic_block;
3002     return (size_t)jcp.typesize_out * el_offset;
3003 }
get_out_row_offset(int ihb,int icb,int j) const3004 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_row_offset(
3005         int ihb, int icb, int j) const {
3006     size_t offset_w = jcp.is_nspc ? (size_t)jcp.typesize_out * j * jcp.ngroups
3007                     * jcp.ic_without_padding
3008                                   : (size_t)jcp.typesize_out * j * jcp.ic_block;
3009     return get_out_icb_offset(ihb, icb) + offset_w;
3010 }
get_out_shift(int width) const3011 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_shift(int width) const {
3012     return jcp.is_nspc ? (size_t)jcp.typesize_out * width * jcp.ngroups
3013                     * jcp.ic_without_padding
3014                        : (size_t)jcp.typesize_out * width * jcp.ic_block;
3015 }
get_wsp_icb_offset(int ihb,int icb) const3016 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_icb_offset(
3017         int ihb, int icb) const {
3018     size_t el_offset = (size_t)icb * prv_width_ * jcp.ic_block
3019             + (size_t)ihb * jcp.nb_ic_blocking * jcp.full_tile_width
3020                     * jcp.ic_block;
3021     return jcp.typesize_acc * el_offset;
3022 }
get_wsp_row_offset(int ihb,int icb,int j) const3023 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_row_offset(
3024         int ihb, int icb, int j) const {
3025     return get_wsp_icb_offset(ihb, icb)
3026             + (size_t)jcp.typesize_acc * j * jcp.ic_block;
3027 }
3028 
3029 // Code generation
prepare_output()3030 void jit_avx512_core_amx_bwd_data_kernel_t::prepare_output() {
3031     for (int h = 0; h < jcp.nb_ih_blocking; h++)
3032         for (int i = 0; i < jcp.nb_ic_blocking; i++)
3033             tilezero(Tmm(get_out_tensor(h, i)));
3034 }
3035 
init_runtime_counters(bool start_with_last_tile_block)3036 void jit_avx512_core_amx_bwd_data_kernel_t::init_runtime_counters(
3037         bool start_with_last_tile_block) {
3038     prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
3039             ? jcp.tile_tail
3040             : jcp.tile_width;
3041 
3042     row_count_ = 0;
3043     is_store_done_ = false;
3044     is_buffer_empty_ = true;
3045 }
3046 
maybe_eltwise(int position)3047 bool jit_avx512_core_amx_bwd_data_kernel_t::maybe_eltwise(int position) {
3048     using namespace primitive_kind;
3049     const auto &p = attr_.post_ops_;
3050 
3051     if (position == 0) {
3052         /* eltwise before sum */
3053         return p.contain(eltwise, 0);
3054     } else if (position == 1) {
3055         /* eltwise after sum */
3056         return p.contain(sum, 0) && p.contain(eltwise, 1);
3057     }
3058 
3059     return false;
3060 }
3061 
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)3062 Ymm jit_avx512_core_amx_bwd_data_kernel_t::ymm_mask(
3063         const Ymm &ymm_in, bool mask_flag, bool store) {
3064     return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
3065                      : ymm_in;
3066 }
3067 
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)3068 Zmm jit_avx512_core_amx_bwd_data_kernel_t::zmm_mask(
3069         const Zmm &zmm_in, bool mask_flag, bool store) {
3070     return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
3071                      : zmm_in;
3072 }
3073 
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)3074 void jit_avx512_core_amx_bwd_data_kernel_t::cvt2ps(data_type_t type_in,
3075         const Zmm &zmm_in, const Operand &op, bool mask_flag) {
3076     const Zmm zmm = zmm_mask(zmm_in, mask_flag);
3077     switch (type_in) {
3078         case data_type::f32:
3079         case data_type::s32: vmovups(zmm, op); break;
3080         case data_type::s8: vpmovsxbd(zmm, op); break;
3081         case data_type::u8: vpmovzxbd(zmm, op); break;
3082         default: assert(!"unsupported data type");
3083     }
3084     if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
3085 }
3086 
store_output_vector_bf16(const Zmm & zmm_out,int icb,int h,int w)3087 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_bf16(
3088         const Zmm &zmm_out, int icb, int h, int w) {
3089     const bool mask_flag = jcp.is_nspc && jcp.ic_without_padding != jcp.ic
3090             && icb == (jcp.nb_ic_blocking - 1);
3091 
3092     auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3093 
3094     const auto &p = attr_.post_ops_;
3095 
3096     const int sum_idx = p.find(primitive_kind::sum);
3097     if (sum_idx != -1) {
3098         if (jcp.dsrc_dt == data_type::bf16) {
3099             vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
3100             vpslld(zmm_prev_dst, zmm_prev_dst, 16);
3101             vaddps(zmm_out, zmm_prev_dst);
3102         } else {
3103             vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
3104             vaddps(zmm_out, zmm_prev_dst);
3105         }
3106     }
3107     if (jcp.with_bias) {
3108         int bias_offset = jcp.typesize_bia * icb * jcp.ic_block;
3109         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3110         if (jcp.bia_dt == data_type::bf16) {
3111             vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
3112             vpslld(zmm_bias, zmm_bias, 16);
3113             vaddps(zmm_out, zmm_bias);
3114         } else
3115             vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
3116     }
3117 
3118     const int eltwise_ind = p.find(primitive_kind::eltwise);
3119     if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx());
3120 
3121     if (jcp.dsrc_dt == data_type::bf16) {
3122         Ymm ymm_out = Ymm(zmm_out.getIdx());
3123         vcvtneps2bf16(ymm_out, zmm_out);
3124         vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
3125     } else {
3126         vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
3127     }
3128 }
3129 
store_output_vector_int8(const Zmm & zmm_out,int icb,int h,int w)3130 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3131         const Zmm &zmm_out, int icb, int h, int w) {
3132     const int nb_ic_block = jcp.nb_ic_blocking;
3133     const int ic_block = jcp.ic_block;
3134     const bool mask_flag = true && jcp.ic_without_padding != jcp.ic
3135             && icb == (nb_ic_block - 1);
3136 
3137     auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3138 
3139     const auto &p = attr_.post_ops_;
3140     const int sum_idx = p.find(primitive_kind::sum);
3141     const float *p_sum_scale = nullptr;
3142     const int32_t *p_sum_zp = nullptr;
3143     if (sum_idx != -1) {
3144         const auto &p_entry = p.entry_[sum_idx];
3145         p_sum_scale = &p_entry.sum.scale;
3146         p_sum_zp = &p_entry.sum.zero_point;
3147     }
3148 
3149     if (p_sum_scale) {
3150         if (*p_sum_scale != 1.f)
3151             mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
3152         if (*p_sum_zp != 0)
3153             mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
3154     }
3155 
3156     int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
3157     if (jcp.with_bias) {
3158         int bias_offset = jcp.typesize_bia * icb * ic_block;
3159         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3160         cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
3161     }
3162     /* add bias to zmm_accum */
3163     vcvtdq2ps(zmm_out, zmm_out);
3164     if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
3165     const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
3166     vmulps(zmm_out_msk, zmm_out,
3167             EVEX_compress_addr(reg_ptr_scales, scale_offset));
3168 
3169     /* Do post-ops */
3170     if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3171     if (p_sum_scale) { // post_op: sum
3172         cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag);
3173         if (*p_sum_zp != 0) {
3174             vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
3175             vsubps(zmm_prev_dst, zmm_sum_zp);
3176         }
3177         if (*p_sum_scale == 1.f)
3178             vaddps(zmm_out, zmm_prev_dst);
3179         else
3180             vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
3181     }
3182     if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3183 
3184     // Properly saturate the accumulators for integer datatypes
3185     if (one_of(jcp.dsrc_dt, u8, s8, s32)) {
3186         init_saturate_f32(
3187                 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dsrc_dt);
3188         saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dsrc_dt);
3189         vcvtps2dq(zmm_out, zmm_out);
3190     }
3191 
3192     const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
3193 
3194     switch (jcp.dsrc_dt) {
3195         case data_type::f32:
3196         case data_type::s32: vmovups(addr, zmm_out_store); break;
3197         case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
3198         case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
3199         default: assert(!"unknown dst_dt");
3200     }
3201 }
3202 
store_output_vector(const Zmm & zmm_out,int icb,int h,int w)3203 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector(
3204         const Zmm &zmm_out, int icb, int h, int w) {
3205     /*
3206     Output:
3207               jcp.is_nspc              !jcp.is_nspc
3208               ---------------------    ---------------------
3209         INT8: [N][H][W][NBIC][16IC]
3210         BF16: [N][H][W][NBIC][16IC] or [N][NBIC][H][W][16IC]
3211     */
3212     if (jcp.ddst_dt == data_type::bf16) {
3213         store_output_vector_bf16(zmm_out, icb, h, w);
3214     } else {
3215         store_output_vector_int8(zmm_out, icb, h, w);
3216     }
3217 }
3218 
store_output(int width,bool do_store)3219 void jit_avx512_core_amx_bwd_data_kernel_t::store_output(
3220         int width, bool do_store) {
3221     auto store_output_block = [=](int width, bool do_store,
3222                                       bool is_last_ih_blks) {
3223         // Calculate the number of ih blocks; it may differ on last call
3224         const int n_ih_blks = is_last_ih_blks ? jcp.ih % jcp.nb_ih_blocking
3225                                               : jcp.nb_ih_blocking;
3226         for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3227             for (int ihb = 0; ihb < n_ih_blks; ihb++) {
3228                 /* Formats: Workspace: [NBIH][NBIC][W][16OC] */
3229                 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
3230                                    + get_wsp_icb_offset(ihb, icb)],
3231                         Tmm(get_out_tensor(ihb, icb)));
3232                 is_buffer_empty_ = false;
3233                 is_store_done_ = false;
3234                 for (int tw = 0; tw < width && do_store; tw++) {
3235                     Zmm zmm_out = Zmm(tw);
3236                     vmovups(zmm_out,
3237                             ptr[reg_wsp_ptr
3238                                     + get_wsp_row_offset(ihb, icb, tw)]);
3239                     store_output_vector(zmm_out, icb, ihb, tw);
3240                 }
3241             }
3242         }
3243     };
3244 
3245     // adjustment in case interleave store is turned off
3246     do_store = do_store || jcp.per_one_pstore == 0;
3247     if (jcp.ih % jcp.nb_ih_blocking == 0) {
3248         store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3249     } else {
3250         Label label_full_store, label_done;
3251         cmp(reg_last_h, 0);
3252         jne(label_full_store, T_NEAR);
3253         store_output_block(width, do_store, /* is_last_ih_blks = */ true);
3254         jmp(label_done, T_NEAR);
3255         L(label_full_store);
3256         store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3257         L(label_done);
3258     }
3259     if (do_store) add(reg_out_ptr, get_out_shift(width));
3260 }
3261 
interleave_store(int width)3262 void jit_avx512_core_amx_bwd_data_kernel_t::interleave_store(int width) {
3263     for (int c = 0;
3264             c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
3265             c++) {
3266         // row_count = ihb * ICB * TW + icb * TW + tw
3267         int tw = row_count_ % prv_width_;
3268         int icb = (row_count_ / prv_width_) % jcp.nb_ic_blocking;
3269         int ihb = (row_count_ / prv_width_) / jcp.nb_ic_blocking;
3270 
3271         Zmm zmm_out = Zmm(tw);
3272         vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]);
3273         store_output_vector(zmm_out, icb, ihb, tw);
3274         row_count_++;
3275 
3276         if (row_count_
3277                 == prv_width_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking) {
3278             add(reg_out_ptr, get_out_shift(prv_width_));
3279             row_count_ = 0;
3280             is_store_done_ = true;
3281             prv_width_ = width;
3282         }
3283     }
3284 }
3285 
compute_ocb_loop(int width,bool do_store)3286 void jit_avx512_core_amx_bwd_data_kernel_t::compute_ocb_loop(
3287         int width, bool do_store) {
3288 
3289     auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
3290         switch (jcp.ddst_dt) {
3291             using namespace data_type;
3292             case bf16: tdpbf16ps(x1, x2, x3); break;
3293             case s8: tdpbssd(x1, x2, x3); break;
3294             case u8: tdpbusd(x1, x2, x3); break;
3295             default: assert(!"unsupported data type");
3296         }
3297     };
3298 
3299     prepare_output();
3300 
3301     for (int ocb = 0; ocb < jcp.nb_oc_int; ocb++) {
3302         // reverse order through spatial components of weights so that
3303         // input buffer is accessed in a monotonically increasing fashion
3304         for (int kh = jcp.kh - 1; kh >= 0; kh--) {
3305             for (int kw = jcp.kw - 1; kw >= 0; kw--) {
3306                 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3307                     tileloadd(Tmm(get_inp_tensor(ihb)),
3308                             ptr[reg_inp_ptr + get_inp_offset(ihb, kh, kw)
3309                                     + reg_inp_stride]);
3310                 }
3311                 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3312                     tileloadd(Tmm(get_wei_tensor(icb)),
3313                             ptr[reg_wei_ptr + get_wei_offset(icb, kh, kw)
3314                                     + reg_wei_stride]);
3315                     for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3316                         tdpbxxd(Tmm(get_out_tensor(ihb, icb)),
3317                                 Tmm(get_inp_tensor(ihb)),
3318                                 Tmm(get_wei_tensor(icb)));
3319                         interleave_store(width);
3320                     }
3321                 }
3322             }
3323         }
3324         add(reg_inp_ptr, get_inp_ocb_step());
3325         add(reg_wei_ptr, get_wei_ocb_step());
3326     }
3327     sub(reg_inp_ptr, get_inp_ocb_step() * jcp.nb_oc_int);
3328     sub(reg_wei_ptr, get_wei_ocb_step() * jcp.nb_oc_int);
3329 
3330     store_output(width, do_store);
3331 
3332     add(reg_inp_ptr, get_inp_shift());
3333 }
3334 
compute_iw_loop()3335 void jit_avx512_core_amx_bwd_data_kernel_t::compute_iw_loop() {
3336     auto compute_iw_loop_body = [=](bool last_iwb, int num_tile_blocks) {
3337         int gen_tile_tail = last_iwb && jcp.tile_tail > 0 ? jcp.tile_tail
3338                                                           : jcp.tile_width;
3339         init_runtime_counters(last_iwb && num_tile_blocks == 1);
3340         for (int iwb = 0; iwb < num_tile_blocks - 1; iwb++)
3341             compute_ocb_loop(jcp.tile_width, false);
3342         compute_ocb_loop(gen_tile_tail, true);
3343     };
3344 
3345     if (jcp.nb_iw == 1) {
3346         compute_iw_loop_body(true, jcp.iw_blocks);
3347     } else {
3348         Label label_done;
3349         int iw_blocks_per_call = div_up(jcp.iw_block, jcp.tile_width);
3350         int last_iwb_tile_blocks = jcp.iw_blocks % iw_blocks_per_call;
3351         if (last_iwb_tile_blocks == 0 && jcp.tile_tail > 0)
3352             last_iwb_tile_blocks = iw_blocks_per_call;
3353         if (last_iwb_tile_blocks > 0) {
3354             Label label_not_last_iwb;
3355             mov(reg_tmp, ptr[param1 + GET_OFF(iwb)]);
3356             cmp(reg_tmp, jcp.nb_iw - 1);
3357             jne(label_not_last_iwb, T_NEAR);
3358 
3359             compute_iw_loop_body(true, last_iwb_tile_blocks);
3360 
3361             jmp(label_done, T_NEAR);
3362 
3363             L(label_not_last_iwb);
3364         }
3365         compute_iw_loop_body(false, iw_blocks_per_call);
3366 
3367         L(label_done);
3368     }
3369 }
3370 
generate()3371 void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
3372     preamble();
3373 
3374     mov(reg_inp_ptr, ptr[param1 + GET_OFF(dst)]); // padded buffer of diff_dst
3375     mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); // weights
3376     mov(reg_out_ptr, ptr[param1 + GET_OFF(src)]); // diff_src
3377     mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
3378 
3379     if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
3380 
3381     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
3382 
3383     mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
3384 
3385     const int inp_stride = jcp.oc_block_int * jcp.typesize_in;
3386     const int wei_stride = jcp.ic_block * jcp.typesize_acc;
3387     mov(reg_inp_stride, inp_stride);
3388     mov(reg_wei_stride, wei_stride);
3389 
3390     if (jcp.is_nspc && jcp.ic_without_padding != jcp.ic) {
3391         // Use mask 0xF by default for all output data and post-ops
3392         // loads / stores with block index
3393         // icb = icc * jcp.nb_ic_blocking + (jcp.nb_ic_blocking - 1)
3394         // TODO: use masked loads / stores for the last icc only
3395         int current_block_size = jcp.ic_block;
3396         int mask = (1 << current_block_size) - 1;
3397         Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
3398         mov(regw_tmp, mask);
3399         kmovw(ktail_mask, regw_tmp);
3400         Xbyak::Label mask_is_set;
3401         mov(reg_ic_blocks, ptr[param1 + GET_OFF(ic_blocks)]);
3402         cmp(reg_ic_blocks, jcp.nb_ic - jcp.nb_ic_blocking);
3403         jne(mask_is_set, T_NEAR);
3404         // Reset the mask
3405         current_block_size = jcp.ic_without_padding % jcp.ic_block;
3406         mask = (1 << current_block_size) - 1;
3407         mov(regw_tmp, mask);
3408         kmovw(ktail_mask, regw_tmp);
3409 
3410         L(mask_is_set);
3411     }
3412     compute_iw_loop();
3413 
3414     postamble();
3415 
3416     if (jcp.with_eltwise) eltwise_injector_->prepare_table();
3417 }
3418 
post_ops_ok(const jit_conv_conf_t & jcp,primitive_attr_t & attr)3419 bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
3420         const jit_conv_conf_t &jcp, primitive_attr_t &attr) {
3421     using namespace primitive_kind;
3422     const auto &p = attr.post_ops_;
3423     const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
3424 
3425     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
3426 
3427     auto is_sum = [&](int idx) {
3428         if (is_bf16)
3429             return p.entry_[idx].is_sum();
3430         else
3431             return p.contain(sum, idx);
3432     };
3433 
3434     switch (p.len()) {
3435         case 0: return true;
3436         case 1: return is_eltwise(0) || is_sum(0);
3437         case 2:
3438             return (is_sum(0) && is_eltwise(1))
3439                     || (!is_bf16 && is_sum(1) && is_eltwise(0));
3440         default: return false;
3441     }
3442 
3443     return false;
3444 }
3445 
tile_configure(char * tcfg_buff)3446 void jit_avx512_core_amx_bwd_data_kernel_t::tile_configure(char *tcfg_buff) {
3447     const int vnni_width = jcp.ddst_dt == data_type::bf16 ? 2 : 4;
3448     // Input tile dimensions
3449     const int a_col = jcp.oc_block_int;
3450     const int a_row = jcp.tile_width;
3451     // Weights tile dimensions
3452     const int b_col = jcp.ic_block * vnni_width;
3453     const int b_row = a_col / vnni_width;
3454     // Accumulator tile dimensions
3455     const int c_col = jcp.ic_block;
3456     const int c_row = a_row;
3457 
3458     for (size_t i = 0; i < 64; i++)
3459         tcfg_buff[i] = 0;
3460 
3461     // Weights (W_BASE) Tensor Tiles
3462     for (int i = 0; i < jcp.nb_ic_blocking; i++)
3463         tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
3464                 b_row, b_col * jcp.typesize_in);
3465 
3466     // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
3467     for (int h = 0; h < jcp.nb_ih_blocking; h++) {
3468         tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
3469                 a_row, a_col * jcp.typesize_in);
3470         for (int i = 0; i < jcp.nb_ic_blocking; i++)
3471             tc_configure_tile((palette_config_t *)tcfg_buff,
3472                     get_out_tensor(h, i), c_row, c_col * jcp.typesize_acc);
3473     }
3474 
3475     ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3476 }
3477 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & diff_src_md,memory_desc_t & weights_md,memory_desc_t & diff_dst_md,memory_desc_t * bias_md,primitive_attr_t & attr,int nthreads)3478 status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3479         const convolution_desc_t &cd, memory_desc_t &diff_src_md,
3480         memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
3481         memory_desc_t *bias_md, primitive_attr_t &attr, int nthreads) {
3482     using namespace prop_kind;
3483 
3484     const memory_desc_wrapper diff_src_d(&diff_src_md);
3485     const memory_desc_wrapper weights_d(&weights_md);
3486     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
3487     const memory_desc_wrapper bias_d(bias_md);
3488 
3489     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
3490     int ndims = diff_src_d.ndims();
3491     bool is_1d = ndims == 3;
3492     bool is_3d = ndims == 5;
3493 
3494     if (is_3d) return status::unimplemented;
3495 
3496     using namespace data_type;
3497     const bool is_deconv = cd.prop_kind != prop_kind::backward_data;
3498     const bool is_bf16 = everyone_is(true, diff_dst_d.data_type() == bf16,
3499             weights_d.data_type() == bf16,
3500             one_of(diff_src_d.data_type(), bf16, f32));
3501     const bool is_bf16_convolution = is_bf16 && !is_deconv;
3502     const bool is_bf16_deconvolution = is_bf16 && is_deconv;
3503     const bool is_int8_deconvolution = is_deconv
3504             && everyone_is(true, one_of(diff_dst_d.data_type(), s8, u8),
3505                     weights_d.data_type() == s8,
3506                     one_of(diff_src_d.data_type(), f32, s32, s8, u8));
3507 
3508     bool supported = false || (is_bf16 && mayiuse(avx512_core_bf16_amx_bf16))
3509             || (is_int8_deconvolution && mayiuse(avx512_core_bf16_amx_int8));
3510     if (!supported) return status::unimplemented;
3511 
3512     jcp = zero<decltype(jcp)>();
3513     jcp.isa = is_bf16 ? avx512_core_bf16_amx_bf16 : avx512_core_bf16_amx_int8;
3514     jcp.ndims = ndims;
3515     jcp.prop_kind = cd.prop_kind;
3516     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
3517 
3518     jcp.mb = diff_src_d.dims()[0];
3519     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
3520     jcp.oc_without_padding = jcp.oc;
3521     jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
3522     jcp.ic_without_padding = jcp.ic;
3523     jcp.ih = !is_1d ? diff_src_d.dims()[ndims - 2] : 1;
3524     jcp.iw = diff_src_d.dims()[ndims - 1];
3525     jcp.oh = !is_1d ? diff_dst_d.dims()[ndims - 2] : 1;
3526     jcp.ow = diff_dst_d.dims()[ndims - 1];
3527     jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
3528     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
3529     jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
3530     jcp.l_pad = cd.padding[0][ndims - 3];
3531     jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
3532     jcp.stride_w = cd.strides[ndims - 3];
3533 
3534     // No bias for bf16 case to simplify integration with ref_deconvolution
3535     jcp.with_bias = bias_md && !is_bf16_convolution
3536             && cd.bias_desc.format_kind != format_kind::undef;
3537 
3538     jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
3539     jcp.dilate_w = cd.dilates[ndims - 3];
3540 
3541     const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
3542     const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
3543     jcp.b_pad = calculate_end_padding(
3544             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
3545     jcp.r_pad = calculate_end_padding(
3546             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
3547     if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
3548             || jcp.b_pad >= gen_kh)
3549         return status::unimplemented;
3550 
3551     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
3552     if (is_deconv) {
3553         jcp.ddst_dt = cd.src_desc.data_type;
3554         jcp.dsrc_dt = cd.dst_desc.data_type;
3555     } else {
3556         jcp.ddst_dt = cd.diff_dst_desc.data_type;
3557         jcp.dsrc_dt = cd.diff_src_desc.data_type;
3558     }
3559     jcp.wei_dt = cd.weights_desc.data_type;
3560 
3561     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
3562 
3563     if (jcp.is_depthwise)
3564         return status::unimplemented; // TODO: add support of DW convolution
3565 
3566     format_tag_t dat_tag_ncsp
3567             = pick(ndims - 3, format_tag::nCw16c, format_tag::nChw16c);
3568     format_tag_t dat_tag_nspc
3569             = pick(ndims - 3, format_tag::nwc, format_tag::nhwc);
3570     // To toggle the default data layout for BF16 between nChw16c and nhwc,
3571     // swap the following two variable definitions. Current choice: nhwc.
3572     format_tag_t dat_tag_opt = dat_tag_nspc;
3573     format_tag_t dat_tag_alt = is_bf16 ? dat_tag_ncsp : dat_tag_nspc;
3574 
3575     if (diff_src_d.format_kind() == format_kind::any) {
3576         CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_opt));
3577         jcp.src_tag = dat_tag_opt;
3578     } else
3579         jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
3580 
3581     if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
3582         return status::unimplemented;
3583 
3584     jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3585     assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc));
3586 
3587     // TODO: remove all support for nChw16c from this implementation
3588     if (!jcp.is_nspc) return status::unimplemented;
3589 
3590     if (diff_dst_d.format_kind() == format_kind::any) {
3591         CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
3592         jcp.dst_tag = jcp.src_tag;
3593     } else
3594         jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
3595 
3596     if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
3597 
3598     if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
3599         CHECK(memory_desc_init_by_tag(*bias_md, format_tag::x));
3600 
3601     jcp.nthr = nthreads;
3602 
3603     jcp.ic_block = 16;
3604     jcp.oc_block = 16;
3605 
3606     if (jcp.ngroups == 1) {
3607         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
3608         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
3609     }
3610     bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
3611     if (!args_ok) return status::unimplemented;
3612 
3613     const int vnni_width = is_bf16 ? 2 : 4;
3614     jcp.oc_block_int = jcp.oc_block * vnni_width; // 32 for bf16, 64 for int8
3615 
3616     if (attr.set_default_formats(&diff_src_md) != status::success)
3617         return status::unimplemented;
3618     if (!post_ops_ok(jcp, attr)) return status::unimplemented;
3619 
3620     const auto &p = attr.post_ops_;
3621     const int eltwise_ind = p.find(primitive_kind::eltwise);
3622     jcp.with_eltwise = eltwise_ind != -1;
3623     if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
3624 
3625     auto set_or_check_wei_format = [&]() {
3626         using namespace format_tag;
3627         format_tag_t wei_tag;
3628         if (is_bf16_convolution)
3629             wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16o16i2o,
3630                     gOIw16o16i2o, OIhw16o16i2o, gOIhw16o16i2o);
3631         else if (is_bf16_deconvolution)
3632             wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
3633                     gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i);
3634         else if (is_int8_deconvolution)
3635             wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i,
3636                     gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i);
3637         else {
3638             assert(!"unsupported combination");
3639             return false;
3640         }
3641 
3642         memory_desc_t want_wei_md = weights_md;
3643         memory_desc_init_by_tag(want_wei_md, wei_tag);
3644 
3645         if (weights_md.format_kind == format_kind::any) {
3646             weights_md = want_wei_md;
3647             return true;
3648         }
3649         return weights_md == want_wei_md;
3650     };
3651 
3652     if (!set_or_check_wei_format()) return status::unimplemented;
3653 
3654     jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
3655     jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
3656     jcp.typesize_bia
3657             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
3658     jcp.typesize_acc = sizeof(int32_t);
3659 
3660     jcp.nb_ic = jcp.ic / jcp.ic_block;
3661     jcp.nb_oc = jcp.oc / jcp.oc_block;
3662     jcp.nb_oc_int = div_up(jcp.oc, jcp.oc_block_int);
3663 
3664     const int max_palette = amx::get_max_palette();
3665     jcp.max_tiles = amx::get_max_tiles(max_palette);
3666     jcp.full_tile_width = amx::get_max_rows(max_palette);
3667     if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
3668         return status::unimplemented;
3669 
3670     jcp.tile_width = nstl::min(jcp.full_tile_width, jcp.iw);
3671     jcp.iw_blocks = div_up(jcp.iw, jcp.tile_width);
3672 
3673     // Prefer to use a single tile width when possible
3674     // (eg iw28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
3675     if (jcp.iw % jcp.iw_blocks == 0) jcp.tile_width = jcp.iw / jcp.iw_blocks;
3676     jcp.tile_tail = jcp.iw % jcp.tile_width;
3677 
3678     jcp.nb_ic_blocking = (jcp.nb_ic % 2 == 0) ? 2 : 1;
3679     jcp.nb_ih_blocking
3680             = everyone_is(true, jcp.ih > 1,
3681                       // requirement for interleave stores
3682                       IMPLICATION(jcp.iw_blocks > 1, jcp.ih % 2 == 0))
3683             ? 2
3684             : 1;
3685 
3686     // TODO: tune ih blocking
3687     const int ih_blk_size_tmp = 10;
3688     const int ih_step = jcp.nb_ih_blocking;
3689     jcp.ih_blk_size = rnd_up(nstl::min(jcp.ih, ih_blk_size_tmp), ih_step);
3690     // ohp includes all elements that are really used in calculation,
3691     // including zero-padded "dilate-by-strides" and top and bottom overflow
3692     jcp.ohp = jcp.ih_blk_size + gen_kh - 1;
3693 
3694     // TODO: tune iw blocking
3695     const int iw_blocks_per_call = 2;
3696     jcp.iw_block = jcp.tile_width * iw_blocks_per_call;
3697     jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
3698     // owp includes all elements that are really used in calculation,
3699     // including zero-padded "dilate-by-strides" and left and right overflow
3700     jcp.owp = jcp.iw_block + gen_kw - 1;
3701 
3702     // Number of ops per tile store
3703     int ops_tile_store = jcp.tile_width;
3704     // Number of ops per accumulation tile
3705     int avaliable_ops = jcp.nb_oc_int * jcp.kh * jcp.kw;
3706     // Number of vectors to store per tile operation
3707     // NOTE: set to zero to turn off interleave store (mostly for debugging)
3708     jcp.per_one_pstore = div_up(ops_tile_store, avaliable_ops);
3709 
3710     jcp.inp_buffer_size
3711             = (size_t)jcp.nb_oc_int * jcp.ohp * jcp.owp * jcp.oc_block_int;
3712     jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking
3713             * jcp.full_tile_width * jcp.ic_block;
3714 
3715     const auto &oscales = attr.output_scales_;
3716     jcp.is_ic_scale = oscales.mask_ == 1 << 1;
3717 
3718     return status::success;
3719 }
3720 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)3721 void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad(
3722         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
3723         const primitive_attr_t &attr) {
3724 
3725     size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
3726     scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
3727     size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
3728     scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
3729     if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) {
3730         assert(jcp.ngroups == 1);
3731         scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia);
3732     }
3733     scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
3734 }
3735 
3736 const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32;
3737 
3738 // Tile register decomposition
3739 // { C_BASE = 0, A_BASE = 4, B_BASE = 6, }
get_wei_tensor(int ocb,int icb) const3740 int jit_avx512_core_amx_bwd_weights_kernel_t::get_wei_tensor(
3741         int ocb, int icb) const {
3742     const int C_BASE = 0;
3743     const int C_LAST = 4;
3744     assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
3745     MAYBE_UNUSED(C_LAST);
3746     const int tile = C_BASE + ocb * jcp.nb_oc_blocking + icb;
3747     assert(C_BASE <= tile && tile < C_LAST);
3748     return tile;
3749 }
get_src_tensor(int icb) const3750 int jit_avx512_core_amx_bwd_weights_kernel_t::get_src_tensor(int icb) const {
3751     const int A_BASE = 4;
3752     const int A_LAST = 6;
3753     assert(0 <= A_BASE && A_BASE < A_LAST && A_LAST <= jcp.max_tiles);
3754     MAYBE_UNUSED(A_LAST);
3755     const int tile = A_BASE + icb;
3756     assert(A_BASE <= tile && tile < A_LAST);
3757     return tile;
3758 }
get_ddst_tensor(int ocb) const3759 int jit_avx512_core_amx_bwd_weights_kernel_t::get_ddst_tensor(int ocb) const {
3760     const int B_BASE = 6;
3761     const int B_LAST = 8;
3762     assert(0 <= B_BASE && B_BASE < B_LAST && B_LAST <= jcp.max_tiles);
3763     MAYBE_UNUSED(B_LAST);
3764     const int tile = B_BASE + ocb;
3765     assert(B_BASE <= tile && tile < B_LAST);
3766     return tile;
3767 }
3768 
tile_configure(char * tcfg_buff)3769 void jit_avx512_core_amx_bwd_weights_kernel_t::tile_configure(char *tcfg_buff) {
3770     // Input tile dimensions
3771     const int a_col = jcp.ur_w;
3772     const int a_row = jcp.ic_block;
3773     // Weights tile dimensions
3774     const int b_col = jcp.oc_block * 2;
3775     const int b_row = a_col / 2;
3776     // Accumulator tile dimensions
3777     const int c_col = jcp.oc_block;
3778     const int c_row = a_row;
3779 
3780     for (size_t i = 0; i < 64; i++)
3781         tcfg_buff[i] = 0;
3782 
3783     for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3784         tc_configure_tile((palette_config_t *)tcfg_buff, get_src_tensor(icb),
3785                 a_row, a_col * jcp.typesize_in);
3786 
3787     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3788         tc_configure_tile((palette_config_t *)tcfg_buff, get_ddst_tensor(ocb),
3789                 b_row, b_col * jcp.typesize_in);
3790 
3791     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3792         for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3793             tc_configure_tile((palette_config_t *)tcfg_buff,
3794                     get_wei_tensor(ocb, icb), c_row, c_col * jcp.typesize_out);
3795 
3796     ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3797 }
3798 
od_step_comeback_pointers()3799 void jit_avx512_core_amx_bwd_weights_kernel_t::od_step_comeback_pointers() {
3800     Label kd_comeback_label;
3801     mov(kj, reg_kd_count);
3802     L(kd_comeback_label);
3803     {
3804         sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3805         sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3806         dec(kj);
3807         jnz(kd_comeback_label, T_NEAR);
3808     }
3809 }
3810 
oh_step_comeback_pointers()3811 void jit_avx512_core_amx_bwd_weights_kernel_t::oh_step_comeback_pointers() {
3812     Label kh_comeback_label;
3813     mov(kj, reg_kh);
3814     L(kh_comeback_label);
3815     {
3816         sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3817         sub(reg_kernel, get_kernel_offset(0, jcp.kw));
3818         dec(kj);
3819         jnz(kh_comeback_label, T_NEAR);
3820     }
3821 }
3822 
compute_full_spat_loop(int nb_ic_blocking,int nb_oc_blocking)3823 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop(
3824         int nb_ic_blocking, int nb_oc_blocking) {
3825     // General code layout:
3826     //
3827     // Blocking over OH -- top level
3828     // (Reduces L2 pressure; not very useful right now)
3829     //  Loop over all KHxKW kernel -- emit_kh_kw_loop()
3830     //    Loop over OH block -- emit_h_loop()
3831     //      Loop over OW blocks -- emit_fma_block()
3832     //      (Supports both fully unrolled and partially unrolled
3833     //      versions to reduce code size)
3834     //          Loop over OW block -- emit_fma_step()
3835 
3836     auto src_row_size = get_src_offset(0, 0, 1);
3837     auto ddst_row_size = get_ddst_offset(0, 1);
3838     auto row_size = src_row_size + ddst_row_size;
3839 
3840     int h_block_size = jcp.oh;
3841     int h_last_block_size = h_block_size;
3842     int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3843     auto working_set_size = row_size * h_block_size;
3844 
3845     if (working_set_size > full_spat_max_working_set_size) {
3846         assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3847 
3848         while (working_set_size > full_spat_opt_working_set_size
3849                 && h_block_size >= min_h_block_size) {
3850             for (int i = 2; i <= h_block_size; i++)
3851                 if (i == h_block_size)
3852                     h_block_size = h_block_size / 2;
3853                 else if (h_block_size % i == 0) {
3854                     h_block_size = h_block_size / i;
3855                     break;
3856                 }
3857             working_set_size = row_size * h_block_size;
3858         }
3859         h_block_size = nstl::max(min_h_block_size, h_block_size);
3860         h_last_block_size = jcp.oh % h_block_size;
3861         if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3862     }
3863 
3864     Opmask reg_h_block = k1;
3865     Reg64 reg_kh = rax;
3866     Reg64 reg_kw = rbx;
3867     Reg64 reg_tmp = abi_not_param1;
3868     Reg32 reg_tmp_w = reg_tmp.cvt32();
3869     Reg64 reg_ohs = rdx;
3870     Reg64 reg_ihs = rsi;
3871     Reg64 reg_h = r8;
3872     Reg64 reg_j = r10;
3873 
3874     Reg64 reg_src = r13;
3875     Reg64 reg_ddst = r14;
3876     Reg64 reg_ker = r15;
3877 
3878     Reg64 reg_dense_stride = abi_param1;
3879     Reg64 reg_a_stride = reg_tmp;
3880 
3881     auto emit_block = [&]() {
3882         mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
3883         for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
3884             dim_t ur_w_src_offset = ur_w_b * get_src_offset(0, jcp.ur_w);
3885             dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
3886 
3887             for (int icb = 0; icb < nb_ic_blocking; icb++) {
3888                 dim_t icb_offset = jcp.typesize_in * icb * jcp.tr_src_buf_size;
3889                 tileloadd(Tmm(get_src_tensor(icb)),
3890                         ptr[reg_src + reg_a_stride + icb_offset
3891                                 + ur_w_src_offset]);
3892             }
3893             for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
3894                 tileloadd(Tmm(get_ddst_tensor(ocb)),
3895                         ptr[reg_ddst + reg_dense_stride
3896                                 + jcp.typesize_in * ocb
3897                                         * jcp.tr_diff_dst_buf_size
3898                                 + ur_w_ddst_offset]);
3899                 for (int icb = 0; icb < nb_ic_blocking; icb++)
3900                     tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
3901                             Tmm(get_src_tensor(icb)),
3902                             Tmm(get_ddst_tensor(ocb)));
3903             }
3904         }
3905     };
3906 
3907     auto emit_h_loop = [&]() {
3908         Label h_loop, skip_h_loop;
3909         mov(reg_j, 1);
3910         cmp(reg_j, reg_h);
3911         je(skip_h_loop, T_NEAR);
3912         L(h_loop);
3913         {
3914             emit_block();
3915 
3916             add(reg_src, get_src_offset(0, 0, 1));
3917             add(reg_ddst, get_ddst_offset(0, 1));
3918             add(reg_j, 1);
3919             cmp(reg_j, reg_h);
3920             jb(h_loop);
3921         }
3922         L(skip_h_loop);
3923 
3924         emit_block();
3925     };
3926 
3927     auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3928         xor_(reg_kh, reg_kh);
3929         Label kh_loop, kh_loop_end;
3930 
3931         int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3932         // NB: this is correct because we only support t_pad = kh / 2 and thus
3933         // ih == oh
3934         int ih_block_size = oh_block_size
3935                 + (!is_first_block + !is_last_block) * jcp.t_pad;
3936 
3937         L(kh_loop);
3938         {
3939             if (is_first_block) {
3940                 xor_(reg_tmp, reg_tmp);
3941                 mov(reg_ohs, jcp.t_pad);
3942                 sub(reg_ohs, reg_kh);
3943                 cmovb(reg_ohs, reg_tmp);
3944 
3945                 mov(reg_ihs, reg_ohs);
3946                 sub(reg_ihs, jcp.t_pad);
3947                 add(reg_ihs, reg_kh);
3948             } else {
3949                 xor_(reg_ohs, reg_ohs);
3950                 mov(reg_ihs, reg_kh);
3951             }
3952 
3953             mov(reg_tmp, oh_block_size);
3954             sub(reg_tmp, reg_ohs);
3955             mov(reg_h, ih_block_size);
3956             sub(reg_h, reg_ihs);
3957             cmp(reg_tmp, reg_h);
3958             cmovb(reg_h, reg_tmp);
3959 
3960             Label kh_loop_work;
3961             cmp(reg_h, 0);
3962             jg(kh_loop_work, T_NEAR);
3963 
3964             // empty h loop for this jcp.kh:
3965             // - set the ddst to 0 if necessary
3966             // - move ker pt
3967             // - jump to the end
3968             sub(reg_h, 1);
3969             Label skip_ker_zeroing;
3970 
3971             // The reg_ker ptr has highest bit set if the ddst needs to be
3972             // zeroed. Those who have byte-aligned their data will suffer the
3973             // consequences :(
3974             // TODO: move the flag to a mask register? (Roma)
3975             test(reg_ker, 1);
3976             jz(skip_ker_zeroing, T_NEAR);
3977 
3978             Label zeroing_loop;
3979             vpxord(zmm0, zmm0, zmm0);
3980             and_(reg_ker, ~1); // temporarily clear the zeroing flag
3981 
3982             mov(reg_dense_stride, 64);
3983             tilezero(Tmm(get_wei_tensor(0, 0)));
3984             for (int kw = 0; kw < jcp.kw; kw++) {
3985                 // dim_t kw_offset = kw * get_kernel_offset(jcp.ic_block, 0);
3986                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
3987                 for (int icb = 0; icb < nb_ic_blocking; icb++)
3988                     tilestored(
3989                             ptr[reg_ker + reg_dense_stride
3990                                     + get_full_kernel_offset(ocb, icb, 0, kw)],
3991                             Tmm(get_wei_tensor(0, 0)));
3992             }
3993             // restore the zeroing flag (it will be cleared after the end of
3994             // emit_kh_kw_loop, but we may need it until then)
3995             or_(reg_ker, 1);
3996             jmp(kh_loop_end, T_NEAR);
3997 
3998             L(skip_ker_zeroing);
3999             add(reg_ker, get_kernel_offset(0, jcp.kw));
4000             jmp(kh_loop_end, T_NEAR);
4001 
4002             L(kh_loop_work);
4003 
4004             mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
4005             mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
4006 
4007             add(reg_src, reg_ihs);
4008             add(reg_ddst, reg_ohs);
4009 
4010             Label kw_loop;
4011             xor_(reg_kw, reg_kw);
4012 
4013             mov(reg_dense_stride, 64);
4014             L(kw_loop);
4015             {
4016                 Label do_zero, ker_init_done;
4017                 test(reg_ker, 1);
4018                 jnz(do_zero, T_NEAR);
4019 
4020                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4021                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4022                     tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4023                             ptr[reg_ker + reg_dense_stride
4024                                     + get_full_kernel_offset(ocb, icb, 0, 0)]);
4025                 jmp(ker_init_done);
4026                 L(do_zero);
4027                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4028                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4029                     tilezero(Tmm(get_wei_tensor(ocb, icb)));
4030 
4031                 L(ker_init_done);
4032 
4033                 mov(ptr[rsp + ddst_save_offset], reg_ddst);
4034                 mov(ptr[rsp + src_save_offset], reg_src);
4035 
4036                 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
4037                 emit_h_loop();
4038 
4039                 mov(reg_ddst, ptr[rsp + ddst_save_offset]);
4040                 mov(reg_src, ptr[rsp + src_save_offset]);
4041 
4042                 // The reg_ker ptr has highest bit set if the ddst needs to
4043                 // be zeroed. Those who have byte-aligned their data will
4044                 // suffer the consiquences :(
4045                 mov(reg_tmp, reg_ker);
4046                 and_(reg_ker, ~1);
4047 
4048                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4049                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4050                     tilestored(
4051                             ptr[reg_ker + reg_dense_stride
4052                                     + get_full_kernel_offset(ocb, icb, 0, 0)],
4053                             Tmm(get_wei_tensor(ocb, icb)));
4054 
4055                 mov(reg_ker, reg_tmp);
4056                 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
4057                 add(reg_kw, 1);
4058                 cmp(reg_kw, jcp.kw);
4059                 jl(kw_loop);
4060             }
4061 
4062             sub(reg_src, reg_ihs);
4063             sub(reg_ddst, reg_ohs);
4064 
4065             L(kh_loop_end);
4066             add(reg_kh, 1);
4067             cmp(reg_kh, jcp.kh);
4068             jl(kh_loop);
4069         }
4070     };
4071 
4072     mov(reg_src, ptr[param + GET_OFF(src)]);
4073     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4074     mov(reg_ker, ptr[param + GET_OFF(filt)]);
4075     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4076     or_(reg_ker, reg_tmp);
4077 
4078     bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
4079 
4080     auto src_row_step = get_src_offset(0, 0, 1);
4081     auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
4082     auto ddst_block_step = get_ddst_offset(0, h_block_size);
4083 
4084     emit_kh_kw_loop(true, single_kh_kw_loop);
4085 
4086     if (!single_kh_kw_loop) {
4087         auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
4088         sub(reg_ker, ker_reset_offset);
4089         and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4090 
4091         add(reg_src, first_src_block_step);
4092         add(reg_ddst, ddst_block_step);
4093 
4094         int num_innermost_iters
4095                 = (jcp.oh - h_last_block_size) / h_block_size - 1;
4096         if (num_innermost_iters > 0) {
4097             Label h_block_loop;
4098 
4099             mov(reg_tmp_w, num_innermost_iters);
4100             kmovw(reg_h_block, reg_tmp_w);
4101             L(h_block_loop);
4102             {
4103                 emit_kh_kw_loop(false, false);
4104                 sub(reg_ker, ker_reset_offset);
4105                 add(reg_src, src_row_step * h_block_size);
4106                 add(reg_ddst, ddst_block_step);
4107 
4108                 kmovw(reg_tmp_w, reg_h_block);
4109                 sub(reg_tmp_w, 1);
4110                 kmovw(reg_h_block, reg_tmp_w);
4111                 jnz(h_block_loop);
4112             }
4113         }
4114 
4115         emit_kh_kw_loop(false, true);
4116     }
4117 }
4118 
compute_ic_loop(int ic_block,int nb_ic_blocking,int nb_oc_blocking)4119 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_ic_loop(
4120         int ic_block, int nb_ic_blocking, int nb_oc_blocking) {
4121     assert(jcp.ur_w % 2 == 0);
4122     const int str_w = jcp.stride_w;
4123     assert(jcp.tr_iw % str_w == 0);
4124     const int src_stride_w_shift = jcp.tr_iw / str_w;
4125 
4126     mov(reg_b_stride, 64);
4127     mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
4128 
4129     for (int s = 0; s < str_w; s++) {
4130         for (int i_kw = s; i_kw < jcp.kw; i_kw += str_w) {
4131 
4132             for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4133                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4134                     tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4135                             ptr[reg_kernel + reg_b_stride
4136                                     + get_full_kernel_offset(
4137                                             ocb, icb, 0, i_kw)]);
4138 
4139             int src_offset_l = (i_kw * (jcp.dilate_w + 1)) / str_w
4140                     + s * src_stride_w_shift;
4141 
4142             for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
4143                 dim_t ur_w_src_offset = ur_w_b
4144                         * get_src_offset(0, filter_w_to_src(0, jcp.ur_w, 0));
4145                 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
4146                 for (int icb = 0; icb < nb_ic_blocking; icb++) {
4147                     dim_t icb_offset = icb * jcp.tr_src_buf_size;
4148                     tileloadd(Tmm(get_src_tensor(icb)),
4149                             ptr[reg_src
4150                                     + jcp.typesize_in
4151                                             * (src_offset_l + icb_offset)
4152                                     + ur_w_src_offset + reg_a_stride]);
4153                 }
4154                 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4155                     tileloadd(Tmm(get_ddst_tensor(ocb)),
4156                             ptr[reg_ddst
4157                                     + jcp.typesize_in * ocb
4158                                             * jcp.tr_diff_dst_buf_size
4159                                     + ur_w_ddst_offset + reg_b_stride]);
4160                     for (int icb = 0; icb < nb_ic_blocking; icb++)
4161                         tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
4162                                 Tmm(get_src_tensor(icb)),
4163                                 Tmm(get_ddst_tensor(ocb)));
4164                 }
4165             }
4166 
4167             for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4168                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4169                     tilestored(ptr[reg_kernel + reg_b_stride
4170                                        + get_full_kernel_offset(
4171                                                ocb, icb, 0, i_kw)],
4172                             Tmm(get_wei_tensor(ocb, icb)));
4173         }
4174     }
4175     safe_add(reg_src, get_src_offset(ic_block, 0), reg_long_offt);
4176     add(reg_kernel, get_kernel_offset(ic_block, 0));
4177 }
4178 
compute_diff_bias_init(int ocb)4179 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_init(int ocb) {
4180     auto reg_unit_val = reg_tmp.cvt16();
4181     mov(reg_unit_val, 0x3f80); // bf16 value of 1.
4182     vpbroadcastw(vreg_bias_unit, reg_unit_val);
4183 
4184     mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4185     vmovups(vreg_bias_acc, ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block]);
4186 }
4187 
compute_diff_bias_row(bool is_partial,int ocb)4188 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_row(
4189         bool is_partial, int ocb) {
4190     if (!jcp.with_bias) return;
4191     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4192     Label skip_label;
4193     test(reg_tmp, FLAG_IC_FIRST);
4194     jz(skip_label, T_NEAR);
4195 
4196     if (is_partial) { compute_diff_bias_init(ocb); }
4197     auto compute_step = [&]() {
4198         vmovups(vreg_bias_ddst, ptr[reg_ddst]);
4199         vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
4200     };
4201 
4202     Label ow_loop, ow_tail;
4203     int niters = jcp.tr_ow / 2;
4204     if (niters > 0) {
4205         mov(reg_tmp, jcp.tr_ow / 2);
4206         L(ow_loop);
4207         compute_step();
4208         add(reg_ddst, get_ddst_offset(2));
4209         sub(reg_tmp, 1);
4210         jnz(ow_loop, T_NEAR);
4211     }
4212     if (jcp.tr_ow % 2) compute_step();
4213 
4214     if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
4215 
4216     if (is_partial) {
4217         mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4218         vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4219                 vreg_bias_acc);
4220     }
4221 
4222     L(skip_label);
4223 }
4224 
maybe_compute_diff_bias(int nb_oc_blocking)4225 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_compute_diff_bias(
4226         int nb_oc_blocking) {
4227     // In harness_3d_reduction case calculation of diff_bias is called
4228     // for every ow row separately to be aligned with od loop in
4229     // compute_od_loop_common()
4230     if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
4231     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4232 
4233     Label skip_label;
4234     test(reg_tmp, FLAG_IC_FIRST);
4235     jz(skip_label, T_NEAR);
4236 
4237     for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4238         Label bias_loop, skip_label_local;
4239 
4240         mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4241         add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size);
4242 
4243         switch (jcp.harness) {
4244             case harness_2d_reduction:
4245                 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4246                 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4247                 break;
4248             case harness_mb_reduction:
4249             case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
4250             case harness_3d_reduction:
4251             default: assert(!"Invalid harness type");
4252         }
4253 
4254         cmp(reg_oj, 0);
4255         jle(skip_label_local, T_NEAR); // nothing to do
4256 
4257         compute_diff_bias_init(ocb);
4258         L(bias_loop);
4259         {
4260             compute_diff_bias_row(false, ocb);
4261             add(reg_ddst, get_ddst_offset(0, 1));
4262 
4263             sub(reg_oj, 1);
4264             jnz(bias_loop, T_NEAR);
4265         }
4266 
4267         mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4268         vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4269                 vreg_bias_acc);
4270 
4271         L(skip_label_local);
4272     }
4273     // restore reg_ddst value
4274     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4275 
4276     L(skip_label);
4277 }
4278 
compute_oh_step_common(int nb_ic_blocking,int nb_oc_blocking)4279 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_step_common(
4280         int nb_ic_blocking, int nb_oc_blocking) {
4281     Label kh_label, ic_block_label, ow_block_label, kd_label;
4282 
4283     int ic_block = jcp.ic_block;
4284     int ic_tail = jcp.ic_tail;
4285 
4286     auto ic_loop = [&](int nb_ic_blocking, int nb_oc_blocking) {
4287         Label ic_tail_label, ic_loop_done_label;
4288 
4289         if (ic_tail) {
4290             mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
4291             cmp(reg_icb, jcp.ic_tail);
4292             jne(ic_tail_label, T_NEAR);
4293 
4294             compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4295             jmp(ic_loop_done_label, T_NEAR);
4296 
4297             L(ic_tail_label);
4298             compute_ic_loop(ic_tail, nb_ic_blocking, nb_oc_blocking);
4299             add(reg_kernel, get_kernel_offset(jcp.ic_block - ic_tail, 0));
4300             safe_add(reg_src,
4301                     get_src_offset(0, 0, filter_h_to_src(1))
4302                             - get_src_offset(ic_tail, 0),
4303                     reg_long_offt);
4304             L(ic_loop_done_label);
4305         } else {
4306             compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4307         }
4308     };
4309 
4310     if (jcp.ndims == 5) {
4311         /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
4312          * 'movs' must be guaranteed. */
4313         mov(ki, reg_kd_count);
4314         mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
4315         mov(aux_reg_src, reg_src);
4316         mov(aux_reg_kernel, reg_kernel);
4317 
4318         L(kd_label);
4319         mov(reg_src, aux_reg_src);
4320         mov(reg_kernel, aux_reg_kernel);
4321     }
4322 
4323     mov(kj, reg_kh);
4324     L(kh_label);
4325     {
4326         ic_loop(nb_ic_blocking, nb_oc_blocking);
4327 
4328         if (jcp.dilate_h > 0) {
4329             add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
4330         }
4331         // substract pointer shift made within ic block loop
4332         // and move to next kh index
4333         add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
4334         dec(kj);
4335         cmp(kj, 0);
4336         jg(kh_label, T_NEAR);
4337     }
4338     if (jcp.ndims == 5) {
4339         add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
4340         add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
4341         dec(ki);
4342         cmp(ki, 0);
4343         jg(kd_label, T_NEAR);
4344     }
4345     // In harness_3d_reduction case calculation of diff_bias is called
4346     // for every ow row separately to be aligned with od loop in
4347     // compute_od_loop_common()
4348     if (jcp.harness == harness_3d_reduction) {
4349         auto reg_save_ddst = reg_a_stride;
4350         mov(reg_save_ddst, reg_ddst);
4351         for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4352             safe_add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size,
4353                     reg_long_offt);
4354             compute_diff_bias_row(true, ocb);
4355         }
4356         mov(reg_ddst, reg_save_ddst);
4357     }
4358 
4359     if (jcp.ndims == 5) {
4360         mov(reg_src, aux_reg_src);
4361         mov(reg_kernel, aux_reg_kernel);
4362         mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
4363         od_step_comeback_pointers();
4364     } else {
4365         oh_step_comeback_pointers();
4366     }
4367 }
4368 
maybe_zero_kernel(int nb_ic_blocking,int nb_oc_blocking)4369 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_zero_kernel(
4370         int nb_ic_blocking, int nb_oc_blocking) {
4371     if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
4372     Label skip_zeroing, zeroing_loop;
4373 
4374     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4375     cmp(reg_tmp, 0);
4376     jz(skip_zeroing, T_NEAR);
4377 
4378     Zmm zero = Zmm(0);
4379     vpxord(zero, zero, zero);
4380     if (jcp.with_bias) {
4381         Label skip_bias_zeroing;
4382         mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4383         test(reg_tmp, FLAG_IC_FIRST);
4384         jz(skip_bias_zeroing, T_NEAR);
4385         for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4386             mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4387             vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], zero);
4388         }
4389         L(skip_bias_zeroing);
4390         if (jcp.harness == harness_compute_full_spatial)
4391             jmp(skip_zeroing, T_NEAR);
4392     }
4393 
4394     mov(reg_b_stride, 64);
4395     tilezero(Tmm(get_wei_tensor(0, 0)));
4396     for (dim_t shift = 0;
4397             shift < get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
4398             shift += get_kernel_offset(jcp.ic_block, 0)) {
4399         for_(int icb = 0; icb < nb_ic_blocking; icb++)
4400         for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4401             tilestored(
4402                     ptr[reg_kernel + reg_b_stride
4403                             + get_full_kernel_offset(ocb, icb, 0, 0) + shift],
4404                     Tmm(get_wei_tensor(0, 0)));
4405         }
4406     }
4407     L(skip_zeroing);
4408 }
4409 
compute_oh_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4410 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_loop_common(
4411         int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4412     int b_pad = jcp.b_pad;
4413     int t_pad = jcp.t_pad;
4414 
4415     bool is_dilated = jcp.dilate_h != 0;
4416     int dilate_h = jcp.dilate_h + 1;
4417     int stride_h = jcp.stride_h;
4418     auto filter_step_size = get_kernel_offset(0, jcp.kw);
4419     auto src_step_size = get_src_offset(0, 0, 1);
4420     auto ddst_step_size = get_ddst_offset(0, 1);
4421     Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
4422             oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
4423             oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
4424             oh_dilate_label_end, oh_dilate_setup_label_shift,
4425             oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
4426 
4427     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4428     int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
4429     int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
4430     int oh_head_overflow_end = div_up(t_pad, stride_h);
4431     int oh_tail_end = jcp.oh;
4432 
4433     int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
4434     int ih_body_end
4435             = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
4436 
4437     if (is_partial)
4438         mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4439     else
4440         xor_(reg_oj, reg_oj);
4441 
4442     /* Compute 'top' edge */
4443     if (t_pad > 0) {
4444         if (is_partial) {
4445             cmp(reg_oj, oh_head_overflow_end);
4446             jge(oh_tpad_tail_label_end, T_NEAR);
4447         }
4448         const int overflow
4449                 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
4450         const int underflow = div_up(t_pad, dilate_h);
4451         const int initial_kh = jcp.kh - overflow - underflow;
4452 
4453         // Setup reg_kh, reg_kernel, and reg_src
4454         mov(reg_kh, initial_kh);
4455         add(reg_kernel, filter_step_size * underflow);
4456         if (is_dilated) {
4457             const int tail = t_pad % dilate_h;
4458             const int shift = tail == 0 ? 0 : dilate_h - tail;
4459             mov(reg_ih_shift, shift);
4460             if (!is_partial) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4461             add(reg_src, src_step_size * shift);
4462         }
4463 
4464         if (is_partial) {
4465             Label head_setup, head_setup_finish;
4466             cmp(reg_oj, 0);
4467             je(head_setup_finish, T_NEAR);
4468             mov(reg_oj_setup, reg_oj);
4469 
4470             L(head_setup);
4471             if (is_dilated) {
4472                 inc(reg_ih_shift);
4473                 cmp(reg_ih_shift, dilate_h);
4474                 jl(oh_dilate_setup_label_shift, T_NEAR);
4475                 // unshift src as new kernel element enters
4476                 sub(reg_src, src_step_size * (dilate_h - 1));
4477                 xor_(reg_ih_shift, reg_ih_shift);
4478             }
4479             // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4480             add(reg_kh, stride_h);
4481             sub(reg_kernel, filter_step_size * stride_h);
4482             if (is_dilated) {
4483                 jmp(oh_dilate_setup_label_noshift, T_NEAR);
4484                 L(oh_dilate_setup_label_shift);
4485                 // shift src as old kernel element progresses
4486                 add(reg_src, src_step_size * stride_h);
4487                 L(oh_dilate_setup_label_noshift);
4488             }
4489             sub(reg_oj_setup, 1);
4490             jg(head_setup, T_NEAR);
4491             L(head_setup_finish);
4492 
4493             if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4494             if (oh_head_end < oh_head_overflow_end) {
4495                 cmp(reg_oj, oh_head_end);
4496                 jge(oh_tpad_label_end, T_NEAR);
4497             }
4498         }
4499 
4500         //Setup reg_kernel
4501         // If dilated, shift src ptr
4502         // Loop
4503         L(oh_tpad_label);
4504         compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4505         add(reg_ddst, ddst_step_size);
4506         if (is_dilated) {
4507             mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4508             inc(reg_ih_shift);
4509             mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4510             cmp(reg_ih_shift, dilate_h);
4511             jl(oh_dilate_label_shift, T_NEAR);
4512             // unshift src as new kernel element enters
4513             sub(reg_src, src_step_size * (dilate_h - 1));
4514             xor_(reg_ih_shift, reg_ih_shift);
4515             mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4516         }
4517         // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4518         add(reg_kh, stride_h);
4519         sub(reg_kernel, filter_step_size * stride_h);
4520         if (is_dilated) {
4521             jmp(oh_dilate_label_noshift, T_NEAR);
4522             L(oh_dilate_label_shift);
4523             // shift src as old kernel element progresses
4524             add(reg_src, src_step_size * stride_h);
4525             L(oh_dilate_label_noshift);
4526         }
4527         inc(reg_oj);
4528 
4529         if (is_partial) {
4530             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4531             jge(oh_bpad_label_end, T_NEAR);
4532         }
4533         cmp(reg_oj, oh_head_end);
4534         jl(oh_tpad_label, T_NEAR);
4535 
4536         L(oh_tpad_label_end);
4537         // need second loop to process kernel if it is larger than the src
4538         // (does not apply to dilations as they must have unit stride)
4539         if (oh_head_end < oh_head_overflow_end) {
4540             assert(!is_dilated);
4541 
4542             cmp(reg_oj, oh_head_overflow_end);
4543             jge(oh_tpad_tail_label_end, T_NEAR);
4544 
4545             mov(reg_kh, jcp.ih);
4546             L(oh_tpad_tail_label);
4547             {
4548                 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4549                 add(reg_ddst, ddst_step_size);
4550                 sub(reg_kernel, filter_step_size * stride_h);
4551 
4552                 inc(reg_oj);
4553 
4554                 if (is_partial) {
4555                     cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4556                     jge(oh_bpad_label_end, T_NEAR);
4557                 }
4558                 cmp(reg_oj, oh_head_overflow_end);
4559                 jl(oh_tpad_tail_label, T_NEAR);
4560             }
4561         }
4562         if (body_src_start_offset != 0) {
4563             add(reg_kernel, filter_step_size * body_src_start_offset);
4564             add(reg_src, src_step_size * body_src_start_offset);
4565         }
4566         L(oh_tpad_tail_label_end);
4567     }
4568 
4569     if (is_partial) {
4570         cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4571         jge(oh_bpad_label_end, T_NEAR);
4572     }
4573     cmp(reg_oj, oh_body_end);
4574     jge(oh_label_end, T_NEAR);
4575 
4576     /* Compute middle block(s) */
4577     mov(reg_kh, jcp.kh);
4578     L(oh_label);
4579     {
4580         compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4581         add(reg_src, src_step_size * stride_h);
4582         add(reg_ddst, ddst_step_size);
4583 
4584         inc(reg_oj);
4585 
4586         if (is_partial) {
4587             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4588             jge(oh_bpad_label_end, T_NEAR);
4589         }
4590 
4591         cmp(reg_oj, oh_body_end);
4592         jl(oh_label, T_NEAR);
4593     }
4594     L(oh_label_end);
4595 
4596     /* Compute bottom edge */
4597     if (b_pad > 0) {
4598         if (is_partial) {
4599             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4600             jge(oh_bpad_label_end, T_NEAR);
4601         }
4602         cmp(reg_oj, jcp.oh);
4603         jge(oh_bpad_label_end, T_NEAR);
4604 
4605         if (is_dilated) {
4606             // Assumes unit stride for dilations
4607             mov(reg_kh, jcp.kh - 1);
4608             xor_(reg_ih_shift, reg_ih_shift);
4609         } else {
4610             assert(jcp.dilate_h == 0);
4611             mov(reg_kh, jcp.ih - ih_body_end);
4612         }
4613         if (is_partial) {
4614             lea(reg_oj_setup,
4615                     ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
4616             if (stride_h == 1 && !is_dilated) {
4617                 sub(reg_kh, reg_oj_setup);
4618             } else {
4619                 Label body_setup, body_setup_finish, dilate_skip;
4620                 cmp(reg_oj_setup, 0);
4621                 je(body_setup_finish, T_NEAR);
4622 
4623                 L(body_setup);
4624                 if (is_dilated) {
4625                     inc(reg_ih_shift);
4626                     cmp(reg_ih_shift, dilate_h);
4627                     jl(dilate_skip, T_NEAR);
4628                     xor_(reg_ih_shift, reg_ih_shift);
4629                 }
4630                 sub(reg_kh, stride_h);
4631                 L(dilate_skip);
4632                 sub(reg_oj_setup, 1);
4633                 jg(body_setup, T_NEAR);
4634                 L(body_setup_finish);
4635             }
4636         }
4637 
4638         if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4639         L(oh_bpad_label);
4640         {
4641             compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4642             add(reg_src, src_step_size * stride_h);
4643             add(reg_ddst, ddst_step_size);
4644 
4645             if (is_dilated) {
4646                 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4647                 inc(reg_ih_shift);
4648                 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4649                 cmp(reg_ih_shift, dilate_h);
4650                 jl(oh_dilate_label_end, T_NEAR);
4651                 xor_(reg_ih_shift, reg_ih_shift);
4652                 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4653             }
4654             sub(reg_kh, stride_h);
4655             L(oh_dilate_label_end);
4656             inc(reg_oj);
4657             if (is_partial) {
4658                 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4659                 jge(oh_bpad_label_end, T_NEAR);
4660             }
4661             cmp(reg_oj, oh_tail_end);
4662             jl(oh_bpad_label, T_NEAR);
4663         }
4664     }
4665     L(oh_bpad_label_end);
4666 }
4667 
compute_od_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4668 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_od_loop_common(
4669         int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4670     assert(jcp.harness == harness_3d_reduction);
4671 
4672     const int src_backpad_overlap
4673             = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
4674 
4675     const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
4676     const auto src_shift = get_src_offset(0, 0, jcp.ih);
4677     const auto ddst_shift = get_ddst_offset(0, jcp.oh);
4678 
4679     const int kd_front_pad = nstl::max(0, jcp.f_pad);
4680     const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
4681 
4682     Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
4683             backpad_end_label, backpad_label;
4684 
4685     /* initially offset 'kd' by f_pad */
4686     mov(reg_src_d, ptr[param + GET_OFF(src)]);
4687     mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
4688 
4689     if (is_partial) {
4690         add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
4691         mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
4692         mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
4693     } else {
4694         const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
4695         const int kd_offset = get_kernel_offset(
4696                 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
4697         add(reg_kernel, kd_offset);
4698         xor_(reg_d_index, reg_d_index);
4699         mov(reg_kd_count, kd_padding);
4700     }
4701 
4702     cmp(reg_kd_count, 0);
4703     jle(loop_end_label, T_NEAR); // no iterations along kd
4704     if (is_partial)
4705         cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4706     else
4707         cmp(reg_d_index, jcp.od);
4708     jge(loop_end_label, T_NEAR); // no iterations along depth dimension
4709 
4710     L(d_loop_label);
4711 
4712     mov(reg_src, reg_src_d);
4713     mov(reg_ddst, reg_ddst_d);
4714 
4715     mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
4716     mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
4717     mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
4718 
4719     compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4720 
4721     mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
4722     mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
4723     mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
4724 
4725     /* Compute 'front' edge */
4726     if (jcp.f_pad > 0) {
4727         /* Check if within fpad region */
4728         cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
4729         jge(fpad_end_label, T_NEAR);
4730 
4731         /* Fpad steps */
4732         sub(reg_kernel, filter_shift * jcp.stride_d);
4733         add(reg_kd_count, jcp.stride_d);
4734 
4735         /* Final number of kernel elements that overlap with src */
4736         const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
4737         cmp(reg_kd_count, src_ker_overlap);
4738         jle(common_block_label, T_NEAR);
4739 
4740         /* Correct any excess shifts to kernel and src */
4741         if (jcp.f_pad <= jcp.od * jcp.stride_d) {
4742             /* Filter has moved beyond padding (adjust for stride effects) */
4743             if (jcp.f_pad % jcp.stride_d != 0) {
4744                 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
4745                 add(reg_kernel, filter_shift * src_corr);
4746                 add(reg_src_d, src_shift * src_corr);
4747             }
4748         } else {
4749             /* Filter still overlaps padding (complete reset) */
4750             sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
4751         }
4752 
4753         /* Apply correction */
4754         mov(reg_kd_count, src_ker_overlap);
4755         jmp(common_block_label);
4756 
4757         L(fpad_end_label);
4758     }
4759 
4760     /* Compute bottom edge */
4761     if (jcp.back_pad > 0) {
4762 
4763         /* Check if within back_pad region */
4764         cmp(reg_d_index, src_backpad_overlap - 1);
4765         jl(backpad_end_label, T_NEAR);
4766         jg(backpad_label, T_NEAR);
4767 
4768         /* Execute overlap correction between the filter and the initial
4769          * back_pad region. */
4770         mov(reg_kd_count,
4771                 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
4772         jmp(backpad_end_label, T_NEAR);
4773 
4774         L(backpad_label);
4775         sub(reg_kd_count, jcp.stride_d);
4776         cmp(reg_kd_count, 0);
4777         jle(loop_end_label, T_NEAR);
4778 
4779         L(backpad_end_label);
4780     }
4781 
4782     /* Compute middle block */
4783     add(reg_src_d, src_shift * jcp.stride_d);
4784 
4785     /* Execute common block and loop */
4786     L(common_block_label);
4787     add(reg_ddst_d, ddst_shift);
4788     inc(reg_d_index);
4789     if (is_partial)
4790         cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4791     else
4792         cmp(reg_d_index, jcp.od);
4793     jl(d_loop_label, T_NEAR);
4794 
4795     L(loop_end_label);
4796 }
4797 
compute_loop(int nb_ic_blocking,int nb_oc_blocking)4798 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_loop(
4799         int nb_ic_blocking, int nb_oc_blocking) {
4800     mov(reg_src, ptr[param + GET_OFF(src)]);
4801     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4802     mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4803 
4804     maybe_zero_kernel(nb_ic_blocking, nb_oc_blocking);
4805     maybe_compute_diff_bias(nb_oc_blocking);
4806 
4807     switch (jcp.harness) {
4808         case harness_3d_reduction:
4809             compute_od_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4810             break;
4811         case harness_2d_reduction:
4812             compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4813             break;
4814         case harness_mb_reduction:
4815             compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4816             break;
4817         case harness_compute_full_spatial:
4818             compute_full_spat_loop(nb_ic_blocking, nb_oc_blocking);
4819             break;
4820         default: assert(!"Invalid harness type");
4821     }
4822 }
4823 
setup_stack_space()4824 void jit_avx512_core_amx_bwd_weights_kernel_t::setup_stack_space() {
4825     kd_count_offset = ic_block_step_stack_size;
4826     src_d_offset = ic_block_step_stack_size + 8;
4827     ddst_d_offset = ic_block_step_stack_size + 16;
4828     d_index_offset = ic_block_step_stack_size + 24;
4829     ih_dilate_offset = ic_block_step_stack_size + 32;
4830     src_save_offset = ic_block_step_stack_size + 40;
4831     ddst_save_offset = ic_block_step_stack_size + 48;
4832     stack_space_needed = ic_block_step_stack_size + 56;
4833 }
4834 
generate()4835 void jit_avx512_core_amx_bwd_weights_kernel_t::generate() {
4836     preamble();
4837 
4838     setup_stack_space();
4839 
4840     sub(rsp, stack_space_needed);
4841 
4842     Label last_ic_block_label, last_blocks_done_label;
4843 
4844     mov(reg_tmp, ptr[param + GET_OFF(last_ic_block)]);
4845     cmp(reg_tmp, 0);
4846     jne(last_ic_block_label, T_NEAR);
4847     { // full nb_ic_blocking
4848         Label last_oc_block_label;
4849         mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4850         cmp(reg_tmp, 0);
4851         jne(last_oc_block_label, T_NEAR);
4852         { // full nb_oc_blocking
4853             compute_loop(jcp.nb_ic_blocking, jcp.nb_oc_blocking);
4854             jmp(last_blocks_done_label, T_NEAR);
4855         }
4856         L(last_oc_block_label);
4857         { // tail of nb_oc_blocking
4858             compute_loop(jcp.nb_ic_blocking, 1);
4859             jmp(last_blocks_done_label, T_NEAR);
4860         }
4861     }
4862     L(last_ic_block_label);
4863     { // tail nb_ic_blocking
4864         Label last_oc_block_label;
4865         mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4866         cmp(reg_tmp, 0);
4867         jne(last_oc_block_label, T_NEAR);
4868         { // full nb_oc_blocking
4869             compute_loop(1, jcp.nb_oc_blocking);
4870             jmp(last_blocks_done_label, T_NEAR);
4871         }
4872         L(last_oc_block_label);
4873         { // tail of nb_oc_blocking
4874             compute_loop(1, 1);
4875             jmp(last_blocks_done_label, T_NEAR);
4876         }
4877     }
4878 
4879     L(last_blocks_done_label);
4880     add(rsp, stack_space_needed);
4881 
4882     postamble();
4883 }
4884 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_bias_md,memory_desc_t & diff_dst_md,int nthreads)4885 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(
4886         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4887         memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4888         memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4889     const memory_desc_wrapper src_d(&src_md);
4890     const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4891     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4892     const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4893 
4894     jcp = zero<decltype(jcp)>();
4895 
4896     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4897     int ndims = src_d.ndims();
4898 
4899     if (!mayiuse(avx512_core_bf16_amx_bf16)) return status::unimplemented;
4900     jcp.isa = avx512_core_bf16_amx_bf16;
4901 
4902     jcp.ver = ver_vnni; // Needed for transpose routines
4903     jcp.nthr = nthreads;
4904 
4905     jcp.ndims = ndims;
4906     jcp.prop_kind = cd.prop_kind;
4907 
4908     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4909     jcp.mb = src_d.dims()[0];
4910 
4911     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4912     jcp.oc_without_padding = jcp.oc;
4913     jcp.ic = src_d.dims()[1] / jcp.ngroups;
4914 
4915     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4916     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4917     jcp.iw = src_d.dims()[ndims - 1];
4918     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4919     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4920     jcp.ow = diff_dst_d.dims()[ndims - 1];
4921 
4922     jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4923     jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4924     jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4925 
4926     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4927     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4928     jcp.l_pad = cd.padding[0][ndims - 3];
4929 
4930     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4931     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4932     jcp.stride_w = cd.strides[ndims - 3];
4933 
4934     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4935     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4936     jcp.dilate_w = cd.dilates[ndims - 3];
4937 
4938     int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4939     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4940     int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4941 
4942     bool ok = true
4943             // general condition to simplify dilations
4944             && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4945             && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4946             && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4947             // special condition to simplify dilations in compute_oh_loop_common
4948             && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4949     if (!ok) return status::unimplemented;
4950 
4951     ok = true && one_of(ndims, 3, 4, 5)
4952             && everyone_is(
4953                     data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4954             && one_of(diff_weights_d.data_type(), data_type::f32,
4955                     data_type::bf16);
4956     if (!ok) return status::unimplemented;
4957 
4958     jcp.transform_to_vnni = diff_weights_d.data_type() == data_type::bf16;
4959 
4960     jcp.r_pad = nstl::max(0,
4961             calculate_end_padding(
4962                     jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4963     jcp.b_pad = nstl::max(0,
4964             calculate_end_padding(
4965                     jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4966     jcp.back_pad = nstl::max(0,
4967             calculate_end_padding(
4968                     jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4969 
4970     /* XXX: no support for padding when dilation_d > 0 */
4971     if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4972         return status::unimplemented;
4973 
4974     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4975     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4976     jcp.ohp = jcp.oh;
4977     jcp.owp = jcp.ow;
4978 
4979     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
4980     if (jcp.is_depthwise)
4981         return status::unimplemented; // TODO: add support of DW convolution
4982 
4983     const int dat_format_tag = ndims - 3;
4984     format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc,
4985             format_tag::nhwc, format_tag::ndhwc);
4986     format_tag_t dat_tag_opt = dat_tag_nspc;
4987 
4988     if (src_d.format_kind() == format_kind::any) {
4989         CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
4990         jcp.src_tag = dat_tag_opt;
4991     } else
4992         jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt);
4993     if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented;
4994     jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
4995 
4996     if (diff_dst_d.format_kind() == format_kind::any) {
4997         CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
4998         jcp.dst_tag = jcp.src_tag;
4999     } else
5000         jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
5001     if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
5002 
5003     if (!jcp.is_nspc) return status::unimplemented;
5004 
5005     const int wei_format_tag = 2 * ndims - 6 + with_groups;
5006     format_tag_t wei_tag;
5007     if (jcp.transform_to_vnni)
5008         wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i,
5009                 format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i,
5010                 format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i,
5011                 format_tag::gOIdhw16i16o2i);
5012     else
5013         wei_tag = pick(wei_format_tag, format_tag::OIw16i16o,
5014                 format_tag::gOIw16i16o, format_tag::OIhw16i16o,
5015                 format_tag::gOIhw16i16o, format_tag::OIdhw16i16o,
5016                 format_tag::gOIdhw16i16o);
5017     if (diff_weights_md.format_kind == format_kind::any) {
5018         CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
5019         jcp.wei_tag = wei_tag;
5020     } else {
5021         jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
5022         if (jcp.wei_tag != wei_tag) return status::unimplemented;
5023     }
5024     jcp.wei_dt = diff_weights_d.data_type();
5025 
5026     /* conditions on bias memory */
5027     jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
5028     if (jcp.with_bias) {
5029         if (diff_bias_d.format_kind() == format_kind::any)
5030             CHECK(memory_desc_init_by_tag(diff_bias_md, format_tag::x));
5031     }
5032     jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
5033     jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
5034 
5035     /* kernel applicability check wrt boundaries
5036      * the conditions are quite general across the kernels we have,
5037      * but ideally the check should belong to a specific kernel... */
5038     const int max_pad_h = ext_kh / 2;
5039     const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
5040             && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
5041             && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd;
5042     if (!boundaries_ok) return status::unimplemented;
5043 
5044     jcp.ic_block = 16;
5045     jcp.oc_block = 16;
5046 
5047     jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
5048     jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
5049 
5050     jcp.ic_tail = jcp.ic % jcp.ic_block;
5051     jcp.oc_tail = jcp.oc % jcp.oc_block;
5052 
5053     jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1;
5054     jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1;
5055 
5056     int max_palette = amx::get_max_palette();
5057     jcp.max_tiles = amx::get_max_tiles(max_palette);
5058     jcp.full_tile_width = amx::get_max_rows(max_palette);
5059 
5060     if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
5061         return status::unimplemented;
5062 
5063     const bool is_2d = (ndims == 4);
5064     const bool is_3d = (ndims == 5);
5065     jcp.typesize_in = sizeof(bfloat16_t);
5066     jcp.typesize_out = sizeof(float);
5067 
5068     // TODO: Find more shapes (especially 3D with large spatials) for which
5069     // local transposition will be beneficial. Furthermore, for TBB threads
5070     // more shapes can potentially benefit from spatial blocking
5071     int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
5072 
5073     jcp.global_transpose = dnnl_thr_syncable();
5074     jcp.spatial_blk_size = optimal_blk_size;
5075 
5076     const int tr_round = 32; // To load full tile register
5077     int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
5078     jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
5079             * jcp.stride_w;
5080 
5081     jcp.tr_src_num_guard_elems = tr_pad; // upper bound
5082     jcp.tr_ow = rnd_up(jcp.ow, 2);
5083 
5084     if (jcp.tr_ow <= max_ur_w) {
5085         jcp.ur_w = jcp.tr_ow;
5086         jcp.ur_w_blocks = 1;
5087     } else {
5088         jcp.ur_w = 1;
5089         for (int i = max_ur_w; i >= 1; i -= 2) {
5090             if (jcp.tr_ow % i == 0) {
5091                 jcp.ur_w = i;
5092                 break;
5093             }
5094         }
5095         jcp.ur_w_blocks = jcp.tr_ow / jcp.ur_w;
5096     }
5097 
5098     bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
5099             && jcp.oc <= diff_dst_d.padded_dims()[1]
5100             && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
5101             && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
5102     if (!args_ok) return status::unimplemented;
5103 
5104     bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh
5105             && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w)
5106             && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
5107             // TODO: Remove this constraint: only 3x3 kernel works now
5108             && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
5109             && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw
5110             && jcp.ih >= jcp.kh && jcp.iw >= jcp.kw;
5111 
5112     jcp.harness = ndims == 5
5113             ? harness_3d_reduction
5114             : (use_full_spat_loop ? harness_compute_full_spatial
5115                                   : (ndims == 4) ? harness_2d_reduction
5116                                                  : harness_mb_reduction);
5117     switch (jcp.harness) {
5118         case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
5119         case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
5120         case harness_compute_full_spatial:
5121         case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
5122         default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
5123     }
5124     { // balancing
5125         int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
5126         balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
5127         jcp.nthr = nthr;
5128         jcp.nthr_mb = nthr_mb;
5129         jcp.nthr_g = nthr_g;
5130         jcp.nthr_oc_b = nthr_oc_b;
5131         jcp.nthr_ic_b = nthr_ic_b;
5132 
5133         // TODO: Optimize memory allocation when threaded on height and depth
5134         jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
5135         jcp.tr_src_buf_count = jcp.global_transpose
5136                 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
5137                 : jcp.nthr;
5138 
5139         jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
5140         jcp.tr_diff_dst_buf_count = jcp.global_transpose
5141                 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
5142                 : jcp.nthr;
5143     }
5144 
5145     return status::success;
5146 }
5147 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_dst_md)5148 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad(
5149         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
5150         memory_desc_t &src_md, memory_desc_t &diff_weights_md,
5151         memory_desc_t &diff_dst_md) {
5152     const memory_desc_wrapper src_d(&src_md);
5153     const memory_desc_wrapper diff_weights_d(&diff_weights_md);
5154     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
5155 
5156     // XXX: See the comment about tr_iw and guarding elements in
5157     // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf()
5158     const size_t tr_src_size
5159             = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking)
5160             + jcp.tr_src_num_guard_elems;
5161     scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
5162 
5163     /* prepare synchronization contexts */
5164     if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
5165         const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
5166         scratchpad.book<simple_barrier::ctx_t>(
5167                 key_conv_tr_src_bctx, tr_src_bctx_size);
5168     }
5169 
5170     const size_t tr_diff_dst_size = jcp.tr_diff_dst_buf_count
5171             * jcp.tr_diff_dst_buf_size * jcp.nb_oc_blocking;
5172 
5173     const size_t min_align = 64;
5174     scratchpad.book(
5175             key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, min_align);
5176 
5177     /* prepare synchronization contexts */
5178     if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
5179         const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
5180         scratchpad.book<simple_barrier::ctx_t>(
5181                 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
5182     }
5183 
5184     if (IMPLICATION(jcp.nthr_mb == 1,
5185                 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
5186                         || jcp.wei_dt == data_type::bf16)) {
5187         const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
5188                 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
5189         const size_t bia_size
5190                 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
5191 
5192         const int num_wei_buffers
5193                 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
5194         const int num_bia_buffers = jcp.with_bias
5195                 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
5196                                                  : jcp.nthr_mb - 1)
5197                 : 0;
5198 
5199         const size_t wei_bia_reduction_size
5200                 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
5201 
5202         scratchpad.book<float>(
5203                 key_conv_wei_bia_reduction, wei_bia_reduction_size);
5204 
5205         scratchpad.book<simple_barrier::ctx_t>(
5206                 key_conv_wei_bia_reduction_bctx, 1);
5207     }
5208 
5209     if (jcp.with_bias
5210             && ((jcp.oc_without_padding % jcp.oc_block != 0)
5211                     && jcp.bia_dt == data_type::f32)) {
5212         scratchpad.book(key_conv_padded_bias,
5213                 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
5214     }
5215     scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
5216 
5217     constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
5218             << 30; // 32Gb - TODO: may it's too large?
5219     const size_t scratchpad_limit_by_tensor_sizes = (size_t)32 * jcp.nthr
5220             * (src_d.size() + diff_weights_d.size() + diff_dst_d.size());
5221     const size_t scratchpad_limit
5222             = nstl::min(scratchpad_limit_by_absolute_value,
5223                     scratchpad_limit_by_tensor_sizes);
5224     if (scratchpad.size() > scratchpad_limit)
5225         return status::unimplemented;
5226     else
5227         return status::success;
5228 }
5229 
balance(const jit_conv_conf_t & j,int & nthr_,int & nthr_mb_,int & nthr_g_,int & nthr_oc_b_,int & nthr_ic_b_)5230 void jit_avx512_core_amx_bwd_weights_kernel_t::balance(const jit_conv_conf_t &j,
5231         int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_,
5232         int &nthr_ic_b_) {
5233     nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
5234 
5235     const int max_threads = dnnl_get_max_threads();
5236 
5237     if (max_threads < j.ngroups) {
5238         /* simplification... fortunately it doesn't hurt much */
5239         nthr_ = nthr_g_ = max_threads;
5240         return;
5241     }
5242 
5243     nthr_g_ = j.ngroups;
5244     const int nthr = max_threads / nthr_g_;
5245 
5246     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
5247         /* calculate per thread memory cost (read/write). high level optimizer
5248          * tries to minimize memory consumption. few notes:
5249          *  (n1) if weights tensor size is less than source and destination
5250          *       tensors we apply the ratio of the source and destination
5251          *       tensor sizes to weights one as compensation coefficient to
5252          *       avoid parallelization across batch size only, othervise we
5253          *       apply additional coefficient to source component based on
5254          *       performance measurements
5255          *  (n2) use scales based on output vs input channels ratio for source
5256          *       and destination componets to imporve threading balance across
5257          *       input and output channels */
5258 
5259         const dim_t src_type_size = 2;
5260         const dim_t wei_type_size = 4;
5261 
5262         dim_t src_size
5263                 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
5264         dim_t dst_size
5265                 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
5266         dim_t wei_size
5267                 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
5268 
5269         float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
5270         float oi_channels_ratio = (float)(j.nb_oc / j.nb_oc_blocking)
5271                 / (j.nb_ic / j.nb_ic_blocking);
5272         auto get_src_coef = [=]() {
5273             float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
5274             if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
5275 
5276             return src_coef;
5277         };
5278 
5279         auto get_dst_coef
5280                 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
5281 
5282         auto get_wei_coef
5283                 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
5284 
5285         const float src_coef = get_src_coef();
5286         const float dst_coef = get_dst_coef();
5287         const float wei_coef = get_wei_coef();
5288 
5289         float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
5290                 * div_up(j.ngroups, nthr_g_)
5291                 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.mb
5292                 * (j.ic_block * j.nb_ic_blocking) * j.id * j.ih * j.tr_iw
5293                 / j.nthr_mb_work / j.stride_d / j.stride_h / j.stride_w;
5294         float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
5295                 * div_up((j.nb_oc / j.nb_oc_blocking),
5296                         (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5297                 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.kh * j.kw
5298                 * j.kd * (j.ic_block * j.nb_ic_blocking)
5299                 * (j.oc_block * j.nb_oc_blocking);
5300         float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
5301                 * div_up(j.ngroups, nthr_g_)
5302                 * div_up((j.nb_oc / j.nb_oc_blocking),
5303                         (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5304                 * j.mb * (j.oc_block * j.nb_oc_blocking) * j.od * j.oh * j.tr_ow
5305                 / j.nthr_mb_work;
5306 
5307         return src_v + dst_v + wei_v;
5308     };
5309 
5310     float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
5311 
5312     /* find the best thread distribution with lowest memory cost */
5313 
5314     const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
5315     for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
5316         const int nthr_par = nthr / nthr_mb;
5317         const int nthr_oc_b_max = nstl::min(nthr_par,
5318                 (j.nb_oc / j.nb_oc_blocking)); // Amount of nb_oc_blocks
5319         for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
5320             int nthr_ic_b = nstl::min(
5321                     nthr_par / nthr_oc_b, (j.nb_ic / j.nb_ic_blocking));
5322 
5323             float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
5324             if (mem_cost <= best_mem_cost) {
5325                 best_mem_cost = mem_cost;
5326                 nthr_mb_ = nthr_mb;
5327                 nthr_oc_b_ = nthr_oc_b;
5328                 nthr_ic_b_ = nthr_ic_b;
5329             }
5330         }
5331     }
5332 
5333     if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
5334         nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
5335     nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5336 
5337     assert(nthr_ <= max_threads);
5338 }
5339 
5340 } // namespace x64
5341 } // namespace cpu
5342 } // namespace impl
5343 } // namespace dnnl
5344 
5345 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
5346