1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/memory_tracking.hpp"
19 #include "common/nstl.hpp"
20 #include "common/type_helpers.hpp"
21 #include "common/utils.hpp"
22 
23 #include "cpu/platform.hpp"
24 #include "cpu/x64/cpu_barrier.hpp"
25 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
26 #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
27 #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp"
28 
29 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace x64 {
35 
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::data_type;
38 using namespace dnnl::impl::utils;
39 using namespace Xbyak;
40 
prepare_output(int ur_w)41 void jit_avx512_core_amx_compute_zp_pbuff_t::prepare_output(int ur_w) {
42     for (int oc = 0; oc < jcp.nb_oc_blocking; oc++)
43         for (int ur = 0; ur < ur_w; ur++) {
44             const Zmm zmm = zmm_out(ur, oc);
45             vpxord(zmm, zmm, zmm);
46         }
47 }
48 
store_output(int ur_w,bool last_oc_block_flag)49 void jit_avx512_core_amx_compute_zp_pbuff_t::store_output(
50         int ur_w, bool last_oc_block_flag) {
51     assert(jcp.is_nspc);
52 
53     const int nb_oc_block = jcp.nb_oc_blocking;
54     const int oc_block = jcp.oc_block;
55 
56     const auto src_zp_addr = EVEX_compress_addr(reg_src_zero_point, 0, true);
57 
58     /* write out register to output_addr */
59     for (int oc = 0; oc < nb_oc_block; oc++) {
60         const bool mask_flag = last_oc_block_flag && oc == nb_oc_block - 1;
61         for (int ur = 0; ur < ur_w; ur++) {
62             const int output_offset = sizeof(int32_t)
63                     * (oc * oc_block
64                             + ur * jcp.oc_without_padding * jcp.ngroups);
65             const Zmm zmm_dst = zmm_out(ur, oc);
66             const Zmm m_zmm_dst = mask_flag ? zmm_dst | ktail_mask : zmm_dst;
67             // multiply dst by src_zero_point
68             vpmulld(m_zmm_dst, zmm_dst, src_zp_addr);
69             vmovups(EVEX_compress_addr(reg_zp_pbuff, output_offset), m_zmm_dst);
70         }
71     }
72 }
73 
compute_ker(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool padded)74 void jit_avx512_core_amx_compute_zp_pbuff_t::compute_ker(int ur_w, int pad_l,
75         int pad_r, ic_block_t last_ic_block_flag, bool padded) {
76 
77     const int kw = jcp.kw;
78     const int ic_block = jcp.ic_block_int_np;
79     const int oc_block = jcp.oc_block;
80     const int nb_oc_block = jcp.nb_oc_blocking;
81 
82     const bool ic_tail
83             = (jcp.ic_without_padding % (jcp.ic_block / ic_inner_block)) > 0;
84     const bool masked_write = ic_tail && last_ic_block_flag == last_ic_block;
85 
86     /* Skip the last loads of input
87             if (ic%16)/ic_sub_step < ic_block/ic_sub_step */
88     const int icb = (last_ic_block_flag == last_ic_block)
89             ? div_up(
90                     (jcp.ic_without_padding % jcp.ic_block_int), ic_inner_block)
91             : ic_block / ic_inner_block;
92 
93     auto get_filter_offset = [=](int ocb, int ic, int ki) {
94         size_t w_step = jcp.is_relo ? jcp.kh : 1;
95         size_t kw_offset = static_cast<size_t>(ki) * w_step
96                 * jcp.ic_block_int_np * jcp.oc_block;
97         size_t oc_subblock_step = static_cast<size_t>(jcp.kd) * jcp.kh * jcp.kw
98                 * jcp.ic_block_int_np * jcp.oc_block;
99         size_t offset = kw_offset
100                 + static_cast<size_t>(ocb) * jcp.nb_ic_int * oc_subblock_step
101                 + static_cast<size_t>(ic) * oc_block * ic_inner_block;
102         return sizeof(char) * offset;
103     };
104     auto compute_fma = [=](const Zmm zmm_accum, const int ic,
105                                const Address addr) {
106         if (jcp.is_relo) {
107             vmovups(zmm_permb, ptr[reg_scratch]); // get permute index table
108             const Zmm r_zmm = masked_write && ic == icb - 1
109                     ? zmm_permb | kmask_ic_block | T_z
110                     : zmm_permb;
111             // only values from 'src2' are used to write dst
112             vpermi2b(r_zmm, zmm_permb, addr);
113             vpdpbusd(zmm_accum, zmm_one,
114                     zmm_permb); // XXX - using the same register for all ur_w
115         } else {
116             vpdpbusd(zmm_accum, zmm_one, addr);
117         }
118     };
119 
120     if (jcp.is_relo && last_ic_block_flag == last_ic_block && ic_tail) {
121         const Reg64 reg_tmp = reg_scratch;
122         mov(reg_tmp, ic_mask_label);
123         kmovq(kmask_ic_block, qword[reg_tmp]);
124     }
125     if (jcp.is_relo) mov(reg_scratch, permb_idx_label);
126 
127     for (int ki = 0; ki < kw; ki++) {
128         const int ur_start = get_ow_start(ki, pad_l);
129         const int ur_end = get_ow_end(ur_w, ki, pad_r);
130         for (int ur = 0; ur < ur_w; ur++) {
131             // Calculate zero_point padding as:
132             // accum = is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0)
133             if (ur < ur_start || ur >= ur_end || padded) {
134                 for (int oc = 0; oc < nb_oc_block; oc++) {
135                     const Zmm zmm_accum = zmm_out(ur, oc);
136                     for (int ic = 0; ic < icb; ic++) {
137                         const auto addr_filt = EVEX_compress_addr(
138                                 aux_reg_filt, get_filter_offset(oc, ic, ki));
139                         compute_fma(zmm_accum, ic, addr_filt);
140                     }
141                 }
142             }
143         }
144     }
145 }
146 
kh_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)147 void jit_avx512_core_amx_compute_zp_pbuff_t::kh_loop(int ur_w, int pad_l,
148         int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
149 
150     Label kh_label, skip_kh_loop;
151     const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
152     const size_t shift_wei_h_step = sizeof(char)
153             * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
154             * jcp.oc_block;
155 
156     // Compute zero_point compensation for the padded region. Total compute
157     // area is 'overflow * kw' where 'overflow' indicates the overlap
158     // between the filter and either top_pad or bottom_pad region.
159     auto compute_kh_loop = [=](size_t param_overflow) {
160         Label overflow_label, no_overflow_label;
161 
162         mov(reg_overflow, ptr[param1 + param_overflow]);
163         cmp(reg_overflow, 0);
164         je(no_overflow_label, T_NEAR);
165         L(overflow_label);
166         {
167             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
168             add(aux_reg_filt, shift_wei_h_step);
169             dec(reg_overflow);
170             jne(overflow_label, T_NEAR);
171         }
172         L(no_overflow_label);
173     };
174 
175     if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(t_overflow));
176 
177     // check for holes and skip computation due to dilation
178     mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
179     if ((jcp.dilate_h >= jcp.ih)) {
180         cmp(reg_kj, 0);
181         je(skip_kh_loop, T_NEAR);
182     }
183 
184     L(kh_label);
185     {
186         compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
187 
188         add(aux_reg_filt, shift_wei_h_step);
189         dec(reg_kj);
190         jne(kh_label, T_NEAR);
191     }
192 
193     L(skip_kh_loop);
194 
195     if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(b_overflow));
196 }
197 
kd_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)198 void jit_avx512_core_amx_compute_zp_pbuff_t::kd_loop(int ur_w, int pad_l,
199         int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
200 
201     Label kd_label, skip_kd_loop;
202     const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
203     const size_t shift_wei_h_step = sizeof(char)
204             * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
205             * jcp.oc_block;
206 
207     // Compute zero_point compensation for the padded region. Total compute
208     // area is 'overflow * kh * kw' where 'overflow' indicates the overlap
209     // between the filter and either front_pad or back_pad region.
210     auto compute_kd_loop = [=](size_t param_overflow) {
211         Label kh_loop_label;
212         Label no_overflow_label, overflow_label;
213 
214         mov(reg_ki, ptr[param1 + param_overflow]);
215         cmp(reg_ki, 0);
216         je(no_overflow_label, T_NEAR);
217         L(overflow_label);
218         {
219             mov(aux_reg_filt, aux_reg_filt_d);
220             mov(reg_kj, jcp.kh);
221             L(kh_loop_label);
222             {
223                 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
224                 add(aux_reg_filt, shift_wei_h_step);
225                 dec(reg_kj);
226                 jne(kh_loop_label, T_NEAR);
227             }
228             add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
229             dec(reg_ki);
230             jne(overflow_label, T_NEAR);
231         }
232         L(no_overflow_label);
233     };
234 
235     const bool zp_d_padding
236             = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
237     if (zp_d_padding) {
238         mov(aux_reg_filt_d, reg_filt);
239         compute_kd_loop(GET_OFF(f_overflow));
240 
241         // check for holes and skip computation due to dilation
242         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
243         if (jcp.dilate_d >= jcp.id) {
244             cmp(reg_ki, 0);
245             je(skip_kd_loop, T_NEAR);
246         }
247         L(kd_label);
248         mov(aux_reg_filt, aux_reg_filt_d);
249 
250     } else {
251         mov(aux_reg_filt, reg_filt);
252     }
253 
254     kh_loop(ur_w, pad_l, pad_r, last_ic_block_flag, handle_h_pad);
255 
256     if (zp_d_padding) {
257         add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
258         dec(reg_ki);
259         jne(kd_label, T_NEAR);
260 
261         L(skip_kd_loop);
262 
263         compute_kd_loop(GET_OFF(back_overflow));
264     }
265 }
266 
icb_loop(int ur_w,int pad_l,int pad_r,bool handle_h_pad)267 void jit_avx512_core_amx_compute_zp_pbuff_t::icb_loop(
268         int ur_w, int pad_l, int pad_r, bool handle_h_pad) {
269 
270     Label icb_label;
271     const size_t nb_ic = jcp.nb_ic_int;
272     const bool do_icb_loop = nb_ic > 1;
273 
274     /* Initialize zmm_one for weight accumulation */
275     xor_(reg_scratch, reg_scratch);
276     const Reg8 _t8 = reg_scratch.cvt8();
277     mov(_t8, 0x1);
278     vpbroadcastb(zmm_one, _t8);
279 
280     prepare_output(ur_w);
281 
282     mov(reg_icb, nb_ic);
283 
284     L(icb_label);
285     if (jcp.ic_without_padding != jcp.ic) {
286         Label common_ker, end_ker;
287         if (do_icb_loop) {
288             cmp(reg_icb, 1); // The last ic block
289             jne(common_ker, T_NEAR);
290         }
291         kd_loop(ur_w, pad_l, pad_r, last_ic_block, handle_h_pad);
292         if (do_icb_loop) {
293             jmp(end_ker, T_NEAR);
294 
295             L(common_ker);
296             kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
297 
298             L(end_ker);
299         }
300     } else {
301         kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
302     }
303     // End of IC Loop
304     if (do_icb_loop) {
305         const size_t shift_wei_icb_step = static_cast<size_t>(jcp.kd) * jcp.kh
306                 * jcp.kw * jcp.oc_block * jcp.ic_block_int_np;
307         add(reg_filt, sizeof(char) * shift_wei_icb_step);
308 
309         dec(reg_icb);
310         cmp(reg_icb, 0);
311         jg(icb_label, T_NEAR);
312 
313         sub(reg_filt, sizeof(char) * shift_wei_icb_step * nb_ic);
314     }
315 
316     if (jcp.oc_without_padding != jcp.oc) {
317         Label common_store, end_store;
318 
319         cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
320         jne(common_store, T_NEAR);
321 
322         store_output(ur_w, true); // last oc block
323         jmp(end_store, T_NEAR);
324 
325         L(common_store);
326         store_output(ur_w, false);
327 
328         L(end_store);
329     } else {
330         store_output(ur_w, false);
331     }
332 }
333 
unroll_width(const bool h_padding)334 void jit_avx512_core_amx_compute_zp_pbuff_t::unroll_width(
335         const bool h_padding) {
336 
337     auto ur_w_shift = [&](const int ur_w) {
338         return sizeof(int32_t) * (ur_w * jcp.oc_without_padding * jcp.ngroups);
339     };
340 
341     const int max_ur_w = jit_avx512_core_amx_compute_zp_pbuff_t::max_regs_ur
342             / (jcp.nb_oc_blocking);
343     const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
344     int l_pad = jcp.l_pad;
345 
346     const int l_pad_output = jcp.l_pad_output;
347     const int r_pad_output = jcp.r_pad_output;
348 
349     // a single middle element (if required) containing only height padding
350     const int no_pad = nstl::max(0, jcp.ow - l_pad_output - r_pad_output);
351 
352     const int ow_start = nstl::max(jcp.ow - r_pad_output, l_pad_output);
353     const int r_pad_start = nstl::min(jcp.ow_pad - l_pad_output, r_pad_output);
354 
355     int ow = 0;
356     int cur_l_pad_output = l_pad_output;
357     while (cur_l_pad_output > 0) {
358         const int ur_w = nstl::min(cur_l_pad_output, max_ur_w);
359         ow += ur_w;
360         const int cur_r_pad = calculate_end_padding(
361                 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
362         icb_loop(ur_w, l_pad, cur_r_pad, h_padding);
363         add(reg_zp_pbuff, ur_w_shift(ur_w));
364 
365         l_pad = nstl::max(l_pad - ur_w * jcp.stride_w, 0);
366         cur_l_pad_output = nstl::max(cur_l_pad_output - ur_w, 0);
367     }
368 
369     if (no_pad > 0) {
370         const int ur_w = 1;
371         if (h_padding) icb_loop(ur_w, 0, 0, true);
372         if (h_padding || jcp.ow_mid) add(reg_zp_pbuff, ur_w_shift(ur_w));
373     }
374     assert(ow + no_pad == ow_start);
375 
376     ow = ow_start;
377     int cur_r_pad_output = r_pad_start;
378     while (cur_r_pad_output > 0 && ow < jcp.ow) {
379         const int ur_w = nstl::min(cur_r_pad_output, max_ur_w);
380         ow += ur_w;
381         const int cur_r_pad = calculate_end_padding(
382                 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
383         icb_loop(ur_w, 0, cur_r_pad, h_padding);
384         add(reg_zp_pbuff, ur_w_shift(ur_w));
385 
386         cur_r_pad_output = nstl::max(cur_r_pad_output - ur_w, 0);
387     }
388 }
389 
generate()390 void jit_avx512_core_amx_compute_zp_pbuff_t::generate() {
391     Label h_pad_label, end_label;
392 
393     assert(jcp.req_zero_point_buffer);
394     assert(jcp.typesize_in == sizeof(char));
395 
396     preamble();
397 
398     mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
399     mov(reg_zp_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
400     mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
401 
402     if (jcp.oc_without_padding != jcp.oc) {
403         const Reg32 reg_tmp = reg_scratch.cvt32();
404         const int tail_size = jcp.oc_without_padding % jcp.oc_block;
405         const int mask = (1 << tail_size) - 1;
406         mov(reg_tmp, mask);
407         kmovw(ktail_mask, reg_tmp);
408         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
409     }
410 
411     mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
412     cmp(reg_overflow, 0);
413     jne(h_pad_label, T_NEAR);
414     mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
415     cmp(reg_overflow, 0);
416     jne(h_pad_label, T_NEAR);
417     if (jcp.ndims == 5 && (jcp.f_pad_output > 0 || jcp.back_pad_output > 0)) {
418         mov(reg_overflow, ptr[param1 + GET_OFF(kd_padding)]);
419         cmp(reg_overflow, jcp.kd);
420         jne(h_pad_label, T_NEAR);
421     }
422 
423     // Handle width padding region
424     unroll_width(false);
425     jmp(end_label, T_NEAR);
426 
427     // handle height padding region
428     L(h_pad_label);
429     unroll_width(true);
430 
431     L(end_label);
432 
433     postamble();
434 
435     // reduced-lowering ('is_relo' == true) weights format is '..i16o', so
436     // permute elements through permb into the VNNI layout '...16i4i'.
437     if (jcp.is_relo) {
438         align(64);
439         L(permb_idx_label);
440         // permb: id-bit for table selection is bit[6]
441         const uint8_t select_src2_bit = 0x40;
442         // permb: bits [5:0] select the element within each input table
443         const uint8_t permb_idx_table[64] = {0, 16, 32, 48, 1, 17, 33, 49, 2,
444                 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22,
445                 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42,
446                 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46,
447                 62, 15, 31, 47, 63};
448         for (size_t i = 0; i < 64; ++i)
449             db(select_src2_bit | permb_idx_table[i]);
450 
451         // write zero-mask (permb) for ic_tail in VNNI format '..16o4i'
452         const int ic_tail_size
453                 = jcp.ic_without_padding % (jcp.ic_block / ic_inner_block);
454         if (jcp.ic_without_padding != jcp.ic && ic_tail_size > 0) {
455             align(64);
456             L(ic_mask_label);
457 
458             assert(4 > ic_tail_size);
459             // mask is on a 4-bit basis from the 4 ic elements in a zmm
460             const int nibble = (1 << ic_tail_size) - 1;
461             for (int i = 0; i < 16; ++i) {
462                 db(nibble | (nibble << 4));
463             }
464         }
465     }
466 }
467 
generate()468 void jit_avx512_core_amx_copy_to_wbuffer_t::generate() {
469 
470     const bool is_bf16 = jcp.src_dt == data_type::bf16;
471 
472     // required for use of VPERMB instruction
473     assert(IMPLICATION(!is_bf16, cpu().has(Xbyak::util::Cpu::tAVX512_VBMI)));
474     assert(jcp.ic_block_int * jcp.typesize_in == 64);
475 
476     preamble();
477 
478     mov(reg_src, ptr[param1 + GET_OFF(src)]);
479     mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
480 
481     // load permute indices from data section
482     Label permute_index_table;
483     mov(reg_tmp, permute_index_table);
484     if (is_bf16)
485         vmovdqu16(zmm_idx, ptr[reg_tmp]);
486     else
487         vmovdqu8(zmm_idx, ptr[reg_tmp]);
488 
489     const int vnni_width = is_bf16 ? 2 : 4;
490     const int r = jcp.kh * jcp.kw * jcp.ic_without_padding;
491     const int nb_r = div_up(r, vnni_width);
492     const int rtail = (r % vnni_width) * jcp.oc_block;
493     if (rtail > 0) {
494         uint64_t mask = (UINT64_C(1) << rtail) - 1;
495         mov(reg_tmp, mask);
496         kmovq(kmask_load, reg_tmp);
497     }
498     const int nb_z = rnd_up(nb_r, jcp.ic_block);
499     if (nb_r < nb_z) vpxord(zmm_zero, zmm_zero, zmm_zero);
500 
501     const int tile_size = jcp.ic_block_int * jcp.oc_block * jcp.typesize_in;
502     const int ocb_src_step = r * jcp.oc_block * jcp.typesize_in;
503     const int ocb_dst_step = rnd_up(ocb_src_step, tile_size);
504 
505     // reorder from ~Owhi16o -> ~OR16oVr with r := whi and V := vnni_width
506     for (int g = 0; g < jcp.ngroups; g++) {
507         for (int ocb = 0; ocb < jcp.nb_oc; ocb++) {
508             int offset = 0;
509             int rb = 0;
510             for (; rb < nb_r; offset += 64, rb++) {
511                 auto zmm_src_tmp = (rtail > 0 && rb == nb_r - 1)
512                         ? zmm_src | kmask_load | T_z
513                         : zmm_src;
514                 if (is_bf16) {
515                     vmovdqu16(zmm_src_tmp, ptr[reg_src + offset]);
516                     vpermw(zmm_dst, zmm_idx, zmm_src);
517                     vmovdqu16(ptr[reg_dst + offset], zmm_dst);
518                 } else {
519                     vmovdqu8(zmm_src_tmp, ptr[reg_src + offset]);
520                     vpermb(zmm_dst, zmm_idx, zmm_src);
521                     vmovdqu8(ptr[reg_dst + offset], zmm_dst);
522                 }
523             }
524             for (; rb < nb_z; offset += 64, rb++) {
525                 if (is_bf16)
526                     vmovdqu16(ptr[reg_dst + offset], zmm_zero);
527                 else
528                     vmovdqu8(ptr[reg_dst + offset], zmm_zero);
529             }
530             add(reg_src, ocb_src_step);
531             add(reg_dst, ocb_dst_step);
532         }
533     }
534 
535     postamble();
536 
537     align(64);
538     L(permute_index_table);
539     const uint8_t no = 16; // 16o
540     const uint8_t nr = is_bf16 ? 2 : 4; // 2r or 4r
541     for (uint8_t o = 0; o < no; ++o) {
542         for (uint8_t r = 0; r < nr; r++) {
543             const uint8_t index = o + r * no;
544             if (is_bf16)
545                 dw(index);
546             else
547                 db(index);
548         }
549     }
550 }
551 
copy_row_body(int lpad,int iw_len,int icb)552 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_body(
553         int lpad, int iw_len, int icb) {
554 
555     const bool is_bf16 = jcp.src_dt == data_type::bf16;
556     int iwp_idx = 0;
557     // there are min(gen_kw, jcp.stride_w) continuous sets of input
558     // data (for each stride idx), they are placed one by one
559     // without additional padding
560     const bool are_sets_interleaved
561             = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
562     const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
563     const int num_sets = are_sets_interleaved ? jcp.n_stride_sets : jcp.kw;
564     for (int set_idx = 0; set_idx < num_sets; set_idx++) {
565         int set_width_padded = !jcp.is_pbuffer_strided
566                 ? (jcp.ow_block - 1) * jcp.stride_w + gen_kw
567                 : are_sets_interleaved ? jcp.ow_block - 1 + gen_kw / num_sets
568                                 + (set_idx < gen_kw % num_sets ? 1 : 0)
569                                        : jcp.ow_block;
570         for (int set_shift = 0; set_shift < set_width_padded;
571                 set_shift++, iwp_idx++) {
572             int iw_idx = set_idx * (jcp.dilate_w + 1)
573                     + set_shift * (jcp.is_pbuffer_strided ? jcp.stride_w : 1)
574                     - lpad;
575             size_t out_base_offset
576                     = (size_t)jcp.typesize_in * iwp_idx * jcp.ic_block_int_np;
577             if (iw_idx < 0 || iw_idx >= iw_len) {
578                 // left or right padding
579                 vmovups(ptr[reg_aux_out_ptr + out_base_offset], zmm_zero);
580             } else if (jcp.is_nspc) {
581                 size_t inp_w_offset = (size_t)jcp.typesize_in * iw_idx
582                         * jcp.ngroups * jcp.ic_without_padding;
583                 int ic = icb * jcp.ic_block_int_np;
584                 // TODO: use Xmm or Ymm moves for better small ic efficiency
585                 auto zmm_tmp_mask
586                         = ic + jcp.ic_block_int <= jcp.ic_without_padding
587                         ? zmm_tmp
588                         : zmm_tmp | ktail_mask | T_z;
589                 if (is_bf16) {
590                     vmovdqu16(
591                             zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
592                     vmovdqu16(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
593                 } else {
594                     vmovdqu8(zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
595                     vmovdqu8(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
596                 }
597             } else {
598                 assert(is_bf16);
599                 size_t inp_w_offset
600                         = (size_t)jcp.typesize_in * iw_idx * jcp.ic_block;
601                 for (int j = 0; j < jcp.ic_block_int_np / jcp.ic_block; j++) {
602                     int ic = icb * jcp.ic_block_int_np + j * jcp.ic_block;
603                     size_t inp_c_w_offset = (size_t)jcp.typesize_in * j * jcp.ih
604                                     * jcp.iw * jcp.ic_block
605                             + inp_w_offset;
606                     if (ic + jcp.ic_block <= jcp.ic) {
607                         vmovdqu16(
608                                 ymm_tmp, ptr[reg_aux_inp_ptr + inp_c_w_offset]);
609                     } else {
610                         vpxord(ymm_tmp, ymm_tmp, ymm_tmp);
611                     }
612                     size_t out_offset = out_base_offset
613                             + (size_t)jcp.typesize_in * j * jcp.ic_block;
614                     vmovdqu16(ptr[reg_aux_out_ptr + out_offset], ymm_tmp);
615                 }
616             }
617         }
618     }
619 }
620 
copy_row(int icb)621 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) {
622     if (jcp.nb_ow == 1) {
623         copy_row_body(jcp.l_pad, jcp.iw, icb);
624     } else {
625         auto get_iw_len_required = [&](int cur_ow_block, int cur_lpad) {
626             return (cur_ow_block - 1) * jcp.stride_w
627                     + (jcp.kw - 1) * (jcp.dilate_w + 1) + 1 - cur_lpad;
628         };
629 
630         auto get_iw_len_limited = [&](int owb, int cur_ow_block, int cur_lpad) {
631             auto len_req = get_iw_len_required(cur_ow_block, cur_lpad);
632             if (owb < 0) return len_req;
633             int ow_block_start = nstl::max(
634                     0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad);
635             return nstl::min(jcp.iw - ow_block_start, len_req);
636         };
637 
638         int general_owb_cases = jcp.nb_ow;
639         Xbyak::Label copy_row_done_label;
640         bool special_first_block_case = jcp.l_pad > 0;
641         if (special_first_block_case) {
642             general_owb_cases--;
643             Xbyak::Label skip_first_block_case_label;
644             cmp(reg_owb, 0);
645             jne(skip_first_block_case_label, T_NEAR);
646             copy_row_body(jcp.l_pad,
647                     get_iw_len_limited(0, jcp.ow_block, jcp.l_pad), icb);
648             jmp(copy_row_done_label, T_NEAR);
649             L(skip_first_block_case_label);
650         }
651         bool special_last_block_case = false
652                 // has ow_block_tail
653                 || jcp.ow % jcp.ow_block != 0
654                 // there is no ow_block_tail but right padding exists
655                 || get_iw_len_limited(jcp.nb_ow - 1, jcp.ow_block, 0)
656                         != get_iw_len_required(jcp.ow_block, 0);
657         if (special_last_block_case) {
658             general_owb_cases--;
659             Xbyak::Label skip_last_block_case_label;
660             cmp(reg_owb, jcp.nb_ow - 1);
661             jne(skip_last_block_case_label, T_NEAR);
662             int ow_block_tail = jcp.ow % jcp.ow_block;
663             int cur_ow_block = ow_block_tail > 0 ? ow_block_tail : jcp.ow_block;
664             copy_row_body(
665                     0, get_iw_len_limited(jcp.nb_ow - 1, cur_ow_block, 0), icb);
666             jmp(copy_row_done_label, T_NEAR);
667             L(skip_last_block_case_label);
668         }
669 
670         bool special_penult_block_case = true
671                 // if nb_ow = 2 and l_pad > 0 it's the same as
672                 // special_first_block_case
673                 && jcp.nb_ow >= (special_first_block_case ? 3 : 2)
674                 // right padding exists in penult block
675                 && get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0)
676                         != get_iw_len_required(jcp.ow_block, 0);
677         if (special_penult_block_case) {
678             general_owb_cases--;
679             Xbyak::Label skip_penult_block_case_label;
680             cmp(reg_owb, jcp.nb_ow - 2);
681             jne(skip_penult_block_case_label, T_NEAR);
682             copy_row_body(
683                     0, get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0), icb);
684             jmp(copy_row_done_label, T_NEAR);
685             L(skip_penult_block_case_label);
686         }
687 
688         if (general_owb_cases > 0) // general case
689             copy_row_body(0, get_iw_len_required(jcp.ow_block, 0), icb);
690 
691         L(copy_row_done_label);
692     }
693 }
694 
copy_row_reduced_lowering()695 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() {
696     assert(jcp.nb_ic_int == 1);
697     assert(jcp.ic_block_int * jcp.typesize_in == 64);
698     assert(jcp.is_nspc);
699 
700     auto load_mask = [=](int tail, Opmask kmask) {
701         uint64_t mask = (UINT64_C(1) << tail) - 1;
702         mov(reg_tmp, mask);
703         kmovq(kmask, reg_tmp);
704     };
705 
706     const bool is_bf16 = jcp.src_dt == data_type::bf16;
707     const int inp_w_step
708             = jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in;
709     const int inp_h_step = jcp.iw * inp_w_step;
710     const int out_h_step = jcp.ic_without_padding * jcp.typesize_in;
711     const int out_w_step = jcp.kh * out_h_step;
712     const int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
713     if (tail_size > 0) load_mask(tail_size, ktail_mask);
714 
715     auto zero_it = [=](reg64_t tmp_out_ptr) {
716         for (int ic = 0; ic < jcp.ic_without_padding; ic += jcp.ic_block_int) {
717             const int offset = ic * jcp.typesize_in;
718             const bool masked = ic + jcp.ic_block_int > jcp.ic_without_padding;
719             Zmm zmm = masked ? zmm_zero | ktail_mask : zmm_zero;
720             if (is_bf16)
721                 vmovdqu16(ptr[tmp_out_ptr + offset], zmm);
722             else
723                 vmovdqu8(ptr[tmp_out_ptr + offset], zmm);
724         }
725     };
726 
727     // pointer to 1st needed element in src buffer
728     mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
729     // pointer to 1st needed element in dst buffer
730     mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
731 
732     // total number of rows to copy
733     mov(reg_kht, ptr[param1 + GET_OFF(kh_offset)]);
734 
735     // number of rows of src buffer to copy
736     mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
737     // number of zero-padded rows above src buffer to copy
738     mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
739     // number of zero-padded rows below src buffer to copy
740     mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
741 
742     // number of columns of src buffer to copy
743     mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
744     // number of zero-padded columns before src buffer to copy
745     mov(reg_lov, ptr[param1 + GET_OFF(f_overflow)]);
746     // number of zero-padded columns before src buffer to copy
747     mov(reg_rov, ptr[param1 + GET_OFF(back_overflow)]);
748 
749     vpxord(zmm_zero, zmm_zero, zmm_zero);
750 
751     { // Handle Left Overflow
752         Label label_lov, label_lov_skip;
753         test(reg_lov, reg_lov);
754         jz(label_lov_skip, T_NEAR);
755         L(label_lov); // handle left or right overflow
756         {
757             Label label_lov_inner;
758             mov(reg_aux_out_ptr, reg_out_ptr);
759             mov(reg_cnt, reg_kht);
760             L(label_lov_inner);
761             {
762                 zero_it(reg_aux_out_ptr);
763                 add(reg_aux_out_ptr, out_h_step);
764                 dec(reg_cnt);
765                 jnz(label_lov_inner, T_NEAR);
766             }
767             add(reg_out_ptr, out_w_step);
768             dec(reg_lov);
769             jnz(label_lov, T_NEAR);
770         }
771         L(label_lov_skip);
772     }
773 
774     // save output pointer for later use
775     mov(reg_save_out_ptr, reg_out_ptr);
776 
777     // just in case there is no meat...
778     Label label_kwp_end;
779     test(reg_kwp, reg_kwp);
780     jz(label_kwp_end, T_NEAR);
781 
782     // Unroll over W-dimension in powers of 2
783     Label label_tov;
784     Label label_khp, label_no_khp;
785     Label label_bov;
786     test(reg_tov, reg_tov);
787     jnz(label_tov, T_NEAR);
788     test(reg_khp, reg_khp);
789     jnz(label_khp, T_NEAR);
790     test(reg_bov, reg_bov);
791     jnz(label_bov, T_NEAR);
792     jmp(label_kwp_end, T_NEAR); // safe exit in case of bad parameters
793 
794     L(label_tov); // handle top overflow
795     {
796         Label label_tov_inner;
797         mov(reg_aux_out_ptr, reg_out_ptr);
798         mov(reg_cnt, reg_kwp);
799         L(label_tov_inner);
800         {
801             zero_it(reg_aux_out_ptr);
802             add(reg_aux_out_ptr, out_w_step);
803             dec(reg_cnt);
804             jnz(label_tov_inner, T_NEAR);
805         }
806         add(reg_out_ptr, out_h_step);
807         dec(reg_tov);
808         jnz(label_tov, T_NEAR);
809     }
810     test(reg_khp, reg_khp);
811     jz(label_no_khp, T_NEAR);
812     L(label_khp); // handle kh padding (not fully unrolled)
813     {
814         Label label_khp_inner;
815         mov(reg_aux_inp_ptr, reg_inp_ptr);
816         mov(reg_aux_out_ptr, reg_out_ptr);
817         mov(reg_cnt, reg_kwp);
818         L(label_khp_inner);
819         {
820             for (int ic = 0; ic < jcp.ic_without_padding;
821                     ic += jcp.ic_block_int) {
822                 const int offset = ic * jcp.typesize_in;
823                 const bool masked
824                         = ic + jcp.ic_block_int > jcp.ic_without_padding;
825                 // zero masking is needed to avoid dependency on destination
826                 Zmm zmm_load = masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
827                 Zmm zmm_store = masked ? zmm_tmp | ktail_mask : zmm_tmp;
828                 if (is_bf16) {
829                     vmovdqu16(zmm_load, ptr[reg_aux_inp_ptr + offset]);
830                     vmovdqu16(ptr[reg_aux_out_ptr + offset], zmm_store);
831                 } else {
832                     vmovdqu8(zmm_load, ptr[reg_aux_inp_ptr + offset]);
833                     vmovdqu8(ptr[reg_aux_out_ptr + offset], zmm_store);
834                 }
835             }
836             add(reg_aux_inp_ptr, inp_w_step);
837             add(reg_aux_out_ptr, out_w_step);
838             dec(reg_cnt);
839             jnz(label_khp_inner, T_NEAR);
840         }
841         add(reg_inp_ptr, inp_h_step);
842         add(reg_out_ptr, out_h_step);
843         dec(reg_khp);
844         jnz(label_khp, T_NEAR);
845     }
846     L(label_no_khp);
847     test(reg_bov, reg_bov);
848     jz(label_kwp_end, T_NEAR);
849     L(label_bov); // handle bottom overflow
850     {
851         Label label_bov_inner;
852         mov(reg_aux_out_ptr, reg_out_ptr);
853         mov(reg_cnt, reg_kwp);
854         L(label_bov_inner);
855         {
856             zero_it(reg_aux_out_ptr);
857             add(reg_aux_out_ptr, out_w_step);
858             dec(reg_cnt);
859             jnz(label_bov_inner, T_NEAR);
860         }
861         add(reg_out_ptr, out_h_step);
862         dec(reg_bov);
863         jnz(label_bov, T_NEAR);
864     }
865     L(label_kwp_end);
866 
867     { // Handle Right Overflow
868         Label label_rov, label_rov_skip;
869         // retrieve output pointer
870         mov(reg_out_ptr, reg_save_out_ptr);
871         // calculate the shift
872         imul(reg_tmp, reg_kwp, out_w_step);
873         // shift past the body
874         add(reg_out_ptr, reg_tmp);
875         // skip if no right overflow
876         test(reg_rov, reg_rov);
877         jz(label_rov_skip, T_NEAR);
878 
879         L(label_rov); // handle left or right overflow
880         {
881             Label label_rov_inner;
882             mov(reg_aux_out_ptr, reg_out_ptr);
883             mov(reg_cnt, reg_kht);
884             L(label_rov_inner);
885             {
886                 zero_it(reg_aux_out_ptr);
887                 add(reg_aux_out_ptr, out_h_step);
888                 dec(reg_cnt);
889                 jnz(label_rov_inner, T_NEAR);
890             }
891             add(reg_out_ptr, out_w_step);
892             dec(reg_rov);
893             jnz(label_rov, T_NEAR);
894         }
895         L(label_rov_skip);
896     }
897 
898     // For bf16, zero-pad an extra cacheline to avoid NaNs
899     // For int8, it is sufficient to zero-pad the weights only
900     if (is_bf16) {
901         // shift forward to align h index to end of needed buffer
902         imul(reg_tmp, reg_kht, out_h_step);
903         add(reg_out_ptr, reg_tmp);
904         // shift backward to align w index to end of needed buffer
905         sub(reg_out_ptr, out_w_step);
906         vmovdqu16(ptr[reg_out_ptr], zmm_zero);
907     }
908 }
909 
generate()910 void jit_avx512_core_amx_copy_to_pbuffer_t::generate() {
911 
912     // Special copy kernel for reduced lowering
913     if (jcp.is_relo) {
914         assert(jcp.nb_ic_int == 1);
915         preamble();
916         copy_row_reduced_lowering();
917         postamble();
918         return;
919     }
920 
921     preamble();
922 
923     const bool is_3d = jcp.ndims == 5;
924     mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
925     mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
926     if (is_3d) mov(reg_kdp, ptr[param1 + GET_OFF(kd_padding)]);
927     mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
928     mov(reg_tover, ptr[param1 + GET_OFF(t_overflow)]);
929     mov(reg_bover, ptr[param1 + GET_OFF(b_overflow)]);
930     mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
931 
932     vpxord(zmm_zero, zmm_zero, zmm_zero);
933 
934     if (jcp.is_nspc && jcp.ic_without_padding % jcp.ic_block_int) {
935         int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
936         uint64_t mask = (UINT64_C(1) << tail_size) - 1;
937         mov(reg_tmp, mask);
938         kmovq(ktail_mask, reg_tmp);
939     }
940 
941     for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
942         Xbyak::Label kd_label, no_kd_label;
943         Xbyak::Label kh_label, no_kh_label, icb_label;
944         Xbyak::Label kh_tover_label, kh_bover_label;
945         Xbyak::Label no_kh_tover_label, no_kh_bover_label;
946 
947         mov(reg_aux_inp_ptr, reg_inp_ptr);
948         mov(reg_aux_out_ptr, reg_out_ptr);
949         if (is_3d) {
950             cmp(reg_kdp, 0);
951             jle(no_kd_label, T_NEAR);
952             mov(reg_kdc, reg_kdp);
953             L(kd_label);
954             push(reg_aux_inp_ptr);
955             push(reg_aux_out_ptr);
956         }
957         cmp(reg_khp, 0);
958         jle(no_kh_bover_label, T_NEAR); // nothing to do
959         mov(reg_khc, reg_khp);
960 
961         cmp(reg_tover, 0);
962         jle(no_kh_tover_label, T_NEAR);
963 
964         mov(reg_kh_over, reg_tover);
965         L(kh_tover_label);
966         {
967             // TODO: adjust step to improve zeroing efficiency for small ic
968             for (int iw = 0; iw < jcp.iwp; iw++)
969                 vmovups(ptr[reg_aux_out_ptr
970                                 + jcp.typesize_in * iw * jcp.ic_block_int_np],
971                         zmm_zero);
972             int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
973             add(reg_aux_out_ptr, out_h_offset);
974 
975             dec(reg_kh_over);
976             jnz(kh_tover_label, T_NEAR);
977         }
978         sub(reg_khc, reg_tover);
979         L(no_kh_tover_label);
980 
981         cmp(reg_khc, reg_bover);
982         jle(no_kh_label, T_NEAR);
983 
984         L(kh_label);
985         {
986             copy_row(icb);
987             size_t inp_h_offset = !jcp.is_nspc
988                     ? (size_t)jcp.typesize_in * jcp.iw * jcp.ic_block
989                     : (size_t)jcp.typesize_in * jcp.iw * jcp.ngroups
990                             * jcp.ic_without_padding;
991             size_t out_h_offset
992                     = (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
993 
994             add(reg_aux_inp_ptr, inp_h_offset);
995             add(reg_aux_out_ptr, out_h_offset);
996 
997             dec(reg_khc);
998             cmp(reg_khc, reg_bover);
999             jg(kh_label, T_NEAR);
1000         }
1001         L(no_kh_label);
1002 
1003         cmp(reg_khc, 0);
1004         jle(no_kh_bover_label, T_NEAR);
1005 
1006         L(kh_bover_label);
1007         {
1008             // TODO: adjust step to improve zeroing efficiency for small ic
1009             for (int iw = 0; iw < jcp.iwp; iw++)
1010                 vmovups(ptr[reg_aux_out_ptr
1011                                 + jcp.typesize_in * iw * jcp.ic_block_int_np],
1012                         zmm_zero);
1013             int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
1014             add(reg_aux_out_ptr, out_h_offset);
1015 
1016             dec(reg_khc);
1017             jnz(kh_bover_label, T_NEAR);
1018         }
1019         size_t out_d_offset = (size_t)jcp.typesize_in
1020                 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1021         L(no_kh_bover_label);
1022         if (is_3d) {
1023             size_t inp_d_offset = !jcp.is_nspc
1024                     ? (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block
1025                             * (jcp.dilate_d + 1)
1026                     : (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ngroups
1027                             * jcp.ic_without_padding * (jcp.dilate_d + 1);
1028             pop(reg_aux_out_ptr);
1029             pop(reg_aux_inp_ptr);
1030             add(reg_aux_inp_ptr, inp_d_offset);
1031             add(reg_aux_out_ptr, out_d_offset);
1032             dec(reg_kdc);
1033             jnz(kd_label, T_NEAR);
1034             L(no_kd_label);
1035         }
1036         // End IC Loop
1037         size_t inp_cb_offset = !jcp.is_nspc
1038                 ? (size_t)jcp.typesize_in * (jcp.ic_block_int_np / jcp.ic_block)
1039                         * jcp.id * jcp.ih * jcp.iw * jcp.ic_block
1040                 : (size_t)jcp.typesize_in * jcp.ic_block_int_np;
1041         size_t out_cb_offset = (size_t)jcp.kd * out_d_offset;
1042 
1043         add(reg_inp_ptr, inp_cb_offset);
1044         add(reg_out_ptr, out_cb_offset);
1045     }
1046 
1047     postamble();
1048 }
1049 
jit_avx512_core_amx_fwd_kernel_t(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_t & dst_md)1050 jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t(
1051         const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
1052         const memory_desc_t &dst_md)
1053     : jit_generator(nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
1054     , jcp(ajcp)
1055     , attr_(attr) {
1056     if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
1057         using namespace binary_injector;
1058         const auto &rhs_addr_reg = bin_injector_helper_reg_1;
1059         const auto &rhs_helper_reg = bin_injector_helper_reg_2;
1060         static constexpr bool preserve_gpr = false;
1061         static constexpr bool preserve_vmm = false;
1062         const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
1063         static constexpr bool use_exact_tail_scalar_bcast = true;
1064 
1065         const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
1066                 31, rhs_addr_reg, rhs_helper_reg, preserve_gpr, preserve_vmm,
1067                 GET_OFF(post_ops_binary_rhs_arg_vec),
1068                 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
1069                 use_exact_tail_scalar_bcast};
1070         const binary_injector::static_params_t static_params {
1071                 this->param1, rhs_arg_static_params};
1072 
1073         postops_injector_ = utils::make_unique<
1074                 injector::jit_uni_postops_injector_t<avx512_core>>(
1075                 this, jcp.post_ops, static_params);
1076     }
1077     copy_to_pbuffer_
1078             = utils::make_unique<jit_avx512_core_amx_copy_to_pbuffer_t>(jcp);
1079     if (jcp.is_relo)
1080         copy_to_wbuffer_
1081                 = utils::make_unique<jit_avx512_core_amx_copy_to_wbuffer_t>(
1082                         jcp);
1083 }
1084 
create_kernel()1085 status_t jit_avx512_core_amx_fwd_kernel_t::create_kernel() {
1086     CHECK(jit_generator::create_kernel());
1087     CHECK(copy_to_pbuffer_->create_kernel());
1088     if (jcp.is_relo) CHECK(copy_to_wbuffer_->create_kernel());
1089     if (jcp.req_zero_point_buffer) {
1090         zp_pbuff_kernel_
1091                 = utils::make_unique<jit_avx512_core_amx_compute_zp_pbuff_t>(
1092                         jcp);
1093         if (zp_pbuff_kernel_ == nullptr) return status::out_of_memory;
1094         CHECK(zp_pbuff_kernel_->create_kernel());
1095     }
1096     return status::success;
1097 }
1098 
1099 // Tile register decomposition
1100 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i,bool is_h_tail) const1101 int jit_avx512_core_amx_fwd_kernel_t::get_out_tensor(
1102         int h, int i, bool is_h_tail) const {
1103     const int C_BASE = 0;
1104     const int C_LAST = 4;
1105     assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
1106     MAYBE_UNUSED(C_LAST);
1107     const int tile = C_BASE
1108             + (jcp.nb_oh_blocking > 1
1109                             ? h * jcp.nb_oh_blocking + i
1110                             : (int)is_h_tail * jcp.nb_oc_blocking + i);
1111     assert(C_BASE <= tile && tile < C_LAST);
1112     return tile;
1113 }
get_inp_tensor(int h,bool is_h_tail) const1114 int jit_avx512_core_amx_fwd_kernel_t::get_inp_tensor(
1115         int h, bool is_h_tail) const {
1116     const int I_BASE = 4;
1117     const int I_LAST = 6;
1118     assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
1119     MAYBE_UNUSED(I_LAST);
1120     const int tile = I_BASE + (jcp.nb_oh_blocking > 1 ? h : (int)is_h_tail);
1121     assert(I_BASE <= tile && tile < I_LAST);
1122     return tile;
1123 }
get_wei_tensor(int i) const1124 int jit_avx512_core_amx_fwd_kernel_t::get_wei_tensor(int i) const {
1125     const int W_BASE = 6;
1126     const int W_LAST = 8;
1127     assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
1128     MAYBE_UNUSED(W_LAST);
1129     const int tile = W_BASE + i;
1130     assert(W_BASE <= tile && tile < W_LAST);
1131     return tile;
1132 }
1133 
1134 // Shifts and offsets
get_inp_icb_step() const1135 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_icb_step() const {
1136     return (size_t)jcp.kd * get_inp_d_step();
1137 }
get_wei_icb_step() const1138 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_icb_step() const {
1139     return (size_t)jcp.typesize_in * jcp.kd * jcp.kh * jcp.kw
1140             * jcp.ic_block_int_np * jcp.oc_block;
1141 }
get_inp_d_step() const1142 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_d_step() const {
1143     return (size_t)jcp.typesize_in
1144             * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1145 }
get_inp_h_step() const1146 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_h_step() const {
1147     return (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np
1148             * (jcp.dilate_h + 1);
1149 }
get_wei_d_step() const1150 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_d_step() const {
1151     return (size_t)jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block_int_np
1152             * jcp.oc_block;
1153 }
get_wei_h_step() const1154 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_h_step() const {
1155     return (size_t)jcp.typesize_in * jcp.kw * jcp.ic_block_int_np
1156             * jcp.oc_block;
1157 }
get_out_ocb_offset(int ohb,int ocb,size_t typesize) const1158 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_ocb_offset(
1159         int ohb, int ocb, size_t typesize) const {
1160     size_t el_offset = jcp.is_nspc
1161             ? (size_t)ocb * jcp.oc_block
1162                     + (size_t)ohb * jcp.ow * jcp.ngroups
1163                             * jcp.oc_without_padding
1164             : (size_t)ocb * jcp.oh * jcp.ow * jcp.oc_block
1165                     + (size_t)ohb * jcp.ow * jcp.oc_block;
1166     return (size_t)typesize * el_offset;
1167 }
get_out_row_offset(int ohb,int ocb,int j,size_t typesize) const1168 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_row_offset(
1169         int ohb, int ocb, int j, size_t typesize) const {
1170     size_t offset_w = jcp.is_nspc
1171             ? (size_t)typesize * j * jcp.ngroups * jcp.oc_without_padding
1172             : (size_t)typesize * j * jcp.oc_block;
1173     return get_out_ocb_offset(ohb, ocb, typesize) + offset_w;
1174 }
get_out_shift(int width,size_t typesize) const1175 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_shift(
1176         int width, size_t typesize) const {
1177     return jcp.is_nspc
1178             ? (size_t)typesize * width * jcp.ngroups * jcp.oc_without_padding
1179             : (size_t)typesize * width * jcp.oc_block;
1180 }
get_wsp_ocb_offset(int ohb,int ocb) const1181 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_ocb_offset(
1182         int ohb, int ocb) const {
1183     size_t el_offset = (size_t)ocb * prv_width_ * jcp.oc_block
1184             + (size_t)ohb * jcp.nb_oc_blocking * jcp.full_tile_width
1185                     * jcp.oc_block;
1186     return jcp.typesize_acc * el_offset;
1187 }
get_wsp_row_offset(int ohb,int ocb,int j) const1188 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_row_offset(
1189         int ohb, int ocb, int j) const {
1190     return get_wsp_ocb_offset(ohb, ocb)
1191             + (size_t)jcp.typesize_acc * j * jcp.oc_block;
1192 }
get_wsp_shift() const1193 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_shift() const {
1194     return (size_t)jcp.typesize_acc * jcp.nb_oh_blocking * jcp.full_tile_width
1195             * jcp.oc_block * jcp.nb_oc_blocking;
1196 }
get_wei_offset(int ocb,int kw) const1197 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_offset(int ocb, int kw) const {
1198     size_t el_offset = (size_t)kw * jcp.ic_block_int_np * jcp.oc_block;
1199     size_t raw_oc_subblock_step
1200             = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
1201     size_t oc_subblock_step = jcp.is_relo
1202             ? rnd_up(raw_oc_subblock_step, jcp.ic_block_int * jcp.oc_block)
1203             : raw_oc_subblock_step;
1204     el_offset += (size_t)ocb * jcp.nb_ic_int * oc_subblock_step;
1205     return jcp.typesize_in * el_offset;
1206 }
get_inp_shift() const1207 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_shift() const {
1208     size_t w_step = (jcp.is_relo ? jcp.stride_w * jcp.kh
1209                                  : jcp.is_pbuffer_strided ? 1 : jcp.stride_w)
1210             * jcp.ic_block_int_np;
1211     return (size_t)jcp.typesize_in * jcp.tile_width * w_step;
1212 }
get_inp_offset(int ohb,int kw) const1213 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_offset(int ohb, int kw) const {
1214     if (jcp.is_relo)
1215         return ohb * jcp.iwp * jcp.kh * jcp.ic_block_int_np * jcp.typesize_in;
1216     // calculate offset by height dimension
1217     const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
1218     const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh);
1219     size_t el_offset = (size_t)ohb * jcp.oh_per_tile * gen_stride_h * jcp.iwp
1220             * jcp.ic_block_int_np;
1221 
1222     // add offset by width dimension
1223     if (IMPLICATION(jcp.is_pbuffer_strided, jcp.stride_w == 1)) {
1224         el_offset += (size_t)kw * (jcp.dilate_w + 1) * jcp.ic_block_int_np;
1225     } else if (jcp.dilate_w > 0) {
1226         el_offset += (size_t)kw * jcp.ow_block * jcp.ic_block_int_np;
1227     } else {
1228         // dilate_w == 0 && stride_w > 1
1229         // there are min(jcp.kw, jcp.stride_w) continuous sets of input data
1230         // (foreach stride idx), they are placed one by one without additional
1231         // padding
1232 
1233         // calculate set idx for current kw value
1234         int set_idx = kw % jcp.stride_w;
1235         // calculate shift within set for current kw value
1236         int set_shift = kw / jcp.stride_w;
1237 
1238         // calculate the beginning of the current set along width, each set
1239         // with index set_i contains number of elements along width equal to
1240         // jcp.ow - 1 + jcp.kw / jcp.stride_w
1241         //     + (set_i < jcp.kw % jcp.stride_w)
1242         size_t set_start = (jcp.ow_block - 1 + jcp.kw / jcp.stride_w) * set_idx
1243                 + nstl::min(set_idx, jcp.kw % jcp.stride_w);
1244         el_offset += (set_start + set_shift) * jcp.ic_block_int_np;
1245     }
1246     return jcp.typesize_in * el_offset;
1247 }
1248 
get_zp_comp_offset(int ocb,int zp_h,int zp_w) const1249 size_t jit_avx512_core_amx_fwd_kernel_t::get_zp_comp_offset(
1250         int ocb, int zp_h, int zp_w) const {
1251     const size_t ocb_offset = (size_t)ocb * jcp.oc_block;
1252     const size_t sp_offset = (size_t)(zp_h * jcp.ow_pad + zp_w) * jcp.ngroups
1253             * jcp.oc_without_padding;
1254     return (ocb_offset + sp_offset) * sizeof(int32_t);
1255 }
1256 
get_zp_index_offset(int index,int mid,int s_pad_output,int e_pad_output)1257 int jit_avx512_core_amx_fwd_kernel_t::get_zp_index_offset(
1258         int index, int mid, int s_pad_output, int e_pad_output) {
1259     using namespace nstl;
1260     const int mid_end = e_pad_output - 1;
1261     int zp_mid = min(mid, max(0, index - mid_end));
1262     int zp_pad_offset
1263             = accum_with_upper_bound(index, s_pad_output, e_pad_output);
1264     return zp_pad_offset + zp_mid;
1265 }
1266 
1267 // Code generation
prepare_output(int tail)1268 void jit_avx512_core_amx_fwd_kernel_t::prepare_output(int tail) {
1269     for (int h = 0; h < jcp.nb_oh_blocking; h++)
1270         for (int i = 0; i < jcp.nb_oc_blocking; i++)
1271             tilezero(Tmm(get_out_tensor(h, i, tail)));
1272 }
1273 
init_runtime_counters(bool start_with_last_tile_block)1274 void jit_avx512_core_amx_fwd_kernel_t::init_runtime_counters(
1275         bool start_with_last_tile_block) {
1276     prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
1277             ? jcp.tile_tail
1278             : jcp.tile_width;
1279 
1280     row_count_ = 0;
1281     is_store_done_ = false;
1282     is_buffer_empty_ = true;
1283 }
1284 
reduce_to_block(const int block_size,const int pad_output)1285 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_block(
1286         const int block_size, const int pad_output) {
1287     return (size_t)(pad_output >= block_size ? block_size : 0)
1288             + (pad_output % block_size);
1289 }
1290 
reduce_to_blocked_dims(const int dim_size,const int block_size,const int s_pad_output,const int e_pad_output)1291 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_blocked_dims(
1292         const int dim_size, const int block_size, const int s_pad_output,
1293         const int e_pad_output) {
1294     using namespace nstl;
1295 
1296     // start padding (s_pad)
1297     int s_pad_limit = reduce_to_block(block_size, s_pad_output);
1298     int s_pad_area_blk = rnd_up(s_pad_limit, block_size);
1299 
1300     // middle (no padding)
1301     int no_pad_area = max(
1302             0, dim_size - rnd_up(s_pad_output, block_size) - e_pad_output);
1303     int no_pad_limit = (no_pad_area >= block_size ? block_size : 0);
1304 
1305     // end padding (e_pad)
1306     int no_pad_area_shift = no_pad_area % block_size;
1307     int e_pad_area_overlap
1308             = no_pad_area_shift == 0 ? 0 : block_size - no_pad_area_shift;
1309     // middle and end padding shift
1310     int e_pad_shift_limit
1311             = no_pad_area_shift + min(e_pad_output, e_pad_area_overlap);
1312     int e_pad_area_blk = max(0, e_pad_output - e_pad_area_overlap);
1313     // full end padding block
1314     int e_pad_limit = reduce_to_block(block_size, e_pad_area_blk);
1315 
1316     // calculate reduced size of s_pad, middle and e_pad blocks.
1317     return min((size_t)dim_size,
1318             (size_t)s_pad_area_blk + no_pad_limit + e_pad_shift_limit
1319                     + e_pad_limit);
1320 }
1321 
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)1322 Ymm jit_avx512_core_amx_fwd_kernel_t::ymm_mask(
1323         const Ymm &ymm_in, bool mask_flag, bool store) {
1324     return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
1325                      : ymm_in;
1326 }
1327 
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)1328 Zmm jit_avx512_core_amx_fwd_kernel_t::zmm_mask(
1329         const Zmm &zmm_in, bool mask_flag, bool store) {
1330     return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
1331                      : zmm_in;
1332 }
1333 
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)1334 void jit_avx512_core_amx_fwd_kernel_t::cvt2ps(data_type_t type_in,
1335         const Zmm &zmm_in, const Operand &op, bool mask_flag) {
1336     const Zmm zmm = zmm_mask(zmm_in, mask_flag);
1337     switch (type_in) {
1338         case data_type::f32:
1339         case data_type::s32: vmovups(zmm, op); break;
1340         case data_type::s8: vpmovsxbd(zmm, op); break;
1341         case data_type::u8: vpmovzxbd(zmm, op); break;
1342         default: assert(!"unsupported data type");
1343     }
1344     if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
1345 }
1346 
apply_sum(const Zmm & zmm_out,const float * p_sum_scale,const Xbyak::Address & addr,const bool mask_flag)1347 void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out,
1348         const float *p_sum_scale, const Xbyak::Address &addr,
1349         const bool mask_flag) {
1350     if (p_sum_scale) {
1351         const float p_sum_scale_val = *p_sum_scale;
1352         const auto sum_injector = [&, p_sum_scale_val, mask_flag]() {
1353             cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
1354             if (p_sum_scale_val == 1.f)
1355                 vaddps(zmm_out, zmm_prev_dst);
1356             else
1357                 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
1358         };
1359         postops_injector_->set_lambda_injector(
1360                 primitive_kind::sum, sum_injector);
1361     }
1362 }
1363 
apply_postops(const Zmm & zmm_out,const float * p_sum_scale,const Xbyak::Address & addr,const bool mask_flag,const size_t off,const int ocb)1364 void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out,
1365         const float *p_sum_scale, const Xbyak::Address &addr,
1366         const bool mask_flag, const size_t off, const int ocb) {
1367     if (jcp.with_eltwise || jcp.with_binary
1368             || (jcp.with_sum && p_sum_scale != nullptr)) {
1369         binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1370 
1371         apply_sum(zmm_out, p_sum_scale, addr, mask_flag);
1372 
1373         const auto vmm_idx = zmm_out.getIdx();
1374         if (jcp.with_binary) {
1375             const int oc_l_offset = ocb * jcp.oc_block;
1376             rhs_arg_params.vmm_idx_to_oc_elem_off_addr.emplace(
1377                     vmm_idx, ptr[param1 + GET_OFF(oc_l_off)]);
1378             rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(
1379                     vmm_idx, oc_l_offset);
1380             rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1381                     vmm_idx, static_cast<int>(off));
1382             if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1383         }
1384 
1385         postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
1386     }
1387 }
1388 
store_output_vector_bf16(const Zmm & zmm_out,int ocb,int h,int w)1389 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16(
1390         const Zmm &zmm_out, int ocb, int h, int w) {
1391     const bool mask_flag = jcp.is_nspc && jcp.oc_without_padding != jcp.oc
1392             && ocb == (jcp.nb_oc_blocking - 1);
1393 
1394     const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1395     auto addr = EVEX_compress_addr(reg_out_ptr, off);
1396 
1397     const auto &p = attr_.post_ops_;
1398 
1399     const int sum_idx = p.find(primitive_kind::sum);
1400     if (sum_idx != -1) {
1401         if (jcp.dst_dt == data_type::bf16) {
1402             vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
1403             vpslld(zmm_prev_dst, zmm_prev_dst, 16);
1404             vaddps(zmm_out, zmm_prev_dst);
1405         } else {
1406             vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
1407             vaddps(zmm_out, zmm_prev_dst);
1408         }
1409     }
1410     if (jcp.with_bias) {
1411         int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
1412         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1413         if (jcp.bia_dt == data_type::bf16) {
1414             vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
1415             vpslld(zmm_bias, zmm_bias, 16);
1416             vaddps(zmm_out, zmm_bias);
1417         } else
1418             vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
1419     }
1420 
1421     static constexpr auto skip_sum_injection = nullptr;
1422     apply_postops(zmm_out, skip_sum_injection, addr, mask_flag, off, ocb);
1423 
1424     if (jcp.dst_dt == data_type::bf16) {
1425         Ymm ymm_out = Ymm(zmm_out.getIdx());
1426         vcvtneps2bf16(ymm_out, zmm_out);
1427         vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
1428     } else {
1429         vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
1430     }
1431 }
1432 
store_output_vector_int8(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1433 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8(
1434         const Zmm &zmm_out, int ocb, int h, int w, const bool compute_zp,
1435         const int zp_h, const int zp_w) {
1436     const int nb_oc_block = jcp.nb_oc_blocking;
1437     const int oc_block = jcp.oc_block;
1438     const bool mask_flag = true && jcp.oc_without_padding != jcp.oc
1439             && ocb == (nb_oc_block - 1);
1440 
1441     const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1442     auto addr = EVEX_compress_addr(reg_out_ptr, off);
1443 
1444     const auto &p = attr_.post_ops_;
1445     const int sum_idx = p.find(primitive_kind::sum);
1446     const float *p_sum_scale = nullptr;
1447     if (sum_idx != -1) {
1448         const auto &p_entry = p.entry_[sum_idx];
1449         p_sum_scale = &p_entry.sum.scale;
1450     }
1451 
1452     if (p_sum_scale && *p_sum_scale != 1.f)
1453         mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
1454 
1455     int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * oc_block);
1456     if (jcp.with_bias) {
1457         int bias_offset = jcp.typesize_bia * ocb * oc_block;
1458         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1459         cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
1460     }
1461     if (compute_zp) {
1462         assert(jcp.req_zero_point_buffer);
1463         // add zero-point padding compensation when accum data is S32
1464         const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1465         vmovups(m_zmm_zp,
1466                 EVEX_compress_addr(reg_zero_point_pbuff,
1467                         get_zp_comp_offset(ocb, zp_h, zp_w)));
1468         const Zmm m_zmm_out = zmm_mask(zmm_out, mask_flag);
1469         vpaddd(m_zmm_out, zmm_out, zmm_zp);
1470     }
1471     if (jcp.src_zero_point) {
1472         // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
1473         int zp_offset = sizeof(int32_t) * ocb * oc_block;
1474         const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1475         vpmulld(m_zmm_zp, zmm_src_zp,
1476                 EVEX_compress_addr(reg_zp_compensation, zp_offset));
1477         vpaddd(zmm_out, zmm_out, zmm_zp);
1478     }
1479 
1480     /* add bias and zero-point to zmm_accum */
1481     vcvtdq2ps(zmm_out, zmm_out);
1482     if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
1483     const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
1484     vmulps(zmm_out_msk, zmm_out,
1485             EVEX_compress_addr(reg_ptr_scales, scale_offset));
1486 
1487     apply_postops(zmm_out, p_sum_scale, addr, mask_flag, off, ocb);
1488 
1489     if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); }
1490 
1491     // Properly saturate the accumulators for integer datatypes
1492     if (one_of(jcp.dst_dt, u8, s8, s32)) {
1493         init_saturate_f32(
1494                 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dst_dt);
1495         saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt);
1496         vcvtps2dq(zmm_out, zmm_out);
1497     }
1498 
1499     const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
1500 
1501     switch (jcp.dst_dt) {
1502         case data_type::f32:
1503         case data_type::s32: vmovups(addr, zmm_out_store); break;
1504         case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
1505         case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
1506         default: assert(!"unknown dst_dt");
1507     }
1508 }
1509 
store_output_vector(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1510 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector(const Zmm &zmm_out,
1511         int ocb, int h, int w, const bool compute_zp, const int zp_h,
1512         const int zp_w) {
1513     /*
1514     Output:
1515               jcp.is_nspc              !jcp.is_nspc
1516               ---------------------    ---------------------
1517         INT8: [N][H][W][NBOC][16OC]
1518         BF16: [N][H][W][NBOC][16OC] or [N][NBOC][H][W][16OC]
1519     */
1520     if (jcp.src_dt == data_type::bf16) {
1521         store_output_vector_bf16(zmm_out, ocb, h, w);
1522     } else {
1523         store_output_vector_int8(zmm_out, ocb, h, w, compute_zp, zp_h, zp_w);
1524     }
1525 }
1526 
store_output(int width,int tail,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool is_last_oh_block,const bool zp_3d_pad)1527 void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail,
1528         bool do_store, const bool handle_h_blk, const int t_pad_output,
1529         const int b_pad_output, const int l_pad_output, const int r_pad_output,
1530         const bool is_last_oh_block, const bool zp_3d_pad) {
1531     auto store_output_block = [=](int width, int tail, bool do_store,
1532                                       bool is_last_h = false) {
1533         // Calculate the number of oh blocks; it may differ on last call
1534         const int last_h_blks
1535                 = div_up(jcp.oh, jcp.oh_per_tile) % jcp.nb_oh_blocking;
1536         const int h_blks = is_last_h && last_h_blks != 0 ? last_h_blks
1537                                                          : jcp.nb_oh_blocking;
1538         // Calculate the number of oh rows per tile; it may differ on last call
1539         const int h_tail = is_last_h && jcp.oh % jcp.oh_per_tile != 0
1540                 ? (h_blks - 1) * jcp.oh_per_tile + jcp.oh % jcp.oh_per_tile
1541                 : h_blks * jcp.oh_per_tile;
1542         const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
1543         const int owp = gen_kw + jcp.ow - 1;
1544 
1545         if (jcp.src_zero_point) {
1546             mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1547             mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1548             vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1549         }
1550         if (jcp.dst_zero_point) {
1551             mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1552             vcvtdq2ps(zmm_dst_zp,
1553                     EVEX_compress_addr(reg_dst_zero_point, 0, true));
1554         }
1555 
1556         for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
1557         for (int ohb = 0; ohb < h_blks; ohb++) {
1558             /* Formats: Workspace: [NBOC][W][16OC] */
1559             tilestored(ptr[reg_wsp_ptr + reg_wei_stride
1560                                + get_wsp_ocb_offset(ohb, ocb)],
1561                     Tmm(get_out_tensor(ohb, ocb, tail)));
1562             is_buffer_empty_ = false;
1563             is_store_done_ = false;
1564 
1565             // preserve registers used by binary post_ops injector
1566             const injector_utils::conditional_register_preserve_guard_t
1567                     cond_register_guard(jcp.with_binary, this,
1568                             {bin_injector_helper_reg_1,
1569                                     bin_injector_helper_reg_2});
1570 
1571             for (int tw = 0; tw < width && do_store; tw++) {
1572                 // height
1573                 const int oh_index = ohb * jcp.oh_per_tile + tw / owp;
1574                 const bool zp_h_pad
1575                         = oh_index < t_pad_output || oh_index >= b_pad_output;
1576                 const int zp_h = get_zp_index_offset(
1577                         oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1578                 // width
1579                 const int ow_index = tw % owp;
1580                 const bool zp_w_pad
1581                         = ow_index < l_pad_output || ow_index >= r_pad_output;
1582                 const int zp_w = get_zp_index_offset(
1583                         ow_index, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1584 
1585                 const bool compute_zp = jcp.req_zero_point_buffer
1586                         && (zp_3d_pad || zp_w_pad || zp_h_pad);
1587 
1588                 assert(IMPLICATION(jcp.oh_per_tile == 1,
1589                         ohb == oh_index && tw == ow_index));
1590                 if (oh_index < h_tail && ow_index < jcp.ow) {
1591                     Zmm zmm_r = zmm_out(tw);
1592                     vmovups(zmm_r,
1593                             ptr[reg_wsp_ptr
1594                                     + get_wsp_row_offset(ohb, ocb, tw)]);
1595                     store_output_vector(zmm_r, ocb, oh_index, ow_index,
1596                             compute_zp, zp_h, zp_w);
1597                 }
1598             }
1599         }
1600     };
1601 
1602     // adjustment in case interleave store is turned off
1603     do_store = do_store || jcp.per_one_pstore == 0;
1604     if (!do_store) { w_padding.emplace(l_pad_output, r_pad_output); }
1605     if (!handle_h_blk) {
1606         store_output_block(width, tail, do_store, is_last_oh_block);
1607     } else {
1608         if (jcp.oh % (jcp.oh_per_tile * jcp.nb_oh_blocking) == 0) {
1609             store_output_block(width, tail, do_store);
1610         } else {
1611             Label label_oh_oc_store, label_done;
1612             mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
1613             cmp(reg_last_h, 0);
1614             jne(label_oh_oc_store, T_NEAR);
1615             store_output_block(width, tail, do_store, true); // last h
1616             jmp(label_done, T_NEAR);
1617             L(label_oh_oc_store);
1618             store_output_block(width, tail, do_store, false);
1619             L(label_done);
1620         }
1621     }
1622     if (do_store) {
1623         add(reg_out_ptr, get_out_shift(width, jcp.typesize_out));
1624         if (jcp.req_zero_point_buffer) {
1625             const size_t sp_shift
1626                     = accum_with_upper_bound(width, l_pad_output, r_pad_output);
1627             add(reg_zero_point_pbuff, get_out_shift(sp_shift, sizeof(int32_t)));
1628         }
1629     }
1630 }
1631 
interleave_store(int width,int const t_pad_output,int const b_pad_output,const bool zp_3d_pad)1632 void jit_avx512_core_amx_fwd_kernel_t::interleave_store(int width,
1633         int const t_pad_output, int const b_pad_output, const bool zp_3d_pad) {
1634     for (int c = 0;
1635             c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
1636             c++) {
1637         // row_count = ohb * OCB * TW + ocb * TW + tw
1638         int tw = row_count_ % prv_width_;
1639         int ocb = (row_count_ / prv_width_) % jcp.nb_oc_blocking;
1640         int ohb = (row_count_ / prv_width_) / jcp.nb_oc_blocking;
1641 
1642         // preserve registers used by binary post_ops injector
1643         const injector_utils::conditional_register_preserve_guard_t
1644                 cond_register_guard(jcp.with_binary, this,
1645                         {bin_injector_helper_reg_1, bin_injector_helper_reg_2});
1646 
1647         // height
1648         const int oh_index = ohb;
1649         const bool zp_h_pad
1650                 = oh_index < t_pad_output || oh_index >= b_pad_output;
1651         const int zp_h = get_zp_index_offset(
1652                 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1653         // width
1654         const int l_pad_output
1655                 = w_padding.empty() ? 0 : w_padding.front().l_pad_output;
1656         const int r_pad_output
1657                 = w_padding.empty() ? jcp.ow : w_padding.front().r_pad_output;
1658 
1659         const bool zp_w_pad = tw < l_pad_output || tw >= r_pad_output;
1660         const int zp_w = get_zp_index_offset(
1661                 tw, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1662 
1663         const bool compute_zp = jcp.req_zero_point_buffer
1664                 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1665 
1666         Zmm zmm_r = zmm_out(tw);
1667         vmovups(zmm_r, ptr[reg_wsp_ptr + get_wsp_row_offset(ohb, ocb, tw)]);
1668         store_output_vector(zmm_r, ocb, ohb, tw, compute_zp, zp_h, zp_w);
1669         row_count_++;
1670 
1671         if (row_count_
1672                 == prv_width_ * jcp.nb_oc_blocking * jcp.nb_oh_blocking) {
1673             add(reg_out_ptr, get_out_shift(prv_width_, jcp.typesize_out));
1674             if (jcp.req_zero_point_buffer) {
1675                 const size_t sp_shift = accum_with_upper_bound(
1676                         prv_width_, l_pad_output, r_pad_output);
1677                 add(reg_zero_point_pbuff,
1678                         get_out_shift(sp_shift, sizeof(int32_t)));
1679                 if (!w_padding.empty()) w_padding.pop();
1680             }
1681             row_count_ = 0;
1682             is_store_done_ = true;
1683             prv_width_ = width;
1684         }
1685     }
1686 }
1687 
compute_icb_loop(int width,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad,const bool is_last_oh_block)1688 void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width,
1689         bool do_store, const bool handle_h_blk, const int t_pad_output,
1690         const int b_pad_output, const int l_pad_output, const int r_pad_output,
1691         const bool zp_3d_pad, const bool is_last_oh_block) {
1692     const bool tail = width == jcp.tile_tail;
1693 
1694     auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
1695         if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) {
1696             tdpbf16ps(x1, x2, x3);
1697         } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) {
1698             tdpbuud(x1, x2, x3);
1699         } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) {
1700             tdpbusd(x1, x2, x3);
1701         } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) {
1702             tdpbsud(x1, x2, x3);
1703         } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) {
1704             tdpbssd(x1, x2, x3);
1705         } else {
1706             assert(!"unsupported combination");
1707         }
1708     };
1709 
1710     prepare_output(tail);
1711 
1712     // prepare registers for when 'interleave_store()' is computed
1713     if (jcp.src_zero_point) {
1714         mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1715         mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1716         vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1717     }
1718     if (jcp.dst_zero_point) {
1719         mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1720         vcvtdq2ps(zmm_dst_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
1721     }
1722 
1723     // reduced lowering path
1724     if (jcp.is_relo) {
1725         const int nreduce = jcp.nreduce;
1726         const int stride = jcp.ic_block_int; // ie 64 (32) for int8 (bf16)
1727 
1728         push(reg_inp_ptr);
1729         push(reg_wei_ptr);
1730 
1731         for (int ireduce = 0; ireduce < nreduce; ireduce += stride) {
1732             for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1733                 tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1734                         ptr[reg_inp_ptr + get_inp_offset(ohb, 0)
1735                                 + reg_inp_stride]);
1736             }
1737             for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1738                 tileloadd(Tmm(get_wei_tensor(ocb)),
1739                         ptr[reg_wei_ptr + get_wei_offset(ocb, 0)
1740                                 + reg_wei_stride]);
1741                 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1742                     tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1743                             Tmm(get_inp_tensor(ohb, tail)),
1744                             Tmm(get_wei_tensor(ocb)));
1745                     interleave_store(width, t_pad_output, b_pad_output);
1746                 }
1747             }
1748             if (ireduce + stride < nreduce) {
1749                 add(reg_inp_ptr, stride * jcp.typesize_in);
1750                 add(reg_wei_ptr, stride * jcp.oc_block * jcp.typesize_in);
1751             }
1752         }
1753         pop(reg_wei_ptr);
1754         pop(reg_inp_ptr);
1755 
1756         store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1757                 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block);
1758 
1759         add(reg_inp_ptr, get_inp_shift());
1760 
1761         return;
1762     }
1763 
1764     auto wei_offset = [&](int icb, int ocb, int kd, int kh, int kw) {
1765         return (size_t)icb * get_wei_icb_step() + kd * get_wei_d_step()
1766                 + kh * get_wei_h_step() + get_wei_offset(ocb, kw);
1767     };
1768 
1769     auto inp_offset = [&](int icb, int ohb, int kd, int kh, int kw) {
1770         return (size_t)icb * get_inp_icb_step() + kd * get_inp_d_step()
1771                 + kh * get_inp_h_step() + get_inp_offset(ohb, kw);
1772     };
1773 
1774     auto safe_tileloadd
1775             = [=](const Tmm &t1, const Xbyak::Reg64 &reg_ptr, size_t offset,
1776                       const Xbyak::Reg64 &reg_stride) {
1777                   if (offset <= INT32_MAX) {
1778                       tileloadd(t1, ptr[reg_ptr + offset + reg_stride]);
1779                   } else {
1780                       safe_add(reg_ptr, offset, reg_tmp);
1781                       tileloadd(t1, ptr[reg_ptr + reg_stride]);
1782                       safe_sub(reg_ptr, offset, reg_tmp);
1783                   }
1784               };
1785 
1786     // normal and k-remainders path
1787     const bool check_kd_padding
1788             = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
1789     for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
1790         Label kd_skip_compute;
1791         if (check_kd_padding) mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1792 
1793         for (int kd = 0; kd < jcp.kd; kd++) {
1794             if (check_kd_padding) {
1795                 dec(reg_kd);
1796                 jl(kd_skip_compute, T_NEAR);
1797                 push(reg_kd);
1798             }
1799             for (int kh = 0; kh < jcp.kh; kh++) {
1800                 for (int set_idx = 0; set_idx < jcp.n_stride_sets;
1801                         set_idx++) { // used to optimize input memory reuse in L1$
1802                     for (int kw = set_idx; kw < jcp.kw; kw += jcp.kw_step) {
1803                         for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1804                             const size_t inp_off
1805                                     = inp_offset(icb, ohb, kd, kh, kw);
1806                             safe_tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1807                                     reg_inp_ptr, inp_off, reg_inp_stride);
1808                         }
1809                         for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1810                             const size_t wei_off
1811                                     = wei_offset(icb, ocb, kd, kh, kw);
1812                             safe_tileloadd(Tmm(get_wei_tensor(ocb)),
1813                                     reg_wei_ptr, wei_off, reg_wei_stride);
1814                             for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1815                                 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1816                                         Tmm(get_inp_tensor(ohb, tail)),
1817                                         Tmm(get_wei_tensor(ocb)));
1818                                 interleave_store(width, t_pad_output,
1819                                         b_pad_output, zp_3d_pad);
1820                             }
1821                         }
1822                     }
1823                 }
1824             }
1825             if (check_kd_padding) pop(reg_kd);
1826         }
1827         L(kd_skip_compute);
1828     }
1829 
1830     store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1831             b_pad_output, l_pad_output, r_pad_output, is_last_oh_block,
1832             zp_3d_pad);
1833 
1834     add(reg_inp_ptr, get_inp_shift());
1835 }
1836 
dispatch_icb_loop(int width,bool do_store,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad)1837 void jit_avx512_core_amx_fwd_kernel_t::dispatch_icb_loop(int width,
1838         bool do_store, const int l_pad_output, const int r_pad_output,
1839         const bool zp_3d_pad) {
1840     if (jcp.req_zero_point_buffer
1841             && (jcp.t_pad_output > 0 || jcp.b_pad_output > 0)) {
1842         const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
1843         const size_t height_limit = reduce_to_blocked_dims(
1844                 jcp.oh, oh_step_size, jcp.t_pad_output, jcp.b_pad_output);
1845         const int ur_h = div_up(height_limit, oh_step_size);
1846         assert(6 >= ur_h);
1847 
1848         // Use a jump-table to execute the corresponding block
1849         Label h_blk_label[6], h_blk_end_label, jmp_table_label;
1850         mov(reg_jmp_blk, ptr[param1 + GET_OFF(ohb)]);
1851         mov(reg_tmp, jmp_table_label);
1852         jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1853         jmp(h_blk_end_label, T_NEAR); // error, shouldn't happen
1854 
1855         align(8);
1856         L(jmp_table_label);
1857         for (int u = 0; u < ur_h; ++u) {
1858             putL(h_blk_label[u]);
1859         }
1860 
1861         // Save value of global variables for the next 'h_blk' iteration
1862         const int local_prv_width = prv_width_;
1863         const int local_row_count = row_count_;
1864         const bool local_is_store_done = is_store_done_;
1865         const bool local_is_buffer_empty = is_buffer_empty_;
1866 
1867         // Unroll ow_block with regards to l_pad_output and r_pad_output
1868         int cur_t_pad = reduce_to_block(oh_step_size, jcp.t_pad_output);
1869         int cur_b_pad = height_limit
1870                 - reduce_to_block(oh_step_size, jcp.b_pad_output);
1871         for (int u = 0; u < ur_h; u++) {
1872             bool last = u == ur_h - 1;
1873             L(h_blk_label[u]);
1874 
1875             // restore to previous 'h_blk' state of variables
1876             prv_width_ = local_prv_width;
1877             row_count_ = local_row_count;
1878             is_store_done_ = local_is_store_done;
1879             is_buffer_empty_ = local_is_buffer_empty;
1880             compute_icb_loop(width, do_store, false, cur_t_pad, cur_b_pad,
1881                     l_pad_output, r_pad_output, zp_3d_pad, last);
1882             cur_t_pad = nstl::max(0, cur_t_pad - oh_step_size);
1883             cur_b_pad = nstl::max(0, cur_b_pad - oh_step_size);
1884             if (!last) jmp(h_blk_end_label, T_NEAR);
1885         }
1886         L(h_blk_end_label);
1887     } else {
1888         compute_icb_loop(width, do_store, true, 0, jcp.oh, l_pad_output,
1889                 r_pad_output, zp_3d_pad);
1890     }
1891 }
1892 
dispatch_zp_3d_compute(int width,bool do_store,const int l_pad_output,const int r_pad_output)1893 void jit_avx512_core_amx_fwd_kernel_t::dispatch_zp_3d_compute(int width,
1894         bool do_store, const int l_pad_output, const int r_pad_output) {
1895     if (jcp.req_zero_point_buffer && (jcp.f_pad > 0 || jcp.back_pad > 0)) {
1896         Label compute_3d_zp_label, zp_d_end_label;
1897         mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1898         cmp(reg_kd, jcp.kd);
1899         jne(compute_3d_zp_label, T_NEAR);
1900 
1901         // Save value of global variables for next 'dispatch_icb_loop'
1902         const int local_prv_width = prv_width_;
1903         const int local_row_count = row_count_;
1904         const bool local_is_store_done = is_store_done_;
1905         const bool local_is_buffer_empty = is_buffer_empty_;
1906         dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1907 
1908         jmp(zp_d_end_label, T_NEAR);
1909         L(compute_3d_zp_label);
1910 
1911         prv_width_ = local_prv_width;
1912         row_count_ = local_row_count;
1913         is_store_done_ = local_is_store_done;
1914         is_buffer_empty_ = local_is_buffer_empty;
1915         dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, true);
1916 
1917         L(zp_d_end_label);
1918     } else
1919         dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1920 }
1921 
compute_ow_loop()1922 void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() {
1923     auto compute_ow_loop_body = [=](bool last_owb, int num_tile_blocks,
1924                                         const int l_pad_output,
1925                                         const int r_pad_output) {
1926         int cur_l_pad_output = l_pad_output;
1927         int cur_r_pad_output = r_pad_output;
1928         int gen_tile_tail = last_owb && jcp.tile_tail > 0 ? jcp.tile_tail
1929                                                           : jcp.tile_width;
1930         init_runtime_counters(last_owb && num_tile_blocks == 1);
1931         for (int owb = 0; owb < num_tile_blocks - 1; owb++) {
1932             dispatch_zp_3d_compute(
1933                     jcp.tile_width, false, cur_l_pad_output, cur_r_pad_output);
1934             cur_l_pad_output = nstl::max(0, cur_l_pad_output - jcp.tile_width);
1935             cur_r_pad_output = nstl::max(0, cur_r_pad_output - jcp.tile_width);
1936         }
1937         dispatch_zp_3d_compute(
1938                 gen_tile_tail, true, cur_l_pad_output, cur_r_pad_output);
1939     };
1940 
1941     assert(jcp.nb_ow > 0);
1942     if (jcp.nb_ow == 1) {
1943         const int ow_r_pad_start
1944                 = nstl::max(jcp.ow - jcp.r_pad_output, jcp.l_pad_output);
1945         compute_ow_loop_body(
1946                 true, jcp.ow_blocks, jcp.l_pad_output, ow_r_pad_start);
1947     } else if (jcp.req_zero_point_buffer
1948             && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0)) {
1949 
1950         const size_t zp_addr_shift
1951                 = jcp.ngroups * jcp.oc_without_padding * sizeof(int32_t);
1952         const int ow_step_size = jcp.ow_block;
1953         const int ow_blocks_per_call = div_up(ow_step_size, jcp.tile_width);
1954         const int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call == 0
1955                 ? ow_blocks_per_call
1956                 : jcp.ow_blocks % ow_blocks_per_call;
1957         const int width_limit = reduce_to_blocked_dims(
1958                 jcp.ow, ow_step_size, jcp.l_pad_output, jcp.r_pad_output);
1959         const int ur_w = div_up(width_limit, ow_step_size);
1960         assert(6 >= ur_w);
1961         // Use a jump-table to execute the corresponding block
1962         Label w_blk_label[6], w_blk_end_label, jmp_table_label;
1963         mov(reg_jmp_blk, ptr[param1 + GET_OFF(owb)]);
1964         mov(reg_tmp, jmp_table_label);
1965         jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1966         jmp(w_blk_end_label, T_NEAR); // error, shouldn't happen
1967 
1968         align(8);
1969         L(jmp_table_label);
1970         for (int u = 0; u < ur_w; ++u) {
1971             putL(w_blk_label[u]);
1972         }
1973 
1974         // Unroll ow_block with regards to l_pad_output and r_pad_output
1975         int cur_l_pad = reduce_to_block(ow_step_size, jcp.l_pad_output);
1976         int cur_r_pad
1977                 = width_limit - reduce_to_block(ow_step_size, jcp.r_pad_output);
1978         int zp_offset = 0;
1979         for (int u = 0; u < ur_w; u++) {
1980             const bool last = u == ur_w - 1;
1981             L(w_blk_label[u]);
1982             if (u > 0) add(reg_zero_point_pbuff, zp_offset * zp_addr_shift);
1983             compute_ow_loop_body(last,
1984                     last ? last_owb_tile_blocks : ow_blocks_per_call, cur_l_pad,
1985                     cur_r_pad);
1986             zp_offset += accum_with_upper_bound(
1987                     ow_step_size, cur_l_pad, cur_r_pad);
1988             cur_l_pad = nstl::max(0, cur_l_pad - ow_step_size);
1989             cur_r_pad = nstl::max(0, cur_r_pad - ow_step_size);
1990             if (!last) jmp(w_blk_end_label, T_NEAR);
1991         }
1992         L(w_blk_end_label);
1993 
1994     } else {
1995         assert(jcp.oh_per_tile == 1);
1996         Label label_done;
1997         int ow_blocks_per_call = utils::div_up(jcp.ow_block, jcp.tile_width);
1998         int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call;
1999         if (last_owb_tile_blocks == 0 && jcp.tile_tail > 0)
2000             last_owb_tile_blocks = ow_blocks_per_call;
2001         if (last_owb_tile_blocks > 0) {
2002             Label label_not_last_owb;
2003             mov(reg_tmp, ptr[param1 + GET_OFF(owb)]);
2004             cmp(reg_tmp, jcp.nb_ow - 1);
2005             jne(label_not_last_owb, T_NEAR);
2006 
2007             compute_ow_loop_body(true, last_owb_tile_blocks, 0, jcp.ow);
2008 
2009             jmp(label_done, T_NEAR);
2010 
2011             L(label_not_last_owb);
2012         }
2013         compute_ow_loop_body(false, ow_blocks_per_call, 0, jcp.ow);
2014 
2015         L(label_done);
2016     }
2017 }
2018 
generate()2019 void jit_avx512_core_amx_fwd_kernel_t::generate() {
2020     preamble();
2021 
2022     mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
2023     mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]);
2024     mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
2025     mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
2026     if (jcp.req_zero_point_buffer)
2027         mov(reg_zero_point_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
2028 
2029     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
2030     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
2031 
2032     const int fac = jcp.is_relo ? jcp.stride_w * jcp.kh
2033                                 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w;
2034     const int inp_stride = fac * jcp.ic_block_int_np * jcp.typesize_in;
2035     const int wei_stride = jcp.oc_block * jcp.typesize_acc;
2036     mov(reg_inp_stride, inp_stride);
2037     mov(reg_wei_stride, wei_stride);
2038 
2039     if (jcp.is_nspc && jcp.oc_without_padding != jcp.oc) {
2040         // Use mask 0xF by default for all output data and post-ops
2041         // loads / stores with block index
2042         // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1)
2043         // TODO: use masked loads / stores for the last occ only
2044         int current_block_size = jcp.oc_block;
2045         int mask = (1 << current_block_size) - 1;
2046         Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
2047         mov(regw_tmp, mask);
2048         kmovw(ktail_mask, regw_tmp);
2049         Xbyak::Label mask_is_set;
2050         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
2051         cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
2052         jne(mask_is_set, T_NEAR);
2053         // Reset the mask
2054         current_block_size = jcp.oc_without_padding % jcp.oc_block;
2055         mask = (1 << current_block_size) - 1;
2056         mov(regw_tmp, mask);
2057         kmovw(ktail_mask, regw_tmp);
2058 
2059         L(mask_is_set);
2060     }
2061     compute_ow_loop();
2062 
2063     postamble();
2064 
2065     if (jcp.with_eltwise) postops_injector_->prepare_table();
2066 }
2067 
tile_configure(char * tcfg_buff)2068 void jit_avx512_core_amx_fwd_kernel_t::tile_configure(char *tcfg_buff) {
2069     const int vnni_width = jcp.src_dt == data_type::bf16 ? 2 : 4;
2070     // Input tile dimensions
2071     const int a_col = jcp.is_relo ? jcp.ic_block_int
2072                                   : jcp.ic_block_int_np * jcp.kw_per_tile;
2073     // Weights tile dimensions
2074     const int b_col = jcp.oc_block * vnni_width;
2075     const int b_row = a_col / vnni_width;
2076     // Accumulator tile dimensions
2077     const int c_col = 16;
2078 
2079     for (size_t i = 0; i < 64; i++)
2080         tcfg_buff[i] = 0;
2081 
2082     // Weights (W_BASE) Tensor Tiles
2083     for (int i = 0; i < jcp.nb_oc_blocking; i++)
2084         tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
2085                 b_row, b_col * jcp.typesize_in);
2086 
2087     // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
2088     for (int h = 0; h < jcp.nb_oh_blocking; h++) {
2089         tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
2090                 jcp.tile_width, a_col * jcp.typesize_in);
2091         for (int i = 0; i < jcp.nb_oc_blocking; i++)
2092             tc_configure_tile((palette_config_t *)tcfg_buff,
2093                     get_out_tensor(h, i), jcp.tile_width,
2094                     c_col * jcp.typesize_acc);
2095     }
2096     if (jcp.tile_tail != 0) {
2097         assert(jcp.nb_oh_blocking == 1);
2098         assert(jcp.oh_per_tile == 1);
2099         assert(jcp.ow > jcp.tile_width);
2100         tc_configure_tile((palette_config_t *)tcfg_buff,
2101                 get_inp_tensor(0, true), jcp.tile_tail,
2102                 a_col * jcp.typesize_in);
2103         for (int i = 0; i < jcp.nb_oc_blocking; i++)
2104             tc_configure_tile((palette_config_t *)tcfg_buff,
2105                     get_out_tensor(0, i, true), jcp.tile_tail,
2106                     c_col * jcp.typesize_acc);
2107     }
2108 
2109     ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
2110 }
2111 
set_oh_blk_limits(jit_conv_conf_t & jcp)2112 void jit_avx512_core_amx_fwd_kernel_t::set_oh_blk_limits(jit_conv_conf_t &jcp) {
2113 
2114     constexpr int size = sizeof(jcp.h_blk_limits) / sizeof(jcp.h_blk_limits[0]);
2115     // set default values
2116     for (int i = 0; i < size; i++)
2117         jcp.h_blk_limits[i] = jcp.oh;
2118 
2119     const bool calculate_oh_limits
2120             = jcp.t_pad_output > 0 || jcp.b_pad_output > 0;
2121     if (jcp.req_zero_point_buffer && calculate_oh_limits) {
2122 
2123         int limit_idx = 0;
2124         const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2125 
2126         // full t_pad output block
2127         const int t_pad_blk_end = rnd_dn(jcp.t_pad_output, oh_step_size);
2128         if (jcp.t_pad_output >= oh_step_size) {
2129             jcp.h_blk_limits[limit_idx++] = t_pad_blk_end;
2130         }
2131         // t_pad output overlap with no padding
2132         const int t_pad_shift = jcp.t_pad_output % oh_step_size;
2133         if (t_pad_shift != 0) {
2134             jcp.h_blk_limits[limit_idx++] = t_pad_blk_end + t_pad_shift;
2135         }
2136         const int t_pad_next_blk = rnd_up(jcp.t_pad_output, oh_step_size);
2137         const int oh_blk_tail = jcp.oh % oh_step_size;
2138         const int b_pad_no_tail = nstl::max(0, jcp.b_pad_output - oh_blk_tail);
2139         const int b_pad_start
2140                 = nstl::max(jcp.t_pad_output, jcp.oh - jcp.b_pad_output);
2141         const int b_pad_blk_start = rnd_dn(b_pad_start, oh_step_size);
2142         // middle block without padding
2143         const int mid_blk = nstl::max(0, b_pad_blk_start - t_pad_next_blk);
2144         if (mid_blk >= oh_step_size) {
2145             jcp.h_blk_limits[limit_idx++] = b_pad_blk_start;
2146         }
2147         // no padding with b_pad overlap
2148         const int b_pad_shift = b_pad_no_tail % oh_step_size;
2149         if (b_pad_shift != 0) {
2150             jcp.h_blk_limits[limit_idx++] = rnd_up(b_pad_start, oh_step_size);
2151         }
2152         // full b_pad output block
2153         if (b_pad_no_tail >= oh_step_size) {
2154             jcp.h_blk_limits[limit_idx++] = jcp.oh - oh_blk_tail;
2155         }
2156         // b_pad tail block does not require a limit
2157     }
2158 }
2159 
set_ow_blk_limits(jit_conv_conf_t & jcp)2160 void jit_avx512_core_amx_fwd_kernel_t::set_ow_blk_limits(jit_conv_conf_t &jcp) {
2161 
2162     jcp.l_pad_blk = 0;
2163     jcp.no_pad_w_blk = 0;
2164     jcp.r_pad_blk = 0;
2165 
2166     const bool calculate_ow_limits
2167             = jcp.nb_ow > 1 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0);
2168     if (jcp.req_zero_point_buffer && calculate_ow_limits) {
2169         const int ow_step_size = jcp.ow_block;
2170 
2171         // l_pad
2172         const int l_pad_limit
2173                 = (jcp.l_pad_output >= ow_step_size ? ow_step_size : 0)
2174                 + (jcp.l_pad_output % ow_step_size);
2175         const int l_pad_area_blk = rnd_up(l_pad_limit, ow_step_size);
2176         jcp.l_pad_blk = div_up(l_pad_limit, ow_step_size);
2177 
2178         // middle (area without padding)
2179         const int no_pad_area
2180                 = nstl::max(0, jcp.ow - l_pad_area_blk - jcp.r_pad_output);
2181         jcp.no_pad_w_blk = no_pad_area >= ow_step_size ? 1 : 0;
2182 
2183         // r_pad
2184         const int no_pad_area_shift = no_pad_area % ow_step_size;
2185         const int r_pad_area_overlap
2186                 = no_pad_area_shift == 0 ? 0 : ow_step_size - no_pad_area_shift;
2187         const int r_pad_area
2188                 = nstl::max(0, jcp.r_pad_output - r_pad_area_overlap);
2189         const int r_pad_limit = (r_pad_area >= ow_step_size ? ow_step_size : 0)
2190                 + (r_pad_area % ow_step_size);
2191         jcp.r_pad_blk = (r_pad_area_overlap > 0 ? 1 : 0)
2192                 + div_up(r_pad_limit, ow_step_size);
2193     }
2194 }
2195 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr,int nthreads)2196 status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
2197         const convolution_desc_t &cd, memory_desc_t &src_md,
2198         memory_desc_t &weights_md, memory_desc_t &dst_md,
2199         memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) {
2200     using namespace prop_kind;
2201 
2202     const memory_desc_wrapper src_d(&src_md);
2203     const memory_desc_wrapper weights_d(&weights_md);
2204     const memory_desc_wrapper dst_d(&dst_md);
2205     const memory_desc_wrapper bias_d(&bias_md);
2206 
2207     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
2208     int ndims = src_d.ndims();
2209     bool is_1d = ndims == 3;
2210     bool is_3d = ndims == 5;
2211 
2212     const bool is_bf16_convolution
2213             = everyone_is(true, src_d.data_type() == data_type::bf16,
2214                     weights_d.data_type() == data_type::bf16,
2215                     one_of(dst_d.data_type(), data_type::bf16, data_type::f32));
2216     const bool is_int8_convolution = everyone_is(true,
2217             (src_d.data_type() == data_type::u8
2218                     || src_d.data_type() == data_type::s8),
2219             weights_d.data_type() == data_type::s8,
2220             one_of(dst_d.data_type(), data_type::f32, data_type::s32,
2221                     data_type::s8, data_type::u8));
2222 
2223     bool supported = false
2224             || (is_bf16_convolution && mayiuse(avx512_core_bf16_amx_bf16))
2225             || (is_int8_convolution && mayiuse(avx512_core_bf16_amx_int8));
2226     if (!supported) return status::unimplemented;
2227 
2228     jcp = zero<decltype(jcp)>();
2229     jcp.isa = is_bf16_convolution ? avx512_core_bf16_amx_bf16
2230                                   : avx512_core_bf16_amx_int8;
2231     jcp.ndims = ndims;
2232     jcp.prop_kind = cd.prop_kind;
2233     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2234 
2235     jcp.mb = src_d.dims()[0];
2236     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
2237     jcp.oc_without_padding = jcp.oc;
2238     jcp.ic = src_d.dims()[1] / jcp.ngroups;
2239     jcp.ic_without_padding = jcp.ic;
2240     jcp.id = is_3d ? src_d.dims()[2] : 1;
2241     jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1;
2242     jcp.iw = src_d.dims()[ndims - 1];
2243     jcp.od = is_3d ? dst_d.dims()[2] : 1;
2244     jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1;
2245     jcp.ow = dst_d.dims()[ndims - 1];
2246     jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
2247     jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
2248     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2249     jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
2250     jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
2251     jcp.l_pad = cd.padding[0][ndims - 3];
2252     jcp.stride_d = is_3d ? cd.strides[0] : 1;
2253     jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
2254     jcp.stride_w = cd.strides[ndims - 3];
2255     jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
2256 
2257     jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0;
2258     jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
2259     jcp.dilate_w = cd.dilates[ndims - 3];
2260 
2261     const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
2262     const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
2263     const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
2264     jcp.back_pad = calculate_end_padding(
2265             jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd);
2266     jcp.b_pad = calculate_end_padding(
2267             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
2268     jcp.r_pad = calculate_end_padding(
2269             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
2270     if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
2271             || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd
2272             || jcp.back_pad >= gen_kd)
2273         return status::unimplemented;
2274 
2275     const int max_pad = 28; // akin to maximum jcp.ur_w value in other jits
2276     if (jcp.l_pad > max_pad || jcp.r_pad > max_pad)
2277         return status::unimplemented; // TODO: relax this restriction
2278 
2279     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
2280     jcp.dst_dt = cd.dst_desc.data_type;
2281     jcp.src_dt = cd.src_desc.data_type;
2282     jcp.wei_dt = cd.weights_desc.data_type;
2283 
2284     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
2285 
2286     if (jcp.is_depthwise)
2287         return status::unimplemented; // TODO: add support of DW convolution
2288 
2289     const auto zp = attr.zero_points_;
2290     jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
2291     jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
2292     jcp.zp_src_is_common = zp.common(
2293             DNNL_ARG_SRC); // otherwise, it's per-channel (not supported)
2294 
2295     if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)
2296             || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point,
2297                     is_int8_convolution))
2298         return status::unimplemented;
2299 
2300     // Calculate zero-point padding values outside of the main JIT-kernel
2301     // and store the results in an auxiliary buffer.
2302     jcp.req_zero_point_buffer = jcp.src_zero_point
2303             && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 || jcp.t_pad > 0
2304                     || jcp.f_pad > 0 || jcp.back_pad > 0);
2305 
2306     format_tag_t dat_tag_ncsp = utils::pick(ndims - 3, format_tag::nCw16c,
2307             format_tag::nChw16c, format_tag::nCdhw16c);
2308     format_tag_t dat_tag_nspc = utils::pick(
2309             ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
2310     // To toggle the default data layout for BF16 between nChw16c and nhwc,
2311     // swap the following two variable definitions. Current choice: nhwc.
2312 
2313     // Clang-tidy change - if it was intentional please revert it and
2314     // put `NOLINTNEXTLINE` to suppress the warning.
2315     format_tag_t dat_tag_opt = dat_tag_nspc;
2316     format_tag_t dat_tag_alt
2317             = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
2318 
2319     if (src_d.format_kind() == format_kind::any) {
2320         CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
2321         jcp.src_tag = dat_tag_opt;
2322     } else
2323         jcp.src_tag = src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
2324 
2325     if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
2326         return status::unimplemented;
2327 
2328     jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
2329     assert(IMPLICATION(is_int8_convolution, jcp.is_nspc));
2330 
2331     // TODO: remove all support for nChw16c from this implementation
2332     if (!jcp.is_nspc) return status::unimplemented;
2333 
2334     if (dst_d.format_kind() == format_kind::any) {
2335         CHECK(memory_desc_init_by_tag(dst_md, jcp.src_tag));
2336         jcp.dst_tag = jcp.src_tag;
2337     } else
2338         jcp.dst_tag = dst_d.matches_one_of_tag(jcp.src_tag);
2339 
2340     if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
2341 
2342     if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
2343         CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
2344 
2345     jcp.nthr = nthreads;
2346 
2347     jcp.ic_block = 16;
2348     jcp.oc_block = 16;
2349 
2350     if (jcp.ngroups == 1) {
2351         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2352         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2353     }
2354     bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
2355     if (!args_ok) return status::unimplemented;
2356 
2357     const int vnni_width = is_bf16_convolution ? 2 : 4;
2358     jcp.ic_block_int = jcp.ic_block * vnni_width; // 32 for bf16, 64 for int8
2359 
2360     // fallback to non-amx impl when accumulation is too small
2361     const dim_t total_k = jcp.ic_without_padding * jcp.kd * jcp.kh * jcp.kw;
2362     const bool is_tiny_k = total_k < jcp.ic_block_int / 2;
2363     if (is_tiny_k) return status::unimplemented;
2364 
2365     // small-ic parameters
2366     jcp.ic_block_int_np = jcp.is_nspc
2367             ? nstl::min(jcp.ic_block_int, jcp.ic_without_padding)
2368             : jcp.ic_block_int;
2369     bool is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2370 
2371     // reduced lowering
2372     jcp.is_relo = (!is_3d)
2373             && is_small_ic
2374             // no trivial cases
2375             && 1 < jcp.kh * jcp.kw
2376             // required for use of VPERMB instruction in weights copy kernel
2377             && IMPLICATION(is_int8_convolution,
2378                     cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))
2379             // no dilation or excessive stride along w-direction
2380             && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
2381             // no dilation or excessive stride along h-direction
2382             && jcp.stride_h <= jcp.kh && jcp.stride_w <= jcp.kw;
2383     jcp.nreduce = jcp.kh * jcp.kw * jcp.ic_block_int_np;
2384 
2385     if (!jcp.is_relo) {
2386         jcp.ic_block_int_np = is_bf16_convolution
2387                 ? jcp.ic_block_int
2388                 : rnd_up(jcp.ic_block_int_np, vnni_width);
2389         is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2390     }
2391 
2392     // k-remainders
2393     jcp.kw_per_tile = is_small_ic && !jcp.is_relo && jcp.dilate_w == 0
2394                     && jcp.stride_w <= jcp.kw // TODO: relax this restriction
2395                     && jcp.kw * jcp.ic_block_int_np <= jcp.ic_block_int
2396             ? jcp.kw
2397             : 1;
2398     jcp.is_pbuffer_strided = (1 == jcp.kw_per_tile);
2399     jcp.n_stride_sets
2400             = jcp.is_pbuffer_strided ? nstl::min(jcp.stride_w, jcp.kw) : 1;
2401     jcp.kw_step = jcp.is_pbuffer_strided ? jcp.stride_w : jcp.kw_per_tile;
2402 
2403     const auto &p = attr.post_ops_;
2404 
2405     const int sum_ind = p.find(primitive_kind::sum);
2406     jcp.with_sum = sum_ind != -1;
2407     const int eltwise_ind = p.find(primitive_kind::eltwise);
2408     jcp.with_eltwise = eltwise_ind != -1;
2409     const int binary_ind = p.find(primitive_kind::binary);
2410     jcp.with_binary = binary_ind != -1;
2411 
2412     jcp.post_ops = p;
2413 
2414     using namespace injector;
2415     const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16);
2416     const bool sum_requires_scale_one = sum_at_pos_0_only;
2417     const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
2418             jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
2419             {broadcasting_strategy_t::scalar,
2420                     broadcasting_strategy_t::per_oc}});
2421     if (!post_ops_ok_) return status::unimplemented;
2422 
2423     auto set_or_check_wei_format = [&]() {
2424         using namespace format_tag;
2425         using namespace memory_extra_flags;
2426         format_tag_t wei_tag;
2427         wei_tag = jcp.is_relo ? pick(with_groups + 2 * (ndims - 3), Owi16o,
2428                           gOwi16o, Owhi16o, gOwhi16o) // no 3d support
2429                               : is_bf16_convolution
2430                         ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
2431                                 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i,
2432                                 OIdhw16i16o2i, gOIdhw16i16o2i)
2433                         : is_small_ic ? pick(with_groups + 2 * (ndims - 3),
2434                                   OwI16o4i, gOwI16o4i, OhwI16o4i, gOhwI16o4i,
2435                                   OdhwI16o4i, gOdhwI16o4i)
2436                                       : pick(with_groups + 2 * (ndims - 3),
2437                                               OIw16i16o4i, gOIw16i16o4i,
2438                                               OIhw16i16o4i, gOIhw16i16o4i,
2439                                               OIdhw16i16o4i, gOIdhw16i16o4i);
2440 
2441         memory_desc_t want_wei_md = weights_md;
2442         memory_desc_init_by_tag(want_wei_md, wei_tag);
2443 
2444         if (jcp.src_zero_point) {
2445             want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
2446             want_wei_md.extra.asymm_compensation_mask = (1 << 0)
2447                     + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
2448         }
2449         if (weights_md.format_kind == format_kind::any) {
2450             weights_md = want_wei_md;
2451             return true;
2452         }
2453         return weights_md == want_wei_md;
2454     };
2455 
2456     if (!set_or_check_wei_format()) return status::unimplemented;
2457 
2458     jcp.typesize_in = types::data_type_size(src_d.data_type());
2459     jcp.typesize_out = types::data_type_size(dst_d.data_type());
2460     jcp.typesize_bia
2461             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
2462     jcp.typesize_acc = sizeof(int32_t);
2463 
2464     jcp.nb_ic = jcp.ic / jcp.ic_block;
2465     jcp.nb_oc = jcp.oc / jcp.oc_block;
2466     jcp.nb_ic_int = div_up(jcp.ic, jcp.ic_block_int);
2467 
2468     jcp.nb_oc_blocking_thr_chunk = 1;
2469 
2470     const int max_palette = amx::get_max_palette();
2471     jcp.max_tiles = amx::get_max_tiles(max_palette);
2472     jcp.full_tile_width = amx::get_max_rows(max_palette);
2473     if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
2474         return status::unimplemented;
2475 
2476     // Pack n rows per tile, such that:
2477     // ow + (ow + gen_kw - 1) * (n - 1) <= jcp.full_tile_width
2478     auto calculate_tile_width = [&](int n) {
2479         assert(n > 0);
2480         return jcp.ow + (gen_kw + jcp.ow - 1) * (n - 1);
2481     };
2482     const bool ok_to_pack_tile = !jcp.is_relo
2483             && (utils::everyone_is(1, jcp.kh, jcp.kw)
2484                     || utils::everyone_is(1, jcp.stride_h, jcp.stride_w));
2485     const int max_oh_per_tile
2486             = 1 + (jcp.full_tile_width - jcp.ow) / (jcp.ow + gen_kw - 1);
2487     jcp.oh_per_tile = ok_to_pack_tile
2488             ? nstl::min(jcp.oh, nstl::max(1, max_oh_per_tile))
2489             : 1;
2490     jcp.tile_width = nstl::min<int>(
2491             jcp.full_tile_width, calculate_tile_width(jcp.oh_per_tile));
2492     jcp.ow_blocks = utils::div_up(jcp.ow, jcp.tile_width);
2493 
2494     // Prefer to use a single tile width when possible
2495     // (eg ow28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
2496     if (jcp.oh_per_tile == 1 && jcp.ow % jcp.ow_blocks == 0)
2497         jcp.tile_width = jcp.ow / jcp.ow_blocks;
2498     jcp.tile_tail = jcp.oh_per_tile == 1 ? jcp.ow % jcp.tile_width : 0;
2499 
2500     jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1;
2501     jcp.nb_ic_blocking = 1;
2502     jcp.nb_oh_blocking
2503             = utils::everyone_is(true, jcp.tile_tail == 0,
2504                       // requirement for interleave stores
2505                       IMPLICATION(jcp.ow_blocks > 1, jcp.oh % 2 == 0),
2506                       // requirement for small spatial
2507                       utils::div_up(jcp.oh, jcp.oh_per_tile) > 1,
2508                       // choose maximal pbuffer overlap for reduced lowering
2509                       !jcp.is_relo)
2510             ? 2
2511             : 1;
2512 
2513     // TODO: tune oh blocking
2514     const int oh_blk_size_param = jcp.is_relo ? 1 : 10;
2515     const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2516     const int oh_blk_size = rnd_up(oh_blk_size_param, oh_step_size);
2517     jcp.oh_blk_size = rnd_up(nstl::min(jcp.oh, oh_blk_size), oh_step_size);
2518     // Here ihp means the input buffer height including padding (ie the number
2519     // of input rows required for computation of jcp.oh_blk_size output rows.
2520     // If an input row doesn't participate in the computation of any output row,
2521     // it isn't copied to the buffer at all (eg jcp.stride_h > gen_kh).
2522     jcp.ihp = jcp.is_relo
2523             ? jcp.oh_blk_size
2524             : (jcp.oh_blk_size - 1) * nstl::min(jcp.stride_h, gen_kh) + gen_kh;
2525 
2526     // TODO: tune ow blocking
2527     const int ow_blocks_per_call = jcp.is_relo ? 10 : 2;
2528     jcp.ow_block = nstl::min(jcp.ow, jcp.tile_width * ow_blocks_per_call);
2529     jcp.nb_ow = utils::div_up(jcp.ow, jcp.ow_block);
2530     // iwp includes all width elements that are really used in calculation
2531     // including left and right zero padding
2532     const bool are_sets_interleaved
2533             = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
2534     jcp.iwp = are_sets_interleaved
2535             ? (jcp.ow_block - 1) * nstl::min(jcp.stride_w, jcp.kw) + gen_kw
2536             : jcp.ow_block * jcp.kw;
2537 
2538     // Number of ops per tile store
2539     int ops_tile_store = jcp.tile_width;
2540     // Number of ops per accumulation tile
2541     int avaliable_ops = jcp.is_relo
2542             ? utils::div_up(jcp.nreduce, jcp.ic_block_int)
2543             : jcp.nb_ic_int * jcp.kh * (jcp.kw / jcp.kw_per_tile);
2544     // Number of vectors to store per tile operation
2545     // NOTE: set to zero to turn off interleave store (mostly for debugging)
2546     jcp.per_one_pstore = utils::div_up(ops_tile_store, avaliable_ops);
2547 
2548     if (jcp.is_relo) {
2549         jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.ihp * jcp.iwp * jcp.kh
2550                         * jcp.ic_block_int_np
2551                 // pbuffer pointer shifts each oh step for reduced-lowering
2552                 + (jcp.oh - 1) * jcp.stride_h * jcp.ic_block_int_np
2553                 // extra $line due to pbuffer writing full Zmm
2554                 + jcp.ic_block_int;
2555     } else {
2556         jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.kd
2557                 * ((size_t)jcp.ihp * jcp.iwp * jcp.ic_block_int_np
2558                         // extra $line due to pbuffer writing full Zmm
2559                         + jcp.ic_block_int);
2560     }
2561     jcp.wei_buffer_size = (size_t)jcp.ngroups * jcp.nb_oc
2562             * rnd_up(jcp.kh * jcp.kw * jcp.ic * jcp.oc_block, 1024);
2563     jcp.wsp_buffer_size = (size_t)jcp.nb_oh_blocking * jcp.nb_oc_blocking
2564             * jcp.full_tile_width * jcp.oc_block;
2565 
2566     const auto &oscales = attr.output_scales_;
2567     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
2568 
2569     // Note: currently unsupported, results in seg-fault
2570     const int l_pad_output = nstl::min(jcp.ow, div_up(jcp.l_pad, jcp.stride_w));
2571     if (!jcp.is_relo && (l_pad_output > jcp.ow_block))
2572         return status::unimplemented;
2573 
2574     // Relevant to 'zero_point padding buffer' (pbuff) jit kernel
2575     if (jcp.req_zero_point_buffer) {
2576         auto calculate_output_padding_dims = [=](int o_dim, int s_pad,
2577                                                      int e_pad,
2578                                                      int &s_pad_output,
2579                                                      int &e_pad_output,
2580                                                      bool &o_mid, int &o_pad,
2581                                                      int stride,
2582                                                      bool req_mid_area) {
2583             s_pad_output = nstl::min(o_dim, div_up(s_pad, stride));
2584             e_pad_output = nstl::min(o_dim, div_up(e_pad, stride));
2585             o_mid = (o_dim - s_pad_output - e_pad_output > 0) && req_mid_area;
2586             o_pad = nstl::min(o_dim,
2587                     nstl::max(1, s_pad_output + e_pad_output + (int)o_mid));
2588         };
2589 
2590         const bool mid_w_area = (jcp.l_pad > 0 || jcp.r_pad > 0)
2591                 && (jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
2592                         || jcp.back_pad > 0);
2593         const bool mid_h_area = (jcp.t_pad > 0 || jcp.b_pad > 0)
2594                 && (jcp.l_pad > 0 || jcp.r_pad > 0 || jcp.f_pad > 0
2595                         || jcp.back_pad > 0);
2596         const bool mid_d_area = (jcp.f_pad > 0 || jcp.back_pad > 0)
2597                 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0
2598                         || jcp.t_pad > 0);
2599         calculate_output_padding_dims(jcp.ow, jcp.l_pad, jcp.r_pad,
2600                 jcp.l_pad_output, jcp.r_pad_output, jcp.ow_mid, jcp.ow_pad,
2601                 jcp.stride_w, mid_w_area);
2602         calculate_output_padding_dims(jcp.oh, jcp.t_pad, jcp.b_pad,
2603                 jcp.t_pad_output, jcp.b_pad_output, jcp.oh_mid, jcp.oh_pad,
2604                 jcp.stride_h, mid_h_area);
2605         calculate_output_padding_dims(jcp.od, jcp.f_pad, jcp.back_pad,
2606                 jcp.f_pad_output, jcp.back_pad_output, jcp.od_mid, jcp.od_pad,
2607                 jcp.stride_d, mid_d_area);
2608         jcp.zp_pbuff_size
2609                 = jcp.od_pad * jcp.oh_pad * jcp.ow_pad * jcp.oc * jcp.ngroups;
2610 
2611         // compute zero-point padding kernel outside of the main parallel
2612         // region when threads are more likely to parallelize work across mb
2613         // within the convolution compute block.
2614         jcp.zp_pbuff_outer_compute = jcp.mb > 1 || is_3d;
2615 
2616         const bool params_ok = ((jcp.ow_pad - (int)jcp.ow_mid) <= max_pad * 2);
2617         if (!params_ok) { return status::unimplemented; }
2618     }
2619 
2620     // Set default parameters for driver code, but mostly required for
2621     // 'zero_point padding buffer' (pbuff) accumulation over output tensor
2622     set_oh_blk_limits(jcp);
2623     set_ow_blk_limits(jcp);
2624 
2625     return status::success;
2626 }
2627 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)2628 status_t jit_avx512_core_amx_fwd_kernel_t::init_scratchpad(
2629         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
2630         const primitive_attr_t &attr) {
2631 
2632     size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
2633     scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
2634     if (jcp.is_relo) {
2635         scratchpad.book(
2636                 key_conv_amx_wei_buffer, jcp.wei_buffer_size, jcp.typesize_in);
2637     }
2638 
2639     size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
2640     scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
2641     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
2642         assert(jcp.ngroups == 1);
2643         scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
2644     }
2645     scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
2646     if (jcp.req_zero_point_buffer) {
2647         const int nthr = jcp.zp_pbuff_outer_compute ? 1 : jcp.nthr;
2648         scratchpad.book(key_conv_zero_point_pad,
2649                 (size_t)nthr * jcp.zp_pbuff_size, sizeof(int32_t));
2650         if (!jcp.zp_pbuff_outer_compute) {
2651             const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
2652             scratchpad.book<bool>(key_conv_zero_point_flag,
2653                     (size_t)jcp.nthr * oc_chunks * jcp.ngroups);
2654         }
2655     }
2656 
2657     // Keep scratchpad memory footprint under control
2658     const size_t L2_size_per_core = platform::get_per_core_cache_size(2);
2659     const size_t L3_size_per_core = platform::get_per_core_cache_size(3);
2660     const size_t max_scratchpad_size
2661             = jcp.nthr * (L2_size_per_core + L3_size_per_core);
2662     // TODO: tune this relationship as needed
2663     if (scratchpad.size() > max_scratchpad_size) return status::unimplemented;
2664     return status::success;
2665 }
2666 
copy_row(const bool is_masked)2667 void jit_avx512_core_amx_bwd_data_copy_kernel_t::copy_row(
2668         const bool is_masked) {
2669     assert(jcp.is_nspc && "no support for nChw16c in this copy kernel");
2670 
2671     const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
2672     const int inp_w_step
2673             = jcp.ngroups * jcp.oc_without_padding * jcp.typesize_in;
2674     const int inp_h_step = jcp.ow * inp_w_step;
2675     const int out_w_step = jcp.oc_block_int * jcp.typesize_in;
2676     const int out_h_step = jcp.owp * out_w_step;
2677 
2678     auto zero_it = [=](reg64_t tmp_out_ptr, int offset) {
2679         // no mask as output is a padded buffer
2680         if (is_bf16)
2681             vmovdqu16(ptr[tmp_out_ptr + offset], zmm_zero);
2682         else
2683             vmovdqu8(ptr[tmp_out_ptr + offset], zmm_zero);
2684     };
2685 
2686     auto copy_it = [=](reg64_t tmp_inp_ptr, int inp_off, reg64_t tmp_out_ptr,
2687                            int out_off) {
2688         Zmm zmm_load = is_masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
2689         Zmm zmm_stor = zmm_tmp; // no mask as output is padded buffer
2690         if (is_bf16) {
2691             vmovdqu16(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2692             vmovdqu16(ptr[tmp_out_ptr + out_off], zmm_stor);
2693         } else {
2694             vmovdqu8(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2695             vmovdqu8(ptr[tmp_out_ptr + out_off], zmm_stor);
2696         }
2697     };
2698 
2699     mov(reg_ptr_aux_out, reg_ptr_out);
2700 
2701     { // Handle Top Overflow
2702         Label label_tov_loop, label_tov_skip;
2703         test(reg_tov, reg_tov);
2704         jz(label_tov_skip, T_NEAR);
2705         mov(reg_cnt_tmp, reg_tov);
2706         L(label_tov_loop);
2707         {
2708             for (int ow = 0; ow < jcp.owp; ow++) {
2709                 const int offset = ow * out_w_step;
2710                 zero_it(reg_ptr_aux_out, offset);
2711             }
2712             add(reg_ptr_aux_out, out_h_step);
2713             dec(reg_cnt_tmp);
2714             jnz(label_tov_loop, T_NEAR);
2715         }
2716         L(label_tov_skip);
2717     }
2718 
2719     mov(reg_ptr_aux_inp_h, reg_ptr_inp);
2720 
2721     // Handle Middle Loop
2722     Label label_khp_loop, label_khp_skip;
2723     test(reg_khp, reg_khp);
2724     jz(label_khp_skip, T_NEAR);
2725     mov(reg_cnt_khp, reg_khp);
2726     L(label_khp_loop);
2727     {
2728         Label label_lov, label_lov_skip;
2729         Label label_kwp, label_kwp_skip;
2730         Label label_rov, label_rov_skip;
2731         test(reg_lov, reg_lov);
2732         jnz(label_lov, T_NEAR);
2733         test(reg_kwp, reg_kwp);
2734         jnz(label_kwp, T_NEAR);
2735         test(reg_rov, reg_rov);
2736         jnz(label_rov, T_NEAR);
2737 
2738         test(reg_lov, reg_lov);
2739         jz(label_lov_skip, T_NEAR); // not really needed, but just to be safe
2740         L(label_lov); // Handle Left Overflow
2741         {
2742             Label label_lov_loop;
2743             mov(reg_cnt_tmp, reg_lov);
2744             L(label_lov_loop);
2745             {
2746                 zero_it(reg_ptr_aux_out, 0);
2747                 add(reg_ptr_aux_out, out_w_step);
2748                 dec(reg_cnt_tmp);
2749                 jnz(label_lov_loop, T_NEAR);
2750             }
2751         }
2752         L(label_lov_skip);
2753 
2754         test(reg_kwp, reg_kwp);
2755         jz(label_kwp_skip, T_NEAR);
2756         L(label_kwp); // Handle Center Loop
2757         {
2758             Label label_kwp_loop;
2759             mov(reg_ptr_aux_inp_w, reg_ptr_aux_inp_h);
2760             mov(reg_cnt_tmp, reg_kwp);
2761             L(label_kwp_loop);
2762             {
2763                 copy_it(reg_ptr_aux_inp_w, 0, reg_ptr_aux_out, 0);
2764                 add(reg_ptr_aux_out, out_w_step);
2765                 add(reg_ptr_aux_inp_w, inp_w_step);
2766                 dec(reg_cnt_tmp);
2767 
2768                 if (jcp.stride_w > 1) {
2769                     jz(label_kwp_skip, T_NEAR);
2770                     // Handle Dilation-by-Stride
2771                     for (int sw = 0; sw < jcp.stride_w - 1; sw++) {
2772                         const int offset = sw * out_w_step;
2773                         zero_it(reg_ptr_aux_out, offset);
2774                     }
2775                     add(reg_ptr_aux_out, (jcp.stride_w - 1) * out_w_step);
2776                     if (jcp.stride_w == 2)
2777                         dec(reg_cnt_tmp);
2778                     else
2779                         sub(reg_cnt_tmp, jcp.stride_w - 1);
2780                     jmp(label_kwp_loop, T_NEAR);
2781                 } else {
2782                     jnz(label_kwp_loop, T_NEAR);
2783                 }
2784             }
2785         }
2786         L(label_kwp_skip);
2787 
2788         test(reg_rov, reg_rov);
2789         jz(label_rov_skip, T_NEAR);
2790         L(label_rov); // Handle Right Overflow
2791         {
2792             Label label_rov_loop;
2793             mov(reg_cnt_tmp, reg_rov);
2794             L(label_rov_loop);
2795             {
2796                 zero_it(reg_ptr_aux_out, 0);
2797                 add(reg_ptr_aux_out, out_w_step);
2798                 dec(reg_cnt_tmp);
2799                 jnz(label_rov_loop, T_NEAR);
2800             }
2801         }
2802         L(label_rov_skip);
2803 
2804         add(reg_ptr_aux_inp_h, inp_h_step);
2805         dec(reg_cnt_khp);
2806 
2807         if (jcp.stride_h > 1) {
2808             jz(label_khp_skip, T_NEAR);
2809             // Handle Dilation-by-Stride
2810             for (int sh = 0; sh < jcp.stride_h - 1; sh++) {
2811                 for (int ow = 0; ow < jcp.owp; ow++) {
2812                     const int offset = sh * out_h_step + ow * out_w_step;
2813                     zero_it(reg_ptr_aux_out, offset);
2814                 }
2815             }
2816             add(reg_ptr_aux_out, (jcp.stride_h - 1) * out_h_step);
2817             if (jcp.stride_h == 2)
2818                 dec(reg_cnt_khp);
2819             else
2820                 sub(reg_cnt_khp, jcp.stride_h - 1);
2821             jmp(label_khp_loop, T_NEAR);
2822         } else {
2823             jnz(label_khp_loop, T_NEAR);
2824         }
2825     }
2826     L(label_khp_skip);
2827 
2828     { // Handle Bottom Overflow
2829         Label label_bov_loop, label_bov_skip;
2830         test(reg_bov, reg_bov);
2831         jz(label_bov_skip, T_NEAR);
2832         mov(reg_cnt_tmp, reg_bov);
2833         L(label_bov_loop);
2834         {
2835             for (int ow = 0; ow < jcp.owp; ow++) {
2836                 const int offset = ow * out_w_step;
2837                 zero_it(reg_ptr_aux_out, offset);
2838             }
2839             add(reg_ptr_aux_out, out_h_step);
2840             dec(reg_cnt_tmp);
2841             jnz(label_bov_loop, T_NEAR);
2842         }
2843         L(label_bov_skip);
2844     }
2845 }
2846 
generate()2847 void jit_avx512_core_amx_bwd_data_copy_kernel_t::generate() {
2848 
2849     const int inp_c_step = jcp.oc_block_int * jcp.typesize_in;
2850     const int out_c_step = jcp.ohp * jcp.owp * inp_c_step;
2851     const int nb_oc_int_no_tail = jcp.oc_without_padding / jcp.oc_block_int;
2852     const int oc_block_int_tail = jcp.oc_without_padding % jcp.oc_block_int;
2853 
2854     preamble();
2855 
2856     // pointer to 1st needed element in src buffer
2857     mov(reg_ptr_inp, ptr[param1 + GET_OFF(src)]);
2858     // pointer to 1st needed element in dst buffer
2859     mov(reg_ptr_out, ptr[param1 + GET_OFF(dst)]);
2860 
2861     // number of rows of src buffer to copy
2862     mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
2863     // number of zero-padded rows above src buffer to copy
2864     mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
2865     // number of zero-padded rows below src buffer to copy
2866     mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
2867 
2868     // number of columns of src buffer to copy
2869     mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
2870     // number of zero-padded columns before src buffer to copy
2871     mov(reg_lov, ptr[param1 + GET_OFF(l_overflow)]);
2872     // number of zero-padded columns before src buffer to copy
2873     mov(reg_rov, ptr[param1 + GET_OFF(r_overflow)]);
2874 
2875     vpxord(zmm_zero, zmm_zero, zmm_zero);
2876 
2877     if (oc_block_int_tail > 0) {
2878         uint64_t mask = (UINT64_C(1) << oc_block_int_tail) - 1;
2879         mov(reg_tmp, mask);
2880         kmovq(ktail_mask, reg_tmp);
2881     }
2882 
2883     if (nb_oc_int_no_tail == 0) {
2884         copy_row(true); // masked
2885     } else if (nb_oc_int_no_tail == 1) {
2886         copy_row(false); // unmasked!
2887         if (oc_block_int_tail > 0) {
2888             add(reg_ptr_inp, inp_c_step);
2889             add(reg_ptr_out, out_c_step);
2890             copy_row(true); // masked
2891         }
2892     } else if (nb_oc_int_no_tail > 1) {
2893         mov(reg_cnt_ocb, nb_oc_int_no_tail);
2894         Label label_ocb_loop;
2895         L(label_ocb_loop);
2896         {
2897             copy_row(false); // unmasked!
2898             add(reg_ptr_inp, inp_c_step);
2899             add(reg_ptr_out, out_c_step);
2900             dec(reg_cnt_ocb);
2901             jnz(label_ocb_loop);
2902         }
2903         if (oc_block_int_tail > 0) copy_row(true); // masked
2904     }
2905 
2906     postamble();
2907 }
2908 
2909 // Tile register decomposition
2910 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i) const2911 int jit_avx512_core_amx_bwd_data_kernel_t::get_out_tensor(int h, int i) const {
2912     const int C_BASE = 0;
2913     const int C_LAST = 4;
2914     assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
2915     MAYBE_UNUSED(C_LAST);
2916     const int tile = C_BASE + h * jcp.nb_ih_blocking + i;
2917     assert(C_BASE <= tile && tile < C_LAST);
2918     return tile;
2919 }
get_inp_tensor(int h) const2920 int jit_avx512_core_amx_bwd_data_kernel_t::get_inp_tensor(int h) const {
2921     const int I_BASE = 4;
2922     const int I_LAST = 6;
2923     assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
2924     MAYBE_UNUSED(I_LAST);
2925     const int tile = I_BASE + h;
2926     assert(I_BASE <= tile && tile < I_LAST);
2927     return tile;
2928 }
get_wei_tensor(int i) const2929 int jit_avx512_core_amx_bwd_data_kernel_t::get_wei_tensor(int i) const {
2930     const int W_BASE = 6;
2931     const int W_LAST = 8;
2932     assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
2933     MAYBE_UNUSED(W_LAST);
2934     const int tile = W_BASE + i;
2935     assert(W_BASE <= tile && tile < W_LAST);
2936     return tile;
2937 }
2938 
2939 // Strides, shifts and offsets
2940 // - inp is a padded buffer ~ [nb_oc_int][ohp][owp]{32c,64c}
2941 // - weights is user buffer ~ OIhw16o16i{2o,4o}
2942 // - output is tiled buffer ~ [NBIH][NBIC][tile_width][16c]
get_inp_kh_step() const2943 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_kh_step() const {
2944     return (size_t)jcp.typesize_in * (jcp.dilate_h + 1) * jcp.owp
2945             * jcp.oc_block_int;
2946 }
get_inp_ocb_step() const2947 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_ocb_step() const {
2948     return (size_t)jcp.typesize_in * jcp.ohp * jcp.owp * jcp.oc_block_int;
2949 }
get_inp_shift() const2950 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_shift() const {
2951     return (size_t)jcp.typesize_in * jcp.tile_width * jcp.oc_block_int;
2952 }
get_inp_offset(int ihb,int kh,int kw) const2953 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_offset(
2954         int ihb, int kh, int kw) const {
2955     // calculate offset by src height dimension
2956     size_t sp_offset = (size_t)ihb * jcp.owp;
2957     // add offset by kernel height dimension
2958     sp_offset += (size_t)(jcp.kh - 1 - kh) * (jcp.dilate_h + 1) * jcp.owp;
2959     // add offset by kernel width dimension
2960     sp_offset += (size_t)(jcp.kw - 1 - kw) * (jcp.dilate_w + 1);
2961     return jcp.typesize_in * sp_offset * jcp.oc_block_int;
2962 }
get_wei_kh_step() const2963 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_kh_step() const {
2964     return (size_t)jcp.typesize_in * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2965 }
get_wei_ocb_step() const2966 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_ocb_step() const {
2967     const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2968     return (size_t)jcp.typesize_in * (is_deconv ? 1 : jcp.nb_ic) * jcp.kh
2969             * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2970 }
get_wei_offset(int icb,int kh,int kw) const2971 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_offset(
2972         int icb, int kh, int kw) const {
2973     const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2974     const size_t wei_kw_stride = jcp.oc_block_int * jcp.ic_block;
2975     const size_t wei_kh_stride = jcp.kw * wei_kw_stride;
2976     const size_t wei_icb_stride
2977             = (is_deconv ? jcp.nb_oc_int : 1) * jcp.kh * wei_kh_stride;
2978     return jcp.typesize_in
2979             * (icb * wei_icb_stride + kh * wei_kh_stride + kw * wei_kw_stride);
2980 }
get_out_icb_offset(int ihb,int icb) const2981 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_icb_offset(
2982         int ihb, int icb) const {
2983     size_t el_offset = jcp.is_nspc
2984             ? (size_t)icb * jcp.ic_block
2985                     + (size_t)ihb * jcp.iw * jcp.ngroups
2986                             * jcp.ic_without_padding
2987             : (size_t)icb * jcp.ih * jcp.iw * jcp.ic_block
2988                     + (size_t)ihb * jcp.iw * jcp.ic_block;
2989     return (size_t)jcp.typesize_out * el_offset;
2990 }
get_out_row_offset(int ihb,int icb,int j) const2991 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_row_offset(
2992         int ihb, int icb, int j) const {
2993     size_t offset_w = jcp.is_nspc ? (size_t)jcp.typesize_out * j * jcp.ngroups
2994                     * jcp.ic_without_padding
2995                                   : (size_t)jcp.typesize_out * j * jcp.ic_block;
2996     return get_out_icb_offset(ihb, icb) + offset_w;
2997 }
get_out_shift(int width) const2998 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_shift(int width) const {
2999     return jcp.is_nspc ? (size_t)jcp.typesize_out * width * jcp.ngroups
3000                     * jcp.ic_without_padding
3001                        : (size_t)jcp.typesize_out * width * jcp.ic_block;
3002 }
get_wsp_icb_offset(int ihb,int icb) const3003 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_icb_offset(
3004         int ihb, int icb) const {
3005     size_t el_offset = (size_t)icb * prv_width_ * jcp.ic_block
3006             + (size_t)ihb * jcp.nb_ic_blocking * jcp.full_tile_width
3007                     * jcp.ic_block;
3008     return jcp.typesize_acc * el_offset;
3009 }
get_wsp_row_offset(int ihb,int icb,int j) const3010 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_row_offset(
3011         int ihb, int icb, int j) const {
3012     return get_wsp_icb_offset(ihb, icb)
3013             + (size_t)jcp.typesize_acc * j * jcp.ic_block;
3014 }
3015 
3016 // Code generation
prepare_output()3017 void jit_avx512_core_amx_bwd_data_kernel_t::prepare_output() {
3018     for (int h = 0; h < jcp.nb_ih_blocking; h++)
3019         for (int i = 0; i < jcp.nb_ic_blocking; i++)
3020             tilezero(Tmm(get_out_tensor(h, i)));
3021 }
3022 
init_runtime_counters(bool start_with_last_tile_block)3023 void jit_avx512_core_amx_bwd_data_kernel_t::init_runtime_counters(
3024         bool start_with_last_tile_block) {
3025     prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
3026             ? jcp.tile_tail
3027             : jcp.tile_width;
3028 
3029     row_count_ = 0;
3030     is_store_done_ = false;
3031     is_buffer_empty_ = true;
3032 }
3033 
maybe_eltwise(int position)3034 bool jit_avx512_core_amx_bwd_data_kernel_t::maybe_eltwise(int position) {
3035     using namespace primitive_kind;
3036     const auto &p = attr_.post_ops_;
3037 
3038     if (position == 0) {
3039         /* eltwise before sum */
3040         return p.contain(eltwise, 0);
3041     } else if (position == 1) {
3042         /* eltwise after sum */
3043         return p.contain(sum, 0) && p.contain(eltwise, 1);
3044     }
3045 
3046     return false;
3047 }
3048 
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)3049 Ymm jit_avx512_core_amx_bwd_data_kernel_t::ymm_mask(
3050         const Ymm &ymm_in, bool mask_flag, bool store) {
3051     return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
3052                      : ymm_in;
3053 }
3054 
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)3055 Zmm jit_avx512_core_amx_bwd_data_kernel_t::zmm_mask(
3056         const Zmm &zmm_in, bool mask_flag, bool store) {
3057     return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
3058                      : zmm_in;
3059 }
3060 
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)3061 void jit_avx512_core_amx_bwd_data_kernel_t::cvt2ps(data_type_t type_in,
3062         const Zmm &zmm_in, const Operand &op, bool mask_flag) {
3063     const Zmm zmm = zmm_mask(zmm_in, mask_flag);
3064     switch (type_in) {
3065         case data_type::f32:
3066         case data_type::s32: vmovups(zmm, op); break;
3067         case data_type::s8: vpmovsxbd(zmm, op); break;
3068         case data_type::u8: vpmovzxbd(zmm, op); break;
3069         default: assert(!"unsupported data type");
3070     }
3071     if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
3072 }
3073 
store_output_vector_bf16(const Zmm & zmm_out,int icb,int h,int w)3074 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_bf16(
3075         const Zmm &zmm_out, int icb, int h, int w) {
3076     const bool mask_flag = jcp.is_nspc && jcp.ic_without_padding != jcp.ic
3077             && icb == (jcp.nb_ic_blocking - 1);
3078 
3079     auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3080 
3081     const auto &p = attr_.post_ops_;
3082 
3083     const int sum_idx = p.find(primitive_kind::sum);
3084     if (sum_idx != -1) {
3085         if (jcp.dsrc_dt == data_type::bf16) {
3086             vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
3087             vpslld(zmm_prev_dst, zmm_prev_dst, 16);
3088             vaddps(zmm_out, zmm_prev_dst);
3089         } else {
3090             vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
3091             vaddps(zmm_out, zmm_prev_dst);
3092         }
3093     }
3094     if (jcp.with_bias) {
3095         int bias_offset = jcp.typesize_bia * icb * jcp.ic_block;
3096         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3097         if (jcp.bia_dt == data_type::bf16) {
3098             vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
3099             vpslld(zmm_bias, zmm_bias, 16);
3100             vaddps(zmm_out, zmm_bias);
3101         } else
3102             vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
3103     }
3104 
3105     const int eltwise_ind = p.find(primitive_kind::eltwise);
3106     if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx());
3107 
3108     if (jcp.dsrc_dt == data_type::bf16) {
3109         Ymm ymm_out = Ymm(zmm_out.getIdx());
3110         vcvtneps2bf16(ymm_out, zmm_out);
3111         vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
3112     } else {
3113         vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
3114     }
3115 }
3116 
store_output_vector_int8(const Zmm & zmm_out,int icb,int h,int w)3117 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3118         const Zmm &zmm_out, int icb, int h, int w) {
3119     const int nb_ic_block = jcp.nb_ic_blocking;
3120     const int ic_block = jcp.ic_block;
3121     const bool mask_flag = true && jcp.ic_without_padding != jcp.ic
3122             && icb == (nb_ic_block - 1);
3123 
3124     auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3125 
3126     const auto &p = attr_.post_ops_;
3127     const int sum_idx = p.find(primitive_kind::sum);
3128     const float *p_sum_scale = nullptr;
3129     if (sum_idx != -1) {
3130         const auto &p_entry = p.entry_[sum_idx];
3131         p_sum_scale = &p_entry.sum.scale;
3132     }
3133 
3134     if (p_sum_scale && *p_sum_scale != 1.f)
3135         mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
3136 
3137     int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
3138     if (jcp.with_bias) {
3139         int bias_offset = jcp.typesize_bia * icb * ic_block;
3140         auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3141         cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
3142     }
3143     /* add bias to zmm_accum */
3144     vcvtdq2ps(zmm_out, zmm_out);
3145     if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
3146     const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
3147     vmulps(zmm_out_msk, zmm_out,
3148             EVEX_compress_addr(reg_ptr_scales, scale_offset));
3149 
3150     /* Do post-ops */
3151     if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3152     if (p_sum_scale) { // post_op: sum
3153         cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag);
3154         if (*p_sum_scale == 1.f)
3155             vaddps(zmm_out, zmm_prev_dst);
3156         else
3157             vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
3158     }
3159     if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3160 
3161     // Properly saturate the accumulators for integer datatypes
3162     if (one_of(jcp.dsrc_dt, u8, s8, s32)) {
3163         init_saturate_f32(
3164                 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dsrc_dt);
3165         saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dsrc_dt);
3166         vcvtps2dq(zmm_out, zmm_out);
3167     }
3168 
3169     const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
3170 
3171     switch (jcp.dsrc_dt) {
3172         case data_type::f32:
3173         case data_type::s32: vmovups(addr, zmm_out_store); break;
3174         case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
3175         case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
3176         default: assert(!"unknown dst_dt");
3177     }
3178 }
3179 
store_output_vector(const Zmm & zmm_out,int icb,int h,int w)3180 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector(
3181         const Zmm &zmm_out, int icb, int h, int w) {
3182     /*
3183     Output:
3184               jcp.is_nspc              !jcp.is_nspc
3185               ---------------------    ---------------------
3186         INT8: [N][H][W][NBIC][16IC]
3187         BF16: [N][H][W][NBIC][16IC] or [N][NBIC][H][W][16IC]
3188     */
3189     if (jcp.ddst_dt == data_type::bf16) {
3190         store_output_vector_bf16(zmm_out, icb, h, w);
3191     } else {
3192         store_output_vector_int8(zmm_out, icb, h, w);
3193     }
3194 }
3195 
store_output(int width,bool do_store)3196 void jit_avx512_core_amx_bwd_data_kernel_t::store_output(
3197         int width, bool do_store) {
3198     auto store_output_block = [=](int width, bool do_store,
3199                                       bool is_last_ih_blks) {
3200         // Calculate the number of ih blocks; it may differ on last call
3201         const int n_ih_blks = is_last_ih_blks ? jcp.ih % jcp.nb_ih_blocking
3202                                               : jcp.nb_ih_blocking;
3203         for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3204             for (int ihb = 0; ihb < n_ih_blks; ihb++) {
3205                 /* Formats: Workspace: [NBIH][NBIC][W][16OC] */
3206                 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
3207                                    + get_wsp_icb_offset(ihb, icb)],
3208                         Tmm(get_out_tensor(ihb, icb)));
3209                 is_buffer_empty_ = false;
3210                 is_store_done_ = false;
3211                 for (int tw = 0; tw < width && do_store; tw++) {
3212                     Zmm zmm_out = Zmm(tw);
3213                     vmovups(zmm_out,
3214                             ptr[reg_wsp_ptr
3215                                     + get_wsp_row_offset(ihb, icb, tw)]);
3216                     store_output_vector(zmm_out, icb, ihb, tw);
3217                 }
3218             }
3219         }
3220     };
3221 
3222     // adjustment in case interleave store is turned off
3223     do_store = do_store || jcp.per_one_pstore == 0;
3224     if (jcp.ih % jcp.nb_ih_blocking == 0) {
3225         store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3226     } else {
3227         Label label_full_store, label_done;
3228         cmp(reg_last_h, 0);
3229         jne(label_full_store, T_NEAR);
3230         store_output_block(width, do_store, /* is_last_ih_blks = */ true);
3231         jmp(label_done, T_NEAR);
3232         L(label_full_store);
3233         store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3234         L(label_done);
3235     }
3236     if (do_store) add(reg_out_ptr, get_out_shift(width));
3237 }
3238 
interleave_store(int width)3239 void jit_avx512_core_amx_bwd_data_kernel_t::interleave_store(int width) {
3240     for (int c = 0;
3241             c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
3242             c++) {
3243         // row_count = ihb * ICB * TW + icb * TW + tw
3244         int tw = row_count_ % prv_width_;
3245         int icb = (row_count_ / prv_width_) % jcp.nb_ic_blocking;
3246         int ihb = (row_count_ / prv_width_) / jcp.nb_ic_blocking;
3247 
3248         Zmm zmm_out = Zmm(tw);
3249         vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]);
3250         store_output_vector(zmm_out, icb, ihb, tw);
3251         row_count_++;
3252 
3253         if (row_count_
3254                 == prv_width_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking) {
3255             add(reg_out_ptr, get_out_shift(prv_width_));
3256             row_count_ = 0;
3257             is_store_done_ = true;
3258             prv_width_ = width;
3259         }
3260     }
3261 }
3262 
compute_ocb_loop(int width,bool do_store)3263 void jit_avx512_core_amx_bwd_data_kernel_t::compute_ocb_loop(
3264         int width, bool do_store) {
3265 
3266     auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
3267         switch (jcp.ddst_dt) {
3268             using namespace data_type;
3269             case bf16: tdpbf16ps(x1, x2, x3); break;
3270             case s8: tdpbssd(x1, x2, x3); break;
3271             case u8: tdpbusd(x1, x2, x3); break;
3272             default: assert(!"unsupported data type");
3273         }
3274     };
3275 
3276     prepare_output();
3277 
3278     for (int ocb = 0; ocb < jcp.nb_oc_int; ocb++) {
3279         // reverse order through spatial components of weights so that
3280         // input buffer is accessed in a monotonically increasing fashion
3281         for (int kh = jcp.kh - 1; kh >= 0; kh--) {
3282             for (int kw = jcp.kw - 1; kw >= 0; kw--) {
3283                 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3284                     tileloadd(Tmm(get_inp_tensor(ihb)),
3285                             ptr[reg_inp_ptr + get_inp_offset(ihb, kh, kw)
3286                                     + reg_inp_stride]);
3287                 }
3288                 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3289                     tileloadd(Tmm(get_wei_tensor(icb)),
3290                             ptr[reg_wei_ptr + get_wei_offset(icb, kh, kw)
3291                                     + reg_wei_stride]);
3292                     for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3293                         tdpbxxd(Tmm(get_out_tensor(ihb, icb)),
3294                                 Tmm(get_inp_tensor(ihb)),
3295                                 Tmm(get_wei_tensor(icb)));
3296                         interleave_store(width);
3297                     }
3298                 }
3299             }
3300         }
3301         add(reg_inp_ptr, get_inp_ocb_step());
3302         add(reg_wei_ptr, get_wei_ocb_step());
3303     }
3304     sub(reg_inp_ptr, get_inp_ocb_step() * jcp.nb_oc_int);
3305     sub(reg_wei_ptr, get_wei_ocb_step() * jcp.nb_oc_int);
3306 
3307     store_output(width, do_store);
3308 
3309     add(reg_inp_ptr, get_inp_shift());
3310 }
3311 
compute_iw_loop()3312 void jit_avx512_core_amx_bwd_data_kernel_t::compute_iw_loop() {
3313     auto compute_iw_loop_body = [=](bool last_iwb, int num_tile_blocks) {
3314         int gen_tile_tail = last_iwb && jcp.tile_tail > 0 ? jcp.tile_tail
3315                                                           : jcp.tile_width;
3316         init_runtime_counters(last_iwb && num_tile_blocks == 1);
3317         for (int iwb = 0; iwb < num_tile_blocks - 1; iwb++)
3318             compute_ocb_loop(jcp.tile_width, false);
3319         compute_ocb_loop(gen_tile_tail, true);
3320     };
3321 
3322     if (jcp.nb_iw == 1) {
3323         compute_iw_loop_body(true, jcp.iw_blocks);
3324     } else {
3325         Label label_done;
3326         int iw_blocks_per_call = div_up(jcp.iw_block, jcp.tile_width);
3327         int last_iwb_tile_blocks = jcp.iw_blocks % iw_blocks_per_call;
3328         if (last_iwb_tile_blocks == 0 && jcp.tile_tail > 0)
3329             last_iwb_tile_blocks = iw_blocks_per_call;
3330         if (last_iwb_tile_blocks > 0) {
3331             Label label_not_last_iwb;
3332             mov(reg_tmp, ptr[param1 + GET_OFF(iwb)]);
3333             cmp(reg_tmp, jcp.nb_iw - 1);
3334             jne(label_not_last_iwb, T_NEAR);
3335 
3336             compute_iw_loop_body(true, last_iwb_tile_blocks);
3337 
3338             jmp(label_done, T_NEAR);
3339 
3340             L(label_not_last_iwb);
3341         }
3342         compute_iw_loop_body(false, iw_blocks_per_call);
3343 
3344         L(label_done);
3345     }
3346 }
3347 
generate()3348 void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
3349     preamble();
3350 
3351     mov(reg_inp_ptr, ptr[param1 + GET_OFF(dst)]); // padded buffer of diff_dst
3352     mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); // weights
3353     mov(reg_out_ptr, ptr[param1 + GET_OFF(src)]); // diff_src
3354     mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
3355 
3356     if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
3357 
3358     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
3359 
3360     mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
3361 
3362     const int inp_stride = jcp.oc_block_int * jcp.typesize_in;
3363     const int wei_stride = jcp.ic_block * jcp.typesize_acc;
3364     mov(reg_inp_stride, inp_stride);
3365     mov(reg_wei_stride, wei_stride);
3366 
3367     if (jcp.is_nspc && jcp.ic_without_padding != jcp.ic) {
3368         // Use mask 0xF by default for all output data and post-ops
3369         // loads / stores with block index
3370         // icb = icc * jcp.nb_ic_blocking + (jcp.nb_ic_blocking - 1)
3371         // TODO: use masked loads / stores for the last icc only
3372         int current_block_size = jcp.ic_block;
3373         int mask = (1 << current_block_size) - 1;
3374         Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
3375         mov(regw_tmp, mask);
3376         kmovw(ktail_mask, regw_tmp);
3377         Xbyak::Label mask_is_set;
3378         mov(reg_ic_blocks, ptr[param1 + GET_OFF(ic_blocks)]);
3379         cmp(reg_ic_blocks, jcp.nb_ic - jcp.nb_ic_blocking);
3380         jne(mask_is_set, T_NEAR);
3381         // Reset the mask
3382         current_block_size = jcp.ic_without_padding % jcp.ic_block;
3383         mask = (1 << current_block_size) - 1;
3384         mov(regw_tmp, mask);
3385         kmovw(ktail_mask, regw_tmp);
3386 
3387         L(mask_is_set);
3388     }
3389     compute_iw_loop();
3390 
3391     postamble();
3392 
3393     if (jcp.with_eltwise) eltwise_injector_->prepare_table();
3394 }
3395 
post_ops_ok(const jit_conv_conf_t & jcp,const primitive_attr_t & attr)3396 bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
3397         const jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
3398     using namespace primitive_kind;
3399     const auto &p = attr.post_ops_;
3400     const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
3401 
3402     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
3403 
3404     auto is_sum = [&](int idx) {
3405         if (is_bf16)
3406             return p.entry_[idx].is_sum();
3407         else
3408             return p.contain(sum, idx);
3409     };
3410 
3411     switch (p.len()) {
3412         case 0: return true;
3413         case 1: return is_eltwise(0) || is_sum(0);
3414         case 2:
3415             return (is_sum(0) && is_eltwise(1))
3416                     || (!is_bf16 && is_sum(1) && is_eltwise(0));
3417         default: return false;
3418     }
3419 
3420     return false;
3421 }
3422 
tile_configure(char * tcfg_buff)3423 void jit_avx512_core_amx_bwd_data_kernel_t::tile_configure(char *tcfg_buff) {
3424     const int vnni_width = jcp.ddst_dt == data_type::bf16 ? 2 : 4;
3425     // Input tile dimensions
3426     const int a_col = jcp.oc_block_int;
3427     const int a_row = jcp.tile_width;
3428     // Weights tile dimensions
3429     const int b_col = jcp.ic_block * vnni_width;
3430     const int b_row = a_col / vnni_width;
3431     // Accumulator tile dimensions
3432     const int c_col = jcp.ic_block;
3433     const int c_row = a_row;
3434 
3435     for (size_t i = 0; i < 64; i++)
3436         tcfg_buff[i] = 0;
3437 
3438     // Weights (W_BASE) Tensor Tiles
3439     for (int i = 0; i < jcp.nb_ic_blocking; i++)
3440         tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
3441                 b_row, b_col * jcp.typesize_in);
3442 
3443     // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
3444     for (int h = 0; h < jcp.nb_ih_blocking; h++) {
3445         tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
3446                 a_row, a_col * jcp.typesize_in);
3447         for (int i = 0; i < jcp.nb_ic_blocking; i++)
3448             tc_configure_tile((palette_config_t *)tcfg_buff,
3449                     get_out_tensor(h, i), c_row, c_col * jcp.typesize_acc);
3450     }
3451 
3452     ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3453 }
3454 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & diff_src_md,memory_desc_t & weights_md,memory_desc_t & diff_dst_md,memory_desc_t * bias_md,const primitive_attr_t & attr,int nthreads)3455 status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3456         const convolution_desc_t &cd, memory_desc_t &diff_src_md,
3457         memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
3458         memory_desc_t *bias_md, const primitive_attr_t &attr, int nthreads) {
3459     using namespace prop_kind;
3460 
3461     const memory_desc_wrapper diff_src_d(&diff_src_md);
3462     const memory_desc_wrapper weights_d(&weights_md);
3463     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
3464     const memory_desc_wrapper bias_d(bias_md);
3465 
3466     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
3467     int ndims = diff_src_d.ndims();
3468     bool is_1d = ndims == 3;
3469     bool is_3d = ndims == 5;
3470 
3471     if (is_3d) return status::unimplemented;
3472 
3473     using namespace data_type;
3474     const bool is_deconv = cd.prop_kind != prop_kind::backward_data;
3475     const bool is_bf16_convolution = !is_deconv
3476             && everyone_is(true, diff_dst_d.data_type() == bf16,
3477                     weights_d.data_type() == bf16,
3478                     one_of(diff_src_d.data_type(), bf16, f32));
3479     const bool is_int8_deconvolution = is_deconv
3480             && everyone_is(true, one_of(diff_dst_d.data_type(), s8, u8),
3481                     weights_d.data_type() == s8,
3482                     one_of(diff_src_d.data_type(), f32, s32, s8, u8));
3483 
3484     bool supported = false
3485             || (is_bf16_convolution && mayiuse(avx512_core_bf16_amx_bf16))
3486             || (is_int8_deconvolution && mayiuse(avx512_core_bf16_amx_int8));
3487     if (!supported) return status::unimplemented;
3488 
3489     jcp = zero<decltype(jcp)>();
3490     jcp.isa = is_bf16_convolution ? avx512_core_bf16_amx_bf16
3491                                   : avx512_core_bf16_amx_int8;
3492     jcp.ndims = ndims;
3493     jcp.prop_kind = cd.prop_kind;
3494     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
3495 
3496     jcp.mb = diff_src_d.dims()[0];
3497     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
3498     jcp.oc_without_padding = jcp.oc;
3499     jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
3500     jcp.ic_without_padding = jcp.ic;
3501     jcp.ih = !is_1d ? diff_src_d.dims()[ndims - 2] : 1;
3502     jcp.iw = diff_src_d.dims()[ndims - 1];
3503     jcp.oh = !is_1d ? diff_dst_d.dims()[ndims - 2] : 1;
3504     jcp.ow = diff_dst_d.dims()[ndims - 1];
3505     jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
3506     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
3507     jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
3508     jcp.l_pad = cd.padding[0][ndims - 3];
3509     jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
3510     jcp.stride_w = cd.strides[ndims - 3];
3511 
3512     // No bias for bf16 case to simplify integration with ref_deconvolution
3513     jcp.with_bias = bias_md && !is_bf16_convolution
3514             && cd.bias_desc.format_kind != format_kind::undef;
3515 
3516     jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
3517     jcp.dilate_w = cd.dilates[ndims - 3];
3518 
3519     const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
3520     const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
3521     jcp.b_pad = calculate_end_padding(
3522             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
3523     jcp.r_pad = calculate_end_padding(
3524             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
3525     if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
3526             || jcp.b_pad >= gen_kh)
3527         return status::unimplemented;
3528 
3529     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
3530     if (is_deconv) {
3531         jcp.ddst_dt = cd.src_desc.data_type;
3532         jcp.dsrc_dt = cd.dst_desc.data_type;
3533     } else {
3534         jcp.ddst_dt = cd.diff_dst_desc.data_type;
3535         jcp.dsrc_dt = cd.diff_src_desc.data_type;
3536     }
3537     jcp.wei_dt = cd.weights_desc.data_type;
3538 
3539     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
3540 
3541     if (jcp.is_depthwise)
3542         return status::unimplemented; // TODO: add support of DW convolution
3543 
3544     format_tag_t dat_tag_ncsp
3545             = pick(ndims - 3, format_tag::nCw16c, format_tag::nChw16c);
3546     format_tag_t dat_tag_nspc
3547             = pick(ndims - 3, format_tag::nwc, format_tag::nhwc);
3548     // To toggle the default data layout for BF16 between nChw16c and nhwc,
3549     // swap the following two variable definitions. Current choice: nhwc.
3550     format_tag_t dat_tag_opt = dat_tag_nspc;
3551     format_tag_t dat_tag_alt
3552             = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
3553 
3554     if (diff_src_d.format_kind() == format_kind::any) {
3555         CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_opt));
3556         jcp.src_tag = dat_tag_opt;
3557     } else
3558         jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
3559 
3560     if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
3561         return status::unimplemented;
3562 
3563     jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3564     assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc));
3565 
3566     // TODO: remove all support for nChw16c from this implementation
3567     if (!jcp.is_nspc) return status::unimplemented;
3568 
3569     if (diff_dst_d.format_kind() == format_kind::any) {
3570         CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
3571         jcp.dst_tag = jcp.src_tag;
3572     } else
3573         jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
3574 
3575     if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
3576 
3577     if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
3578         CHECK(memory_desc_init_by_tag(*bias_md, format_tag::x));
3579 
3580     jcp.nthr = nthreads;
3581 
3582     jcp.ic_block = 16;
3583     jcp.oc_block = 16;
3584 
3585     if (jcp.ngroups == 1) {
3586         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
3587         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
3588     }
3589     bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
3590     if (!args_ok) return status::unimplemented;
3591 
3592     const int vnni_width = is_bf16_convolution ? 2 : 4;
3593     jcp.oc_block_int = jcp.oc_block * vnni_width; // 32 for bf16, 64 for int8
3594 
3595     if (!post_ops_ok(jcp, attr)) return status::unimplemented;
3596 
3597     const auto &p = attr.post_ops_;
3598     const int eltwise_ind = p.find(primitive_kind::eltwise);
3599     jcp.with_eltwise = eltwise_ind != -1;
3600     if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
3601 
3602     auto set_or_check_wei_format = [&]() {
3603         using namespace format_tag;
3604         format_tag_t wei_tag;
3605         if (is_bf16_convolution)
3606             wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16o16i2o,
3607                     gOIw16o16i2o, OIhw16o16i2o, gOIhw16o16i2o);
3608         else if (is_int8_deconvolution)
3609             wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i,
3610                     gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i);
3611         else {
3612             assert(!"unsupported combination");
3613             return false;
3614         }
3615 
3616         memory_desc_t want_wei_md = weights_md;
3617         memory_desc_init_by_tag(want_wei_md, wei_tag);
3618 
3619         if (weights_md.format_kind == format_kind::any) {
3620             weights_md = want_wei_md;
3621             return true;
3622         }
3623         return weights_md == want_wei_md;
3624     };
3625 
3626     if (!set_or_check_wei_format()) return status::unimplemented;
3627 
3628     jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
3629     jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
3630     jcp.typesize_bia
3631             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
3632     jcp.typesize_acc = sizeof(int32_t);
3633 
3634     jcp.nb_ic = jcp.ic / jcp.ic_block;
3635     jcp.nb_oc = jcp.oc / jcp.oc_block;
3636     jcp.nb_oc_int = div_up(jcp.oc, jcp.oc_block_int);
3637 
3638     const int max_palette = amx::get_max_palette();
3639     jcp.max_tiles = amx::get_max_tiles(max_palette);
3640     jcp.full_tile_width = amx::get_max_rows(max_palette);
3641     if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
3642         return status::unimplemented;
3643 
3644     jcp.tile_width = nstl::min(jcp.full_tile_width, jcp.iw);
3645     jcp.iw_blocks = div_up(jcp.iw, jcp.tile_width);
3646 
3647     // Prefer to use a single tile width when possible
3648     // (eg iw28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
3649     if (jcp.iw % jcp.iw_blocks == 0) jcp.tile_width = jcp.iw / jcp.iw_blocks;
3650     jcp.tile_tail = jcp.iw % jcp.tile_width;
3651 
3652     jcp.nb_ic_blocking = (jcp.nb_ic % 2 == 0) ? 2 : 1;
3653     jcp.nb_ih_blocking
3654             = everyone_is(true, jcp.ih > 1,
3655                       // requirement for interleave stores
3656                       IMPLICATION(jcp.iw_blocks > 1, jcp.ih % 2 == 0))
3657             ? 2
3658             : 1;
3659 
3660     // TODO: tune ih blocking
3661     const int ih_blk_size_tmp = 10;
3662     const int ih_step = jcp.nb_ih_blocking;
3663     jcp.ih_blk_size = rnd_up(nstl::min(jcp.ih, ih_blk_size_tmp), ih_step);
3664     // ohp includes all elements that are really used in calculation,
3665     // including zero-padded "dilate-by-strides" and top and bottom overflow
3666     jcp.ohp = jcp.ih_blk_size + gen_kh - 1;
3667 
3668     // TODO: tune iw blocking
3669     const int iw_blocks_per_call = 2;
3670     jcp.iw_block = jcp.tile_width * iw_blocks_per_call;
3671     jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
3672     // owp includes all elements that are really used in calculation,
3673     // including zero-padded "dilate-by-strides" and left and right overflow
3674     jcp.owp = jcp.iw_block + gen_kw - 1;
3675 
3676     // Number of ops per tile store
3677     int ops_tile_store = jcp.tile_width;
3678     // Number of ops per accumulation tile
3679     int avaliable_ops = jcp.nb_oc_int * jcp.kh * jcp.kw;
3680     // Number of vectors to store per tile operation
3681     // NOTE: set to zero to turn off interleave store (mostly for debugging)
3682     jcp.per_one_pstore = div_up(ops_tile_store, avaliable_ops);
3683 
3684     jcp.inp_buffer_size
3685             = (size_t)jcp.nb_oc_int * jcp.ohp * jcp.owp * jcp.oc_block_int;
3686     jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking
3687             * jcp.full_tile_width * jcp.ic_block;
3688 
3689     const auto &oscales = attr.output_scales_;
3690     jcp.is_ic_scale = oscales.mask_ == 1 << 1;
3691 
3692     return status::success;
3693 }
3694 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)3695 void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad(
3696         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
3697         const primitive_attr_t &attr) {
3698 
3699     size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
3700     scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
3701     size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
3702     scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
3703     if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) {
3704         assert(jcp.ngroups == 1);
3705         scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia);
3706     }
3707     scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
3708 }
3709 
3710 const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32;
3711 
3712 // Tile register decomposition
3713 // { C_BASE = 0, A_BASE = 4, B_BASE = 6, }
get_wei_tensor(int ocb,int icb) const3714 int jit_avx512_core_amx_bwd_weights_kernel_t::get_wei_tensor(
3715         int ocb, int icb) const {
3716     const int C_BASE = 0;
3717     const int C_LAST = 4;
3718     assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
3719     MAYBE_UNUSED(C_LAST);
3720     const int tile = C_BASE + ocb * jcp.nb_oc_blocking + icb;
3721     assert(C_BASE <= tile && tile < C_LAST);
3722     return tile;
3723 }
get_src_tensor(int icb) const3724 int jit_avx512_core_amx_bwd_weights_kernel_t::get_src_tensor(int icb) const {
3725     const int A_BASE = 4;
3726     const int A_LAST = 6;
3727     assert(0 <= A_BASE && A_BASE < A_LAST && A_LAST <= jcp.max_tiles);
3728     MAYBE_UNUSED(A_LAST);
3729     const int tile = A_BASE + icb;
3730     assert(A_BASE <= tile && tile < A_LAST);
3731     return tile;
3732 }
get_ddst_tensor(int ocb) const3733 int jit_avx512_core_amx_bwd_weights_kernel_t::get_ddst_tensor(int ocb) const {
3734     const int B_BASE = 6;
3735     const int B_LAST = 8;
3736     assert(0 <= B_BASE && B_BASE < B_LAST && B_LAST <= jcp.max_tiles);
3737     MAYBE_UNUSED(B_LAST);
3738     const int tile = B_BASE + ocb;
3739     assert(B_BASE <= tile && tile < B_LAST);
3740     return tile;
3741 }
3742 
tile_configure(char * tcfg_buff)3743 void jit_avx512_core_amx_bwd_weights_kernel_t::tile_configure(char *tcfg_buff) {
3744     // Input tile dimensions
3745     const int a_col = jcp.ur_w;
3746     const int a_row = jcp.ic_block;
3747     // Weights tile dimensions
3748     const int b_col = jcp.oc_block * 2;
3749     const int b_row = a_col / 2;
3750     // Accumulator tile dimensions
3751     const int c_col = jcp.oc_block;
3752     const int c_row = a_row;
3753 
3754     for (size_t i = 0; i < 64; i++)
3755         tcfg_buff[i] = 0;
3756 
3757     for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3758         tc_configure_tile((palette_config_t *)tcfg_buff, get_src_tensor(icb),
3759                 a_row, a_col * jcp.typesize_in);
3760 
3761     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3762         tc_configure_tile((palette_config_t *)tcfg_buff, get_ddst_tensor(ocb),
3763                 b_row, b_col * jcp.typesize_in);
3764 
3765     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3766         for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3767             tc_configure_tile((palette_config_t *)tcfg_buff,
3768                     get_wei_tensor(ocb, icb), c_row, c_col * jcp.typesize_out);
3769 
3770     ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3771 }
3772 
od_step_comeback_pointers()3773 void jit_avx512_core_amx_bwd_weights_kernel_t::od_step_comeback_pointers() {
3774     Label kd_comeback_label;
3775     mov(kj, reg_kd_count);
3776     L(kd_comeback_label);
3777     {
3778         sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3779         sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3780         dec(kj);
3781         jnz(kd_comeback_label, T_NEAR);
3782     }
3783 }
3784 
oh_step_comeback_pointers()3785 void jit_avx512_core_amx_bwd_weights_kernel_t::oh_step_comeback_pointers() {
3786     Label kh_comeback_label;
3787     mov(kj, reg_kh);
3788     L(kh_comeback_label);
3789     {
3790         sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3791         sub(reg_kernel, get_kernel_offset(0, jcp.kw));
3792         dec(kj);
3793         jnz(kh_comeback_label, T_NEAR);
3794     }
3795 }
3796 
compute_full_spat_loop(int nb_ic_blocking,int nb_oc_blocking)3797 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop(
3798         int nb_ic_blocking, int nb_oc_blocking) {
3799     // General code layout:
3800     //
3801     // Blocking over OH -- top level
3802     // (Reduces L2 pressure; not very useful right now)
3803     //  Loop over all KHxKW kernel -- emit_kh_kw_loop()
3804     //    Loop over OH block -- emit_h_loop()
3805     //      Loop over OW blocks -- emit_fma_block()
3806     //      (Supports both fully unrolled and partially unrolled
3807     //      versions to reduce code size)
3808     //          Loop over OW block -- emit_fma_step()
3809 
3810     auto src_row_size = get_src_offset(0, 0, 1);
3811     auto ddst_row_size = get_ddst_offset(0, 1);
3812     auto row_size = src_row_size + ddst_row_size;
3813 
3814     int h_block_size = jcp.oh;
3815     int h_last_block_size = h_block_size;
3816     int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3817     auto working_set_size = row_size * h_block_size;
3818 
3819     if (working_set_size > full_spat_max_working_set_size) {
3820         assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3821 
3822         while (working_set_size > full_spat_opt_working_set_size
3823                 && h_block_size >= min_h_block_size) {
3824             for (int i = 2; i <= h_block_size; i++)
3825                 if (i == h_block_size)
3826                     h_block_size = h_block_size / 2;
3827                 else if (h_block_size % i == 0) {
3828                     h_block_size = h_block_size / i;
3829                     break;
3830                 }
3831             working_set_size = row_size * h_block_size;
3832         }
3833         h_block_size = nstl::max(min_h_block_size, h_block_size);
3834         h_last_block_size = jcp.oh % h_block_size;
3835         if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3836     }
3837 
3838     Opmask reg_h_block = k1;
3839     Reg64 reg_kh = rax;
3840     Reg64 reg_kw = rbx;
3841     Reg64 reg_tmp = abi_not_param1;
3842     Reg32 reg_tmp_w = reg_tmp.cvt32();
3843     Reg64 reg_ohs = rdx;
3844     Reg64 reg_ihs = rsi;
3845     Reg64 reg_h = r8;
3846     Reg64 reg_j = r10;
3847 
3848     Reg64 reg_src = r13;
3849     Reg64 reg_ddst = r14;
3850     Reg64 reg_ker = r15;
3851 
3852     Reg64 reg_dense_stride = abi_param1;
3853     Reg64 reg_a_stride = reg_tmp;
3854 
3855     auto emit_block = [&]() {
3856         mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
3857         for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
3858             dim_t ur_w_src_offset = ur_w_b * get_src_offset(0, jcp.ur_w);
3859             dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
3860 
3861             for (int icb = 0; icb < nb_ic_blocking; icb++) {
3862                 dim_t icb_offset = jcp.typesize_in * icb * jcp.tr_src_buf_size;
3863                 tileloadd(Tmm(get_src_tensor(icb)),
3864                         ptr[reg_src + reg_a_stride + icb_offset
3865                                 + ur_w_src_offset]);
3866             }
3867             for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
3868                 tileloadd(Tmm(get_ddst_tensor(ocb)),
3869                         ptr[reg_ddst + reg_dense_stride
3870                                 + jcp.typesize_in * ocb
3871                                         * jcp.tr_diff_dst_buf_size
3872                                 + ur_w_ddst_offset]);
3873                 for (int icb = 0; icb < nb_ic_blocking; icb++)
3874                     tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
3875                             Tmm(get_src_tensor(icb)),
3876                             Tmm(get_ddst_tensor(ocb)));
3877             }
3878         }
3879     };
3880 
3881     auto emit_h_loop = [&]() {
3882         Label h_loop, skip_h_loop;
3883         mov(reg_j, 1);
3884         cmp(reg_j, reg_h);
3885         je(skip_h_loop, T_NEAR);
3886         L(h_loop);
3887         {
3888             emit_block();
3889 
3890             add(reg_src, get_src_offset(0, 0, 1));
3891             add(reg_ddst, get_ddst_offset(0, 1));
3892             add(reg_j, 1);
3893             cmp(reg_j, reg_h);
3894             jb(h_loop);
3895         }
3896         L(skip_h_loop);
3897 
3898         emit_block();
3899     };
3900 
3901     auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3902         xor_(reg_kh, reg_kh);
3903         Label kh_loop, kh_loop_end;
3904 
3905         int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3906         // NB: this is correct because we only support t_pad = kh / 2 and thus
3907         // ih == oh
3908         int ih_block_size = oh_block_size
3909                 + (!is_first_block + !is_last_block) * jcp.t_pad;
3910 
3911         L(kh_loop);
3912         {
3913             if (is_first_block) {
3914                 xor_(reg_tmp, reg_tmp);
3915                 mov(reg_ohs, jcp.t_pad);
3916                 sub(reg_ohs, reg_kh);
3917                 cmovb(reg_ohs, reg_tmp);
3918 
3919                 mov(reg_ihs, reg_ohs);
3920                 sub(reg_ihs, jcp.t_pad);
3921                 add(reg_ihs, reg_kh);
3922             } else {
3923                 xor_(reg_ohs, reg_ohs);
3924                 mov(reg_ihs, reg_kh);
3925             }
3926 
3927             mov(reg_tmp, oh_block_size);
3928             sub(reg_tmp, reg_ohs);
3929             mov(reg_h, ih_block_size);
3930             sub(reg_h, reg_ihs);
3931             cmp(reg_tmp, reg_h);
3932             cmovb(reg_h, reg_tmp);
3933 
3934             Label kh_loop_work;
3935             cmp(reg_h, 0);
3936             jg(kh_loop_work, T_NEAR);
3937 
3938             // empty h loop for this jcp.kh:
3939             // - set the ddst to 0 if necessary
3940             // - move ker pt
3941             // - jump to the end
3942             sub(reg_h, 1);
3943             Label skip_ker_zeroing;
3944 
3945             // The reg_ker ptr has highest bit set if the ddst needs to be
3946             // zeroed. Those who have byte-aligned their data will suffer the
3947             // consequences :(
3948             // TODO: move the flag to a mask register? (Roma)
3949             test(reg_ker, 1);
3950             jz(skip_ker_zeroing, T_NEAR);
3951 
3952             Label zeroing_loop;
3953             vpxord(zmm0, zmm0, zmm0);
3954             and_(reg_ker, ~1); // temporarily clear the zeroing flag
3955 
3956             mov(reg_dense_stride, 64);
3957             tilezero(Tmm(get_wei_tensor(0, 0)));
3958             for (int kw = 0; kw < jcp.kw; kw++) {
3959                 // dim_t kw_offset = kw * get_kernel_offset(jcp.ic_block, 0);
3960                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
3961                 for (int icb = 0; icb < nb_ic_blocking; icb++)
3962                     tilestored(
3963                             ptr[reg_ker + reg_dense_stride
3964                                     + get_full_kernel_offset(ocb, icb, 0, kw)],
3965                             Tmm(get_wei_tensor(0, 0)));
3966             }
3967             // restore the zeroing flag (it will be cleared after the end of
3968             // emit_kh_kw_loop, but we may need it until then)
3969             or_(reg_ker, 1);
3970             jmp(kh_loop_end, T_NEAR);
3971 
3972             L(skip_ker_zeroing);
3973             add(reg_ker, get_kernel_offset(0, jcp.kw));
3974             jmp(kh_loop_end, T_NEAR);
3975 
3976             L(kh_loop_work);
3977 
3978             mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
3979             mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
3980 
3981             add(reg_src, reg_ihs);
3982             add(reg_ddst, reg_ohs);
3983 
3984             Label kw_loop;
3985             xor_(reg_kw, reg_kw);
3986 
3987             mov(reg_dense_stride, 64);
3988             L(kw_loop);
3989             {
3990                 Label do_zero, ker_init_done;
3991                 test(reg_ker, 1);
3992                 jnz(do_zero, T_NEAR);
3993 
3994                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
3995                 for (int icb = 0; icb < nb_ic_blocking; icb++)
3996                     tileloadd(Tmm(get_wei_tensor(ocb, icb)),
3997                             ptr[reg_ker + reg_dense_stride
3998                                     + get_full_kernel_offset(ocb, icb, 0, 0)]);
3999                 jmp(ker_init_done);
4000                 L(do_zero);
4001                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4002                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4003                     tilezero(Tmm(get_wei_tensor(ocb, icb)));
4004 
4005                 L(ker_init_done);
4006 
4007                 mov(ptr[rsp + ddst_save_offset], reg_ddst);
4008                 mov(ptr[rsp + src_save_offset], reg_src);
4009 
4010                 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
4011                 emit_h_loop();
4012 
4013                 mov(reg_ddst, ptr[rsp + ddst_save_offset]);
4014                 mov(reg_src, ptr[rsp + src_save_offset]);
4015 
4016                 // The reg_ker ptr has highest bit set if the ddst needs to
4017                 // be zeroed. Those who have byte-aligned their data will
4018                 // suffer the consiquences :(
4019                 mov(reg_tmp, reg_ker);
4020                 and_(reg_ker, ~1);
4021 
4022                 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4023                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4024                     tilestored(
4025                             ptr[reg_ker + reg_dense_stride
4026                                     + get_full_kernel_offset(ocb, icb, 0, 0)],
4027                             Tmm(get_wei_tensor(ocb, icb)));
4028 
4029                 mov(reg_ker, reg_tmp);
4030                 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
4031                 add(reg_kw, 1);
4032                 cmp(reg_kw, jcp.kw);
4033                 jl(kw_loop);
4034             }
4035 
4036             sub(reg_src, reg_ihs);
4037             sub(reg_ddst, reg_ohs);
4038 
4039             L(kh_loop_end);
4040             add(reg_kh, 1);
4041             cmp(reg_kh, jcp.kh);
4042             jl(kh_loop);
4043         }
4044     };
4045 
4046     mov(reg_src, ptr[param + GET_OFF(src)]);
4047     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4048     mov(reg_ker, ptr[param + GET_OFF(filt)]);
4049     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4050     or_(reg_ker, reg_tmp);
4051 
4052     bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
4053 
4054     auto src_row_step = get_src_offset(0, 0, 1);
4055     auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
4056     auto ddst_block_step = get_ddst_offset(0, h_block_size);
4057 
4058     emit_kh_kw_loop(true, single_kh_kw_loop);
4059 
4060     if (!single_kh_kw_loop) {
4061         auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
4062         sub(reg_ker, ker_reset_offset);
4063         and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4064 
4065         add(reg_src, first_src_block_step);
4066         add(reg_ddst, ddst_block_step);
4067 
4068         int num_innermost_iters
4069                 = (jcp.oh - h_last_block_size) / h_block_size - 1;
4070         if (num_innermost_iters > 0) {
4071             Label h_block_loop;
4072 
4073             mov(reg_tmp_w, num_innermost_iters);
4074             kmovw(reg_h_block, reg_tmp_w);
4075             L(h_block_loop);
4076             {
4077                 emit_kh_kw_loop(false, false);
4078                 sub(reg_ker, ker_reset_offset);
4079                 add(reg_src, src_row_step * h_block_size);
4080                 add(reg_ddst, ddst_block_step);
4081 
4082                 kmovw(reg_tmp_w, reg_h_block);
4083                 sub(reg_tmp_w, 1);
4084                 kmovw(reg_h_block, reg_tmp_w);
4085                 jnz(h_block_loop);
4086             }
4087         }
4088 
4089         emit_kh_kw_loop(false, true);
4090     }
4091 }
4092 
compute_ic_loop(int ic_block,int nb_ic_blocking,int nb_oc_blocking)4093 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_ic_loop(
4094         int ic_block, int nb_ic_blocking, int nb_oc_blocking) {
4095     assert(jcp.ur_w % 2 == 0);
4096     const int str_w = jcp.stride_w;
4097     assert(jcp.tr_iw % str_w == 0);
4098     const int src_stride_w_shift = jcp.tr_iw / str_w;
4099 
4100     mov(reg_b_stride, 64);
4101     mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
4102 
4103     for (int s = 0; s < str_w; s++) {
4104         for (int i_kw = s; i_kw < jcp.kw; i_kw += str_w) {
4105 
4106             for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4107                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4108                     tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4109                             ptr[reg_kernel + reg_b_stride
4110                                     + get_full_kernel_offset(
4111                                             ocb, icb, 0, i_kw)]);
4112 
4113             int src_offset_l = (i_kw * (jcp.dilate_w + 1)) / str_w
4114                     + s * src_stride_w_shift;
4115 
4116             for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
4117                 dim_t ur_w_src_offset = ur_w_b
4118                         * get_src_offset(0, filter_w_to_src(0, jcp.ur_w, 0));
4119                 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
4120                 for (int icb = 0; icb < nb_ic_blocking; icb++) {
4121                     dim_t icb_offset = icb * jcp.tr_src_buf_size;
4122                     tileloadd(Tmm(get_src_tensor(icb)),
4123                             ptr[reg_src
4124                                     + jcp.typesize_in
4125                                             * (src_offset_l + icb_offset)
4126                                     + ur_w_src_offset + reg_a_stride]);
4127                 }
4128                 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4129                     tileloadd(Tmm(get_ddst_tensor(ocb)),
4130                             ptr[reg_ddst
4131                                     + jcp.typesize_in * ocb
4132                                             * jcp.tr_diff_dst_buf_size
4133                                     + ur_w_ddst_offset + reg_b_stride]);
4134                     for (int icb = 0; icb < nb_ic_blocking; icb++)
4135                         tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
4136                                 Tmm(get_src_tensor(icb)),
4137                                 Tmm(get_ddst_tensor(ocb)));
4138                 }
4139             }
4140 
4141             for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4142                 for (int icb = 0; icb < nb_ic_blocking; icb++)
4143                     tilestored(ptr[reg_kernel + reg_b_stride
4144                                        + get_full_kernel_offset(
4145                                                ocb, icb, 0, i_kw)],
4146                             Tmm(get_wei_tensor(ocb, icb)));
4147         }
4148     }
4149     safe_add(reg_src, get_src_offset(ic_block, 0), reg_long_offt);
4150     add(reg_kernel, get_kernel_offset(ic_block, 0));
4151 }
4152 
compute_diff_bias_init(int ocb)4153 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_init(int ocb) {
4154     auto reg_unit_val = reg_tmp.cvt16();
4155     mov(reg_unit_val, 0x3f80); // bf16 value of 1.
4156     vpbroadcastw(vreg_bias_unit, reg_unit_val);
4157 
4158     mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4159     vmovups(vreg_bias_acc, ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block]);
4160 }
4161 
compute_diff_bias_row(bool is_partial,int ocb)4162 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_row(
4163         bool is_partial, int ocb) {
4164     if (!jcp.with_bias) return;
4165     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4166     Label skip_label;
4167     test(reg_tmp, FLAG_IC_FIRST);
4168     jz(skip_label, T_NEAR);
4169 
4170     if (is_partial) { compute_diff_bias_init(ocb); }
4171     auto compute_step = [&]() {
4172         vmovups(vreg_bias_ddst, ptr[reg_ddst]);
4173         vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
4174     };
4175 
4176     Label ow_loop, ow_tail;
4177     int niters = jcp.tr_ow / 2;
4178     if (niters > 0) {
4179         mov(reg_tmp, jcp.tr_ow / 2);
4180         L(ow_loop);
4181         compute_step();
4182         add(reg_ddst, get_ddst_offset(2));
4183         sub(reg_tmp, 1);
4184         jnz(ow_loop, T_NEAR);
4185     }
4186     if (jcp.tr_ow % 2) compute_step();
4187 
4188     if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
4189 
4190     if (is_partial) {
4191         mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4192         vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4193                 vreg_bias_acc);
4194     }
4195 
4196     L(skip_label);
4197 }
4198 
maybe_compute_diff_bias(int nb_oc_blocking)4199 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_compute_diff_bias(
4200         int nb_oc_blocking) {
4201     // In harness_3d_reduction case calculation of diff_bias is called
4202     // for every ow row separately to be aligned with od loop in
4203     // compute_od_loop_common()
4204     if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
4205     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4206 
4207     Label skip_label;
4208     test(reg_tmp, FLAG_IC_FIRST);
4209     jz(skip_label, T_NEAR);
4210 
4211     for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4212         Label bias_loop, skip_label_local;
4213 
4214         mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4215         add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size);
4216 
4217         switch (jcp.harness) {
4218             case harness_2d_reduction:
4219                 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4220                 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4221                 break;
4222             case harness_mb_reduction:
4223             case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
4224             case harness_3d_reduction:
4225             default: assert(!"Invalid harness type");
4226         }
4227 
4228         cmp(reg_oj, 0);
4229         jle(skip_label_local, T_NEAR); // nothing to do
4230 
4231         compute_diff_bias_init(ocb);
4232         L(bias_loop);
4233         {
4234             compute_diff_bias_row(false, ocb);
4235             add(reg_ddst, get_ddst_offset(0, 1));
4236 
4237             sub(reg_oj, 1);
4238             jnz(bias_loop, T_NEAR);
4239         }
4240 
4241         mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4242         vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4243                 vreg_bias_acc);
4244 
4245         L(skip_label_local);
4246     }
4247     // restore reg_ddst value
4248     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4249 
4250     L(skip_label);
4251 }
4252 
compute_oh_step_common(int nb_ic_blocking,int nb_oc_blocking)4253 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_step_common(
4254         int nb_ic_blocking, int nb_oc_blocking) {
4255     Label kh_label, ic_block_label, ow_block_label, kd_label;
4256 
4257     int ic_block = jcp.ic_block;
4258     int ic_tail = jcp.ic_tail;
4259 
4260     auto ic_loop = [&](int nb_ic_blocking, int nb_oc_blocking) {
4261         Label ic_tail_label, ic_loop_done_label;
4262 
4263         if (ic_tail) {
4264             mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
4265             cmp(reg_icb, jcp.ic_tail);
4266             jne(ic_tail_label, T_NEAR);
4267 
4268             compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4269             jmp(ic_loop_done_label, T_NEAR);
4270 
4271             L(ic_tail_label);
4272             compute_ic_loop(ic_tail, nb_ic_blocking, nb_oc_blocking);
4273             add(reg_kernel, get_kernel_offset(jcp.ic_block - ic_tail, 0));
4274             safe_add(reg_src,
4275                     get_src_offset(0, 0, filter_h_to_src(1))
4276                             - get_src_offset(ic_tail, 0),
4277                     reg_long_offt);
4278             L(ic_loop_done_label);
4279         } else {
4280             compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4281         }
4282     };
4283 
4284     if (jcp.ndims == 5) {
4285         /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
4286          * 'movs' must be guaranteed. */
4287         mov(ki, reg_kd_count);
4288         mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
4289         mov(aux_reg_src, reg_src);
4290         mov(aux_reg_kernel, reg_kernel);
4291 
4292         L(kd_label);
4293         mov(reg_src, aux_reg_src);
4294         mov(reg_kernel, aux_reg_kernel);
4295     }
4296 
4297     mov(kj, reg_kh);
4298     L(kh_label);
4299     {
4300         ic_loop(nb_ic_blocking, nb_oc_blocking);
4301 
4302         if (jcp.dilate_h > 0) {
4303             add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
4304         }
4305         // substract pointer shift made within ic block loop
4306         // and move to next kh index
4307         add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
4308         dec(kj);
4309         cmp(kj, 0);
4310         jg(kh_label, T_NEAR);
4311     }
4312     if (jcp.ndims == 5) {
4313         add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
4314         add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
4315         dec(ki);
4316         cmp(ki, 0);
4317         jg(kd_label, T_NEAR);
4318     }
4319     // In harness_3d_reduction case calculation of diff_bias is called
4320     // for every ow row separately to be aligned with od loop in
4321     // compute_od_loop_common()
4322     if (jcp.harness == harness_3d_reduction) {
4323         auto reg_save_ddst = reg_a_stride;
4324         mov(reg_save_ddst, reg_ddst);
4325         for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4326             safe_add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size,
4327                     reg_long_offt);
4328             compute_diff_bias_row(true, ocb);
4329         }
4330         mov(reg_ddst, reg_save_ddst);
4331     }
4332 
4333     if (jcp.ndims == 5) {
4334         mov(reg_src, aux_reg_src);
4335         mov(reg_kernel, aux_reg_kernel);
4336         mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
4337         od_step_comeback_pointers();
4338     } else {
4339         oh_step_comeback_pointers();
4340     }
4341 }
4342 
maybe_zero_kernel(int nb_ic_blocking,int nb_oc_blocking)4343 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_zero_kernel(
4344         int nb_ic_blocking, int nb_oc_blocking) {
4345     if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
4346     Label skip_zeroing, zeroing_loop;
4347 
4348     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4349     cmp(reg_tmp, 0);
4350     jz(skip_zeroing, T_NEAR);
4351 
4352     Zmm zero = Zmm(0);
4353     vpxord(zero, zero, zero);
4354     if (jcp.with_bias) {
4355         Label skip_bias_zeroing;
4356         mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4357         test(reg_tmp, FLAG_IC_FIRST);
4358         jz(skip_bias_zeroing, T_NEAR);
4359         for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4360             mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4361             vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], zero);
4362         }
4363         L(skip_bias_zeroing);
4364         if (jcp.harness == harness_compute_full_spatial)
4365             jmp(skip_zeroing, T_NEAR);
4366     }
4367 
4368     mov(reg_b_stride, 64);
4369     tilezero(Tmm(get_wei_tensor(0, 0)));
4370     for (dim_t shift = 0;
4371             shift < get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
4372             shift += get_kernel_offset(jcp.ic_block, 0)) {
4373         for_(int icb = 0; icb < nb_ic_blocking; icb++)
4374         for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4375             tilestored(
4376                     ptr[reg_kernel + reg_b_stride
4377                             + get_full_kernel_offset(ocb, icb, 0, 0) + shift],
4378                     Tmm(get_wei_tensor(0, 0)));
4379         }
4380     }
4381     L(skip_zeroing);
4382 }
4383 
compute_oh_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4384 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_loop_common(
4385         int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4386     int b_pad = jcp.b_pad;
4387     int t_pad = jcp.t_pad;
4388 
4389     bool is_dilated = jcp.dilate_h != 0;
4390     int dilate_h = jcp.dilate_h + 1;
4391     int stride_h = jcp.stride_h;
4392     auto filter_step_size = get_kernel_offset(0, jcp.kw);
4393     auto src_step_size = get_src_offset(0, 0, 1);
4394     auto ddst_step_size = get_ddst_offset(0, 1);
4395     Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
4396             oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
4397             oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
4398             oh_dilate_label_end, oh_dilate_setup_label_shift,
4399             oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
4400 
4401     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4402     int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
4403     int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
4404     int oh_head_overflow_end = div_up(t_pad, stride_h);
4405     int oh_tail_end = jcp.oh;
4406 
4407     int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
4408     int ih_body_end
4409             = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
4410 
4411     if (is_partial)
4412         mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4413     else
4414         xor_(reg_oj, reg_oj);
4415 
4416     /* Compute 'top' edge */
4417     if (t_pad > 0) {
4418         if (is_partial) {
4419             cmp(reg_oj, oh_head_overflow_end);
4420             jge(oh_tpad_tail_label_end, T_NEAR);
4421         }
4422         const int overflow
4423                 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
4424         const int underflow = div_up(t_pad, dilate_h);
4425         const int initial_kh = jcp.kh - overflow - underflow;
4426 
4427         // Setup reg_kh, reg_kernel, and reg_src
4428         mov(reg_kh, initial_kh);
4429         add(reg_kernel, filter_step_size * underflow);
4430         if (is_dilated) {
4431             const int tail = t_pad % dilate_h;
4432             const int shift = tail == 0 ? 0 : dilate_h - tail;
4433             mov(reg_ih_shift, shift);
4434             if (!is_partial) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4435             add(reg_src, src_step_size * shift);
4436         }
4437 
4438         if (is_partial) {
4439             Label head_setup, head_setup_finish;
4440             cmp(reg_oj, 0);
4441             je(head_setup_finish, T_NEAR);
4442             mov(reg_oj_setup, reg_oj);
4443 
4444             L(head_setup);
4445             if (is_dilated) {
4446                 inc(reg_ih_shift);
4447                 cmp(reg_ih_shift, dilate_h);
4448                 jl(oh_dilate_setup_label_shift, T_NEAR);
4449                 // unshift src as new kernel element enters
4450                 sub(reg_src, src_step_size * (dilate_h - 1));
4451                 xor_(reg_ih_shift, reg_ih_shift);
4452             }
4453             // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4454             add(reg_kh, stride_h);
4455             sub(reg_kernel, filter_step_size * stride_h);
4456             if (is_dilated) {
4457                 jmp(oh_dilate_setup_label_noshift, T_NEAR);
4458                 L(oh_dilate_setup_label_shift);
4459                 // shift src as old kernel element progresses
4460                 add(reg_src, src_step_size * stride_h);
4461                 L(oh_dilate_setup_label_noshift);
4462             }
4463             sub(reg_oj_setup, 1);
4464             jg(head_setup, T_NEAR);
4465             L(head_setup_finish);
4466 
4467             if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4468             if (oh_head_end < oh_head_overflow_end) {
4469                 cmp(reg_oj, oh_head_end);
4470                 jge(oh_tpad_label_end, T_NEAR);
4471             }
4472         }
4473 
4474         //Setup reg_kernel
4475         // If dilated, shift src ptr
4476         // Loop
4477         L(oh_tpad_label);
4478         compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4479         add(reg_ddst, ddst_step_size);
4480         if (is_dilated) {
4481             mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4482             inc(reg_ih_shift);
4483             mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4484             cmp(reg_ih_shift, dilate_h);
4485             jl(oh_dilate_label_shift, T_NEAR);
4486             // unshift src as new kernel element enters
4487             sub(reg_src, src_step_size * (dilate_h - 1));
4488             xor_(reg_ih_shift, reg_ih_shift);
4489             mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4490         }
4491         // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4492         add(reg_kh, stride_h);
4493         sub(reg_kernel, filter_step_size * stride_h);
4494         if (is_dilated) {
4495             jmp(oh_dilate_label_noshift, T_NEAR);
4496             L(oh_dilate_label_shift);
4497             // shift src as old kernel element progresses
4498             add(reg_src, src_step_size * stride_h);
4499             L(oh_dilate_label_noshift);
4500         }
4501         inc(reg_oj);
4502 
4503         if (is_partial) {
4504             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4505             jge(oh_bpad_label_end, T_NEAR);
4506         }
4507         cmp(reg_oj, oh_head_end);
4508         jl(oh_tpad_label, T_NEAR);
4509 
4510         L(oh_tpad_label_end);
4511         // need second loop to process kernel if it is larger than the src
4512         // (does not apply to dilations as they must have unit stride)
4513         if (oh_head_end < oh_head_overflow_end) {
4514             assert(!is_dilated);
4515 
4516             cmp(reg_oj, oh_head_overflow_end);
4517             jge(oh_tpad_tail_label_end, T_NEAR);
4518 
4519             mov(reg_kh, jcp.ih);
4520             L(oh_tpad_tail_label);
4521             {
4522                 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4523                 add(reg_ddst, ddst_step_size);
4524                 sub(reg_kernel, filter_step_size * stride_h);
4525 
4526                 inc(reg_oj);
4527 
4528                 if (is_partial) {
4529                     cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4530                     jge(oh_bpad_label_end, T_NEAR);
4531                 }
4532                 cmp(reg_oj, oh_head_overflow_end);
4533                 jl(oh_tpad_tail_label, T_NEAR);
4534             }
4535         }
4536         if (body_src_start_offset != 0) {
4537             add(reg_kernel, filter_step_size * body_src_start_offset);
4538             add(reg_src, src_step_size * body_src_start_offset);
4539         }
4540         L(oh_tpad_tail_label_end);
4541     }
4542 
4543     if (is_partial) {
4544         cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4545         jge(oh_bpad_label_end, T_NEAR);
4546     }
4547     cmp(reg_oj, oh_body_end);
4548     jge(oh_label_end, T_NEAR);
4549 
4550     /* Compute middle block(s) */
4551     mov(reg_kh, jcp.kh);
4552     L(oh_label);
4553     {
4554         compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4555         add(reg_src, src_step_size * stride_h);
4556         add(reg_ddst, ddst_step_size);
4557 
4558         inc(reg_oj);
4559 
4560         if (is_partial) {
4561             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4562             jge(oh_bpad_label_end, T_NEAR);
4563         }
4564 
4565         cmp(reg_oj, oh_body_end);
4566         jl(oh_label, T_NEAR);
4567     }
4568     L(oh_label_end);
4569 
4570     /* Compute bottom edge */
4571     if (b_pad > 0) {
4572         if (is_partial) {
4573             cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4574             jge(oh_bpad_label_end, T_NEAR);
4575         }
4576         cmp(reg_oj, jcp.oh);
4577         jge(oh_bpad_label_end, T_NEAR);
4578 
4579         if (is_dilated) {
4580             // Assumes unit stride for dilations
4581             mov(reg_kh, jcp.kh - 1);
4582             xor_(reg_ih_shift, reg_ih_shift);
4583         } else {
4584             assert(jcp.dilate_h == 0);
4585             mov(reg_kh, jcp.ih - ih_body_end);
4586         }
4587         if (is_partial) {
4588             lea(reg_oj_setup,
4589                     ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
4590             if (stride_h == 1 && !is_dilated) {
4591                 sub(reg_kh, reg_oj_setup);
4592             } else {
4593                 Label body_setup, body_setup_finish, dilate_skip;
4594                 cmp(reg_oj_setup, 0);
4595                 je(body_setup_finish, T_NEAR);
4596 
4597                 L(body_setup);
4598                 if (is_dilated) {
4599                     inc(reg_ih_shift);
4600                     cmp(reg_ih_shift, dilate_h);
4601                     jl(dilate_skip, T_NEAR);
4602                     xor_(reg_ih_shift, reg_ih_shift);
4603                 }
4604                 sub(reg_kh, stride_h);
4605                 L(dilate_skip);
4606                 sub(reg_oj_setup, 1);
4607                 jg(body_setup, T_NEAR);
4608                 L(body_setup_finish);
4609             }
4610         }
4611 
4612         if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4613         L(oh_bpad_label);
4614         {
4615             compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4616             add(reg_src, src_step_size * stride_h);
4617             add(reg_ddst, ddst_step_size);
4618 
4619             if (is_dilated) {
4620                 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4621                 inc(reg_ih_shift);
4622                 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4623                 cmp(reg_ih_shift, dilate_h);
4624                 jl(oh_dilate_label_end, T_NEAR);
4625                 xor_(reg_ih_shift, reg_ih_shift);
4626                 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4627             }
4628             sub(reg_kh, stride_h);
4629             L(oh_dilate_label_end);
4630             inc(reg_oj);
4631             if (is_partial) {
4632                 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4633                 jge(oh_bpad_label_end, T_NEAR);
4634             }
4635             cmp(reg_oj, oh_tail_end);
4636             jl(oh_bpad_label, T_NEAR);
4637         }
4638     }
4639     L(oh_bpad_label_end);
4640 }
4641 
compute_od_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4642 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_od_loop_common(
4643         int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4644     assert(jcp.harness == harness_3d_reduction);
4645 
4646     const int src_backpad_overlap
4647             = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
4648 
4649     const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
4650     const auto src_shift = get_src_offset(0, 0, jcp.ih);
4651     const auto ddst_shift = get_ddst_offset(0, jcp.oh);
4652 
4653     const int kd_front_pad = nstl::max(0, jcp.f_pad);
4654     const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
4655 
4656     Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
4657             backpad_end_label, backpad_label;
4658 
4659     /* initially offset 'kd' by f_pad */
4660     mov(reg_src_d, ptr[param + GET_OFF(src)]);
4661     mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
4662 
4663     if (is_partial) {
4664         add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
4665         mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
4666         mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
4667     } else {
4668         const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
4669         const int kd_offset = get_kernel_offset(
4670                 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
4671         add(reg_kernel, kd_offset);
4672         xor_(reg_d_index, reg_d_index);
4673         mov(reg_kd_count, kd_padding);
4674     }
4675 
4676     cmp(reg_kd_count, 0);
4677     jle(loop_end_label, T_NEAR); // no iterations along kd
4678     if (is_partial)
4679         cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4680     else
4681         cmp(reg_d_index, jcp.od);
4682     jge(loop_end_label, T_NEAR); // no iterations along depth dimension
4683 
4684     L(d_loop_label);
4685 
4686     mov(reg_src, reg_src_d);
4687     mov(reg_ddst, reg_ddst_d);
4688 
4689     mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
4690     mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
4691     mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
4692 
4693     compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4694 
4695     mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
4696     mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
4697     mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
4698 
4699     /* Compute 'front' edge */
4700     if (jcp.f_pad > 0) {
4701         /* Check if within fpad region */
4702         cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
4703         jge(fpad_end_label, T_NEAR);
4704 
4705         /* Fpad steps */
4706         sub(reg_kernel, filter_shift * jcp.stride_d);
4707         add(reg_kd_count, jcp.stride_d);
4708 
4709         /* Final number of kernel elements that overlap with src */
4710         const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
4711         cmp(reg_kd_count, src_ker_overlap);
4712         jle(common_block_label, T_NEAR);
4713 
4714         /* Correct any excess shifts to kernel and src */
4715         if (jcp.f_pad <= jcp.od * jcp.stride_d) {
4716             /* Filter has moved beyond padding (adjust for stride effects) */
4717             if (jcp.f_pad % jcp.stride_d != 0) {
4718                 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
4719                 add(reg_kernel, filter_shift * src_corr);
4720                 add(reg_src_d, src_shift * src_corr);
4721             }
4722         } else {
4723             /* Filter still overlaps padding (complete reset) */
4724             sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
4725         }
4726 
4727         /* Apply correction */
4728         mov(reg_kd_count, src_ker_overlap);
4729         jmp(common_block_label);
4730 
4731         L(fpad_end_label);
4732     }
4733 
4734     /* Compute bottom edge */
4735     if (jcp.back_pad > 0) {
4736 
4737         /* Check if within back_pad region */
4738         cmp(reg_d_index, src_backpad_overlap - 1);
4739         jl(backpad_end_label, T_NEAR);
4740         jg(backpad_label, T_NEAR);
4741 
4742         /* Execute overlap correction between the filter and the initial
4743          * back_pad region. */
4744         mov(reg_kd_count,
4745                 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
4746         jmp(backpad_end_label, T_NEAR);
4747 
4748         L(backpad_label);
4749         sub(reg_kd_count, jcp.stride_d);
4750         cmp(reg_kd_count, 0);
4751         jle(loop_end_label, T_NEAR);
4752 
4753         L(backpad_end_label);
4754     }
4755 
4756     /* Compute middle block */
4757     add(reg_src_d, src_shift * jcp.stride_d);
4758 
4759     /* Execute common block and loop */
4760     L(common_block_label);
4761     add(reg_ddst_d, ddst_shift);
4762     inc(reg_d_index);
4763     if (is_partial)
4764         cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4765     else
4766         cmp(reg_d_index, jcp.od);
4767     jl(d_loop_label, T_NEAR);
4768 
4769     L(loop_end_label);
4770 }
4771 
compute_loop(int nb_ic_blocking,int nb_oc_blocking)4772 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_loop(
4773         int nb_ic_blocking, int nb_oc_blocking) {
4774     mov(reg_src, ptr[param + GET_OFF(src)]);
4775     mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4776     mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4777 
4778     maybe_zero_kernel(nb_ic_blocking, nb_oc_blocking);
4779     maybe_compute_diff_bias(nb_oc_blocking);
4780 
4781     switch (jcp.harness) {
4782         case harness_3d_reduction:
4783             compute_od_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4784             break;
4785         case harness_2d_reduction:
4786             compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4787             break;
4788         case harness_mb_reduction:
4789             compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4790             break;
4791         case harness_compute_full_spatial:
4792             compute_full_spat_loop(nb_ic_blocking, nb_oc_blocking);
4793             break;
4794         default: assert(!"Invalid harness type");
4795     }
4796 }
4797 
setup_stack_space()4798 void jit_avx512_core_amx_bwd_weights_kernel_t::setup_stack_space() {
4799     kd_count_offset = ic_block_step_stack_size;
4800     src_d_offset = ic_block_step_stack_size + 8;
4801     ddst_d_offset = ic_block_step_stack_size + 16;
4802     d_index_offset = ic_block_step_stack_size + 24;
4803     ih_dilate_offset = ic_block_step_stack_size + 32;
4804     src_save_offset = ic_block_step_stack_size + 40;
4805     ddst_save_offset = ic_block_step_stack_size + 48;
4806     stack_space_needed = ic_block_step_stack_size + 56;
4807 }
4808 
generate()4809 void jit_avx512_core_amx_bwd_weights_kernel_t::generate() {
4810     preamble();
4811 
4812     setup_stack_space();
4813 
4814     sub(rsp, stack_space_needed);
4815 
4816     Label last_ic_block_label, last_blocks_done_label;
4817 
4818     mov(reg_tmp, ptr[param + GET_OFF(last_ic_block)]);
4819     cmp(reg_tmp, 0);
4820     jne(last_ic_block_label, T_NEAR);
4821     { // full nb_ic_blocking
4822         Label last_oc_block_label;
4823         mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4824         cmp(reg_tmp, 0);
4825         jne(last_oc_block_label, T_NEAR);
4826         { // full nb_oc_blocking
4827             compute_loop(jcp.nb_ic_blocking, jcp.nb_oc_blocking);
4828             jmp(last_blocks_done_label, T_NEAR);
4829         }
4830         L(last_oc_block_label);
4831         { // tail of nb_oc_blocking
4832             compute_loop(jcp.nb_ic_blocking, 1);
4833             jmp(last_blocks_done_label, T_NEAR);
4834         }
4835     }
4836     L(last_ic_block_label);
4837     { // tail nb_ic_blocking
4838         Label last_oc_block_label;
4839         mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4840         cmp(reg_tmp, 0);
4841         jne(last_oc_block_label, T_NEAR);
4842         { // full nb_oc_blocking
4843             compute_loop(1, jcp.nb_oc_blocking);
4844             jmp(last_blocks_done_label, T_NEAR);
4845         }
4846         L(last_oc_block_label);
4847         { // tail of nb_oc_blocking
4848             compute_loop(1, 1);
4849             jmp(last_blocks_done_label, T_NEAR);
4850         }
4851     }
4852 
4853     L(last_blocks_done_label);
4854     add(rsp, stack_space_needed);
4855 
4856     postamble();
4857 }
4858 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_bias_md,memory_desc_t & diff_dst_md,int nthreads)4859 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(
4860         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4861         memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4862         memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4863     const memory_desc_wrapper src_d(&src_md);
4864     const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4865     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4866     const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4867 
4868     jcp = zero<decltype(jcp)>();
4869 
4870     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4871     int ndims = src_d.ndims();
4872 
4873     if (!mayiuse(avx512_core_bf16_amx_bf16)) return status::unimplemented;
4874     jcp.isa = avx512_core_bf16_amx_bf16;
4875 
4876     jcp.ver = ver_vnni; // Needed for transpose routines
4877     jcp.nthr = nthreads;
4878 
4879     jcp.ndims = ndims;
4880     jcp.prop_kind = cd.prop_kind;
4881 
4882     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4883     jcp.mb = src_d.dims()[0];
4884 
4885     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4886     jcp.oc_without_padding = jcp.oc;
4887     jcp.ic = src_d.dims()[1] / jcp.ngroups;
4888 
4889     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4890     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4891     jcp.iw = src_d.dims()[ndims - 1];
4892     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4893     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4894     jcp.ow = diff_dst_d.dims()[ndims - 1];
4895 
4896     jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4897     jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4898     jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4899 
4900     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4901     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4902     jcp.l_pad = cd.padding[0][ndims - 3];
4903 
4904     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4905     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4906     jcp.stride_w = cd.strides[ndims - 3];
4907 
4908     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4909     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4910     jcp.dilate_w = cd.dilates[ndims - 3];
4911 
4912     int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4913     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4914     int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4915 
4916     bool ok = true
4917             // general condition to simplify dilations
4918             && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4919             && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4920             && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4921             // special condition to simplify dilations in compute_oh_loop_common
4922             && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4923     if (!ok) return status::unimplemented;
4924 
4925     ok = true && one_of(ndims, 3, 4, 5)
4926             && everyone_is(
4927                     data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4928             && one_of(diff_weights_d.data_type(), data_type::f32,
4929                     data_type::bf16);
4930     if (!ok) return status::unimplemented;
4931 
4932     jcp.transform_to_vnni = diff_weights_d.data_type() == data_type::bf16;
4933 
4934     jcp.r_pad = nstl::max(0,
4935             calculate_end_padding(
4936                     jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4937     jcp.b_pad = nstl::max(0,
4938             calculate_end_padding(
4939                     jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4940     jcp.back_pad = nstl::max(0,
4941             calculate_end_padding(
4942                     jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4943 
4944     /* XXX: no support for padding when dilation_d > 0 */
4945     if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4946         return status::unimplemented;
4947 
4948     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4949     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4950     jcp.ohp = jcp.oh;
4951     jcp.owp = jcp.ow;
4952 
4953     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
4954     if (jcp.is_depthwise)
4955         return status::unimplemented; // TODO: add support of DW convolution
4956 
4957     const int dat_format_tag = ndims - 3;
4958     format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc,
4959             format_tag::nhwc, format_tag::ndhwc);
4960     format_tag_t dat_tag_opt = dat_tag_nspc;
4961 
4962     if (src_d.format_kind() == format_kind::any) {
4963         CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
4964         jcp.src_tag = dat_tag_opt;
4965     } else
4966         jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt);
4967     if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented;
4968     jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
4969 
4970     if (diff_dst_d.format_kind() == format_kind::any) {
4971         CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
4972         jcp.dst_tag = jcp.src_tag;
4973     } else
4974         jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
4975     if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
4976 
4977     if (!jcp.is_nspc) return status::unimplemented;
4978 
4979     const int wei_format_tag = 2 * ndims - 6 + with_groups;
4980     format_tag_t wei_tag;
4981     if (jcp.transform_to_vnni)
4982         wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i,
4983                 format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i,
4984                 format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i,
4985                 format_tag::gOIdhw16i16o2i);
4986     else
4987         wei_tag = pick(wei_format_tag, format_tag::OIw16i16o,
4988                 format_tag::gOIw16i16o, format_tag::OIhw16i16o,
4989                 format_tag::gOIhw16i16o, format_tag::OIdhw16i16o,
4990                 format_tag::gOIdhw16i16o);
4991     if (diff_weights_md.format_kind == format_kind::any) {
4992         CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4993         jcp.wei_tag = wei_tag;
4994     } else {
4995         jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4996         if (jcp.wei_tag != wei_tag) return status::unimplemented;
4997     }
4998     jcp.wei_dt = diff_weights_d.data_type();
4999 
5000     /* conditions on bias memory */
5001     jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
5002     if (jcp.with_bias) {
5003         if (diff_bias_d.format_kind() == format_kind::any)
5004             CHECK(memory_desc_init_by_tag(diff_bias_md, format_tag::x));
5005     }
5006     jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
5007     jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
5008 
5009     /* kernel applicability check wrt boundaries
5010      * the conditions are quite general across the kernels we have,
5011      * but ideally the check should belong to a specific kernel... */
5012     const int max_pad_h = ext_kh / 2;
5013     const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
5014             && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
5015             && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd;
5016     if (!boundaries_ok) return status::unimplemented;
5017 
5018     jcp.ic_block = 16;
5019     jcp.oc_block = 16;
5020 
5021     jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
5022     jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
5023 
5024     jcp.ic_tail = jcp.ic % jcp.ic_block;
5025     jcp.oc_tail = jcp.oc % jcp.oc_block;
5026 
5027     jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1;
5028     jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1;
5029 
5030     int max_palette = amx::get_max_palette();
5031     jcp.max_tiles = amx::get_max_tiles(max_palette);
5032     jcp.full_tile_width = amx::get_max_rows(max_palette);
5033 
5034     if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
5035         return status::unimplemented;
5036 
5037     const bool is_2d = (ndims == 4);
5038     const bool is_3d = (ndims == 5);
5039     jcp.typesize_in = sizeof(bfloat16_t);
5040     jcp.typesize_out = sizeof(float);
5041 
5042     // TODO: Find more shapes (especially 3D with large spatials) for which
5043     // local transposition will be beneficial. Furthermore, for TBB threads
5044     // more shapes can potentially benefit from spatial blocking
5045     int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
5046 
5047     jcp.global_transpose = dnnl_thr_syncable();
5048     jcp.spatial_blk_size = optimal_blk_size;
5049 
5050     const int tr_round = 32; // To load full tile register
5051     int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
5052     jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
5053             * jcp.stride_w;
5054 
5055     jcp.tr_src_num_guard_elems = tr_pad; // upper bound
5056     jcp.tr_ow = rnd_up(jcp.ow, 2);
5057 
5058     if (jcp.tr_ow <= max_ur_w) {
5059         jcp.ur_w = jcp.tr_ow;
5060         jcp.ur_w_blocks = 1;
5061     } else {
5062         jcp.ur_w = 1;
5063         for (int i = max_ur_w; i >= 1; i -= 2) {
5064             if (jcp.tr_ow % i == 0) {
5065                 jcp.ur_w = i;
5066                 break;
5067             }
5068         }
5069         jcp.ur_w_blocks = jcp.tr_ow / jcp.ur_w;
5070     }
5071 
5072     bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
5073             && jcp.oc <= diff_dst_d.padded_dims()[1]
5074             && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
5075             && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
5076     if (!args_ok) return status::unimplemented;
5077 
5078     bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh
5079             && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w)
5080             && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
5081             // TODO: Remove this constraint: only 3x3 kernel works now
5082             && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
5083             && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw
5084             && jcp.ih >= jcp.kh && jcp.iw >= jcp.kw;
5085 
5086     jcp.harness = ndims == 5
5087             ? harness_3d_reduction
5088             : (use_full_spat_loop ? harness_compute_full_spatial
5089                                   : (ndims == 4) ? harness_2d_reduction
5090                                                  : harness_mb_reduction);
5091     switch (jcp.harness) {
5092         case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
5093         case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
5094         case harness_compute_full_spatial:
5095         case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
5096         default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
5097     }
5098     { // balancing
5099         int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
5100         balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
5101         jcp.nthr = nthr;
5102         jcp.nthr_mb = nthr_mb;
5103         jcp.nthr_g = nthr_g;
5104         jcp.nthr_oc_b = nthr_oc_b;
5105         jcp.nthr_ic_b = nthr_ic_b;
5106 
5107         // TODO: Optimize memory allocation when threaded on height and depth
5108         jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
5109         jcp.tr_src_buf_count = jcp.global_transpose
5110                 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
5111                 : jcp.nthr;
5112 
5113         jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
5114         jcp.tr_diff_dst_buf_count = jcp.global_transpose
5115                 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
5116                 : jcp.nthr;
5117     }
5118 
5119     return status::success;
5120 }
5121 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_dst_md)5122 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad(
5123         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
5124         memory_desc_t &src_md, memory_desc_t &diff_weights_md,
5125         memory_desc_t &diff_dst_md) {
5126     const memory_desc_wrapper src_d(&src_md);
5127     const memory_desc_wrapper diff_weights_d(&diff_weights_md);
5128     const memory_desc_wrapper diff_dst_d(&diff_dst_md);
5129 
5130     // XXX: See the comment about tr_iw and guarding elements in
5131     // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf()
5132     const size_t tr_src_size
5133             = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking)
5134             + jcp.tr_src_num_guard_elems;
5135     scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
5136 
5137     /* prepare synchronization contexts */
5138     if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
5139         const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
5140         scratchpad.book<simple_barrier::ctx_t>(
5141                 key_conv_tr_src_bctx, tr_src_bctx_size);
5142     }
5143 
5144     const size_t tr_diff_dst_size = jcp.tr_diff_dst_buf_count
5145             * jcp.tr_diff_dst_buf_size * jcp.nb_oc_blocking;
5146 
5147     const size_t min_align = 64;
5148     scratchpad.book(
5149             key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, min_align);
5150 
5151     /* prepare synchronization contexts */
5152     if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
5153         const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
5154         scratchpad.book<simple_barrier::ctx_t>(
5155                 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
5156     }
5157 
5158     if (IMPLICATION(jcp.nthr_mb == 1,
5159                 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
5160                         || jcp.wei_dt == data_type::bf16)) {
5161         const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
5162                 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
5163         const size_t bia_size
5164                 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
5165 
5166         const int num_wei_buffers
5167                 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
5168         const int num_bia_buffers = jcp.with_bias
5169                 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
5170                                                  : jcp.nthr_mb - 1)
5171                 : 0;
5172 
5173         const size_t wei_bia_reduction_size
5174                 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
5175 
5176         scratchpad.book<float>(
5177                 key_conv_wei_bia_reduction, wei_bia_reduction_size);
5178 
5179         scratchpad.book<simple_barrier::ctx_t>(
5180                 key_conv_wei_bia_reduction_bctx, 1);
5181     }
5182 
5183     if (jcp.with_bias
5184             && ((jcp.oc_without_padding % jcp.oc_block != 0)
5185                     && jcp.bia_dt == data_type::f32)) {
5186         scratchpad.book(key_conv_padded_bias,
5187                 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
5188     }
5189     scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
5190 
5191     constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
5192             << 30; // 32Gb - TODO: may it's too large?
5193     const size_t scratchpad_limit_by_tensor_sizes = (size_t)32 * jcp.nthr
5194             * (src_d.size() + diff_weights_d.size() + diff_dst_d.size());
5195     const size_t scratchpad_limit
5196             = nstl::min(scratchpad_limit_by_absolute_value,
5197                     scratchpad_limit_by_tensor_sizes);
5198     if (scratchpad.size() > scratchpad_limit)
5199         return status::unimplemented;
5200     else
5201         return status::success;
5202 }
5203 
balance(const jit_conv_conf_t & j,int & nthr_,int & nthr_mb_,int & nthr_g_,int & nthr_oc_b_,int & nthr_ic_b_)5204 void jit_avx512_core_amx_bwd_weights_kernel_t::balance(const jit_conv_conf_t &j,
5205         int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_,
5206         int &nthr_ic_b_) {
5207     nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
5208 
5209     const int max_threads = dnnl_get_max_threads();
5210 
5211     if (max_threads < j.ngroups) {
5212         /* simplification... fortunately it doesn't hurt much */
5213         nthr_ = nthr_g_ = max_threads;
5214         return;
5215     }
5216 
5217     nthr_g_ = j.ngroups;
5218     const int nthr = max_threads / nthr_g_;
5219 
5220     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
5221         /* calculate per thread memory cost (read/write). high level optimizer
5222          * tries to minimize memory consumption. few notes:
5223          *  (n1) if weights tensor size is less than source and destination
5224          *       tensors we apply the ratio of the source and destination
5225          *       tensor sizes to weights one as compensation coefficient to
5226          *       avoid parallelization across batch size only, othervise we
5227          *       apply additional coefficient to source component based on
5228          *       performance measurements
5229          *  (n2) use scales based on output vs input channels ratio for source
5230          *       and destination componets to imporve threading balance across
5231          *       input and output channels */
5232 
5233         const dim_t src_type_size = 2;
5234         const dim_t wei_type_size = 4;
5235 
5236         dim_t src_size
5237                 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
5238         dim_t dst_size
5239                 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
5240         dim_t wei_size
5241                 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
5242 
5243         float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
5244         float oi_channels_ratio = (float)(j.nb_oc / j.nb_oc_blocking)
5245                 / (j.nb_ic / j.nb_ic_blocking);
5246         auto get_src_coef = [=]() {
5247             float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
5248             if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
5249 
5250             return src_coef;
5251         };
5252 
5253         auto get_dst_coef
5254                 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
5255 
5256         auto get_wei_coef
5257                 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
5258 
5259         const float src_coef = get_src_coef();
5260         const float dst_coef = get_dst_coef();
5261         const float wei_coef = get_wei_coef();
5262 
5263         float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
5264                 * div_up(j.ngroups, nthr_g_)
5265                 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.mb
5266                 * (j.ic_block * j.nb_ic_blocking) * j.id * j.ih * j.tr_iw
5267                 / j.nthr_mb_work / j.stride_d / j.stride_h / j.stride_w;
5268         float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
5269                 * div_up((j.nb_oc / j.nb_oc_blocking),
5270                         (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5271                 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.kh * j.kw
5272                 * j.kd * (j.ic_block * j.nb_ic_blocking)
5273                 * (j.oc_block * j.nb_oc_blocking);
5274         float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
5275                 * div_up(j.ngroups, nthr_g_)
5276                 * div_up((j.nb_oc / j.nb_oc_blocking),
5277                         (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5278                 * j.mb * (j.oc_block * j.nb_oc_blocking) * j.od * j.oh * j.tr_ow
5279                 / j.nthr_mb_work;
5280 
5281         return src_v + dst_v + wei_v;
5282     };
5283 
5284     float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
5285 
5286     /* find the best thread distribution with lowest memory cost */
5287 
5288     const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
5289     for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
5290         const int nthr_par = nthr / nthr_mb;
5291         const int nthr_oc_b_max = nstl::min(nthr_par,
5292                 (j.nb_oc / j.nb_oc_blocking)); // Amount of nb_oc_blocks
5293         for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
5294             int nthr_ic_b = nstl::min(
5295                     nthr_par / nthr_oc_b, (j.nb_ic / j.nb_ic_blocking));
5296 
5297             float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
5298             if (mem_cost <= best_mem_cost) {
5299                 best_mem_cost = mem_cost;
5300                 nthr_mb_ = nthr_mb;
5301                 nthr_oc_b_ = nthr_oc_b;
5302                 nthr_ic_b_ = nthr_ic_b;
5303             }
5304         }
5305     }
5306 
5307     if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
5308         nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
5309     nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5310 
5311     assert(nthr_ <= max_threads);
5312 }
5313 
5314 } // namespace x64
5315 } // namespace cpu
5316 } // namespace impl
5317 } // namespace dnnl
5318 
5319 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
5320