1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/c_types_map.hpp"
18 #include "common/memory_tracking.hpp"
19 #include "common/nstl.hpp"
20 #include "common/type_helpers.hpp"
21 #include "common/utils.hpp"
22
23 #include "cpu/platform.hpp"
24 #include "cpu/x64/cpu_barrier.hpp"
25 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
26 #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
27 #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp"
28
29 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
30
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace x64 {
35
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::data_type;
38 using namespace dnnl::impl::utils;
39 using namespace Xbyak;
40
prepare_output(int ur_w)41 void jit_avx512_core_amx_compute_zp_pbuff_t::prepare_output(int ur_w) {
42 for (int oc = 0; oc < jcp.nb_oc_blocking; oc++)
43 for (int ur = 0; ur < ur_w; ur++) {
44 const Zmm zmm = zmm_out(ur, oc);
45 vpxord(zmm, zmm, zmm);
46 }
47 }
48
store_output(int ur_w,bool last_oc_block_flag)49 void jit_avx512_core_amx_compute_zp_pbuff_t::store_output(
50 int ur_w, bool last_oc_block_flag) {
51 assert(jcp.is_nspc);
52
53 const int nb_oc_block = jcp.nb_oc_blocking;
54 const int oc_block = jcp.oc_block;
55
56 const auto src_zp_addr = EVEX_compress_addr(reg_src_zero_point, 0, true);
57
58 /* write out register to output_addr */
59 for (int oc = 0; oc < nb_oc_block; oc++) {
60 const bool mask_flag = last_oc_block_flag && oc == nb_oc_block - 1;
61 for (int ur = 0; ur < ur_w; ur++) {
62 const int output_offset = sizeof(int32_t)
63 * (oc * oc_block
64 + ur * jcp.oc_without_padding * jcp.ngroups);
65 const Zmm zmm_dst = zmm_out(ur, oc);
66 const Zmm m_zmm_dst = mask_flag ? zmm_dst | ktail_mask : zmm_dst;
67 // multiply dst by src_zero_point
68 vpmulld(m_zmm_dst, zmm_dst, src_zp_addr);
69 vmovups(EVEX_compress_addr(reg_zp_pbuff, output_offset), m_zmm_dst);
70 }
71 }
72 }
73
compute_ker(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool padded)74 void jit_avx512_core_amx_compute_zp_pbuff_t::compute_ker(int ur_w, int pad_l,
75 int pad_r, ic_block_t last_ic_block_flag, bool padded) {
76
77 const int kw = jcp.kw;
78 const int ic_block = jcp.ic_block_int_np;
79 const int oc_block = jcp.oc_block;
80 const int nb_oc_block = jcp.nb_oc_blocking;
81
82 const bool ic_tail
83 = (jcp.ic_without_padding % (jcp.ic_block / ic_inner_block)) > 0;
84 const bool masked_write = ic_tail && last_ic_block_flag == last_ic_block;
85
86 /* Skip the last loads of input
87 if (ic%16)/ic_sub_step < ic_block/ic_sub_step */
88 const int icb = (last_ic_block_flag == last_ic_block)
89 ? div_up(
90 (jcp.ic_without_padding % jcp.ic_block_int), ic_inner_block)
91 : ic_block / ic_inner_block;
92
93 auto get_filter_offset = [=](int ocb, int ic, int ki) {
94 size_t w_step = jcp.is_relo ? jcp.kh : 1;
95 size_t kw_offset = static_cast<size_t>(ki) * w_step
96 * jcp.ic_block_int_np * jcp.oc_block;
97 size_t oc_subblock_step = static_cast<size_t>(jcp.kd) * jcp.kh * jcp.kw
98 * jcp.ic_block_int_np * jcp.oc_block;
99 size_t offset = kw_offset
100 + static_cast<size_t>(ocb) * jcp.nb_ic_int * oc_subblock_step
101 + static_cast<size_t>(ic) * oc_block * ic_inner_block;
102 return sizeof(char) * offset;
103 };
104 auto compute_fma = [=](const Zmm zmm_accum, const int ic,
105 const Address addr) {
106 if (jcp.is_relo) {
107 vmovups(zmm_permb, ptr[reg_scratch]); // get permute index table
108 const Zmm r_zmm = masked_write && ic == icb - 1
109 ? zmm_permb | kmask_ic_block | T_z
110 : zmm_permb;
111 // only values from 'src2' are used to write dst
112 vpermi2b(r_zmm, zmm_permb, addr);
113 vpdpbusd(zmm_accum, zmm_one,
114 zmm_permb); // XXX - using the same register for all ur_w
115 } else {
116 vpdpbusd(zmm_accum, zmm_one, addr);
117 }
118 };
119
120 if (jcp.is_relo && last_ic_block_flag == last_ic_block && ic_tail) {
121 const Reg64 reg_tmp = reg_scratch;
122 mov(reg_tmp, ic_mask_label);
123 kmovq(kmask_ic_block, qword[reg_tmp]);
124 }
125 if (jcp.is_relo) mov(reg_scratch, permb_idx_label);
126
127 for (int ki = 0; ki < kw; ki++) {
128 const int ur_start = get_ow_start(ki, pad_l);
129 const int ur_end = get_ow_end(ur_w, ki, pad_r);
130 for (int ur = 0; ur < ur_w; ur++) {
131 // Calculate zero_point padding as:
132 // accum = is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0)
133 if (ur < ur_start || ur >= ur_end || padded) {
134 for (int oc = 0; oc < nb_oc_block; oc++) {
135 const Zmm zmm_accum = zmm_out(ur, oc);
136 for (int ic = 0; ic < icb; ic++) {
137 const auto addr_filt = EVEX_compress_addr(
138 aux_reg_filt, get_filter_offset(oc, ic, ki));
139 compute_fma(zmm_accum, ic, addr_filt);
140 }
141 }
142 }
143 }
144 }
145 }
146
kh_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)147 void jit_avx512_core_amx_compute_zp_pbuff_t::kh_loop(int ur_w, int pad_l,
148 int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
149
150 Label kh_label, skip_kh_loop;
151 const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
152 const size_t shift_wei_h_step = sizeof(char)
153 * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
154 * jcp.oc_block;
155
156 // Compute zero_point compensation for the padded region. Total compute
157 // area is 'overflow * kw' where 'overflow' indicates the overlap
158 // between the filter and either top_pad or bottom_pad region.
159 auto compute_kh_loop = [=](size_t param_overflow) {
160 Label overflow_label, no_overflow_label;
161
162 mov(reg_overflow, ptr[param1 + param_overflow]);
163 cmp(reg_overflow, 0);
164 je(no_overflow_label, T_NEAR);
165 L(overflow_label);
166 {
167 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
168 add(aux_reg_filt, shift_wei_h_step);
169 dec(reg_overflow);
170 jne(overflow_label, T_NEAR);
171 }
172 L(no_overflow_label);
173 };
174
175 if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(t_overflow));
176
177 // check for holes and skip computation due to dilation
178 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
179 if ((jcp.dilate_h >= jcp.ih)) {
180 cmp(reg_kj, 0);
181 je(skip_kh_loop, T_NEAR);
182 }
183
184 L(kh_label);
185 {
186 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
187
188 add(aux_reg_filt, shift_wei_h_step);
189 dec(reg_kj);
190 jne(kh_label, T_NEAR);
191 }
192
193 L(skip_kh_loop);
194
195 if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(b_overflow));
196 }
197
kd_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)198 void jit_avx512_core_amx_compute_zp_pbuff_t::kd_loop(int ur_w, int pad_l,
199 int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
200
201 Label kd_label, skip_kd_loop;
202 const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
203 const size_t shift_wei_h_step = sizeof(char)
204 * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
205 * jcp.oc_block;
206
207 // Compute zero_point compensation for the padded region. Total compute
208 // area is 'overflow * kh * kw' where 'overflow' indicates the overlap
209 // between the filter and either front_pad or back_pad region.
210 auto compute_kd_loop = [=](size_t param_overflow) {
211 Label kh_loop_label;
212 Label no_overflow_label, overflow_label;
213
214 mov(reg_ki, ptr[param1 + param_overflow]);
215 cmp(reg_ki, 0);
216 je(no_overflow_label, T_NEAR);
217 L(overflow_label);
218 {
219 mov(aux_reg_filt, aux_reg_filt_d);
220 mov(reg_kj, jcp.kh);
221 L(kh_loop_label);
222 {
223 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
224 add(aux_reg_filt, shift_wei_h_step);
225 dec(reg_kj);
226 jne(kh_loop_label, T_NEAR);
227 }
228 add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
229 dec(reg_ki);
230 jne(overflow_label, T_NEAR);
231 }
232 L(no_overflow_label);
233 };
234
235 const bool zp_d_padding
236 = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
237 if (zp_d_padding) {
238 mov(aux_reg_filt_d, reg_filt);
239 compute_kd_loop(GET_OFF(f_overflow));
240
241 // check for holes and skip computation due to dilation
242 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
243 if (jcp.dilate_d >= jcp.id) {
244 cmp(reg_ki, 0);
245 je(skip_kd_loop, T_NEAR);
246 }
247 L(kd_label);
248 mov(aux_reg_filt, aux_reg_filt_d);
249
250 } else {
251 mov(aux_reg_filt, reg_filt);
252 }
253
254 kh_loop(ur_w, pad_l, pad_r, last_ic_block_flag, handle_h_pad);
255
256 if (zp_d_padding) {
257 add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
258 dec(reg_ki);
259 jne(kd_label, T_NEAR);
260
261 L(skip_kd_loop);
262
263 compute_kd_loop(GET_OFF(back_overflow));
264 }
265 }
266
icb_loop(int ur_w,int pad_l,int pad_r,bool handle_h_pad)267 void jit_avx512_core_amx_compute_zp_pbuff_t::icb_loop(
268 int ur_w, int pad_l, int pad_r, bool handle_h_pad) {
269
270 Label icb_label;
271 const size_t nb_ic = jcp.nb_ic_int;
272 const bool do_icb_loop = nb_ic > 1;
273
274 /* Initialize zmm_one for weight accumulation */
275 xor_(reg_scratch, reg_scratch);
276 const Reg8 _t8 = reg_scratch.cvt8();
277 mov(_t8, 0x1);
278 vpbroadcastb(zmm_one, _t8);
279
280 prepare_output(ur_w);
281
282 mov(reg_icb, nb_ic);
283
284 L(icb_label);
285 if (jcp.ic_without_padding != jcp.ic) {
286 Label common_ker, end_ker;
287 if (do_icb_loop) {
288 cmp(reg_icb, 1); // The last ic block
289 jne(common_ker, T_NEAR);
290 }
291 kd_loop(ur_w, pad_l, pad_r, last_ic_block, handle_h_pad);
292 if (do_icb_loop) {
293 jmp(end_ker, T_NEAR);
294
295 L(common_ker);
296 kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
297
298 L(end_ker);
299 }
300 } else {
301 kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
302 }
303 // End of IC Loop
304 if (do_icb_loop) {
305 const size_t shift_wei_icb_step = static_cast<size_t>(jcp.kd) * jcp.kh
306 * jcp.kw * jcp.oc_block * jcp.ic_block_int_np;
307 add(reg_filt, sizeof(char) * shift_wei_icb_step);
308
309 dec(reg_icb);
310 cmp(reg_icb, 0);
311 jg(icb_label, T_NEAR);
312
313 sub(reg_filt, sizeof(char) * shift_wei_icb_step * nb_ic);
314 }
315
316 if (jcp.oc_without_padding != jcp.oc) {
317 Label common_store, end_store;
318
319 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
320 jne(common_store, T_NEAR);
321
322 store_output(ur_w, true); // last oc block
323 jmp(end_store, T_NEAR);
324
325 L(common_store);
326 store_output(ur_w, false);
327
328 L(end_store);
329 } else {
330 store_output(ur_w, false);
331 }
332 }
333
unroll_width(const bool h_padding)334 void jit_avx512_core_amx_compute_zp_pbuff_t::unroll_width(
335 const bool h_padding) {
336
337 auto ur_w_shift = [&](const int ur_w) {
338 return sizeof(int32_t) * (ur_w * jcp.oc_without_padding * jcp.ngroups);
339 };
340
341 const int max_ur_w = jit_avx512_core_amx_compute_zp_pbuff_t::max_regs_ur
342 / (jcp.nb_oc_blocking);
343 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
344 int l_pad = jcp.l_pad;
345
346 const int l_pad_output = jcp.l_pad_output;
347 const int r_pad_output = jcp.r_pad_output;
348
349 // a single middle element (if required) containing only height padding
350 const int no_pad = nstl::max(0, jcp.ow - l_pad_output - r_pad_output);
351
352 const int ow_start = nstl::max(jcp.ow - r_pad_output, l_pad_output);
353 const int r_pad_start = nstl::min(jcp.ow_pad - l_pad_output, r_pad_output);
354
355 int ow = 0;
356 int cur_l_pad_output = l_pad_output;
357 while (cur_l_pad_output > 0) {
358 const int ur_w = nstl::min(cur_l_pad_output, max_ur_w);
359 ow += ur_w;
360 const int cur_r_pad = calculate_end_padding(
361 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
362 icb_loop(ur_w, l_pad, cur_r_pad, h_padding);
363 add(reg_zp_pbuff, ur_w_shift(ur_w));
364
365 l_pad = nstl::max(l_pad - ur_w * jcp.stride_w, 0);
366 cur_l_pad_output = nstl::max(cur_l_pad_output - ur_w, 0);
367 }
368
369 if (no_pad > 0) {
370 const int ur_w = 1;
371 if (h_padding) icb_loop(ur_w, 0, 0, true);
372 if (h_padding || jcp.ow_mid) add(reg_zp_pbuff, ur_w_shift(ur_w));
373 }
374 assert(ow + no_pad == ow_start);
375
376 ow = ow_start;
377 int cur_r_pad_output = r_pad_start;
378 while (cur_r_pad_output > 0 && ow < jcp.ow) {
379 const int ur_w = nstl::min(cur_r_pad_output, max_ur_w);
380 ow += ur_w;
381 const int cur_r_pad = calculate_end_padding(
382 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
383 icb_loop(ur_w, 0, cur_r_pad, h_padding);
384 add(reg_zp_pbuff, ur_w_shift(ur_w));
385
386 cur_r_pad_output = nstl::max(cur_r_pad_output - ur_w, 0);
387 }
388 }
389
generate()390 void jit_avx512_core_amx_compute_zp_pbuff_t::generate() {
391 Label h_pad_label, end_label;
392
393 assert(jcp.req_zero_point_buffer);
394 assert(jcp.typesize_in == sizeof(char));
395
396 preamble();
397
398 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
399 mov(reg_zp_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
400 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
401
402 if (jcp.oc_without_padding != jcp.oc) {
403 const Reg32 reg_tmp = reg_scratch.cvt32();
404 const int tail_size = jcp.oc_without_padding % jcp.oc_block;
405 const int mask = (1 << tail_size) - 1;
406 mov(reg_tmp, mask);
407 kmovw(ktail_mask, reg_tmp);
408 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
409 }
410
411 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
412 cmp(reg_overflow, 0);
413 jne(h_pad_label, T_NEAR);
414 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
415 cmp(reg_overflow, 0);
416 jne(h_pad_label, T_NEAR);
417 if (jcp.ndims == 5 && (jcp.f_pad_output > 0 || jcp.back_pad_output > 0)) {
418 mov(reg_overflow, ptr[param1 + GET_OFF(kd_padding)]);
419 cmp(reg_overflow, jcp.kd);
420 jne(h_pad_label, T_NEAR);
421 }
422
423 // Handle width padding region
424 unroll_width(false);
425 jmp(end_label, T_NEAR);
426
427 // handle height padding region
428 L(h_pad_label);
429 unroll_width(true);
430
431 L(end_label);
432
433 postamble();
434
435 // reduced-lowering ('is_relo' == true) weights format is '..i16o', so
436 // permute elements through permb into the VNNI layout '...16i4i'.
437 if (jcp.is_relo) {
438 align(64);
439 L(permb_idx_label);
440 // permb: id-bit for table selection is bit[6]
441 const uint8_t select_src2_bit = 0x40;
442 // permb: bits [5:0] select the element within each input table
443 const uint8_t permb_idx_table[64] = {0, 16, 32, 48, 1, 17, 33, 49, 2,
444 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22,
445 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42,
446 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46,
447 62, 15, 31, 47, 63};
448 for (size_t i = 0; i < 64; ++i)
449 db(select_src2_bit | permb_idx_table[i]);
450
451 // write zero-mask (permb) for ic_tail in VNNI format '..16o4i'
452 const int ic_tail_size
453 = jcp.ic_without_padding % (jcp.ic_block / ic_inner_block);
454 if (jcp.ic_without_padding != jcp.ic && ic_tail_size > 0) {
455 align(64);
456 L(ic_mask_label);
457
458 assert(4 > ic_tail_size);
459 // mask is on a 4-bit basis from the 4 ic elements in a zmm
460 const int nibble = (1 << ic_tail_size) - 1;
461 for (int i = 0; i < 16; ++i) {
462 db(nibble | (nibble << 4));
463 }
464 }
465 }
466 }
467
generate()468 void jit_avx512_core_amx_copy_to_wbuffer_t::generate() {
469
470 const bool is_bf16 = jcp.src_dt == data_type::bf16;
471
472 // required for use of VPERMB instruction
473 assert(IMPLICATION(!is_bf16, cpu().has(Xbyak::util::Cpu::tAVX512_VBMI)));
474 assert(jcp.ic_block_int * jcp.typesize_in == 64);
475
476 preamble();
477
478 mov(reg_src, ptr[param1 + GET_OFF(src)]);
479 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
480
481 // load permute indices from data section
482 Label permute_index_table;
483 mov(reg_tmp, permute_index_table);
484 if (is_bf16)
485 vmovdqu16(zmm_idx, ptr[reg_tmp]);
486 else
487 vmovdqu8(zmm_idx, ptr[reg_tmp]);
488
489 const int vnni_width = is_bf16 ? 2 : 4;
490 const int r = jcp.kh * jcp.kw * jcp.ic_without_padding;
491 const int nb_r = div_up(r, vnni_width);
492 const int rtail = (r % vnni_width) * jcp.oc_block;
493 if (rtail > 0) {
494 uint64_t mask = (UINT64_C(1) << rtail) - 1;
495 mov(reg_tmp, mask);
496 kmovq(kmask_load, reg_tmp);
497 }
498 const int nb_z = rnd_up(nb_r, jcp.ic_block);
499 if (nb_r < nb_z) vpxord(zmm_zero, zmm_zero, zmm_zero);
500
501 const int tile_size = jcp.ic_block_int * jcp.oc_block * jcp.typesize_in;
502 const int ocb_src_step = r * jcp.oc_block * jcp.typesize_in;
503 const int ocb_dst_step = rnd_up(ocb_src_step, tile_size);
504
505 // reorder from ~Owhi16o -> ~OR16oVr with r := whi and V := vnni_width
506 for (int g = 0; g < jcp.ngroups; g++) {
507 for (int ocb = 0; ocb < jcp.nb_oc; ocb++) {
508 int offset = 0;
509 int rb = 0;
510 for (; rb < nb_r; offset += 64, rb++) {
511 auto zmm_src_tmp = (rtail > 0 && rb == nb_r - 1)
512 ? zmm_src | kmask_load | T_z
513 : zmm_src;
514 if (is_bf16) {
515 vmovdqu16(zmm_src_tmp, ptr[reg_src + offset]);
516 vpermw(zmm_dst, zmm_idx, zmm_src);
517 vmovdqu16(ptr[reg_dst + offset], zmm_dst);
518 } else {
519 vmovdqu8(zmm_src_tmp, ptr[reg_src + offset]);
520 vpermb(zmm_dst, zmm_idx, zmm_src);
521 vmovdqu8(ptr[reg_dst + offset], zmm_dst);
522 }
523 }
524 for (; rb < nb_z; offset += 64, rb++) {
525 if (is_bf16)
526 vmovdqu16(ptr[reg_dst + offset], zmm_zero);
527 else
528 vmovdqu8(ptr[reg_dst + offset], zmm_zero);
529 }
530 add(reg_src, ocb_src_step);
531 add(reg_dst, ocb_dst_step);
532 }
533 }
534
535 postamble();
536
537 align(64);
538 L(permute_index_table);
539 const uint8_t no = 16; // 16o
540 const uint8_t nr = is_bf16 ? 2 : 4; // 2r or 4r
541 for (uint8_t o = 0; o < no; ++o) {
542 for (uint8_t r = 0; r < nr; r++) {
543 const uint8_t index = o + r * no;
544 if (is_bf16)
545 dw(index);
546 else
547 db(index);
548 }
549 }
550 }
551
copy_row_body(int lpad,int iw_len,int icb)552 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_body(
553 int lpad, int iw_len, int icb) {
554
555 const bool is_bf16 = jcp.src_dt == data_type::bf16;
556 int iwp_idx = 0;
557 // there are min(gen_kw, jcp.stride_w) continuous sets of input
558 // data (for each stride idx), they are placed one by one
559 // without additional padding
560 const bool are_sets_interleaved
561 = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
562 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
563 const int num_sets = are_sets_interleaved ? jcp.n_stride_sets : jcp.kw;
564 for (int set_idx = 0; set_idx < num_sets; set_idx++) {
565 int set_width_padded = !jcp.is_pbuffer_strided
566 ? (jcp.ow_block - 1) * jcp.stride_w + gen_kw
567 : are_sets_interleaved ? jcp.ow_block - 1 + gen_kw / num_sets
568 + (set_idx < gen_kw % num_sets ? 1 : 0)
569 : jcp.ow_block;
570 for (int set_shift = 0; set_shift < set_width_padded;
571 set_shift++, iwp_idx++) {
572 int iw_idx = set_idx * (jcp.dilate_w + 1)
573 + set_shift * (jcp.is_pbuffer_strided ? jcp.stride_w : 1)
574 - lpad;
575 size_t out_base_offset
576 = (size_t)jcp.typesize_in * iwp_idx * jcp.ic_block_int_np;
577 if (iw_idx < 0 || iw_idx >= iw_len) {
578 // left or right padding
579 vmovups(ptr[reg_aux_out_ptr + out_base_offset], zmm_zero);
580 } else if (jcp.is_nspc) {
581 size_t inp_w_offset = (size_t)jcp.typesize_in * iw_idx
582 * jcp.ngroups * jcp.ic_without_padding;
583 int ic = icb * jcp.ic_block_int_np;
584 // TODO: use Xmm or Ymm moves for better small ic efficiency
585 auto zmm_tmp_mask
586 = ic + jcp.ic_block_int <= jcp.ic_without_padding
587 ? zmm_tmp
588 : zmm_tmp | ktail_mask | T_z;
589 if (is_bf16) {
590 vmovdqu16(
591 zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
592 vmovdqu16(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
593 } else {
594 vmovdqu8(zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
595 vmovdqu8(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
596 }
597 } else {
598 assert(is_bf16);
599 size_t inp_w_offset
600 = (size_t)jcp.typesize_in * iw_idx * jcp.ic_block;
601 for (int j = 0; j < jcp.ic_block_int_np / jcp.ic_block; j++) {
602 int ic = icb * jcp.ic_block_int_np + j * jcp.ic_block;
603 size_t inp_c_w_offset = (size_t)jcp.typesize_in * j * jcp.ih
604 * jcp.iw * jcp.ic_block
605 + inp_w_offset;
606 if (ic + jcp.ic_block <= jcp.ic) {
607 vmovdqu16(
608 ymm_tmp, ptr[reg_aux_inp_ptr + inp_c_w_offset]);
609 } else {
610 vpxord(ymm_tmp, ymm_tmp, ymm_tmp);
611 }
612 size_t out_offset = out_base_offset
613 + (size_t)jcp.typesize_in * j * jcp.ic_block;
614 vmovdqu16(ptr[reg_aux_out_ptr + out_offset], ymm_tmp);
615 }
616 }
617 }
618 }
619 }
620
copy_row(int icb)621 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) {
622 if (jcp.nb_ow == 1) {
623 copy_row_body(jcp.l_pad, jcp.iw, icb);
624 } else {
625 auto get_iw_len_required = [&](int cur_ow_block, int cur_lpad) {
626 return (cur_ow_block - 1) * jcp.stride_w
627 + (jcp.kw - 1) * (jcp.dilate_w + 1) + 1 - cur_lpad;
628 };
629
630 auto get_iw_len_limited = [&](int owb, int cur_ow_block, int cur_lpad) {
631 auto len_req = get_iw_len_required(cur_ow_block, cur_lpad);
632 if (owb < 0) return len_req;
633 int ow_block_start = nstl::max(
634 0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad);
635 return nstl::min(jcp.iw - ow_block_start, len_req);
636 };
637
638 int general_owb_cases = jcp.nb_ow;
639 Xbyak::Label copy_row_done_label;
640 bool special_first_block_case = jcp.l_pad > 0;
641 if (special_first_block_case) {
642 general_owb_cases--;
643 Xbyak::Label skip_first_block_case_label;
644 cmp(reg_owb, 0);
645 jne(skip_first_block_case_label, T_NEAR);
646 copy_row_body(jcp.l_pad,
647 get_iw_len_limited(0, jcp.ow_block, jcp.l_pad), icb);
648 jmp(copy_row_done_label, T_NEAR);
649 L(skip_first_block_case_label);
650 }
651 bool special_last_block_case = false
652 // has ow_block_tail
653 || jcp.ow % jcp.ow_block != 0
654 // there is no ow_block_tail but right padding exists
655 || get_iw_len_limited(jcp.nb_ow - 1, jcp.ow_block, 0)
656 != get_iw_len_required(jcp.ow_block, 0);
657 if (special_last_block_case) {
658 general_owb_cases--;
659 Xbyak::Label skip_last_block_case_label;
660 cmp(reg_owb, jcp.nb_ow - 1);
661 jne(skip_last_block_case_label, T_NEAR);
662 int ow_block_tail = jcp.ow % jcp.ow_block;
663 int cur_ow_block = ow_block_tail > 0 ? ow_block_tail : jcp.ow_block;
664 copy_row_body(
665 0, get_iw_len_limited(jcp.nb_ow - 1, cur_ow_block, 0), icb);
666 jmp(copy_row_done_label, T_NEAR);
667 L(skip_last_block_case_label);
668 }
669
670 bool special_penult_block_case = true
671 // if nb_ow = 2 and l_pad > 0 it's the same as
672 // special_first_block_case
673 && jcp.nb_ow >= (special_first_block_case ? 3 : 2)
674 // right padding exists in penult block
675 && get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0)
676 != get_iw_len_required(jcp.ow_block, 0);
677 if (special_penult_block_case) {
678 general_owb_cases--;
679 Xbyak::Label skip_penult_block_case_label;
680 cmp(reg_owb, jcp.nb_ow - 2);
681 jne(skip_penult_block_case_label, T_NEAR);
682 copy_row_body(
683 0, get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0), icb);
684 jmp(copy_row_done_label, T_NEAR);
685 L(skip_penult_block_case_label);
686 }
687
688 if (general_owb_cases > 0) // general case
689 copy_row_body(0, get_iw_len_required(jcp.ow_block, 0), icb);
690
691 L(copy_row_done_label);
692 }
693 }
694
copy_row_reduced_lowering()695 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() {
696 assert(jcp.nb_ic_int == 1);
697 assert(jcp.ic_block_int * jcp.typesize_in == 64);
698 assert(jcp.is_nspc);
699
700 auto load_mask = [=](int tail, Opmask kmask) {
701 uint64_t mask = (UINT64_C(1) << tail) - 1;
702 mov(reg_tmp, mask);
703 kmovq(kmask, reg_tmp);
704 };
705
706 const bool is_bf16 = jcp.src_dt == data_type::bf16;
707 const int inp_w_step
708 = jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in;
709 const int inp_h_step = jcp.iw * inp_w_step;
710 const int out_h_step = jcp.ic_without_padding * jcp.typesize_in;
711 const int out_w_step = jcp.kh * out_h_step;
712 const int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
713 if (tail_size > 0) load_mask(tail_size, ktail_mask);
714
715 auto zero_it = [=](reg64_t tmp_out_ptr) {
716 for (int ic = 0; ic < jcp.ic_without_padding; ic += jcp.ic_block_int) {
717 const int offset = ic * jcp.typesize_in;
718 const bool masked = ic + jcp.ic_block_int > jcp.ic_without_padding;
719 Zmm zmm = masked ? zmm_zero | ktail_mask : zmm_zero;
720 if (is_bf16)
721 vmovdqu16(ptr[tmp_out_ptr + offset], zmm);
722 else
723 vmovdqu8(ptr[tmp_out_ptr + offset], zmm);
724 }
725 };
726
727 // pointer to 1st needed element in src buffer
728 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
729 // pointer to 1st needed element in dst buffer
730 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
731
732 // total number of rows to copy
733 mov(reg_kht, ptr[param1 + GET_OFF(kh_offset)]);
734
735 // number of rows of src buffer to copy
736 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
737 // number of zero-padded rows above src buffer to copy
738 mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
739 // number of zero-padded rows below src buffer to copy
740 mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
741
742 // number of columns of src buffer to copy
743 mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
744 // number of zero-padded columns before src buffer to copy
745 mov(reg_lov, ptr[param1 + GET_OFF(f_overflow)]);
746 // number of zero-padded columns before src buffer to copy
747 mov(reg_rov, ptr[param1 + GET_OFF(back_overflow)]);
748
749 vpxord(zmm_zero, zmm_zero, zmm_zero);
750
751 { // Handle Left Overflow
752 Label label_lov, label_lov_skip;
753 test(reg_lov, reg_lov);
754 jz(label_lov_skip, T_NEAR);
755 L(label_lov); // handle left or right overflow
756 {
757 Label label_lov_inner;
758 mov(reg_aux_out_ptr, reg_out_ptr);
759 mov(reg_cnt, reg_kht);
760 L(label_lov_inner);
761 {
762 zero_it(reg_aux_out_ptr);
763 add(reg_aux_out_ptr, out_h_step);
764 dec(reg_cnt);
765 jnz(label_lov_inner, T_NEAR);
766 }
767 add(reg_out_ptr, out_w_step);
768 dec(reg_lov);
769 jnz(label_lov, T_NEAR);
770 }
771 L(label_lov_skip);
772 }
773
774 // save output pointer for later use
775 mov(reg_save_out_ptr, reg_out_ptr);
776
777 // just in case there is no meat...
778 Label label_kwp_end;
779 test(reg_kwp, reg_kwp);
780 jz(label_kwp_end, T_NEAR);
781
782 // Unroll over W-dimension in powers of 2
783 Label label_tov;
784 Label label_khp, label_no_khp;
785 Label label_bov;
786 test(reg_tov, reg_tov);
787 jnz(label_tov, T_NEAR);
788 test(reg_khp, reg_khp);
789 jnz(label_khp, T_NEAR);
790 test(reg_bov, reg_bov);
791 jnz(label_bov, T_NEAR);
792 jmp(label_kwp_end, T_NEAR); // safe exit in case of bad parameters
793
794 L(label_tov); // handle top overflow
795 {
796 Label label_tov_inner;
797 mov(reg_aux_out_ptr, reg_out_ptr);
798 mov(reg_cnt, reg_kwp);
799 L(label_tov_inner);
800 {
801 zero_it(reg_aux_out_ptr);
802 add(reg_aux_out_ptr, out_w_step);
803 dec(reg_cnt);
804 jnz(label_tov_inner, T_NEAR);
805 }
806 add(reg_out_ptr, out_h_step);
807 dec(reg_tov);
808 jnz(label_tov, T_NEAR);
809 }
810 test(reg_khp, reg_khp);
811 jz(label_no_khp, T_NEAR);
812 L(label_khp); // handle kh padding (not fully unrolled)
813 {
814 Label label_khp_inner;
815 mov(reg_aux_inp_ptr, reg_inp_ptr);
816 mov(reg_aux_out_ptr, reg_out_ptr);
817 mov(reg_cnt, reg_kwp);
818 L(label_khp_inner);
819 {
820 for (int ic = 0; ic < jcp.ic_without_padding;
821 ic += jcp.ic_block_int) {
822 const int offset = ic * jcp.typesize_in;
823 const bool masked
824 = ic + jcp.ic_block_int > jcp.ic_without_padding;
825 // zero masking is needed to avoid dependency on destination
826 Zmm zmm_load = masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
827 Zmm zmm_store = masked ? zmm_tmp | ktail_mask : zmm_tmp;
828 if (is_bf16) {
829 vmovdqu16(zmm_load, ptr[reg_aux_inp_ptr + offset]);
830 vmovdqu16(ptr[reg_aux_out_ptr + offset], zmm_store);
831 } else {
832 vmovdqu8(zmm_load, ptr[reg_aux_inp_ptr + offset]);
833 vmovdqu8(ptr[reg_aux_out_ptr + offset], zmm_store);
834 }
835 }
836 add(reg_aux_inp_ptr, inp_w_step);
837 add(reg_aux_out_ptr, out_w_step);
838 dec(reg_cnt);
839 jnz(label_khp_inner, T_NEAR);
840 }
841 add(reg_inp_ptr, inp_h_step);
842 add(reg_out_ptr, out_h_step);
843 dec(reg_khp);
844 jnz(label_khp, T_NEAR);
845 }
846 L(label_no_khp);
847 test(reg_bov, reg_bov);
848 jz(label_kwp_end, T_NEAR);
849 L(label_bov); // handle bottom overflow
850 {
851 Label label_bov_inner;
852 mov(reg_aux_out_ptr, reg_out_ptr);
853 mov(reg_cnt, reg_kwp);
854 L(label_bov_inner);
855 {
856 zero_it(reg_aux_out_ptr);
857 add(reg_aux_out_ptr, out_w_step);
858 dec(reg_cnt);
859 jnz(label_bov_inner, T_NEAR);
860 }
861 add(reg_out_ptr, out_h_step);
862 dec(reg_bov);
863 jnz(label_bov, T_NEAR);
864 }
865 L(label_kwp_end);
866
867 { // Handle Right Overflow
868 Label label_rov, label_rov_skip;
869 // retrieve output pointer
870 mov(reg_out_ptr, reg_save_out_ptr);
871 // calculate the shift
872 imul(reg_tmp, reg_kwp, out_w_step);
873 // shift past the body
874 add(reg_out_ptr, reg_tmp);
875 // skip if no right overflow
876 test(reg_rov, reg_rov);
877 jz(label_rov_skip, T_NEAR);
878
879 L(label_rov); // handle left or right overflow
880 {
881 Label label_rov_inner;
882 mov(reg_aux_out_ptr, reg_out_ptr);
883 mov(reg_cnt, reg_kht);
884 L(label_rov_inner);
885 {
886 zero_it(reg_aux_out_ptr);
887 add(reg_aux_out_ptr, out_h_step);
888 dec(reg_cnt);
889 jnz(label_rov_inner, T_NEAR);
890 }
891 add(reg_out_ptr, out_w_step);
892 dec(reg_rov);
893 jnz(label_rov, T_NEAR);
894 }
895 L(label_rov_skip);
896 }
897
898 // For bf16, zero-pad an extra cacheline to avoid NaNs
899 // For int8, it is sufficient to zero-pad the weights only
900 if (is_bf16) {
901 // shift forward to align h index to end of needed buffer
902 imul(reg_tmp, reg_kht, out_h_step);
903 add(reg_out_ptr, reg_tmp);
904 // shift backward to align w index to end of needed buffer
905 sub(reg_out_ptr, out_w_step);
906 vmovdqu16(ptr[reg_out_ptr], zmm_zero);
907 }
908 }
909
generate()910 void jit_avx512_core_amx_copy_to_pbuffer_t::generate() {
911
912 // Special copy kernel for reduced lowering
913 if (jcp.is_relo) {
914 assert(jcp.nb_ic_int == 1);
915 preamble();
916 copy_row_reduced_lowering();
917 postamble();
918 return;
919 }
920
921 preamble();
922
923 const bool is_3d = jcp.ndims == 5;
924 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
925 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
926 if (is_3d) mov(reg_kdp, ptr[param1 + GET_OFF(kd_padding)]);
927 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
928 mov(reg_tover, ptr[param1 + GET_OFF(t_overflow)]);
929 mov(reg_bover, ptr[param1 + GET_OFF(b_overflow)]);
930 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
931
932 vpxord(zmm_zero, zmm_zero, zmm_zero);
933
934 if (jcp.is_nspc && jcp.ic_without_padding % jcp.ic_block_int) {
935 int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
936 uint64_t mask = (UINT64_C(1) << tail_size) - 1;
937 mov(reg_tmp, mask);
938 kmovq(ktail_mask, reg_tmp);
939 }
940
941 for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
942 Xbyak::Label kd_label, no_kd_label;
943 Xbyak::Label kh_label, no_kh_label, icb_label;
944 Xbyak::Label kh_tover_label, kh_bover_label;
945 Xbyak::Label no_kh_tover_label, no_kh_bover_label;
946
947 mov(reg_aux_inp_ptr, reg_inp_ptr);
948 mov(reg_aux_out_ptr, reg_out_ptr);
949 if (is_3d) {
950 cmp(reg_kdp, 0);
951 jle(no_kd_label, T_NEAR);
952 mov(reg_kdc, reg_kdp);
953 L(kd_label);
954 push(reg_aux_inp_ptr);
955 push(reg_aux_out_ptr);
956 }
957 cmp(reg_khp, 0);
958 jle(no_kh_bover_label, T_NEAR); // nothing to do
959 mov(reg_khc, reg_khp);
960
961 cmp(reg_tover, 0);
962 jle(no_kh_tover_label, T_NEAR);
963
964 mov(reg_kh_over, reg_tover);
965 L(kh_tover_label);
966 {
967 // TODO: adjust step to improve zeroing efficiency for small ic
968 for (int iw = 0; iw < jcp.iwp; iw++)
969 vmovups(ptr[reg_aux_out_ptr
970 + jcp.typesize_in * iw * jcp.ic_block_int_np],
971 zmm_zero);
972 int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
973 add(reg_aux_out_ptr, out_h_offset);
974
975 dec(reg_kh_over);
976 jnz(kh_tover_label, T_NEAR);
977 }
978 sub(reg_khc, reg_tover);
979 L(no_kh_tover_label);
980
981 cmp(reg_khc, reg_bover);
982 jle(no_kh_label, T_NEAR);
983
984 L(kh_label);
985 {
986 copy_row(icb);
987 size_t inp_h_offset = !jcp.is_nspc
988 ? (size_t)jcp.typesize_in * jcp.iw * jcp.ic_block
989 : (size_t)jcp.typesize_in * jcp.iw * jcp.ngroups
990 * jcp.ic_without_padding;
991 size_t out_h_offset
992 = (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
993
994 add(reg_aux_inp_ptr, inp_h_offset);
995 add(reg_aux_out_ptr, out_h_offset);
996
997 dec(reg_khc);
998 cmp(reg_khc, reg_bover);
999 jg(kh_label, T_NEAR);
1000 }
1001 L(no_kh_label);
1002
1003 cmp(reg_khc, 0);
1004 jle(no_kh_bover_label, T_NEAR);
1005
1006 L(kh_bover_label);
1007 {
1008 // TODO: adjust step to improve zeroing efficiency for small ic
1009 for (int iw = 0; iw < jcp.iwp; iw++)
1010 vmovups(ptr[reg_aux_out_ptr
1011 + jcp.typesize_in * iw * jcp.ic_block_int_np],
1012 zmm_zero);
1013 int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
1014 add(reg_aux_out_ptr, out_h_offset);
1015
1016 dec(reg_khc);
1017 jnz(kh_bover_label, T_NEAR);
1018 }
1019 size_t out_d_offset = (size_t)jcp.typesize_in
1020 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1021 L(no_kh_bover_label);
1022 if (is_3d) {
1023 size_t inp_d_offset = !jcp.is_nspc
1024 ? (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block
1025 * (jcp.dilate_d + 1)
1026 : (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ngroups
1027 * jcp.ic_without_padding * (jcp.dilate_d + 1);
1028 pop(reg_aux_out_ptr);
1029 pop(reg_aux_inp_ptr);
1030 add(reg_aux_inp_ptr, inp_d_offset);
1031 add(reg_aux_out_ptr, out_d_offset);
1032 dec(reg_kdc);
1033 jnz(kd_label, T_NEAR);
1034 L(no_kd_label);
1035 }
1036 // End IC Loop
1037 size_t inp_cb_offset = !jcp.is_nspc
1038 ? (size_t)jcp.typesize_in * (jcp.ic_block_int_np / jcp.ic_block)
1039 * jcp.id * jcp.ih * jcp.iw * jcp.ic_block
1040 : (size_t)jcp.typesize_in * jcp.ic_block_int_np;
1041 size_t out_cb_offset = (size_t)jcp.kd * out_d_offset;
1042
1043 add(reg_inp_ptr, inp_cb_offset);
1044 add(reg_out_ptr, out_cb_offset);
1045 }
1046
1047 postamble();
1048 }
1049
jit_avx512_core_amx_fwd_kernel_t(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_t & dst_md)1050 jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t(
1051 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
1052 const memory_desc_t &dst_md)
1053 : jit_generator(nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
1054 , jcp(ajcp)
1055 , attr_(attr) {
1056 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
1057 using namespace binary_injector;
1058 const auto &rhs_addr_reg = bin_injector_helper_reg_1;
1059 const auto &rhs_helper_reg = bin_injector_helper_reg_2;
1060 static constexpr bool preserve_gpr = false;
1061 static constexpr bool preserve_vmm = false;
1062 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
1063 static constexpr bool use_exact_tail_scalar_bcast = true;
1064
1065 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
1066 31, rhs_addr_reg, rhs_helper_reg, preserve_gpr, preserve_vmm,
1067 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
1068 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
1069 use_exact_tail_scalar_bcast};
1070 const binary_injector::static_params_t static_params {
1071 this->param1, rhs_arg_static_params};
1072
1073 postops_injector_ = utils::make_unique<
1074 injector::jit_uni_postops_injector_t<avx512_core>>(
1075 this, jcp.post_ops, static_params);
1076 }
1077 copy_to_pbuffer_
1078 = utils::make_unique<jit_avx512_core_amx_copy_to_pbuffer_t>(jcp);
1079 if (jcp.is_relo)
1080 copy_to_wbuffer_
1081 = utils::make_unique<jit_avx512_core_amx_copy_to_wbuffer_t>(
1082 jcp);
1083 }
1084
create_kernel()1085 status_t jit_avx512_core_amx_fwd_kernel_t::create_kernel() {
1086 CHECK(jit_generator::create_kernel());
1087 CHECK(copy_to_pbuffer_->create_kernel());
1088 if (jcp.is_relo) CHECK(copy_to_wbuffer_->create_kernel());
1089 if (jcp.req_zero_point_buffer) {
1090 zp_pbuff_kernel_
1091 = utils::make_unique<jit_avx512_core_amx_compute_zp_pbuff_t>(
1092 jcp);
1093 if (zp_pbuff_kernel_ == nullptr) return status::out_of_memory;
1094 CHECK(zp_pbuff_kernel_->create_kernel());
1095 }
1096 return status::success;
1097 }
1098
1099 // Tile register decomposition
1100 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i,bool is_h_tail) const1101 int jit_avx512_core_amx_fwd_kernel_t::get_out_tensor(
1102 int h, int i, bool is_h_tail) const {
1103 const int C_BASE = 0;
1104 const int C_LAST = 4;
1105 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
1106 MAYBE_UNUSED(C_LAST);
1107 const int tile = C_BASE
1108 + (jcp.nb_oh_blocking > 1
1109 ? h * jcp.nb_oh_blocking + i
1110 : (int)is_h_tail * jcp.nb_oc_blocking + i);
1111 assert(C_BASE <= tile && tile < C_LAST);
1112 return tile;
1113 }
get_inp_tensor(int h,bool is_h_tail) const1114 int jit_avx512_core_amx_fwd_kernel_t::get_inp_tensor(
1115 int h, bool is_h_tail) const {
1116 const int I_BASE = 4;
1117 const int I_LAST = 6;
1118 assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
1119 MAYBE_UNUSED(I_LAST);
1120 const int tile = I_BASE + (jcp.nb_oh_blocking > 1 ? h : (int)is_h_tail);
1121 assert(I_BASE <= tile && tile < I_LAST);
1122 return tile;
1123 }
get_wei_tensor(int i) const1124 int jit_avx512_core_amx_fwd_kernel_t::get_wei_tensor(int i) const {
1125 const int W_BASE = 6;
1126 const int W_LAST = 8;
1127 assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
1128 MAYBE_UNUSED(W_LAST);
1129 const int tile = W_BASE + i;
1130 assert(W_BASE <= tile && tile < W_LAST);
1131 return tile;
1132 }
1133
1134 // Shifts and offsets
get_inp_icb_step() const1135 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_icb_step() const {
1136 return (size_t)jcp.kd * get_inp_d_step();
1137 }
get_wei_icb_step() const1138 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_icb_step() const {
1139 return (size_t)jcp.typesize_in * jcp.kd * jcp.kh * jcp.kw
1140 * jcp.ic_block_int_np * jcp.oc_block;
1141 }
get_inp_d_step() const1142 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_d_step() const {
1143 return (size_t)jcp.typesize_in
1144 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1145 }
get_inp_h_step() const1146 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_h_step() const {
1147 return (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np
1148 * (jcp.dilate_h + 1);
1149 }
get_wei_d_step() const1150 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_d_step() const {
1151 return (size_t)jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block_int_np
1152 * jcp.oc_block;
1153 }
get_wei_h_step() const1154 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_h_step() const {
1155 return (size_t)jcp.typesize_in * jcp.kw * jcp.ic_block_int_np
1156 * jcp.oc_block;
1157 }
get_out_ocb_offset(int ohb,int ocb,size_t typesize) const1158 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_ocb_offset(
1159 int ohb, int ocb, size_t typesize) const {
1160 size_t el_offset = jcp.is_nspc
1161 ? (size_t)ocb * jcp.oc_block
1162 + (size_t)ohb * jcp.ow * jcp.ngroups
1163 * jcp.oc_without_padding
1164 : (size_t)ocb * jcp.oh * jcp.ow * jcp.oc_block
1165 + (size_t)ohb * jcp.ow * jcp.oc_block;
1166 return (size_t)typesize * el_offset;
1167 }
get_out_row_offset(int ohb,int ocb,int j,size_t typesize) const1168 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_row_offset(
1169 int ohb, int ocb, int j, size_t typesize) const {
1170 size_t offset_w = jcp.is_nspc
1171 ? (size_t)typesize * j * jcp.ngroups * jcp.oc_without_padding
1172 : (size_t)typesize * j * jcp.oc_block;
1173 return get_out_ocb_offset(ohb, ocb, typesize) + offset_w;
1174 }
get_out_shift(int width,size_t typesize) const1175 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_shift(
1176 int width, size_t typesize) const {
1177 return jcp.is_nspc
1178 ? (size_t)typesize * width * jcp.ngroups * jcp.oc_without_padding
1179 : (size_t)typesize * width * jcp.oc_block;
1180 }
get_wsp_ocb_offset(int ohb,int ocb) const1181 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_ocb_offset(
1182 int ohb, int ocb) const {
1183 size_t el_offset = (size_t)ocb * prv_width_ * jcp.oc_block
1184 + (size_t)ohb * jcp.nb_oc_blocking * jcp.full_tile_width
1185 * jcp.oc_block;
1186 return jcp.typesize_acc * el_offset;
1187 }
get_wsp_row_offset(int ohb,int ocb,int j) const1188 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_row_offset(
1189 int ohb, int ocb, int j) const {
1190 return get_wsp_ocb_offset(ohb, ocb)
1191 + (size_t)jcp.typesize_acc * j * jcp.oc_block;
1192 }
get_wsp_shift() const1193 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_shift() const {
1194 return (size_t)jcp.typesize_acc * jcp.nb_oh_blocking * jcp.full_tile_width
1195 * jcp.oc_block * jcp.nb_oc_blocking;
1196 }
get_wei_offset(int ocb,int kw) const1197 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_offset(int ocb, int kw) const {
1198 size_t el_offset = (size_t)kw * jcp.ic_block_int_np * jcp.oc_block;
1199 size_t raw_oc_subblock_step
1200 = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
1201 size_t oc_subblock_step = jcp.is_relo
1202 ? rnd_up(raw_oc_subblock_step, jcp.ic_block_int * jcp.oc_block)
1203 : raw_oc_subblock_step;
1204 el_offset += (size_t)ocb * jcp.nb_ic_int * oc_subblock_step;
1205 return jcp.typesize_in * el_offset;
1206 }
get_inp_shift() const1207 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_shift() const {
1208 size_t w_step = (jcp.is_relo ? jcp.stride_w * jcp.kh
1209 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w)
1210 * jcp.ic_block_int_np;
1211 return (size_t)jcp.typesize_in * jcp.tile_width * w_step;
1212 }
get_inp_offset(int ohb,int kw) const1213 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_offset(int ohb, int kw) const {
1214 if (jcp.is_relo)
1215 return ohb * jcp.iwp * jcp.kh * jcp.ic_block_int_np * jcp.typesize_in;
1216 // calculate offset by height dimension
1217 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
1218 const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh);
1219 size_t el_offset = (size_t)ohb * jcp.oh_per_tile * gen_stride_h * jcp.iwp
1220 * jcp.ic_block_int_np;
1221
1222 // add offset by width dimension
1223 if (IMPLICATION(jcp.is_pbuffer_strided, jcp.stride_w == 1)) {
1224 el_offset += (size_t)kw * (jcp.dilate_w + 1) * jcp.ic_block_int_np;
1225 } else if (jcp.dilate_w > 0) {
1226 el_offset += (size_t)kw * jcp.ow_block * jcp.ic_block_int_np;
1227 } else {
1228 // dilate_w == 0 && stride_w > 1
1229 // there are min(jcp.kw, jcp.stride_w) continuous sets of input data
1230 // (foreach stride idx), they are placed one by one without additional
1231 // padding
1232
1233 // calculate set idx for current kw value
1234 int set_idx = kw % jcp.stride_w;
1235 // calculate shift within set for current kw value
1236 int set_shift = kw / jcp.stride_w;
1237
1238 // calculate the beginning of the current set along width, each set
1239 // with index set_i contains number of elements along width equal to
1240 // jcp.ow - 1 + jcp.kw / jcp.stride_w
1241 // + (set_i < jcp.kw % jcp.stride_w)
1242 size_t set_start = (jcp.ow_block - 1 + jcp.kw / jcp.stride_w) * set_idx
1243 + nstl::min(set_idx, jcp.kw % jcp.stride_w);
1244 el_offset += (set_start + set_shift) * jcp.ic_block_int_np;
1245 }
1246 return jcp.typesize_in * el_offset;
1247 }
1248
get_zp_comp_offset(int ocb,int zp_h,int zp_w) const1249 size_t jit_avx512_core_amx_fwd_kernel_t::get_zp_comp_offset(
1250 int ocb, int zp_h, int zp_w) const {
1251 const size_t ocb_offset = (size_t)ocb * jcp.oc_block;
1252 const size_t sp_offset = (size_t)(zp_h * jcp.ow_pad + zp_w) * jcp.ngroups
1253 * jcp.oc_without_padding;
1254 return (ocb_offset + sp_offset) * sizeof(int32_t);
1255 }
1256
get_zp_index_offset(int index,int mid,int s_pad_output,int e_pad_output)1257 int jit_avx512_core_amx_fwd_kernel_t::get_zp_index_offset(
1258 int index, int mid, int s_pad_output, int e_pad_output) {
1259 using namespace nstl;
1260 const int mid_end = e_pad_output - 1;
1261 int zp_mid = min(mid, max(0, index - mid_end));
1262 int zp_pad_offset
1263 = accum_with_upper_bound(index, s_pad_output, e_pad_output);
1264 return zp_pad_offset + zp_mid;
1265 }
1266
1267 // Code generation
prepare_output(int tail)1268 void jit_avx512_core_amx_fwd_kernel_t::prepare_output(int tail) {
1269 for (int h = 0; h < jcp.nb_oh_blocking; h++)
1270 for (int i = 0; i < jcp.nb_oc_blocking; i++)
1271 tilezero(Tmm(get_out_tensor(h, i, tail)));
1272 }
1273
init_runtime_counters(bool start_with_last_tile_block)1274 void jit_avx512_core_amx_fwd_kernel_t::init_runtime_counters(
1275 bool start_with_last_tile_block) {
1276 prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
1277 ? jcp.tile_tail
1278 : jcp.tile_width;
1279
1280 row_count_ = 0;
1281 is_store_done_ = false;
1282 is_buffer_empty_ = true;
1283 }
1284
reduce_to_block(const int block_size,const int pad_output)1285 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_block(
1286 const int block_size, const int pad_output) {
1287 return (size_t)(pad_output >= block_size ? block_size : 0)
1288 + (pad_output % block_size);
1289 }
1290
reduce_to_blocked_dims(const int dim_size,const int block_size,const int s_pad_output,const int e_pad_output)1291 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_blocked_dims(
1292 const int dim_size, const int block_size, const int s_pad_output,
1293 const int e_pad_output) {
1294 using namespace nstl;
1295
1296 // start padding (s_pad)
1297 int s_pad_limit = reduce_to_block(block_size, s_pad_output);
1298 int s_pad_area_blk = rnd_up(s_pad_limit, block_size);
1299
1300 // middle (no padding)
1301 int no_pad_area = max(
1302 0, dim_size - rnd_up(s_pad_output, block_size) - e_pad_output);
1303 int no_pad_limit = (no_pad_area >= block_size ? block_size : 0);
1304
1305 // end padding (e_pad)
1306 int no_pad_area_shift = no_pad_area % block_size;
1307 int e_pad_area_overlap
1308 = no_pad_area_shift == 0 ? 0 : block_size - no_pad_area_shift;
1309 // middle and end padding shift
1310 int e_pad_shift_limit
1311 = no_pad_area_shift + min(e_pad_output, e_pad_area_overlap);
1312 int e_pad_area_blk = max(0, e_pad_output - e_pad_area_overlap);
1313 // full end padding block
1314 int e_pad_limit = reduce_to_block(block_size, e_pad_area_blk);
1315
1316 // calculate reduced size of s_pad, middle and e_pad blocks.
1317 return min((size_t)dim_size,
1318 (size_t)s_pad_area_blk + no_pad_limit + e_pad_shift_limit
1319 + e_pad_limit);
1320 }
1321
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)1322 Ymm jit_avx512_core_amx_fwd_kernel_t::ymm_mask(
1323 const Ymm &ymm_in, bool mask_flag, bool store) {
1324 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
1325 : ymm_in;
1326 }
1327
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)1328 Zmm jit_avx512_core_amx_fwd_kernel_t::zmm_mask(
1329 const Zmm &zmm_in, bool mask_flag, bool store) {
1330 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
1331 : zmm_in;
1332 }
1333
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)1334 void jit_avx512_core_amx_fwd_kernel_t::cvt2ps(data_type_t type_in,
1335 const Zmm &zmm_in, const Operand &op, bool mask_flag) {
1336 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
1337 switch (type_in) {
1338 case data_type::f32:
1339 case data_type::s32: vmovups(zmm, op); break;
1340 case data_type::s8: vpmovsxbd(zmm, op); break;
1341 case data_type::u8: vpmovzxbd(zmm, op); break;
1342 default: assert(!"unsupported data type");
1343 }
1344 if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
1345 }
1346
apply_sum(const Zmm & zmm_out,const float * p_sum_scale,const int32_t * p_sum_zp,const Xbyak::Address & addr,const bool mask_flag)1347 void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out,
1348 const float *p_sum_scale, const int32_t *p_sum_zp,
1349 const Xbyak::Address &addr, const bool mask_flag) {
1350 if (p_sum_scale) {
1351 const float p_sum_scale_val = *p_sum_scale;
1352 const int32_t p_sum_zp_val = *p_sum_zp;
1353 const auto sum_injector = [&, p_sum_scale_val, p_sum_zp_val,
1354 mask_flag]() {
1355 cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag);
1356 if (p_sum_zp_val != 0) {
1357 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
1358 vsubps(zmm_prev_dst, zmm_sum_zp);
1359 }
1360 if (p_sum_scale_val == 1.f)
1361 vaddps(zmm_out, zmm_prev_dst);
1362 else
1363 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
1364 };
1365 postops_injector_->set_lambda_injector(
1366 primitive_kind::sum, sum_injector);
1367 }
1368 }
1369
apply_postops(const Zmm & zmm_out,const float * p_sum_scale,const int32_t * p_sum_zp,const Xbyak::Address & addr,const size_t off,const bool mask_flag)1370 void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out,
1371 const float *p_sum_scale, const int32_t *p_sum_zp,
1372 const Xbyak::Address &addr, const size_t off, const bool mask_flag) {
1373 if (jcp.with_eltwise || jcp.with_binary
1374 || (jcp.with_sum && p_sum_scale != nullptr)) {
1375 apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag);
1376
1377 const auto vmm_idx = zmm_out.getIdx();
1378 if (jcp.with_binary) {
1379 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1380 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_out_ptr);
1381 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off);
1382 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1383
1384 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
1385 } else {
1386 postops_injector_->compute_vector(vmm_idx);
1387 }
1388 }
1389 }
1390
store_output_vector_bf16(const Zmm & zmm_out,int ocb,int h,int w)1391 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16(
1392 const Zmm &zmm_out, int ocb, int h, int w) {
1393 const bool mask_flag = jcp.is_nspc && jcp.oc_without_padding != jcp.oc
1394 && ocb == (jcp.nb_oc_blocking - 1);
1395
1396 const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1397 auto addr = EVEX_compress_addr(reg_out_ptr, off);
1398
1399 const auto &p = attr_.post_ops_;
1400
1401 const int sum_idx = p.find(primitive_kind::sum);
1402 if (sum_idx != -1) {
1403 if (jcp.dst_dt == data_type::bf16) {
1404 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
1405 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
1406 vaddps(zmm_out, zmm_prev_dst);
1407 } else {
1408 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
1409 vaddps(zmm_out, zmm_prev_dst);
1410 }
1411 }
1412 if (jcp.with_bias) {
1413 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
1414 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1415 if (jcp.bia_dt == data_type::bf16) {
1416 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
1417 vpslld(zmm_bias, zmm_bias, 16);
1418 vaddps(zmm_out, zmm_bias);
1419 } else
1420 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
1421 }
1422
1423 static constexpr auto skip_sum_injection = nullptr;
1424 apply_postops(zmm_out, skip_sum_injection, skip_sum_injection, addr, off,
1425 mask_flag);
1426
1427 if (jcp.dst_dt == data_type::bf16) {
1428 Ymm ymm_out = Ymm(zmm_out.getIdx());
1429 vcvtneps2bf16(ymm_out, zmm_out);
1430 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
1431 } else {
1432 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
1433 }
1434 }
1435
store_output_vector_int8(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1436 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8(
1437 const Zmm &zmm_out, int ocb, int h, int w, const bool compute_zp,
1438 const int zp_h, const int zp_w) {
1439 const int nb_oc_block = jcp.nb_oc_blocking;
1440 const int oc_block = jcp.oc_block;
1441 const bool mask_flag = true && jcp.oc_without_padding != jcp.oc
1442 && ocb == (nb_oc_block - 1);
1443
1444 const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1445 auto addr = EVEX_compress_addr(reg_out_ptr, off);
1446
1447 const auto &p = attr_.post_ops_;
1448 const int sum_idx = p.find(primitive_kind::sum);
1449 const float *p_sum_scale = nullptr;
1450 const int32_t *p_sum_zp = nullptr;
1451 if (sum_idx != -1) {
1452 const auto &p_entry = p.entry_[sum_idx];
1453 p_sum_scale = &p_entry.sum.scale;
1454 p_sum_zp = &p_entry.sum.zero_point;
1455 }
1456
1457 if (p_sum_scale) {
1458 if (*p_sum_scale != 1.f)
1459 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
1460 if (*p_sum_zp != 0)
1461 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
1462 }
1463
1464 int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * oc_block);
1465 if (jcp.with_bias) {
1466 int bias_offset = jcp.typesize_bia * ocb * oc_block;
1467 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1468 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
1469 }
1470 if (compute_zp) {
1471 assert(jcp.req_zero_point_buffer);
1472 // add zero-point padding compensation when accum data is S32
1473 const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1474 vmovups(m_zmm_zp,
1475 EVEX_compress_addr(reg_zero_point_pbuff,
1476 get_zp_comp_offset(ocb, zp_h, zp_w)));
1477 const Zmm m_zmm_out = zmm_mask(zmm_out, mask_flag);
1478 vpaddd(m_zmm_out, zmm_out, zmm_zp);
1479 }
1480 if (jcp.src_zero_point) {
1481 // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
1482 int zp_offset = sizeof(int32_t) * ocb * oc_block;
1483 const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1484 vpmulld(m_zmm_zp, zmm_src_zp,
1485 EVEX_compress_addr(reg_zp_compensation, zp_offset));
1486 vpaddd(zmm_out, zmm_out, zmm_zp);
1487 }
1488
1489 /* add bias and zero-point to zmm_accum */
1490 vcvtdq2ps(zmm_out, zmm_out);
1491 if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
1492 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
1493 vmulps(zmm_out_msk, zmm_out,
1494 EVEX_compress_addr(reg_ptr_scales, scale_offset));
1495
1496 apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag);
1497
1498 if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); }
1499
1500 // Properly saturate the accumulators for integer datatypes
1501 if (one_of(jcp.dst_dt, u8, s8, s32)) {
1502 init_saturate_f32(
1503 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dst_dt);
1504 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt);
1505 vcvtps2dq(zmm_out, zmm_out);
1506 }
1507
1508 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
1509
1510 switch (jcp.dst_dt) {
1511 case data_type::f32:
1512 case data_type::s32: vmovups(addr, zmm_out_store); break;
1513 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
1514 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
1515 default: assert(!"unknown dst_dt");
1516 }
1517 }
1518
store_output_vector(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1519 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector(const Zmm &zmm_out,
1520 int ocb, int h, int w, const bool compute_zp, const int zp_h,
1521 const int zp_w) {
1522 /*
1523 Output:
1524 jcp.is_nspc !jcp.is_nspc
1525 --------------------- ---------------------
1526 INT8: [N][H][W][NBOC][16OC]
1527 BF16: [N][H][W][NBOC][16OC] or [N][NBOC][H][W][16OC]
1528 */
1529 if (jcp.src_dt == data_type::bf16) {
1530 store_output_vector_bf16(zmm_out, ocb, h, w);
1531 } else {
1532 store_output_vector_int8(zmm_out, ocb, h, w, compute_zp, zp_h, zp_w);
1533 }
1534 }
1535
store_output(int width,int tail,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool is_last_oh_block,const bool zp_3d_pad)1536 void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail,
1537 bool do_store, const bool handle_h_blk, const int t_pad_output,
1538 const int b_pad_output, const int l_pad_output, const int r_pad_output,
1539 const bool is_last_oh_block, const bool zp_3d_pad) {
1540 auto store_output_block = [=](int width, int tail, bool do_store,
1541 bool is_last_h = false) {
1542 // Calculate the number of oh blocks; it may differ on last call
1543 const int last_h_blks
1544 = div_up(jcp.oh, jcp.oh_per_tile) % jcp.nb_oh_blocking;
1545 const int h_blks = is_last_h && last_h_blks != 0 ? last_h_blks
1546 : jcp.nb_oh_blocking;
1547 // Calculate the number of oh rows per tile; it may differ on last call
1548 const int h_tail = is_last_h && jcp.oh % jcp.oh_per_tile != 0
1549 ? (h_blks - 1) * jcp.oh_per_tile + jcp.oh % jcp.oh_per_tile
1550 : h_blks * jcp.oh_per_tile;
1551 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
1552 const int owp = gen_kw + jcp.ow - 1;
1553
1554 if (jcp.src_zero_point) {
1555 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1556 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1557 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1558 }
1559 if (jcp.dst_zero_point) {
1560 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1561 vcvtdq2ps(zmm_dst_zp,
1562 EVEX_compress_addr(reg_dst_zero_point, 0, true));
1563 }
1564
1565 for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
1566 for (int ohb = 0; ohb < h_blks; ohb++) {
1567 /* Formats: Workspace: [NBOC][W][16OC] */
1568 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
1569 + get_wsp_ocb_offset(ohb, ocb)],
1570 Tmm(get_out_tensor(ohb, ocb, tail)));
1571 is_buffer_empty_ = false;
1572 is_store_done_ = false;
1573
1574 // preserve registers used by binary post_ops injector
1575 const injector_utils::conditional_register_preserve_guard_t
1576 cond_register_guard(jcp.with_binary, this,
1577 {bin_injector_helper_reg_1,
1578 bin_injector_helper_reg_2});
1579
1580 for (int tw = 0; tw < width && do_store; tw++) {
1581 // height
1582 const int oh_index = ohb * jcp.oh_per_tile + tw / owp;
1583 const bool zp_h_pad
1584 = oh_index < t_pad_output || oh_index >= b_pad_output;
1585 const int zp_h = get_zp_index_offset(
1586 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1587 // width
1588 const int ow_index = tw % owp;
1589 const bool zp_w_pad
1590 = ow_index < l_pad_output || ow_index >= r_pad_output;
1591 const int zp_w = get_zp_index_offset(
1592 ow_index, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1593
1594 const bool compute_zp = jcp.req_zero_point_buffer
1595 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1596
1597 assert(IMPLICATION(jcp.oh_per_tile == 1,
1598 ohb == oh_index && tw == ow_index));
1599 if (oh_index < h_tail && ow_index < jcp.ow) {
1600 Zmm zmm_r = zmm_out(tw);
1601 vmovups(zmm_r,
1602 ptr[reg_wsp_ptr
1603 + get_wsp_row_offset(ohb, ocb, tw)]);
1604 store_output_vector(zmm_r, ocb, oh_index, ow_index,
1605 compute_zp, zp_h, zp_w);
1606 }
1607 }
1608 }
1609 };
1610
1611 // adjustment in case interleave store is turned off
1612 do_store = do_store || jcp.per_one_pstore == 0;
1613 if (!do_store) { w_padding.emplace(l_pad_output, r_pad_output); }
1614 if (!handle_h_blk) {
1615 store_output_block(width, tail, do_store, is_last_oh_block);
1616 } else {
1617 if (jcp.oh % (jcp.oh_per_tile * jcp.nb_oh_blocking) == 0) {
1618 store_output_block(width, tail, do_store);
1619 } else {
1620 Label label_oh_oc_store, label_done;
1621 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
1622 cmp(reg_last_h, 0);
1623 jne(label_oh_oc_store, T_NEAR);
1624 store_output_block(width, tail, do_store, true); // last h
1625 jmp(label_done, T_NEAR);
1626 L(label_oh_oc_store);
1627 store_output_block(width, tail, do_store, false);
1628 L(label_done);
1629 }
1630 }
1631 if (do_store) {
1632 add(reg_out_ptr, get_out_shift(width, jcp.typesize_out));
1633 if (jcp.req_zero_point_buffer) {
1634 const size_t sp_shift
1635 = accum_with_upper_bound(width, l_pad_output, r_pad_output);
1636 add(reg_zero_point_pbuff, get_out_shift(sp_shift, sizeof(int32_t)));
1637 }
1638 }
1639 }
1640
interleave_store(int width,int const t_pad_output,int const b_pad_output,const bool zp_3d_pad)1641 void jit_avx512_core_amx_fwd_kernel_t::interleave_store(int width,
1642 int const t_pad_output, int const b_pad_output, const bool zp_3d_pad) {
1643 for (int c = 0;
1644 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
1645 c++) {
1646 // row_count = ohb * OCB * TW + ocb * TW + tw
1647 int tw = row_count_ % prv_width_;
1648 int ocb = (row_count_ / prv_width_) % jcp.nb_oc_blocking;
1649 int ohb = (row_count_ / prv_width_) / jcp.nb_oc_blocking;
1650
1651 // preserve registers used by binary post_ops injector
1652 const injector_utils::conditional_register_preserve_guard_t
1653 cond_register_guard(jcp.with_binary, this,
1654 {bin_injector_helper_reg_1, bin_injector_helper_reg_2});
1655
1656 // height
1657 const int oh_index = ohb;
1658 const bool zp_h_pad
1659 = oh_index < t_pad_output || oh_index >= b_pad_output;
1660 const int zp_h = get_zp_index_offset(
1661 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1662 // width
1663 const int l_pad_output
1664 = w_padding.empty() ? 0 : w_padding.front().l_pad_output;
1665 const int r_pad_output
1666 = w_padding.empty() ? jcp.ow : w_padding.front().r_pad_output;
1667
1668 const bool zp_w_pad = tw < l_pad_output || tw >= r_pad_output;
1669 const int zp_w = get_zp_index_offset(
1670 tw, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1671
1672 const bool compute_zp = jcp.req_zero_point_buffer
1673 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1674
1675 Zmm zmm_r = zmm_out(tw);
1676 vmovups(zmm_r, ptr[reg_wsp_ptr + get_wsp_row_offset(ohb, ocb, tw)]);
1677 store_output_vector(zmm_r, ocb, ohb, tw, compute_zp, zp_h, zp_w);
1678 row_count_++;
1679
1680 if (row_count_
1681 == prv_width_ * jcp.nb_oc_blocking * jcp.nb_oh_blocking) {
1682 add(reg_out_ptr, get_out_shift(prv_width_, jcp.typesize_out));
1683 if (jcp.req_zero_point_buffer) {
1684 const size_t sp_shift = accum_with_upper_bound(
1685 prv_width_, l_pad_output, r_pad_output);
1686 add(reg_zero_point_pbuff,
1687 get_out_shift(sp_shift, sizeof(int32_t)));
1688 if (!w_padding.empty()) w_padding.pop();
1689 }
1690 row_count_ = 0;
1691 is_store_done_ = true;
1692 prv_width_ = width;
1693 }
1694 }
1695 }
1696
compute_icb_loop(int width,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad,const bool is_last_oh_block)1697 void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width,
1698 bool do_store, const bool handle_h_blk, const int t_pad_output,
1699 const int b_pad_output, const int l_pad_output, const int r_pad_output,
1700 const bool zp_3d_pad, const bool is_last_oh_block) {
1701 const bool tail = width == jcp.tile_tail;
1702
1703 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
1704 if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) {
1705 tdpbf16ps(x1, x2, x3);
1706 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) {
1707 tdpbuud(x1, x2, x3);
1708 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) {
1709 tdpbusd(x1, x2, x3);
1710 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) {
1711 tdpbsud(x1, x2, x3);
1712 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) {
1713 tdpbssd(x1, x2, x3);
1714 } else {
1715 assert(!"unsupported combination");
1716 }
1717 };
1718
1719 prepare_output(tail);
1720
1721 // prepare registers for when 'interleave_store()' is computed
1722 if (jcp.src_zero_point) {
1723 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1724 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1725 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1726 }
1727 if (jcp.dst_zero_point) {
1728 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1729 vcvtdq2ps(zmm_dst_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
1730 }
1731
1732 // reduced lowering path
1733 if (jcp.is_relo) {
1734 const int nreduce = jcp.nreduce;
1735 const int stride = jcp.ic_block_int; // ie 64 (32) for int8 (bf16)
1736
1737 push(reg_inp_ptr);
1738 push(reg_wei_ptr);
1739
1740 for (int ireduce = 0; ireduce < nreduce; ireduce += stride) {
1741 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1742 tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1743 ptr[reg_inp_ptr + get_inp_offset(ohb, 0)
1744 + reg_inp_stride]);
1745 }
1746 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1747 tileloadd(Tmm(get_wei_tensor(ocb)),
1748 ptr[reg_wei_ptr + get_wei_offset(ocb, 0)
1749 + reg_wei_stride]);
1750 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1751 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1752 Tmm(get_inp_tensor(ohb, tail)),
1753 Tmm(get_wei_tensor(ocb)));
1754 interleave_store(width, t_pad_output, b_pad_output);
1755 }
1756 }
1757 if (ireduce + stride < nreduce) {
1758 add(reg_inp_ptr, stride * jcp.typesize_in);
1759 add(reg_wei_ptr, stride * jcp.oc_block * jcp.typesize_in);
1760 }
1761 }
1762 pop(reg_wei_ptr);
1763 pop(reg_inp_ptr);
1764
1765 store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1766 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block);
1767
1768 add(reg_inp_ptr, get_inp_shift());
1769
1770 return;
1771 }
1772
1773 auto wei_offset = [&](int icb, int ocb, int kd, int kh, int kw) {
1774 return (size_t)icb * get_wei_icb_step() + kd * get_wei_d_step()
1775 + kh * get_wei_h_step() + get_wei_offset(ocb, kw);
1776 };
1777
1778 auto inp_offset = [&](int icb, int ohb, int kd, int kh, int kw) {
1779 return (size_t)icb * get_inp_icb_step() + kd * get_inp_d_step()
1780 + kh * get_inp_h_step() + get_inp_offset(ohb, kw);
1781 };
1782
1783 auto safe_tileloadd
1784 = [=](const Tmm &t1, const Xbyak::Reg64 ®_ptr, size_t offset,
1785 const Xbyak::Reg64 ®_stride) {
1786 if (offset <= INT32_MAX) {
1787 tileloadd(t1, ptr[reg_ptr + offset + reg_stride]);
1788 } else {
1789 safe_add(reg_ptr, offset, reg_tmp);
1790 tileloadd(t1, ptr[reg_ptr + reg_stride]);
1791 safe_sub(reg_ptr, offset, reg_tmp);
1792 }
1793 };
1794
1795 // normal and k-remainders path
1796 const bool check_kd_padding
1797 = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
1798 for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
1799 Label kd_skip_compute;
1800 if (check_kd_padding) mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1801
1802 for (int kd = 0; kd < jcp.kd; kd++) {
1803 if (check_kd_padding) {
1804 dec(reg_kd);
1805 jl(kd_skip_compute, T_NEAR);
1806 push(reg_kd);
1807 }
1808 for (int kh = 0; kh < jcp.kh; kh++) {
1809 for (int set_idx = 0; set_idx < jcp.n_stride_sets;
1810 set_idx++) { // used to optimize input memory reuse in L1$
1811 for (int kw = set_idx; kw < jcp.kw; kw += jcp.kw_step) {
1812 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1813 const size_t inp_off
1814 = inp_offset(icb, ohb, kd, kh, kw);
1815 safe_tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1816 reg_inp_ptr, inp_off, reg_inp_stride);
1817 }
1818 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1819 const size_t wei_off
1820 = wei_offset(icb, ocb, kd, kh, kw);
1821 safe_tileloadd(Tmm(get_wei_tensor(ocb)),
1822 reg_wei_ptr, wei_off, reg_wei_stride);
1823 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1824 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1825 Tmm(get_inp_tensor(ohb, tail)),
1826 Tmm(get_wei_tensor(ocb)));
1827 interleave_store(width, t_pad_output,
1828 b_pad_output, zp_3d_pad);
1829 }
1830 }
1831 }
1832 }
1833 }
1834 if (check_kd_padding) pop(reg_kd);
1835 }
1836 L(kd_skip_compute);
1837 }
1838
1839 store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1840 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block,
1841 zp_3d_pad);
1842
1843 add(reg_inp_ptr, get_inp_shift());
1844 }
1845
dispatch_icb_loop(int width,bool do_store,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad)1846 void jit_avx512_core_amx_fwd_kernel_t::dispatch_icb_loop(int width,
1847 bool do_store, const int l_pad_output, const int r_pad_output,
1848 const bool zp_3d_pad) {
1849 if (jcp.req_zero_point_buffer
1850 && (jcp.t_pad_output > 0 || jcp.b_pad_output > 0)) {
1851 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
1852 const size_t height_limit = reduce_to_blocked_dims(
1853 jcp.oh, oh_step_size, jcp.t_pad_output, jcp.b_pad_output);
1854 const int ur_h = div_up(height_limit, oh_step_size);
1855 assert(6 >= ur_h);
1856
1857 // Use a jump-table to execute the corresponding block
1858 Label h_blk_label[6], h_blk_end_label, jmp_table_label;
1859 mov(reg_jmp_blk, ptr[param1 + GET_OFF(ohb)]);
1860 mov(reg_tmp, jmp_table_label);
1861 jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1862 jmp(h_blk_end_label, T_NEAR); // error, shouldn't happen
1863
1864 align(8);
1865 L(jmp_table_label);
1866 for (int u = 0; u < ur_h; ++u) {
1867 putL(h_blk_label[u]);
1868 }
1869
1870 // Save value of global variables for the next 'h_blk' iteration
1871 const int local_prv_width = prv_width_;
1872 const int local_row_count = row_count_;
1873 const bool local_is_store_done = is_store_done_;
1874 const bool local_is_buffer_empty = is_buffer_empty_;
1875
1876 // Unroll ow_block with regards to l_pad_output and r_pad_output
1877 int cur_t_pad = reduce_to_block(oh_step_size, jcp.t_pad_output);
1878 int cur_b_pad = height_limit
1879 - reduce_to_block(oh_step_size, jcp.b_pad_output);
1880 for (int u = 0; u < ur_h; u++) {
1881 bool last = u == ur_h - 1;
1882 L(h_blk_label[u]);
1883
1884 // restore to previous 'h_blk' state of variables
1885 prv_width_ = local_prv_width;
1886 row_count_ = local_row_count;
1887 is_store_done_ = local_is_store_done;
1888 is_buffer_empty_ = local_is_buffer_empty;
1889 compute_icb_loop(width, do_store, false, cur_t_pad, cur_b_pad,
1890 l_pad_output, r_pad_output, zp_3d_pad, last);
1891 cur_t_pad = nstl::max(0, cur_t_pad - oh_step_size);
1892 cur_b_pad = nstl::max(0, cur_b_pad - oh_step_size);
1893 if (!last) jmp(h_blk_end_label, T_NEAR);
1894 }
1895 L(h_blk_end_label);
1896 } else {
1897 compute_icb_loop(width, do_store, true, 0, jcp.oh, l_pad_output,
1898 r_pad_output, zp_3d_pad);
1899 }
1900 }
1901
dispatch_zp_3d_compute(int width,bool do_store,const int l_pad_output,const int r_pad_output)1902 void jit_avx512_core_amx_fwd_kernel_t::dispatch_zp_3d_compute(int width,
1903 bool do_store, const int l_pad_output, const int r_pad_output) {
1904 if (jcp.req_zero_point_buffer && (jcp.f_pad > 0 || jcp.back_pad > 0)) {
1905 Label compute_3d_zp_label, zp_d_end_label;
1906 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1907 cmp(reg_kd, jcp.kd);
1908 jne(compute_3d_zp_label, T_NEAR);
1909
1910 // Save value of global variables for next 'dispatch_icb_loop'
1911 const int local_prv_width = prv_width_;
1912 const int local_row_count = row_count_;
1913 const bool local_is_store_done = is_store_done_;
1914 const bool local_is_buffer_empty = is_buffer_empty_;
1915 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1916
1917 jmp(zp_d_end_label, T_NEAR);
1918 L(compute_3d_zp_label);
1919
1920 prv_width_ = local_prv_width;
1921 row_count_ = local_row_count;
1922 is_store_done_ = local_is_store_done;
1923 is_buffer_empty_ = local_is_buffer_empty;
1924 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, true);
1925
1926 L(zp_d_end_label);
1927 } else
1928 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1929 }
1930
compute_ow_loop()1931 void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() {
1932 auto compute_ow_loop_body = [=](bool last_owb, int num_tile_blocks,
1933 const int l_pad_output,
1934 const int r_pad_output) {
1935 int cur_l_pad_output = l_pad_output;
1936 int cur_r_pad_output = r_pad_output;
1937 int gen_tile_tail = last_owb && jcp.tile_tail > 0 ? jcp.tile_tail
1938 : jcp.tile_width;
1939 init_runtime_counters(last_owb && num_tile_blocks == 1);
1940 for (int owb = 0; owb < num_tile_blocks - 1; owb++) {
1941 dispatch_zp_3d_compute(
1942 jcp.tile_width, false, cur_l_pad_output, cur_r_pad_output);
1943 cur_l_pad_output = nstl::max(0, cur_l_pad_output - jcp.tile_width);
1944 cur_r_pad_output = nstl::max(0, cur_r_pad_output - jcp.tile_width);
1945 }
1946 dispatch_zp_3d_compute(
1947 gen_tile_tail, true, cur_l_pad_output, cur_r_pad_output);
1948 };
1949
1950 assert(jcp.nb_ow > 0);
1951 if (jcp.nb_ow == 1) {
1952 const int ow_r_pad_start
1953 = nstl::max(jcp.ow - jcp.r_pad_output, jcp.l_pad_output);
1954 compute_ow_loop_body(
1955 true, jcp.ow_blocks, jcp.l_pad_output, ow_r_pad_start);
1956 } else if (jcp.req_zero_point_buffer
1957 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0)) {
1958
1959 const size_t zp_addr_shift
1960 = jcp.ngroups * jcp.oc_without_padding * sizeof(int32_t);
1961 const int ow_step_size = jcp.ow_block;
1962 const int ow_blocks_per_call = div_up(ow_step_size, jcp.tile_width);
1963 const int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call == 0
1964 ? ow_blocks_per_call
1965 : jcp.ow_blocks % ow_blocks_per_call;
1966 const int width_limit = reduce_to_blocked_dims(
1967 jcp.ow, ow_step_size, jcp.l_pad_output, jcp.r_pad_output);
1968 const int ur_w = div_up(width_limit, ow_step_size);
1969 assert(6 >= ur_w);
1970 // Use a jump-table to execute the corresponding block
1971 Label w_blk_label[6], w_blk_end_label, jmp_table_label;
1972 mov(reg_jmp_blk, ptr[param1 + GET_OFF(owb)]);
1973 mov(reg_tmp, jmp_table_label);
1974 jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1975 jmp(w_blk_end_label, T_NEAR); // error, shouldn't happen
1976
1977 align(8);
1978 L(jmp_table_label);
1979 for (int u = 0; u < ur_w; ++u) {
1980 putL(w_blk_label[u]);
1981 }
1982
1983 // Unroll ow_block with regards to l_pad_output and r_pad_output
1984 int cur_l_pad = reduce_to_block(ow_step_size, jcp.l_pad_output);
1985 int cur_r_pad
1986 = width_limit - reduce_to_block(ow_step_size, jcp.r_pad_output);
1987 int zp_offset = 0;
1988 for (int u = 0; u < ur_w; u++) {
1989 const bool last = u == ur_w - 1;
1990 L(w_blk_label[u]);
1991 if (u > 0) add(reg_zero_point_pbuff, zp_offset * zp_addr_shift);
1992 compute_ow_loop_body(last,
1993 last ? last_owb_tile_blocks : ow_blocks_per_call, cur_l_pad,
1994 cur_r_pad);
1995 zp_offset += accum_with_upper_bound(
1996 ow_step_size, cur_l_pad, cur_r_pad);
1997 cur_l_pad = nstl::max(0, cur_l_pad - ow_step_size);
1998 cur_r_pad = nstl::max(0, cur_r_pad - ow_step_size);
1999 if (!last) jmp(w_blk_end_label, T_NEAR);
2000 }
2001 L(w_blk_end_label);
2002
2003 } else {
2004 assert(jcp.oh_per_tile == 1);
2005 Label label_done;
2006 int ow_blocks_per_call = utils::div_up(jcp.ow_block, jcp.tile_width);
2007 int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call;
2008 if (last_owb_tile_blocks == 0 && jcp.tile_tail > 0)
2009 last_owb_tile_blocks = ow_blocks_per_call;
2010 if (last_owb_tile_blocks > 0) {
2011 Label label_not_last_owb;
2012 mov(reg_tmp, ptr[param1 + GET_OFF(owb)]);
2013 cmp(reg_tmp, jcp.nb_ow - 1);
2014 jne(label_not_last_owb, T_NEAR);
2015
2016 compute_ow_loop_body(true, last_owb_tile_blocks, 0, jcp.ow);
2017
2018 jmp(label_done, T_NEAR);
2019
2020 L(label_not_last_owb);
2021 }
2022 compute_ow_loop_body(false, ow_blocks_per_call, 0, jcp.ow);
2023
2024 L(label_done);
2025 }
2026 }
2027
generate()2028 void jit_avx512_core_amx_fwd_kernel_t::generate() {
2029 preamble();
2030
2031 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
2032 mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]);
2033 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
2034 mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
2035 if (jcp.req_zero_point_buffer)
2036 mov(reg_zero_point_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
2037
2038 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
2039 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
2040
2041 const int fac = jcp.is_relo ? jcp.stride_w * jcp.kh
2042 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w;
2043 const int inp_stride = fac * jcp.ic_block_int_np * jcp.typesize_in;
2044 const int wei_stride = jcp.oc_block * jcp.typesize_acc;
2045 mov(reg_inp_stride, inp_stride);
2046 mov(reg_wei_stride, wei_stride);
2047
2048 if (jcp.is_nspc && jcp.oc_without_padding != jcp.oc) {
2049 // Use mask 0xF by default for all output data and post-ops
2050 // loads / stores with block index
2051 // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1)
2052 // TODO: use masked loads / stores for the last occ only
2053 int current_block_size = jcp.oc_block;
2054 int mask = (1 << current_block_size) - 1;
2055 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
2056 mov(regw_tmp, mask);
2057 kmovw(ktail_mask, regw_tmp);
2058 Xbyak::Label mask_is_set;
2059 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
2060 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
2061 jne(mask_is_set, T_NEAR);
2062 // Reset the mask
2063 current_block_size = jcp.oc_without_padding % jcp.oc_block;
2064 mask = (1 << current_block_size) - 1;
2065 mov(regw_tmp, mask);
2066 kmovw(ktail_mask, regw_tmp);
2067
2068 L(mask_is_set);
2069 }
2070 compute_ow_loop();
2071
2072 postamble();
2073
2074 if (jcp.with_eltwise) postops_injector_->prepare_table();
2075 }
2076
tile_configure(char * tcfg_buff)2077 void jit_avx512_core_amx_fwd_kernel_t::tile_configure(char *tcfg_buff) {
2078 const int vnni_width = jcp.src_dt == data_type::bf16 ? 2 : 4;
2079 // Input tile dimensions
2080 const int a_col = jcp.is_relo ? jcp.ic_block_int
2081 : jcp.ic_block_int_np * jcp.kw_per_tile;
2082 // Weights tile dimensions
2083 const int b_col = jcp.oc_block * vnni_width;
2084 const int b_row = a_col / vnni_width;
2085 // Accumulator tile dimensions
2086 const int c_col = 16;
2087
2088 for (size_t i = 0; i < 64; i++)
2089 tcfg_buff[i] = 0;
2090
2091 // Weights (W_BASE) Tensor Tiles
2092 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2093 tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
2094 b_row, b_col * jcp.typesize_in);
2095
2096 // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
2097 for (int h = 0; h < jcp.nb_oh_blocking; h++) {
2098 tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
2099 jcp.tile_width, a_col * jcp.typesize_in);
2100 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2101 tc_configure_tile((palette_config_t *)tcfg_buff,
2102 get_out_tensor(h, i), jcp.tile_width,
2103 c_col * jcp.typesize_acc);
2104 }
2105 if (jcp.tile_tail != 0) {
2106 assert(jcp.nb_oh_blocking == 1);
2107 assert(jcp.oh_per_tile == 1);
2108 assert(jcp.ow > jcp.tile_width);
2109 tc_configure_tile((palette_config_t *)tcfg_buff,
2110 get_inp_tensor(0, true), jcp.tile_tail,
2111 a_col * jcp.typesize_in);
2112 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2113 tc_configure_tile((palette_config_t *)tcfg_buff,
2114 get_out_tensor(0, i, true), jcp.tile_tail,
2115 c_col * jcp.typesize_acc);
2116 }
2117
2118 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
2119 }
2120
set_oh_blk_limits(jit_conv_conf_t & jcp)2121 void jit_avx512_core_amx_fwd_kernel_t::set_oh_blk_limits(jit_conv_conf_t &jcp) {
2122
2123 constexpr int size = sizeof(jcp.h_blk_limits) / sizeof(jcp.h_blk_limits[0]);
2124 // set default values
2125 for (int i = 0; i < size; i++)
2126 jcp.h_blk_limits[i] = jcp.oh;
2127
2128 const bool calculate_oh_limits
2129 = jcp.t_pad_output > 0 || jcp.b_pad_output > 0;
2130 if (jcp.req_zero_point_buffer && calculate_oh_limits) {
2131
2132 int limit_idx = 0;
2133 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2134
2135 // full t_pad output block
2136 const int t_pad_blk_end = rnd_dn(jcp.t_pad_output, oh_step_size);
2137 if (jcp.t_pad_output >= oh_step_size) {
2138 jcp.h_blk_limits[limit_idx++] = t_pad_blk_end;
2139 }
2140 // t_pad output overlap with no padding
2141 const int t_pad_shift = jcp.t_pad_output % oh_step_size;
2142 if (t_pad_shift != 0) {
2143 jcp.h_blk_limits[limit_idx++] = t_pad_blk_end + t_pad_shift;
2144 }
2145 const int t_pad_next_blk = rnd_up(jcp.t_pad_output, oh_step_size);
2146 const int oh_blk_tail = jcp.oh % oh_step_size;
2147 const int b_pad_no_tail = nstl::max(0, jcp.b_pad_output - oh_blk_tail);
2148 const int b_pad_start
2149 = nstl::max(jcp.t_pad_output, jcp.oh - jcp.b_pad_output);
2150 const int b_pad_blk_start = rnd_dn(b_pad_start, oh_step_size);
2151 // middle block without padding
2152 const int mid_blk = nstl::max(0, b_pad_blk_start - t_pad_next_blk);
2153 if (mid_blk >= oh_step_size) {
2154 jcp.h_blk_limits[limit_idx++] = b_pad_blk_start;
2155 }
2156 // no padding with b_pad overlap
2157 const int b_pad_shift = b_pad_no_tail % oh_step_size;
2158 if (b_pad_shift != 0) {
2159 jcp.h_blk_limits[limit_idx++] = rnd_up(b_pad_start, oh_step_size);
2160 }
2161 // full b_pad output block
2162 if (b_pad_no_tail >= oh_step_size) {
2163 jcp.h_blk_limits[limit_idx++] = jcp.oh - oh_blk_tail;
2164 }
2165 // b_pad tail block does not require a limit
2166 }
2167 }
2168
set_ow_blk_limits(jit_conv_conf_t & jcp)2169 void jit_avx512_core_amx_fwd_kernel_t::set_ow_blk_limits(jit_conv_conf_t &jcp) {
2170
2171 jcp.l_pad_blk = 0;
2172 jcp.no_pad_w_blk = 0;
2173 jcp.r_pad_blk = 0;
2174
2175 const bool calculate_ow_limits
2176 = jcp.nb_ow > 1 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0);
2177 if (jcp.req_zero_point_buffer && calculate_ow_limits) {
2178 const int ow_step_size = jcp.ow_block;
2179
2180 // l_pad
2181 const int l_pad_limit
2182 = (jcp.l_pad_output >= ow_step_size ? ow_step_size : 0)
2183 + (jcp.l_pad_output % ow_step_size);
2184 const int l_pad_area_blk = rnd_up(l_pad_limit, ow_step_size);
2185 jcp.l_pad_blk = div_up(l_pad_limit, ow_step_size);
2186
2187 // middle (area without padding)
2188 const int no_pad_area
2189 = nstl::max(0, jcp.ow - l_pad_area_blk - jcp.r_pad_output);
2190 jcp.no_pad_w_blk = no_pad_area >= ow_step_size ? 1 : 0;
2191
2192 // r_pad
2193 const int no_pad_area_shift = no_pad_area % ow_step_size;
2194 const int r_pad_area_overlap
2195 = no_pad_area_shift == 0 ? 0 : ow_step_size - no_pad_area_shift;
2196 const int r_pad_area
2197 = nstl::max(0, jcp.r_pad_output - r_pad_area_overlap);
2198 const int r_pad_limit = (r_pad_area >= ow_step_size ? ow_step_size : 0)
2199 + (r_pad_area % ow_step_size);
2200 jcp.r_pad_blk = (r_pad_area_overlap > 0 ? 1 : 0)
2201 + div_up(r_pad_limit, ow_step_size);
2202 }
2203 }
2204
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)2205 status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
2206 const convolution_desc_t &cd, memory_desc_t &src_md,
2207 memory_desc_t &weights_md, memory_desc_t &dst_md,
2208 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
2209 using namespace prop_kind;
2210
2211 const memory_desc_wrapper src_d(&src_md);
2212 const memory_desc_wrapper weights_d(&weights_md);
2213 const memory_desc_wrapper dst_d(&dst_md);
2214 const memory_desc_wrapper bias_d(&bias_md);
2215
2216 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
2217 int ndims = src_d.ndims();
2218 bool is_1d = ndims == 3;
2219 bool is_3d = ndims == 5;
2220
2221 const bool is_bf16_convolution
2222 = everyone_is(true, src_d.data_type() == data_type::bf16,
2223 weights_d.data_type() == data_type::bf16,
2224 one_of(dst_d.data_type(), data_type::bf16, data_type::f32));
2225 const bool is_int8_convolution = everyone_is(true,
2226 (src_d.data_type() == data_type::u8
2227 || src_d.data_type() == data_type::s8),
2228 weights_d.data_type() == data_type::s8,
2229 one_of(dst_d.data_type(), data_type::f32, data_type::s32,
2230 data_type::s8, data_type::u8));
2231
2232 bool supported = false
2233 || (is_bf16_convolution && mayiuse(avx512_core_bf16_amx_bf16))
2234 || (is_int8_convolution && mayiuse(avx512_core_bf16_amx_int8));
2235 if (!supported) return status::unimplemented;
2236
2237 jcp = zero<decltype(jcp)>();
2238 jcp.isa = is_bf16_convolution ? avx512_core_bf16_amx_bf16
2239 : avx512_core_bf16_amx_int8;
2240 jcp.ndims = ndims;
2241 jcp.prop_kind = cd.prop_kind;
2242 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2243
2244 jcp.mb = src_d.dims()[0];
2245 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
2246 jcp.oc_without_padding = jcp.oc;
2247 jcp.ic = src_d.dims()[1] / jcp.ngroups;
2248 jcp.ic_without_padding = jcp.ic;
2249 jcp.id = is_3d ? src_d.dims()[2] : 1;
2250 jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1;
2251 jcp.iw = src_d.dims()[ndims - 1];
2252 jcp.od = is_3d ? dst_d.dims()[2] : 1;
2253 jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1;
2254 jcp.ow = dst_d.dims()[ndims - 1];
2255 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
2256 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
2257 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2258 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
2259 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
2260 jcp.l_pad = cd.padding[0][ndims - 3];
2261 jcp.stride_d = is_3d ? cd.strides[0] : 1;
2262 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
2263 jcp.stride_w = cd.strides[ndims - 3];
2264 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
2265
2266 jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0;
2267 jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
2268 jcp.dilate_w = cd.dilates[ndims - 3];
2269
2270 const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
2271 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
2272 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
2273 jcp.back_pad = calculate_end_padding(
2274 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd);
2275 jcp.b_pad = calculate_end_padding(
2276 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
2277 jcp.r_pad = calculate_end_padding(
2278 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
2279 if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
2280 || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd
2281 || jcp.back_pad >= gen_kd)
2282 return status::unimplemented;
2283
2284 const int max_pad = 28; // akin to maximum jcp.ur_w value in other jits
2285 if (jcp.l_pad > max_pad || jcp.r_pad > max_pad)
2286 return status::unimplemented; // TODO: relax this restriction
2287
2288 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
2289 jcp.dst_dt = cd.dst_desc.data_type;
2290 jcp.src_dt = cd.src_desc.data_type;
2291 jcp.wei_dt = cd.weights_desc.data_type;
2292
2293 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
2294
2295 if (jcp.is_depthwise)
2296 return status::unimplemented; // TODO: add support of DW convolution
2297
2298 const auto zp = attr.zero_points_;
2299 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
2300 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
2301 jcp.zp_src_is_common = zp.common(
2302 DNNL_ARG_SRC); // otherwise, it's per-channel (not supported)
2303
2304 if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)
2305 || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point,
2306 is_int8_convolution))
2307 return status::unimplemented;
2308
2309 // Calculate zero-point padding values outside of the main JIT-kernel
2310 // and store the results in an auxiliary buffer.
2311 jcp.req_zero_point_buffer = jcp.src_zero_point
2312 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 || jcp.t_pad > 0
2313 || jcp.f_pad > 0 || jcp.back_pad > 0);
2314
2315 format_tag_t dat_tag_ncsp = utils::pick(ndims - 3, format_tag::nCw16c,
2316 format_tag::nChw16c, format_tag::nCdhw16c);
2317 format_tag_t dat_tag_nspc = utils::pick(
2318 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
2319 // To toggle the default data layout for BF16 between nChw16c and nhwc,
2320 // swap the following two variable definitions. Current choice: nhwc.
2321
2322 // Clang-tidy change - if it was intentional please revert it and
2323 // put `NOLINTNEXTLINE` to suppress the warning.
2324 format_tag_t dat_tag_opt = dat_tag_nspc;
2325 format_tag_t dat_tag_alt
2326 = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
2327
2328 if (src_d.format_kind() == format_kind::any) {
2329 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
2330 jcp.src_tag = dat_tag_opt;
2331 } else
2332 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
2333
2334 if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
2335 return status::unimplemented;
2336
2337 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
2338 assert(IMPLICATION(is_int8_convolution, jcp.is_nspc));
2339
2340 // TODO: remove all support for nChw16c from this implementation
2341 if (!jcp.is_nspc) return status::unimplemented;
2342
2343 if (dst_d.format_kind() == format_kind::any) {
2344 CHECK(memory_desc_init_by_tag(dst_md, jcp.src_tag));
2345 jcp.dst_tag = jcp.src_tag;
2346 } else
2347 jcp.dst_tag = dst_d.matches_one_of_tag(jcp.src_tag);
2348
2349 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
2350
2351 if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
2352 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
2353
2354 jcp.nthr = nthreads;
2355
2356 jcp.ic_block = 16;
2357 jcp.oc_block = 16;
2358
2359 if (jcp.ngroups == 1) {
2360 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2361 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2362 }
2363 bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
2364 if (!args_ok) return status::unimplemented;
2365
2366 const int vnni_width = is_bf16_convolution ? 2 : 4;
2367 jcp.ic_block_int = jcp.ic_block * vnni_width; // 32 for bf16, 64 for int8
2368
2369 // fallback to non-amx impl when accumulation is too small
2370 const dim_t total_k = jcp.ic_without_padding * jcp.kd * jcp.kh * jcp.kw;
2371 const bool is_tiny_k = total_k < jcp.ic_block_int / 2;
2372 if (is_tiny_k) return status::unimplemented;
2373
2374 // small-ic parameters
2375 jcp.ic_block_int_np = jcp.is_nspc
2376 ? nstl::min(jcp.ic_block_int, jcp.ic_without_padding)
2377 : jcp.ic_block_int;
2378 bool is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2379
2380 // reduced lowering
2381 jcp.is_relo = (!is_3d)
2382 && is_small_ic
2383 // no trivial cases
2384 && 1 < jcp.kh * jcp.kw
2385 // required for use of VPERMB instruction in weights copy kernel
2386 && IMPLICATION(is_int8_convolution,
2387 cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))
2388 // no dilation or excessive stride along w-direction
2389 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
2390 // no dilation or excessive stride along h-direction
2391 && jcp.stride_h <= jcp.kh && jcp.stride_w <= jcp.kw;
2392 jcp.nreduce = jcp.kh * jcp.kw * jcp.ic_block_int_np;
2393
2394 if (!jcp.is_relo) {
2395 jcp.ic_block_int_np = is_bf16_convolution
2396 ? jcp.ic_block_int
2397 : rnd_up(jcp.ic_block_int_np, vnni_width);
2398 is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2399 }
2400
2401 // k-remainders
2402 jcp.kw_per_tile = is_small_ic && !jcp.is_relo && jcp.dilate_w == 0
2403 && jcp.stride_w <= jcp.kw // TODO: relax this restriction
2404 && jcp.kw * jcp.ic_block_int_np <= jcp.ic_block_int
2405 ? jcp.kw
2406 : 1;
2407 jcp.is_pbuffer_strided = (1 == jcp.kw_per_tile);
2408 jcp.n_stride_sets
2409 = jcp.is_pbuffer_strided ? nstl::min(jcp.stride_w, jcp.kw) : 1;
2410 jcp.kw_step = jcp.is_pbuffer_strided ? jcp.stride_w : jcp.kw_per_tile;
2411
2412 if (attr.set_default_formats(&dst_md) != status::success)
2413 return status::unimplemented;
2414
2415 const auto &p = attr.post_ops_;
2416
2417 const int sum_ind = p.find(primitive_kind::sum);
2418 jcp.with_sum = sum_ind != -1;
2419 const int eltwise_ind = p.find(primitive_kind::eltwise);
2420 jcp.with_eltwise = eltwise_ind != -1;
2421 const int binary_ind = p.find(primitive_kind::binary);
2422 jcp.with_binary = binary_ind != -1;
2423 jcp.sum_dt = p.get_sum_dt(jcp.dst_dt);
2424
2425 jcp.post_ops = p;
2426
2427 using namespace injector;
2428 const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16);
2429 const bool sum_requires_scale_one = sum_at_pos_0_only;
2430 const bool sum_requires_zp_zero = sum_at_pos_0_only;
2431 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
2432 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
2433 sum_requires_zp_zero});
2434 if (!post_ops_ok_) return status::unimplemented;
2435
2436 auto set_or_check_wei_format = [&]() {
2437 using namespace format_tag;
2438 using namespace memory_extra_flags;
2439 format_tag_t wei_tag;
2440 wei_tag = jcp.is_relo ? pick(with_groups + 2 * (ndims - 3), Owi16o,
2441 gOwi16o, Owhi16o, gOwhi16o) // no 3d support
2442 : is_bf16_convolution
2443 ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
2444 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i,
2445 OIdhw16i16o2i, gOIdhw16i16o2i)
2446 : is_small_ic ? pick(with_groups + 2 * (ndims - 3),
2447 OwI16o4i, gOwI16o4i, OhwI16o4i, gOhwI16o4i,
2448 OdhwI16o4i, gOdhwI16o4i)
2449 : pick(with_groups + 2 * (ndims - 3),
2450 OIw16i16o4i, gOIw16i16o4i,
2451 OIhw16i16o4i, gOIhw16i16o4i,
2452 OIdhw16i16o4i, gOIdhw16i16o4i);
2453
2454 memory_desc_t want_wei_md = weights_md;
2455 memory_desc_init_by_tag(want_wei_md, wei_tag);
2456
2457 if (jcp.src_zero_point) {
2458 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
2459 want_wei_md.extra.asymm_compensation_mask = (1 << 0)
2460 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
2461 }
2462 if (weights_md.format_kind == format_kind::any) {
2463 weights_md = want_wei_md;
2464 return true;
2465 }
2466 return weights_md == want_wei_md;
2467 };
2468
2469 if (!set_or_check_wei_format()) return status::unimplemented;
2470
2471 jcp.typesize_in = types::data_type_size(src_d.data_type());
2472 jcp.typesize_out = types::data_type_size(dst_d.data_type());
2473 jcp.typesize_bia
2474 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
2475 jcp.typesize_acc = sizeof(int32_t);
2476
2477 jcp.nb_ic = jcp.ic / jcp.ic_block;
2478 jcp.nb_oc = jcp.oc / jcp.oc_block;
2479 jcp.nb_ic_int = div_up(jcp.ic, jcp.ic_block_int);
2480
2481 jcp.nb_oc_blocking_thr_chunk = 1;
2482
2483 const int max_palette = amx::get_max_palette();
2484 jcp.max_tiles = amx::get_max_tiles(max_palette);
2485 jcp.full_tile_width = amx::get_max_rows(max_palette);
2486 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
2487 return status::unimplemented;
2488
2489 // Pack n rows per tile, such that:
2490 // ow + (ow + gen_kw - 1) * (n - 1) <= jcp.full_tile_width
2491 auto calculate_tile_width = [&](int n) {
2492 assert(n > 0);
2493 return jcp.ow + (gen_kw + jcp.ow - 1) * (n - 1);
2494 };
2495 const bool ok_to_pack_tile = !jcp.is_relo
2496 && (utils::everyone_is(1, jcp.kh, jcp.kw)
2497 || utils::everyone_is(1, jcp.stride_h, jcp.stride_w));
2498 const int max_oh_per_tile
2499 = 1 + (jcp.full_tile_width - jcp.ow) / (jcp.ow + gen_kw - 1);
2500 jcp.oh_per_tile = ok_to_pack_tile
2501 ? nstl::min(jcp.oh, nstl::max(1, max_oh_per_tile))
2502 : 1;
2503 jcp.tile_width = nstl::min<int>(
2504 jcp.full_tile_width, calculate_tile_width(jcp.oh_per_tile));
2505 jcp.ow_blocks = utils::div_up(jcp.ow, jcp.tile_width);
2506
2507 // Prefer to use a single tile width when possible
2508 // (eg ow28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
2509 if (jcp.oh_per_tile == 1 && jcp.ow % jcp.ow_blocks == 0)
2510 jcp.tile_width = jcp.ow / jcp.ow_blocks;
2511 jcp.tile_tail = jcp.oh_per_tile == 1 ? jcp.ow % jcp.tile_width : 0;
2512
2513 jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1;
2514 jcp.nb_ic_blocking = 1;
2515 jcp.nb_oh_blocking
2516 = utils::everyone_is(true, jcp.tile_tail == 0,
2517 // requirement for interleave stores
2518 IMPLICATION(jcp.ow_blocks > 1, jcp.oh % 2 == 0),
2519 // requirement for small spatial
2520 utils::div_up(jcp.oh, jcp.oh_per_tile) > 1,
2521 // choose maximal pbuffer overlap for reduced lowering
2522 !jcp.is_relo)
2523 ? 2
2524 : 1;
2525
2526 // TODO: tune oh blocking
2527 const int oh_blk_size_param = jcp.is_relo ? 1 : 10;
2528 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2529 const int oh_blk_size = rnd_up(oh_blk_size_param, oh_step_size);
2530 jcp.oh_blk_size = rnd_up(nstl::min(jcp.oh, oh_blk_size), oh_step_size);
2531 // Here ihp means the input buffer height including padding (ie the number
2532 // of input rows required for computation of jcp.oh_blk_size output rows.
2533 // If an input row doesn't participate in the computation of any output row,
2534 // it isn't copied to the buffer at all (eg jcp.stride_h > gen_kh).
2535 jcp.ihp = jcp.is_relo
2536 ? jcp.oh_blk_size
2537 : (jcp.oh_blk_size - 1) * nstl::min(jcp.stride_h, gen_kh) + gen_kh;
2538
2539 // TODO: tune ow blocking
2540 const int ow_blocks_per_call = jcp.is_relo ? 10 : 2;
2541 jcp.ow_block = nstl::min(jcp.ow, jcp.tile_width * ow_blocks_per_call);
2542 jcp.nb_ow = utils::div_up(jcp.ow, jcp.ow_block);
2543 // iwp includes all width elements that are really used in calculation
2544 // including left and right zero padding
2545 const bool are_sets_interleaved
2546 = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
2547 jcp.iwp = are_sets_interleaved
2548 ? (jcp.ow_block - 1) * nstl::min(jcp.stride_w, jcp.kw) + gen_kw
2549 : jcp.ow_block * jcp.kw;
2550
2551 // Number of ops per tile store
2552 int ops_tile_store = jcp.tile_width;
2553 // Number of ops per accumulation tile
2554 int avaliable_ops = jcp.is_relo
2555 ? utils::div_up(jcp.nreduce, jcp.ic_block_int)
2556 : jcp.nb_ic_int * jcp.kh * (jcp.kw / jcp.kw_per_tile);
2557 // Number of vectors to store per tile operation
2558 // NOTE: set to zero to turn off interleave store (mostly for debugging)
2559 jcp.per_one_pstore = utils::div_up(ops_tile_store, avaliable_ops);
2560
2561 if (jcp.is_relo) {
2562 jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.ihp * jcp.iwp * jcp.kh
2563 * jcp.ic_block_int_np
2564 // pbuffer pointer shifts each oh step for reduced-lowering
2565 + (jcp.oh - 1) * jcp.stride_h * jcp.ic_block_int_np
2566 // extra $line due to pbuffer writing full Zmm
2567 + jcp.ic_block_int;
2568 } else {
2569 jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.kd
2570 * ((size_t)jcp.ihp * jcp.iwp * jcp.ic_block_int_np
2571 // extra $line due to pbuffer writing full Zmm
2572 + jcp.ic_block_int);
2573 }
2574 jcp.wei_buffer_size = (size_t)jcp.ngroups * jcp.nb_oc
2575 * rnd_up(jcp.kh * jcp.kw * jcp.ic * jcp.oc_block, 1024);
2576 jcp.wsp_buffer_size = (size_t)jcp.nb_oh_blocking * jcp.nb_oc_blocking
2577 * jcp.full_tile_width * jcp.oc_block;
2578
2579 const auto &oscales = attr.output_scales_;
2580 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
2581
2582 // Note: currently unsupported, results in seg-fault
2583 const int l_pad_output = nstl::min(jcp.ow, div_up(jcp.l_pad, jcp.stride_w));
2584 if (!jcp.is_relo && (l_pad_output > jcp.ow_block))
2585 return status::unimplemented;
2586
2587 // Relevant to 'zero_point padding buffer' (pbuff) jit kernel
2588 if (jcp.req_zero_point_buffer) {
2589 auto calculate_output_padding_dims = [=](int o_dim, int s_pad,
2590 int e_pad,
2591 int &s_pad_output,
2592 int &e_pad_output,
2593 bool &o_mid, int &o_pad,
2594 int stride,
2595 bool req_mid_area) {
2596 s_pad_output = nstl::min(o_dim, div_up(s_pad, stride));
2597 e_pad_output = nstl::min(o_dim, div_up(e_pad, stride));
2598 o_mid = (o_dim - s_pad_output - e_pad_output > 0) && req_mid_area;
2599 o_pad = nstl::min(o_dim,
2600 nstl::max(1, s_pad_output + e_pad_output + (int)o_mid));
2601 };
2602
2603 const bool mid_w_area = (jcp.l_pad > 0 || jcp.r_pad > 0)
2604 && (jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
2605 || jcp.back_pad > 0);
2606 const bool mid_h_area = (jcp.t_pad > 0 || jcp.b_pad > 0)
2607 && (jcp.l_pad > 0 || jcp.r_pad > 0 || jcp.f_pad > 0
2608 || jcp.back_pad > 0);
2609 const bool mid_d_area = (jcp.f_pad > 0 || jcp.back_pad > 0)
2610 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0
2611 || jcp.t_pad > 0);
2612 calculate_output_padding_dims(jcp.ow, jcp.l_pad, jcp.r_pad,
2613 jcp.l_pad_output, jcp.r_pad_output, jcp.ow_mid, jcp.ow_pad,
2614 jcp.stride_w, mid_w_area);
2615 calculate_output_padding_dims(jcp.oh, jcp.t_pad, jcp.b_pad,
2616 jcp.t_pad_output, jcp.b_pad_output, jcp.oh_mid, jcp.oh_pad,
2617 jcp.stride_h, mid_h_area);
2618 calculate_output_padding_dims(jcp.od, jcp.f_pad, jcp.back_pad,
2619 jcp.f_pad_output, jcp.back_pad_output, jcp.od_mid, jcp.od_pad,
2620 jcp.stride_d, mid_d_area);
2621 jcp.zp_pbuff_size
2622 = jcp.od_pad * jcp.oh_pad * jcp.ow_pad * jcp.oc * jcp.ngroups;
2623
2624 // compute zero-point padding kernel outside of the main parallel
2625 // region when threads are more likely to parallelize work across mb
2626 // within the convolution compute block.
2627 jcp.zp_pbuff_outer_compute = jcp.mb > 1 || is_3d;
2628
2629 const bool params_ok = ((jcp.ow_pad - (int)jcp.ow_mid) <= max_pad * 2);
2630 if (!params_ok) { return status::unimplemented; }
2631 }
2632
2633 // Set default parameters for driver code, but mostly required for
2634 // 'zero_point padding buffer' (pbuff) accumulation over output tensor
2635 set_oh_blk_limits(jcp);
2636 set_ow_blk_limits(jcp);
2637
2638 return status::success;
2639 }
2640
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)2641 status_t jit_avx512_core_amx_fwd_kernel_t::init_scratchpad(
2642 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
2643 const primitive_attr_t &attr) {
2644
2645 size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
2646 scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
2647 if (jcp.is_relo) {
2648 scratchpad.book(
2649 key_conv_amx_wei_buffer, jcp.wei_buffer_size, jcp.typesize_in);
2650 }
2651
2652 size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
2653 scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
2654 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
2655 assert(jcp.ngroups == 1);
2656 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
2657 }
2658 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
2659 if (jcp.req_zero_point_buffer) {
2660 const int nthr = jcp.zp_pbuff_outer_compute ? 1 : jcp.nthr;
2661 scratchpad.book(key_conv_zero_point_pad,
2662 (size_t)nthr * jcp.zp_pbuff_size, sizeof(int32_t));
2663 if (!jcp.zp_pbuff_outer_compute) {
2664 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
2665 scratchpad.book<bool>(key_conv_zero_point_flag,
2666 (size_t)jcp.nthr * oc_chunks * jcp.ngroups);
2667 }
2668 }
2669
2670 // Keep scratchpad memory footprint under control
2671 const size_t L2_size_per_core = platform::get_per_core_cache_size(2);
2672 const size_t L3_size_per_core = platform::get_per_core_cache_size(3);
2673 const size_t max_scratchpad_size
2674 = jcp.nthr * (L2_size_per_core + L3_size_per_core);
2675 // TODO: tune this relationship as needed
2676 if (scratchpad.size() > max_scratchpad_size) return status::unimplemented;
2677 return status::success;
2678 }
2679
copy_row(const bool is_masked)2680 void jit_avx512_core_amx_bwd_data_copy_kernel_t::copy_row(
2681 const bool is_masked) {
2682 assert(jcp.is_nspc && "no support for nChw16c in this copy kernel");
2683
2684 const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
2685 const int inp_w_step
2686 = jcp.ngroups * jcp.oc_without_padding * jcp.typesize_in;
2687 const int inp_h_step = jcp.ow * inp_w_step;
2688 const int out_w_step = jcp.oc_block_int * jcp.typesize_in;
2689 const int out_h_step = jcp.owp * out_w_step;
2690
2691 auto zero_it = [=](reg64_t tmp_out_ptr, int offset) {
2692 // no mask as output is a padded buffer
2693 if (is_bf16)
2694 vmovdqu16(ptr[tmp_out_ptr + offset], zmm_zero);
2695 else
2696 vmovdqu8(ptr[tmp_out_ptr + offset], zmm_zero);
2697 };
2698
2699 auto copy_it = [=](reg64_t tmp_inp_ptr, int inp_off, reg64_t tmp_out_ptr,
2700 int out_off) {
2701 Zmm zmm_load = is_masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
2702 Zmm zmm_stor = zmm_tmp; // no mask as output is padded buffer
2703 if (is_bf16) {
2704 vmovdqu16(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2705 vmovdqu16(ptr[tmp_out_ptr + out_off], zmm_stor);
2706 } else {
2707 vmovdqu8(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2708 vmovdqu8(ptr[tmp_out_ptr + out_off], zmm_stor);
2709 }
2710 };
2711
2712 mov(reg_ptr_aux_out, reg_ptr_out);
2713
2714 { // Handle Top Overflow
2715 Label label_tov_loop, label_tov_skip;
2716 test(reg_tov, reg_tov);
2717 jz(label_tov_skip, T_NEAR);
2718 mov(reg_cnt_tmp, reg_tov);
2719 L(label_tov_loop);
2720 {
2721 for (int ow = 0; ow < jcp.owp; ow++) {
2722 const int offset = ow * out_w_step;
2723 zero_it(reg_ptr_aux_out, offset);
2724 }
2725 add(reg_ptr_aux_out, out_h_step);
2726 dec(reg_cnt_tmp);
2727 jnz(label_tov_loop, T_NEAR);
2728 }
2729 L(label_tov_skip);
2730 }
2731
2732 mov(reg_ptr_aux_inp_h, reg_ptr_inp);
2733
2734 // Handle Middle Loop
2735 Label label_khp_loop, label_khp_skip;
2736 test(reg_khp, reg_khp);
2737 jz(label_khp_skip, T_NEAR);
2738 mov(reg_cnt_khp, reg_khp);
2739 L(label_khp_loop);
2740 {
2741 Label label_lov, label_lov_skip;
2742 Label label_kwp, label_kwp_skip;
2743 Label label_rov, label_rov_skip;
2744 test(reg_lov, reg_lov);
2745 jnz(label_lov, T_NEAR);
2746 test(reg_kwp, reg_kwp);
2747 jnz(label_kwp, T_NEAR);
2748 test(reg_rov, reg_rov);
2749 jnz(label_rov, T_NEAR);
2750
2751 test(reg_lov, reg_lov);
2752 jz(label_lov_skip, T_NEAR); // not really needed, but just to be safe
2753 L(label_lov); // Handle Left Overflow
2754 {
2755 Label label_lov_loop;
2756 mov(reg_cnt_tmp, reg_lov);
2757 L(label_lov_loop);
2758 {
2759 zero_it(reg_ptr_aux_out, 0);
2760 add(reg_ptr_aux_out, out_w_step);
2761 dec(reg_cnt_tmp);
2762 jnz(label_lov_loop, T_NEAR);
2763 }
2764 }
2765 L(label_lov_skip);
2766
2767 test(reg_kwp, reg_kwp);
2768 jz(label_kwp_skip, T_NEAR);
2769 L(label_kwp); // Handle Center Loop
2770 {
2771 Label label_kwp_loop;
2772 mov(reg_ptr_aux_inp_w, reg_ptr_aux_inp_h);
2773 mov(reg_cnt_tmp, reg_kwp);
2774 L(label_kwp_loop);
2775 {
2776 copy_it(reg_ptr_aux_inp_w, 0, reg_ptr_aux_out, 0);
2777 add(reg_ptr_aux_out, out_w_step);
2778 add(reg_ptr_aux_inp_w, inp_w_step);
2779 dec(reg_cnt_tmp);
2780
2781 if (jcp.stride_w > 1) {
2782 jz(label_kwp_skip, T_NEAR);
2783 // Handle Dilation-by-Stride
2784 for (int sw = 0; sw < jcp.stride_w - 1; sw++) {
2785 const int offset = sw * out_w_step;
2786 zero_it(reg_ptr_aux_out, offset);
2787 }
2788 add(reg_ptr_aux_out, (jcp.stride_w - 1) * out_w_step);
2789 if (jcp.stride_w == 2)
2790 dec(reg_cnt_tmp);
2791 else
2792 sub(reg_cnt_tmp, jcp.stride_w - 1);
2793 jmp(label_kwp_loop, T_NEAR);
2794 } else {
2795 jnz(label_kwp_loop, T_NEAR);
2796 }
2797 }
2798 }
2799 L(label_kwp_skip);
2800
2801 test(reg_rov, reg_rov);
2802 jz(label_rov_skip, T_NEAR);
2803 L(label_rov); // Handle Right Overflow
2804 {
2805 Label label_rov_loop;
2806 mov(reg_cnt_tmp, reg_rov);
2807 L(label_rov_loop);
2808 {
2809 zero_it(reg_ptr_aux_out, 0);
2810 add(reg_ptr_aux_out, out_w_step);
2811 dec(reg_cnt_tmp);
2812 jnz(label_rov_loop, T_NEAR);
2813 }
2814 }
2815 L(label_rov_skip);
2816
2817 add(reg_ptr_aux_inp_h, inp_h_step);
2818 dec(reg_cnt_khp);
2819
2820 if (jcp.stride_h > 1) {
2821 jz(label_khp_skip, T_NEAR);
2822 // Handle Dilation-by-Stride
2823 for (int sh = 0; sh < jcp.stride_h - 1; sh++) {
2824 for (int ow = 0; ow < jcp.owp; ow++) {
2825 const int offset = sh * out_h_step + ow * out_w_step;
2826 zero_it(reg_ptr_aux_out, offset);
2827 }
2828 }
2829 add(reg_ptr_aux_out, (jcp.stride_h - 1) * out_h_step);
2830 if (jcp.stride_h == 2)
2831 dec(reg_cnt_khp);
2832 else
2833 sub(reg_cnt_khp, jcp.stride_h - 1);
2834 jmp(label_khp_loop, T_NEAR);
2835 } else {
2836 jnz(label_khp_loop, T_NEAR);
2837 }
2838 }
2839 L(label_khp_skip);
2840
2841 { // Handle Bottom Overflow
2842 Label label_bov_loop, label_bov_skip;
2843 test(reg_bov, reg_bov);
2844 jz(label_bov_skip, T_NEAR);
2845 mov(reg_cnt_tmp, reg_bov);
2846 L(label_bov_loop);
2847 {
2848 for (int ow = 0; ow < jcp.owp; ow++) {
2849 const int offset = ow * out_w_step;
2850 zero_it(reg_ptr_aux_out, offset);
2851 }
2852 add(reg_ptr_aux_out, out_h_step);
2853 dec(reg_cnt_tmp);
2854 jnz(label_bov_loop, T_NEAR);
2855 }
2856 L(label_bov_skip);
2857 }
2858 }
2859
generate()2860 void jit_avx512_core_amx_bwd_data_copy_kernel_t::generate() {
2861
2862 const int inp_c_step = jcp.oc_block_int * jcp.typesize_in;
2863 const int out_c_step = jcp.ohp * jcp.owp * inp_c_step;
2864 const int nb_oc_int_no_tail = jcp.oc_without_padding / jcp.oc_block_int;
2865 const int oc_block_int_tail = jcp.oc_without_padding % jcp.oc_block_int;
2866
2867 preamble();
2868
2869 // pointer to 1st needed element in src buffer
2870 mov(reg_ptr_inp, ptr[param1 + GET_OFF(src)]);
2871 // pointer to 1st needed element in dst buffer
2872 mov(reg_ptr_out, ptr[param1 + GET_OFF(dst)]);
2873
2874 // number of rows of src buffer to copy
2875 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
2876 // number of zero-padded rows above src buffer to copy
2877 mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
2878 // number of zero-padded rows below src buffer to copy
2879 mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
2880
2881 // number of columns of src buffer to copy
2882 mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
2883 // number of zero-padded columns before src buffer to copy
2884 mov(reg_lov, ptr[param1 + GET_OFF(l_overflow)]);
2885 // number of zero-padded columns before src buffer to copy
2886 mov(reg_rov, ptr[param1 + GET_OFF(r_overflow)]);
2887
2888 vpxord(zmm_zero, zmm_zero, zmm_zero);
2889
2890 if (oc_block_int_tail > 0) {
2891 uint64_t mask = (UINT64_C(1) << oc_block_int_tail) - 1;
2892 mov(reg_tmp, mask);
2893 kmovq(ktail_mask, reg_tmp);
2894 }
2895
2896 if (nb_oc_int_no_tail == 0) {
2897 copy_row(true); // masked
2898 } else if (nb_oc_int_no_tail == 1) {
2899 copy_row(false); // unmasked!
2900 if (oc_block_int_tail > 0) {
2901 add(reg_ptr_inp, inp_c_step);
2902 add(reg_ptr_out, out_c_step);
2903 copy_row(true); // masked
2904 }
2905 } else if (nb_oc_int_no_tail > 1) {
2906 mov(reg_cnt_ocb, nb_oc_int_no_tail);
2907 Label label_ocb_loop;
2908 L(label_ocb_loop);
2909 {
2910 copy_row(false); // unmasked!
2911 add(reg_ptr_inp, inp_c_step);
2912 add(reg_ptr_out, out_c_step);
2913 dec(reg_cnt_ocb);
2914 jnz(label_ocb_loop);
2915 }
2916 if (oc_block_int_tail > 0) copy_row(true); // masked
2917 }
2918
2919 postamble();
2920 }
2921
2922 // Tile register decomposition
2923 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i) const2924 int jit_avx512_core_amx_bwd_data_kernel_t::get_out_tensor(int h, int i) const {
2925 const int C_BASE = 0;
2926 const int C_LAST = 4;
2927 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
2928 MAYBE_UNUSED(C_LAST);
2929 const int tile = C_BASE + h * jcp.nb_ih_blocking + i;
2930 assert(C_BASE <= tile && tile < C_LAST);
2931 return tile;
2932 }
get_inp_tensor(int h) const2933 int jit_avx512_core_amx_bwd_data_kernel_t::get_inp_tensor(int h) const {
2934 const int I_BASE = 4;
2935 const int I_LAST = 6;
2936 assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
2937 MAYBE_UNUSED(I_LAST);
2938 const int tile = I_BASE + h;
2939 assert(I_BASE <= tile && tile < I_LAST);
2940 return tile;
2941 }
get_wei_tensor(int i) const2942 int jit_avx512_core_amx_bwd_data_kernel_t::get_wei_tensor(int i) const {
2943 const int W_BASE = 6;
2944 const int W_LAST = 8;
2945 assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
2946 MAYBE_UNUSED(W_LAST);
2947 const int tile = W_BASE + i;
2948 assert(W_BASE <= tile && tile < W_LAST);
2949 return tile;
2950 }
2951
2952 // Strides, shifts and offsets
2953 // - inp is a padded buffer ~ [nb_oc_int][ohp][owp]{32c,64c}
2954 // - weights is user buffer ~ OIhw16o16i{2o,4o}
2955 // - output is tiled buffer ~ [NBIH][NBIC][tile_width][16c]
get_inp_kh_step() const2956 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_kh_step() const {
2957 return (size_t)jcp.typesize_in * (jcp.dilate_h + 1) * jcp.owp
2958 * jcp.oc_block_int;
2959 }
get_inp_ocb_step() const2960 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_ocb_step() const {
2961 return (size_t)jcp.typesize_in * jcp.ohp * jcp.owp * jcp.oc_block_int;
2962 }
get_inp_shift() const2963 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_shift() const {
2964 return (size_t)jcp.typesize_in * jcp.tile_width * jcp.oc_block_int;
2965 }
get_inp_offset(int ihb,int kh,int kw) const2966 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_offset(
2967 int ihb, int kh, int kw) const {
2968 // calculate offset by src height dimension
2969 size_t sp_offset = (size_t)ihb * jcp.owp;
2970 // add offset by kernel height dimension
2971 sp_offset += (size_t)(jcp.kh - 1 - kh) * (jcp.dilate_h + 1) * jcp.owp;
2972 // add offset by kernel width dimension
2973 sp_offset += (size_t)(jcp.kw - 1 - kw) * (jcp.dilate_w + 1);
2974 return jcp.typesize_in * sp_offset * jcp.oc_block_int;
2975 }
get_wei_kh_step() const2976 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_kh_step() const {
2977 return (size_t)jcp.typesize_in * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2978 }
get_wei_ocb_step() const2979 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_ocb_step() const {
2980 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2981 return (size_t)jcp.typesize_in * (is_deconv ? 1 : jcp.nb_ic) * jcp.kh
2982 * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2983 }
get_wei_offset(int icb,int kh,int kw) const2984 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_offset(
2985 int icb, int kh, int kw) const {
2986 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2987 const size_t wei_kw_stride = jcp.oc_block_int * jcp.ic_block;
2988 const size_t wei_kh_stride = jcp.kw * wei_kw_stride;
2989 const size_t wei_icb_stride
2990 = (is_deconv ? jcp.nb_oc_int : 1) * jcp.kh * wei_kh_stride;
2991 return jcp.typesize_in
2992 * (icb * wei_icb_stride + kh * wei_kh_stride + kw * wei_kw_stride);
2993 }
get_out_icb_offset(int ihb,int icb) const2994 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_icb_offset(
2995 int ihb, int icb) const {
2996 size_t el_offset = jcp.is_nspc
2997 ? (size_t)icb * jcp.ic_block
2998 + (size_t)ihb * jcp.iw * jcp.ngroups
2999 * jcp.ic_without_padding
3000 : (size_t)icb * jcp.ih * jcp.iw * jcp.ic_block
3001 + (size_t)ihb * jcp.iw * jcp.ic_block;
3002 return (size_t)jcp.typesize_out * el_offset;
3003 }
get_out_row_offset(int ihb,int icb,int j) const3004 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_row_offset(
3005 int ihb, int icb, int j) const {
3006 size_t offset_w = jcp.is_nspc ? (size_t)jcp.typesize_out * j * jcp.ngroups
3007 * jcp.ic_without_padding
3008 : (size_t)jcp.typesize_out * j * jcp.ic_block;
3009 return get_out_icb_offset(ihb, icb) + offset_w;
3010 }
get_out_shift(int width) const3011 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_shift(int width) const {
3012 return jcp.is_nspc ? (size_t)jcp.typesize_out * width * jcp.ngroups
3013 * jcp.ic_without_padding
3014 : (size_t)jcp.typesize_out * width * jcp.ic_block;
3015 }
get_wsp_icb_offset(int ihb,int icb) const3016 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_icb_offset(
3017 int ihb, int icb) const {
3018 size_t el_offset = (size_t)icb * prv_width_ * jcp.ic_block
3019 + (size_t)ihb * jcp.nb_ic_blocking * jcp.full_tile_width
3020 * jcp.ic_block;
3021 return jcp.typesize_acc * el_offset;
3022 }
get_wsp_row_offset(int ihb,int icb,int j) const3023 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_row_offset(
3024 int ihb, int icb, int j) const {
3025 return get_wsp_icb_offset(ihb, icb)
3026 + (size_t)jcp.typesize_acc * j * jcp.ic_block;
3027 }
3028
3029 // Code generation
prepare_output()3030 void jit_avx512_core_amx_bwd_data_kernel_t::prepare_output() {
3031 for (int h = 0; h < jcp.nb_ih_blocking; h++)
3032 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3033 tilezero(Tmm(get_out_tensor(h, i)));
3034 }
3035
init_runtime_counters(bool start_with_last_tile_block)3036 void jit_avx512_core_amx_bwd_data_kernel_t::init_runtime_counters(
3037 bool start_with_last_tile_block) {
3038 prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
3039 ? jcp.tile_tail
3040 : jcp.tile_width;
3041
3042 row_count_ = 0;
3043 is_store_done_ = false;
3044 is_buffer_empty_ = true;
3045 }
3046
maybe_eltwise(int position)3047 bool jit_avx512_core_amx_bwd_data_kernel_t::maybe_eltwise(int position) {
3048 using namespace primitive_kind;
3049 const auto &p = attr_.post_ops_;
3050
3051 if (position == 0) {
3052 /* eltwise before sum */
3053 return p.contain(eltwise, 0);
3054 } else if (position == 1) {
3055 /* eltwise after sum */
3056 return p.contain(sum, 0) && p.contain(eltwise, 1);
3057 }
3058
3059 return false;
3060 }
3061
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)3062 Ymm jit_avx512_core_amx_bwd_data_kernel_t::ymm_mask(
3063 const Ymm &ymm_in, bool mask_flag, bool store) {
3064 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
3065 : ymm_in;
3066 }
3067
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)3068 Zmm jit_avx512_core_amx_bwd_data_kernel_t::zmm_mask(
3069 const Zmm &zmm_in, bool mask_flag, bool store) {
3070 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
3071 : zmm_in;
3072 }
3073
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)3074 void jit_avx512_core_amx_bwd_data_kernel_t::cvt2ps(data_type_t type_in,
3075 const Zmm &zmm_in, const Operand &op, bool mask_flag) {
3076 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
3077 switch (type_in) {
3078 case data_type::f32:
3079 case data_type::s32: vmovups(zmm, op); break;
3080 case data_type::s8: vpmovsxbd(zmm, op); break;
3081 case data_type::u8: vpmovzxbd(zmm, op); break;
3082 default: assert(!"unsupported data type");
3083 }
3084 if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
3085 }
3086
store_output_vector_bf16(const Zmm & zmm_out,int icb,int h,int w)3087 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_bf16(
3088 const Zmm &zmm_out, int icb, int h, int w) {
3089 const bool mask_flag = jcp.is_nspc && jcp.ic_without_padding != jcp.ic
3090 && icb == (jcp.nb_ic_blocking - 1);
3091
3092 auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3093
3094 const auto &p = attr_.post_ops_;
3095
3096 const int sum_idx = p.find(primitive_kind::sum);
3097 if (sum_idx != -1) {
3098 if (jcp.dsrc_dt == data_type::bf16) {
3099 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
3100 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
3101 vaddps(zmm_out, zmm_prev_dst);
3102 } else {
3103 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
3104 vaddps(zmm_out, zmm_prev_dst);
3105 }
3106 }
3107 if (jcp.with_bias) {
3108 int bias_offset = jcp.typesize_bia * icb * jcp.ic_block;
3109 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3110 if (jcp.bia_dt == data_type::bf16) {
3111 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
3112 vpslld(zmm_bias, zmm_bias, 16);
3113 vaddps(zmm_out, zmm_bias);
3114 } else
3115 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
3116 }
3117
3118 const int eltwise_ind = p.find(primitive_kind::eltwise);
3119 if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx());
3120
3121 if (jcp.dsrc_dt == data_type::bf16) {
3122 Ymm ymm_out = Ymm(zmm_out.getIdx());
3123 vcvtneps2bf16(ymm_out, zmm_out);
3124 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
3125 } else {
3126 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
3127 }
3128 }
3129
store_output_vector_int8(const Zmm & zmm_out,int icb,int h,int w)3130 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3131 const Zmm &zmm_out, int icb, int h, int w) {
3132 const int nb_ic_block = jcp.nb_ic_blocking;
3133 const int ic_block = jcp.ic_block;
3134 const bool mask_flag = true && jcp.ic_without_padding != jcp.ic
3135 && icb == (nb_ic_block - 1);
3136
3137 auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3138
3139 const auto &p = attr_.post_ops_;
3140 const int sum_idx = p.find(primitive_kind::sum);
3141 const float *p_sum_scale = nullptr;
3142 const int32_t *p_sum_zp = nullptr;
3143 if (sum_idx != -1) {
3144 const auto &p_entry = p.entry_[sum_idx];
3145 p_sum_scale = &p_entry.sum.scale;
3146 p_sum_zp = &p_entry.sum.zero_point;
3147 }
3148
3149 if (p_sum_scale) {
3150 if (*p_sum_scale != 1.f)
3151 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
3152 if (*p_sum_zp != 0)
3153 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
3154 }
3155
3156 int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
3157 if (jcp.with_bias) {
3158 int bias_offset = jcp.typesize_bia * icb * ic_block;
3159 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3160 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
3161 }
3162 /* add bias to zmm_accum */
3163 vcvtdq2ps(zmm_out, zmm_out);
3164 if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
3165 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
3166 vmulps(zmm_out_msk, zmm_out,
3167 EVEX_compress_addr(reg_ptr_scales, scale_offset));
3168
3169 /* Do post-ops */
3170 if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3171 if (p_sum_scale) { // post_op: sum
3172 cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag);
3173 if (*p_sum_zp != 0) {
3174 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
3175 vsubps(zmm_prev_dst, zmm_sum_zp);
3176 }
3177 if (*p_sum_scale == 1.f)
3178 vaddps(zmm_out, zmm_prev_dst);
3179 else
3180 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
3181 }
3182 if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3183
3184 // Properly saturate the accumulators for integer datatypes
3185 if (one_of(jcp.dsrc_dt, u8, s8, s32)) {
3186 init_saturate_f32(
3187 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dsrc_dt);
3188 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dsrc_dt);
3189 vcvtps2dq(zmm_out, zmm_out);
3190 }
3191
3192 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
3193
3194 switch (jcp.dsrc_dt) {
3195 case data_type::f32:
3196 case data_type::s32: vmovups(addr, zmm_out_store); break;
3197 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
3198 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
3199 default: assert(!"unknown dst_dt");
3200 }
3201 }
3202
store_output_vector(const Zmm & zmm_out,int icb,int h,int w)3203 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector(
3204 const Zmm &zmm_out, int icb, int h, int w) {
3205 /*
3206 Output:
3207 jcp.is_nspc !jcp.is_nspc
3208 --------------------- ---------------------
3209 INT8: [N][H][W][NBIC][16IC]
3210 BF16: [N][H][W][NBIC][16IC] or [N][NBIC][H][W][16IC]
3211 */
3212 if (jcp.ddst_dt == data_type::bf16) {
3213 store_output_vector_bf16(zmm_out, icb, h, w);
3214 } else {
3215 store_output_vector_int8(zmm_out, icb, h, w);
3216 }
3217 }
3218
store_output(int width,bool do_store)3219 void jit_avx512_core_amx_bwd_data_kernel_t::store_output(
3220 int width, bool do_store) {
3221 auto store_output_block = [=](int width, bool do_store,
3222 bool is_last_ih_blks) {
3223 // Calculate the number of ih blocks; it may differ on last call
3224 const int n_ih_blks = is_last_ih_blks ? jcp.ih % jcp.nb_ih_blocking
3225 : jcp.nb_ih_blocking;
3226 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3227 for (int ihb = 0; ihb < n_ih_blks; ihb++) {
3228 /* Formats: Workspace: [NBIH][NBIC][W][16OC] */
3229 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
3230 + get_wsp_icb_offset(ihb, icb)],
3231 Tmm(get_out_tensor(ihb, icb)));
3232 is_buffer_empty_ = false;
3233 is_store_done_ = false;
3234 for (int tw = 0; tw < width && do_store; tw++) {
3235 Zmm zmm_out = Zmm(tw);
3236 vmovups(zmm_out,
3237 ptr[reg_wsp_ptr
3238 + get_wsp_row_offset(ihb, icb, tw)]);
3239 store_output_vector(zmm_out, icb, ihb, tw);
3240 }
3241 }
3242 }
3243 };
3244
3245 // adjustment in case interleave store is turned off
3246 do_store = do_store || jcp.per_one_pstore == 0;
3247 if (jcp.ih % jcp.nb_ih_blocking == 0) {
3248 store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3249 } else {
3250 Label label_full_store, label_done;
3251 cmp(reg_last_h, 0);
3252 jne(label_full_store, T_NEAR);
3253 store_output_block(width, do_store, /* is_last_ih_blks = */ true);
3254 jmp(label_done, T_NEAR);
3255 L(label_full_store);
3256 store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3257 L(label_done);
3258 }
3259 if (do_store) add(reg_out_ptr, get_out_shift(width));
3260 }
3261
interleave_store(int width)3262 void jit_avx512_core_amx_bwd_data_kernel_t::interleave_store(int width) {
3263 for (int c = 0;
3264 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
3265 c++) {
3266 // row_count = ihb * ICB * TW + icb * TW + tw
3267 int tw = row_count_ % prv_width_;
3268 int icb = (row_count_ / prv_width_) % jcp.nb_ic_blocking;
3269 int ihb = (row_count_ / prv_width_) / jcp.nb_ic_blocking;
3270
3271 Zmm zmm_out = Zmm(tw);
3272 vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]);
3273 store_output_vector(zmm_out, icb, ihb, tw);
3274 row_count_++;
3275
3276 if (row_count_
3277 == prv_width_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking) {
3278 add(reg_out_ptr, get_out_shift(prv_width_));
3279 row_count_ = 0;
3280 is_store_done_ = true;
3281 prv_width_ = width;
3282 }
3283 }
3284 }
3285
compute_ocb_loop(int width,bool do_store)3286 void jit_avx512_core_amx_bwd_data_kernel_t::compute_ocb_loop(
3287 int width, bool do_store) {
3288
3289 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
3290 switch (jcp.ddst_dt) {
3291 using namespace data_type;
3292 case bf16: tdpbf16ps(x1, x2, x3); break;
3293 case s8: tdpbssd(x1, x2, x3); break;
3294 case u8: tdpbusd(x1, x2, x3); break;
3295 default: assert(!"unsupported data type");
3296 }
3297 };
3298
3299 prepare_output();
3300
3301 for (int ocb = 0; ocb < jcp.nb_oc_int; ocb++) {
3302 // reverse order through spatial components of weights so that
3303 // input buffer is accessed in a monotonically increasing fashion
3304 for (int kh = jcp.kh - 1; kh >= 0; kh--) {
3305 for (int kw = jcp.kw - 1; kw >= 0; kw--) {
3306 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3307 tileloadd(Tmm(get_inp_tensor(ihb)),
3308 ptr[reg_inp_ptr + get_inp_offset(ihb, kh, kw)
3309 + reg_inp_stride]);
3310 }
3311 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3312 tileloadd(Tmm(get_wei_tensor(icb)),
3313 ptr[reg_wei_ptr + get_wei_offset(icb, kh, kw)
3314 + reg_wei_stride]);
3315 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3316 tdpbxxd(Tmm(get_out_tensor(ihb, icb)),
3317 Tmm(get_inp_tensor(ihb)),
3318 Tmm(get_wei_tensor(icb)));
3319 interleave_store(width);
3320 }
3321 }
3322 }
3323 }
3324 add(reg_inp_ptr, get_inp_ocb_step());
3325 add(reg_wei_ptr, get_wei_ocb_step());
3326 }
3327 sub(reg_inp_ptr, get_inp_ocb_step() * jcp.nb_oc_int);
3328 sub(reg_wei_ptr, get_wei_ocb_step() * jcp.nb_oc_int);
3329
3330 store_output(width, do_store);
3331
3332 add(reg_inp_ptr, get_inp_shift());
3333 }
3334
compute_iw_loop()3335 void jit_avx512_core_amx_bwd_data_kernel_t::compute_iw_loop() {
3336 auto compute_iw_loop_body = [=](bool last_iwb, int num_tile_blocks) {
3337 int gen_tile_tail = last_iwb && jcp.tile_tail > 0 ? jcp.tile_tail
3338 : jcp.tile_width;
3339 init_runtime_counters(last_iwb && num_tile_blocks == 1);
3340 for (int iwb = 0; iwb < num_tile_blocks - 1; iwb++)
3341 compute_ocb_loop(jcp.tile_width, false);
3342 compute_ocb_loop(gen_tile_tail, true);
3343 };
3344
3345 if (jcp.nb_iw == 1) {
3346 compute_iw_loop_body(true, jcp.iw_blocks);
3347 } else {
3348 Label label_done;
3349 int iw_blocks_per_call = div_up(jcp.iw_block, jcp.tile_width);
3350 int last_iwb_tile_blocks = jcp.iw_blocks % iw_blocks_per_call;
3351 if (last_iwb_tile_blocks == 0 && jcp.tile_tail > 0)
3352 last_iwb_tile_blocks = iw_blocks_per_call;
3353 if (last_iwb_tile_blocks > 0) {
3354 Label label_not_last_iwb;
3355 mov(reg_tmp, ptr[param1 + GET_OFF(iwb)]);
3356 cmp(reg_tmp, jcp.nb_iw - 1);
3357 jne(label_not_last_iwb, T_NEAR);
3358
3359 compute_iw_loop_body(true, last_iwb_tile_blocks);
3360
3361 jmp(label_done, T_NEAR);
3362
3363 L(label_not_last_iwb);
3364 }
3365 compute_iw_loop_body(false, iw_blocks_per_call);
3366
3367 L(label_done);
3368 }
3369 }
3370
generate()3371 void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
3372 preamble();
3373
3374 mov(reg_inp_ptr, ptr[param1 + GET_OFF(dst)]); // padded buffer of diff_dst
3375 mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); // weights
3376 mov(reg_out_ptr, ptr[param1 + GET_OFF(src)]); // diff_src
3377 mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
3378
3379 if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
3380
3381 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
3382
3383 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
3384
3385 const int inp_stride = jcp.oc_block_int * jcp.typesize_in;
3386 const int wei_stride = jcp.ic_block * jcp.typesize_acc;
3387 mov(reg_inp_stride, inp_stride);
3388 mov(reg_wei_stride, wei_stride);
3389
3390 if (jcp.is_nspc && jcp.ic_without_padding != jcp.ic) {
3391 // Use mask 0xF by default for all output data and post-ops
3392 // loads / stores with block index
3393 // icb = icc * jcp.nb_ic_blocking + (jcp.nb_ic_blocking - 1)
3394 // TODO: use masked loads / stores for the last icc only
3395 int current_block_size = jcp.ic_block;
3396 int mask = (1 << current_block_size) - 1;
3397 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
3398 mov(regw_tmp, mask);
3399 kmovw(ktail_mask, regw_tmp);
3400 Xbyak::Label mask_is_set;
3401 mov(reg_ic_blocks, ptr[param1 + GET_OFF(ic_blocks)]);
3402 cmp(reg_ic_blocks, jcp.nb_ic - jcp.nb_ic_blocking);
3403 jne(mask_is_set, T_NEAR);
3404 // Reset the mask
3405 current_block_size = jcp.ic_without_padding % jcp.ic_block;
3406 mask = (1 << current_block_size) - 1;
3407 mov(regw_tmp, mask);
3408 kmovw(ktail_mask, regw_tmp);
3409
3410 L(mask_is_set);
3411 }
3412 compute_iw_loop();
3413
3414 postamble();
3415
3416 if (jcp.with_eltwise) eltwise_injector_->prepare_table();
3417 }
3418
post_ops_ok(const jit_conv_conf_t & jcp,primitive_attr_t & attr)3419 bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
3420 const jit_conv_conf_t &jcp, primitive_attr_t &attr) {
3421 using namespace primitive_kind;
3422 const auto &p = attr.post_ops_;
3423 const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
3424
3425 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
3426
3427 auto is_sum = [&](int idx) {
3428 if (is_bf16)
3429 return p.entry_[idx].is_sum();
3430 else
3431 return p.contain(sum, idx);
3432 };
3433
3434 switch (p.len()) {
3435 case 0: return true;
3436 case 1: return is_eltwise(0) || is_sum(0);
3437 case 2:
3438 return (is_sum(0) && is_eltwise(1))
3439 || (!is_bf16 && is_sum(1) && is_eltwise(0));
3440 default: return false;
3441 }
3442
3443 return false;
3444 }
3445
tile_configure(char * tcfg_buff)3446 void jit_avx512_core_amx_bwd_data_kernel_t::tile_configure(char *tcfg_buff) {
3447 const int vnni_width = jcp.ddst_dt == data_type::bf16 ? 2 : 4;
3448 // Input tile dimensions
3449 const int a_col = jcp.oc_block_int;
3450 const int a_row = jcp.tile_width;
3451 // Weights tile dimensions
3452 const int b_col = jcp.ic_block * vnni_width;
3453 const int b_row = a_col / vnni_width;
3454 // Accumulator tile dimensions
3455 const int c_col = jcp.ic_block;
3456 const int c_row = a_row;
3457
3458 for (size_t i = 0; i < 64; i++)
3459 tcfg_buff[i] = 0;
3460
3461 // Weights (W_BASE) Tensor Tiles
3462 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3463 tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
3464 b_row, b_col * jcp.typesize_in);
3465
3466 // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
3467 for (int h = 0; h < jcp.nb_ih_blocking; h++) {
3468 tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
3469 a_row, a_col * jcp.typesize_in);
3470 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3471 tc_configure_tile((palette_config_t *)tcfg_buff,
3472 get_out_tensor(h, i), c_row, c_col * jcp.typesize_acc);
3473 }
3474
3475 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3476 }
3477
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & diff_src_md,memory_desc_t & weights_md,memory_desc_t & diff_dst_md,memory_desc_t * bias_md,primitive_attr_t & attr,int nthreads)3478 status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3479 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
3480 memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
3481 memory_desc_t *bias_md, primitive_attr_t &attr, int nthreads) {
3482 using namespace prop_kind;
3483
3484 const memory_desc_wrapper diff_src_d(&diff_src_md);
3485 const memory_desc_wrapper weights_d(&weights_md);
3486 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
3487 const memory_desc_wrapper bias_d(bias_md);
3488
3489 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
3490 int ndims = diff_src_d.ndims();
3491 bool is_1d = ndims == 3;
3492 bool is_3d = ndims == 5;
3493
3494 if (is_3d) return status::unimplemented;
3495
3496 using namespace data_type;
3497 const bool is_deconv = cd.prop_kind != prop_kind::backward_data;
3498 const bool is_bf16 = everyone_is(true, diff_dst_d.data_type() == bf16,
3499 weights_d.data_type() == bf16,
3500 one_of(diff_src_d.data_type(), bf16, f32));
3501 const bool is_bf16_convolution = is_bf16 && !is_deconv;
3502 const bool is_bf16_deconvolution = is_bf16 && is_deconv;
3503 const bool is_int8_deconvolution = is_deconv
3504 && everyone_is(true, one_of(diff_dst_d.data_type(), s8, u8),
3505 weights_d.data_type() == s8,
3506 one_of(diff_src_d.data_type(), f32, s32, s8, u8));
3507
3508 bool supported = false || (is_bf16 && mayiuse(avx512_core_bf16_amx_bf16))
3509 || (is_int8_deconvolution && mayiuse(avx512_core_bf16_amx_int8));
3510 if (!supported) return status::unimplemented;
3511
3512 jcp = zero<decltype(jcp)>();
3513 jcp.isa = is_bf16 ? avx512_core_bf16_amx_bf16 : avx512_core_bf16_amx_int8;
3514 jcp.ndims = ndims;
3515 jcp.prop_kind = cd.prop_kind;
3516 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
3517
3518 jcp.mb = diff_src_d.dims()[0];
3519 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
3520 jcp.oc_without_padding = jcp.oc;
3521 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
3522 jcp.ic_without_padding = jcp.ic;
3523 jcp.ih = !is_1d ? diff_src_d.dims()[ndims - 2] : 1;
3524 jcp.iw = diff_src_d.dims()[ndims - 1];
3525 jcp.oh = !is_1d ? diff_dst_d.dims()[ndims - 2] : 1;
3526 jcp.ow = diff_dst_d.dims()[ndims - 1];
3527 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
3528 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
3529 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
3530 jcp.l_pad = cd.padding[0][ndims - 3];
3531 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
3532 jcp.stride_w = cd.strides[ndims - 3];
3533
3534 // No bias for bf16 case to simplify integration with ref_deconvolution
3535 jcp.with_bias = bias_md && !is_bf16_convolution
3536 && cd.bias_desc.format_kind != format_kind::undef;
3537
3538 jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
3539 jcp.dilate_w = cd.dilates[ndims - 3];
3540
3541 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
3542 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
3543 jcp.b_pad = calculate_end_padding(
3544 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
3545 jcp.r_pad = calculate_end_padding(
3546 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
3547 if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
3548 || jcp.b_pad >= gen_kh)
3549 return status::unimplemented;
3550
3551 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
3552 if (is_deconv) {
3553 jcp.ddst_dt = cd.src_desc.data_type;
3554 jcp.dsrc_dt = cd.dst_desc.data_type;
3555 } else {
3556 jcp.ddst_dt = cd.diff_dst_desc.data_type;
3557 jcp.dsrc_dt = cd.diff_src_desc.data_type;
3558 }
3559 jcp.wei_dt = cd.weights_desc.data_type;
3560
3561 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
3562
3563 if (jcp.is_depthwise)
3564 return status::unimplemented; // TODO: add support of DW convolution
3565
3566 format_tag_t dat_tag_ncsp
3567 = pick(ndims - 3, format_tag::nCw16c, format_tag::nChw16c);
3568 format_tag_t dat_tag_nspc
3569 = pick(ndims - 3, format_tag::nwc, format_tag::nhwc);
3570 // To toggle the default data layout for BF16 between nChw16c and nhwc,
3571 // swap the following two variable definitions. Current choice: nhwc.
3572 format_tag_t dat_tag_opt = dat_tag_nspc;
3573 format_tag_t dat_tag_alt = is_bf16 ? dat_tag_ncsp : dat_tag_nspc;
3574
3575 if (diff_src_d.format_kind() == format_kind::any) {
3576 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_opt));
3577 jcp.src_tag = dat_tag_opt;
3578 } else
3579 jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
3580
3581 if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
3582 return status::unimplemented;
3583
3584 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3585 assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc));
3586
3587 // TODO: remove all support for nChw16c from this implementation
3588 if (!jcp.is_nspc) return status::unimplemented;
3589
3590 if (diff_dst_d.format_kind() == format_kind::any) {
3591 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
3592 jcp.dst_tag = jcp.src_tag;
3593 } else
3594 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
3595
3596 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
3597
3598 if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
3599 CHECK(memory_desc_init_by_tag(*bias_md, format_tag::x));
3600
3601 jcp.nthr = nthreads;
3602
3603 jcp.ic_block = 16;
3604 jcp.oc_block = 16;
3605
3606 if (jcp.ngroups == 1) {
3607 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
3608 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
3609 }
3610 bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
3611 if (!args_ok) return status::unimplemented;
3612
3613 const int vnni_width = is_bf16 ? 2 : 4;
3614 jcp.oc_block_int = jcp.oc_block * vnni_width; // 32 for bf16, 64 for int8
3615
3616 if (attr.set_default_formats(&diff_src_md) != status::success)
3617 return status::unimplemented;
3618 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
3619
3620 const auto &p = attr.post_ops_;
3621 const int eltwise_ind = p.find(primitive_kind::eltwise);
3622 jcp.with_eltwise = eltwise_ind != -1;
3623 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
3624
3625 auto set_or_check_wei_format = [&]() {
3626 using namespace format_tag;
3627 format_tag_t wei_tag;
3628 if (is_bf16_convolution)
3629 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16o16i2o,
3630 gOIw16o16i2o, OIhw16o16i2o, gOIhw16o16i2o);
3631 else if (is_bf16_deconvolution)
3632 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
3633 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i);
3634 else if (is_int8_deconvolution)
3635 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i,
3636 gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i);
3637 else {
3638 assert(!"unsupported combination");
3639 return false;
3640 }
3641
3642 memory_desc_t want_wei_md = weights_md;
3643 memory_desc_init_by_tag(want_wei_md, wei_tag);
3644
3645 if (weights_md.format_kind == format_kind::any) {
3646 weights_md = want_wei_md;
3647 return true;
3648 }
3649 return weights_md == want_wei_md;
3650 };
3651
3652 if (!set_or_check_wei_format()) return status::unimplemented;
3653
3654 jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
3655 jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
3656 jcp.typesize_bia
3657 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
3658 jcp.typesize_acc = sizeof(int32_t);
3659
3660 jcp.nb_ic = jcp.ic / jcp.ic_block;
3661 jcp.nb_oc = jcp.oc / jcp.oc_block;
3662 jcp.nb_oc_int = div_up(jcp.oc, jcp.oc_block_int);
3663
3664 const int max_palette = amx::get_max_palette();
3665 jcp.max_tiles = amx::get_max_tiles(max_palette);
3666 jcp.full_tile_width = amx::get_max_rows(max_palette);
3667 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
3668 return status::unimplemented;
3669
3670 jcp.tile_width = nstl::min(jcp.full_tile_width, jcp.iw);
3671 jcp.iw_blocks = div_up(jcp.iw, jcp.tile_width);
3672
3673 // Prefer to use a single tile width when possible
3674 // (eg iw28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
3675 if (jcp.iw % jcp.iw_blocks == 0) jcp.tile_width = jcp.iw / jcp.iw_blocks;
3676 jcp.tile_tail = jcp.iw % jcp.tile_width;
3677
3678 jcp.nb_ic_blocking = (jcp.nb_ic % 2 == 0) ? 2 : 1;
3679 jcp.nb_ih_blocking
3680 = everyone_is(true, jcp.ih > 1,
3681 // requirement for interleave stores
3682 IMPLICATION(jcp.iw_blocks > 1, jcp.ih % 2 == 0))
3683 ? 2
3684 : 1;
3685
3686 // TODO: tune ih blocking
3687 const int ih_blk_size_tmp = 10;
3688 const int ih_step = jcp.nb_ih_blocking;
3689 jcp.ih_blk_size = rnd_up(nstl::min(jcp.ih, ih_blk_size_tmp), ih_step);
3690 // ohp includes all elements that are really used in calculation,
3691 // including zero-padded "dilate-by-strides" and top and bottom overflow
3692 jcp.ohp = jcp.ih_blk_size + gen_kh - 1;
3693
3694 // TODO: tune iw blocking
3695 const int iw_blocks_per_call = 2;
3696 jcp.iw_block = jcp.tile_width * iw_blocks_per_call;
3697 jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
3698 // owp includes all elements that are really used in calculation,
3699 // including zero-padded "dilate-by-strides" and left and right overflow
3700 jcp.owp = jcp.iw_block + gen_kw - 1;
3701
3702 // Number of ops per tile store
3703 int ops_tile_store = jcp.tile_width;
3704 // Number of ops per accumulation tile
3705 int avaliable_ops = jcp.nb_oc_int * jcp.kh * jcp.kw;
3706 // Number of vectors to store per tile operation
3707 // NOTE: set to zero to turn off interleave store (mostly for debugging)
3708 jcp.per_one_pstore = div_up(ops_tile_store, avaliable_ops);
3709
3710 jcp.inp_buffer_size
3711 = (size_t)jcp.nb_oc_int * jcp.ohp * jcp.owp * jcp.oc_block_int;
3712 jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking
3713 * jcp.full_tile_width * jcp.ic_block;
3714
3715 const auto &oscales = attr.output_scales_;
3716 jcp.is_ic_scale = oscales.mask_ == 1 << 1;
3717
3718 return status::success;
3719 }
3720
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)3721 void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad(
3722 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
3723 const primitive_attr_t &attr) {
3724
3725 size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
3726 scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
3727 size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
3728 scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
3729 if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) {
3730 assert(jcp.ngroups == 1);
3731 scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia);
3732 }
3733 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
3734 }
3735
3736 const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32;
3737
3738 // Tile register decomposition
3739 // { C_BASE = 0, A_BASE = 4, B_BASE = 6, }
get_wei_tensor(int ocb,int icb) const3740 int jit_avx512_core_amx_bwd_weights_kernel_t::get_wei_tensor(
3741 int ocb, int icb) const {
3742 const int C_BASE = 0;
3743 const int C_LAST = 4;
3744 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
3745 MAYBE_UNUSED(C_LAST);
3746 const int tile = C_BASE + ocb * jcp.nb_oc_blocking + icb;
3747 assert(C_BASE <= tile && tile < C_LAST);
3748 return tile;
3749 }
get_src_tensor(int icb) const3750 int jit_avx512_core_amx_bwd_weights_kernel_t::get_src_tensor(int icb) const {
3751 const int A_BASE = 4;
3752 const int A_LAST = 6;
3753 assert(0 <= A_BASE && A_BASE < A_LAST && A_LAST <= jcp.max_tiles);
3754 MAYBE_UNUSED(A_LAST);
3755 const int tile = A_BASE + icb;
3756 assert(A_BASE <= tile && tile < A_LAST);
3757 return tile;
3758 }
get_ddst_tensor(int ocb) const3759 int jit_avx512_core_amx_bwd_weights_kernel_t::get_ddst_tensor(int ocb) const {
3760 const int B_BASE = 6;
3761 const int B_LAST = 8;
3762 assert(0 <= B_BASE && B_BASE < B_LAST && B_LAST <= jcp.max_tiles);
3763 MAYBE_UNUSED(B_LAST);
3764 const int tile = B_BASE + ocb;
3765 assert(B_BASE <= tile && tile < B_LAST);
3766 return tile;
3767 }
3768
tile_configure(char * tcfg_buff)3769 void jit_avx512_core_amx_bwd_weights_kernel_t::tile_configure(char *tcfg_buff) {
3770 // Input tile dimensions
3771 const int a_col = jcp.ur_w;
3772 const int a_row = jcp.ic_block;
3773 // Weights tile dimensions
3774 const int b_col = jcp.oc_block * 2;
3775 const int b_row = a_col / 2;
3776 // Accumulator tile dimensions
3777 const int c_col = jcp.oc_block;
3778 const int c_row = a_row;
3779
3780 for (size_t i = 0; i < 64; i++)
3781 tcfg_buff[i] = 0;
3782
3783 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3784 tc_configure_tile((palette_config_t *)tcfg_buff, get_src_tensor(icb),
3785 a_row, a_col * jcp.typesize_in);
3786
3787 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3788 tc_configure_tile((palette_config_t *)tcfg_buff, get_ddst_tensor(ocb),
3789 b_row, b_col * jcp.typesize_in);
3790
3791 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3792 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3793 tc_configure_tile((palette_config_t *)tcfg_buff,
3794 get_wei_tensor(ocb, icb), c_row, c_col * jcp.typesize_out);
3795
3796 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3797 }
3798
od_step_comeback_pointers()3799 void jit_avx512_core_amx_bwd_weights_kernel_t::od_step_comeback_pointers() {
3800 Label kd_comeback_label;
3801 mov(kj, reg_kd_count);
3802 L(kd_comeback_label);
3803 {
3804 sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3805 sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3806 dec(kj);
3807 jnz(kd_comeback_label, T_NEAR);
3808 }
3809 }
3810
oh_step_comeback_pointers()3811 void jit_avx512_core_amx_bwd_weights_kernel_t::oh_step_comeback_pointers() {
3812 Label kh_comeback_label;
3813 mov(kj, reg_kh);
3814 L(kh_comeback_label);
3815 {
3816 sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3817 sub(reg_kernel, get_kernel_offset(0, jcp.kw));
3818 dec(kj);
3819 jnz(kh_comeback_label, T_NEAR);
3820 }
3821 }
3822
compute_full_spat_loop(int nb_ic_blocking,int nb_oc_blocking)3823 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop(
3824 int nb_ic_blocking, int nb_oc_blocking) {
3825 // General code layout:
3826 //
3827 // Blocking over OH -- top level
3828 // (Reduces L2 pressure; not very useful right now)
3829 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
3830 // Loop over OH block -- emit_h_loop()
3831 // Loop over OW blocks -- emit_fma_block()
3832 // (Supports both fully unrolled and partially unrolled
3833 // versions to reduce code size)
3834 // Loop over OW block -- emit_fma_step()
3835
3836 auto src_row_size = get_src_offset(0, 0, 1);
3837 auto ddst_row_size = get_ddst_offset(0, 1);
3838 auto row_size = src_row_size + ddst_row_size;
3839
3840 int h_block_size = jcp.oh;
3841 int h_last_block_size = h_block_size;
3842 int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3843 auto working_set_size = row_size * h_block_size;
3844
3845 if (working_set_size > full_spat_max_working_set_size) {
3846 assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3847
3848 while (working_set_size > full_spat_opt_working_set_size
3849 && h_block_size >= min_h_block_size) {
3850 for (int i = 2; i <= h_block_size; i++)
3851 if (i == h_block_size)
3852 h_block_size = h_block_size / 2;
3853 else if (h_block_size % i == 0) {
3854 h_block_size = h_block_size / i;
3855 break;
3856 }
3857 working_set_size = row_size * h_block_size;
3858 }
3859 h_block_size = nstl::max(min_h_block_size, h_block_size);
3860 h_last_block_size = jcp.oh % h_block_size;
3861 if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3862 }
3863
3864 Opmask reg_h_block = k1;
3865 Reg64 reg_kh = rax;
3866 Reg64 reg_kw = rbx;
3867 Reg64 reg_tmp = abi_not_param1;
3868 Reg32 reg_tmp_w = reg_tmp.cvt32();
3869 Reg64 reg_ohs = rdx;
3870 Reg64 reg_ihs = rsi;
3871 Reg64 reg_h = r8;
3872 Reg64 reg_j = r10;
3873
3874 Reg64 reg_src = r13;
3875 Reg64 reg_ddst = r14;
3876 Reg64 reg_ker = r15;
3877
3878 Reg64 reg_dense_stride = abi_param1;
3879 Reg64 reg_a_stride = reg_tmp;
3880
3881 auto emit_block = [&]() {
3882 mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
3883 for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
3884 dim_t ur_w_src_offset = ur_w_b * get_src_offset(0, jcp.ur_w);
3885 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
3886
3887 for (int icb = 0; icb < nb_ic_blocking; icb++) {
3888 dim_t icb_offset = jcp.typesize_in * icb * jcp.tr_src_buf_size;
3889 tileloadd(Tmm(get_src_tensor(icb)),
3890 ptr[reg_src + reg_a_stride + icb_offset
3891 + ur_w_src_offset]);
3892 }
3893 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
3894 tileloadd(Tmm(get_ddst_tensor(ocb)),
3895 ptr[reg_ddst + reg_dense_stride
3896 + jcp.typesize_in * ocb
3897 * jcp.tr_diff_dst_buf_size
3898 + ur_w_ddst_offset]);
3899 for (int icb = 0; icb < nb_ic_blocking; icb++)
3900 tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
3901 Tmm(get_src_tensor(icb)),
3902 Tmm(get_ddst_tensor(ocb)));
3903 }
3904 }
3905 };
3906
3907 auto emit_h_loop = [&]() {
3908 Label h_loop, skip_h_loop;
3909 mov(reg_j, 1);
3910 cmp(reg_j, reg_h);
3911 je(skip_h_loop, T_NEAR);
3912 L(h_loop);
3913 {
3914 emit_block();
3915
3916 add(reg_src, get_src_offset(0, 0, 1));
3917 add(reg_ddst, get_ddst_offset(0, 1));
3918 add(reg_j, 1);
3919 cmp(reg_j, reg_h);
3920 jb(h_loop);
3921 }
3922 L(skip_h_loop);
3923
3924 emit_block();
3925 };
3926
3927 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3928 xor_(reg_kh, reg_kh);
3929 Label kh_loop, kh_loop_end;
3930
3931 int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3932 // NB: this is correct because we only support t_pad = kh / 2 and thus
3933 // ih == oh
3934 int ih_block_size = oh_block_size
3935 + (!is_first_block + !is_last_block) * jcp.t_pad;
3936
3937 L(kh_loop);
3938 {
3939 if (is_first_block) {
3940 xor_(reg_tmp, reg_tmp);
3941 mov(reg_ohs, jcp.t_pad);
3942 sub(reg_ohs, reg_kh);
3943 cmovb(reg_ohs, reg_tmp);
3944
3945 mov(reg_ihs, reg_ohs);
3946 sub(reg_ihs, jcp.t_pad);
3947 add(reg_ihs, reg_kh);
3948 } else {
3949 xor_(reg_ohs, reg_ohs);
3950 mov(reg_ihs, reg_kh);
3951 }
3952
3953 mov(reg_tmp, oh_block_size);
3954 sub(reg_tmp, reg_ohs);
3955 mov(reg_h, ih_block_size);
3956 sub(reg_h, reg_ihs);
3957 cmp(reg_tmp, reg_h);
3958 cmovb(reg_h, reg_tmp);
3959
3960 Label kh_loop_work;
3961 cmp(reg_h, 0);
3962 jg(kh_loop_work, T_NEAR);
3963
3964 // empty h loop for this jcp.kh:
3965 // - set the ddst to 0 if necessary
3966 // - move ker pt
3967 // - jump to the end
3968 sub(reg_h, 1);
3969 Label skip_ker_zeroing;
3970
3971 // The reg_ker ptr has highest bit set if the ddst needs to be
3972 // zeroed. Those who have byte-aligned their data will suffer the
3973 // consequences :(
3974 // TODO: move the flag to a mask register? (Roma)
3975 test(reg_ker, 1);
3976 jz(skip_ker_zeroing, T_NEAR);
3977
3978 Label zeroing_loop;
3979 vpxord(zmm0, zmm0, zmm0);
3980 and_(reg_ker, ~1); // temporarily clear the zeroing flag
3981
3982 mov(reg_dense_stride, 64);
3983 tilezero(Tmm(get_wei_tensor(0, 0)));
3984 for (int kw = 0; kw < jcp.kw; kw++) {
3985 // dim_t kw_offset = kw * get_kernel_offset(jcp.ic_block, 0);
3986 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
3987 for (int icb = 0; icb < nb_ic_blocking; icb++)
3988 tilestored(
3989 ptr[reg_ker + reg_dense_stride
3990 + get_full_kernel_offset(ocb, icb, 0, kw)],
3991 Tmm(get_wei_tensor(0, 0)));
3992 }
3993 // restore the zeroing flag (it will be cleared after the end of
3994 // emit_kh_kw_loop, but we may need it until then)
3995 or_(reg_ker, 1);
3996 jmp(kh_loop_end, T_NEAR);
3997
3998 L(skip_ker_zeroing);
3999 add(reg_ker, get_kernel_offset(0, jcp.kw));
4000 jmp(kh_loop_end, T_NEAR);
4001
4002 L(kh_loop_work);
4003
4004 mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
4005 mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
4006
4007 add(reg_src, reg_ihs);
4008 add(reg_ddst, reg_ohs);
4009
4010 Label kw_loop;
4011 xor_(reg_kw, reg_kw);
4012
4013 mov(reg_dense_stride, 64);
4014 L(kw_loop);
4015 {
4016 Label do_zero, ker_init_done;
4017 test(reg_ker, 1);
4018 jnz(do_zero, T_NEAR);
4019
4020 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4021 for (int icb = 0; icb < nb_ic_blocking; icb++)
4022 tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4023 ptr[reg_ker + reg_dense_stride
4024 + get_full_kernel_offset(ocb, icb, 0, 0)]);
4025 jmp(ker_init_done);
4026 L(do_zero);
4027 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4028 for (int icb = 0; icb < nb_ic_blocking; icb++)
4029 tilezero(Tmm(get_wei_tensor(ocb, icb)));
4030
4031 L(ker_init_done);
4032
4033 mov(ptr[rsp + ddst_save_offset], reg_ddst);
4034 mov(ptr[rsp + src_save_offset], reg_src);
4035
4036 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
4037 emit_h_loop();
4038
4039 mov(reg_ddst, ptr[rsp + ddst_save_offset]);
4040 mov(reg_src, ptr[rsp + src_save_offset]);
4041
4042 // The reg_ker ptr has highest bit set if the ddst needs to
4043 // be zeroed. Those who have byte-aligned their data will
4044 // suffer the consiquences :(
4045 mov(reg_tmp, reg_ker);
4046 and_(reg_ker, ~1);
4047
4048 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4049 for (int icb = 0; icb < nb_ic_blocking; icb++)
4050 tilestored(
4051 ptr[reg_ker + reg_dense_stride
4052 + get_full_kernel_offset(ocb, icb, 0, 0)],
4053 Tmm(get_wei_tensor(ocb, icb)));
4054
4055 mov(reg_ker, reg_tmp);
4056 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
4057 add(reg_kw, 1);
4058 cmp(reg_kw, jcp.kw);
4059 jl(kw_loop);
4060 }
4061
4062 sub(reg_src, reg_ihs);
4063 sub(reg_ddst, reg_ohs);
4064
4065 L(kh_loop_end);
4066 add(reg_kh, 1);
4067 cmp(reg_kh, jcp.kh);
4068 jl(kh_loop);
4069 }
4070 };
4071
4072 mov(reg_src, ptr[param + GET_OFF(src)]);
4073 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4074 mov(reg_ker, ptr[param + GET_OFF(filt)]);
4075 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4076 or_(reg_ker, reg_tmp);
4077
4078 bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
4079
4080 auto src_row_step = get_src_offset(0, 0, 1);
4081 auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
4082 auto ddst_block_step = get_ddst_offset(0, h_block_size);
4083
4084 emit_kh_kw_loop(true, single_kh_kw_loop);
4085
4086 if (!single_kh_kw_loop) {
4087 auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
4088 sub(reg_ker, ker_reset_offset);
4089 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4090
4091 add(reg_src, first_src_block_step);
4092 add(reg_ddst, ddst_block_step);
4093
4094 int num_innermost_iters
4095 = (jcp.oh - h_last_block_size) / h_block_size - 1;
4096 if (num_innermost_iters > 0) {
4097 Label h_block_loop;
4098
4099 mov(reg_tmp_w, num_innermost_iters);
4100 kmovw(reg_h_block, reg_tmp_w);
4101 L(h_block_loop);
4102 {
4103 emit_kh_kw_loop(false, false);
4104 sub(reg_ker, ker_reset_offset);
4105 add(reg_src, src_row_step * h_block_size);
4106 add(reg_ddst, ddst_block_step);
4107
4108 kmovw(reg_tmp_w, reg_h_block);
4109 sub(reg_tmp_w, 1);
4110 kmovw(reg_h_block, reg_tmp_w);
4111 jnz(h_block_loop);
4112 }
4113 }
4114
4115 emit_kh_kw_loop(false, true);
4116 }
4117 }
4118
compute_ic_loop(int ic_block,int nb_ic_blocking,int nb_oc_blocking)4119 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_ic_loop(
4120 int ic_block, int nb_ic_blocking, int nb_oc_blocking) {
4121 assert(jcp.ur_w % 2 == 0);
4122 const int str_w = jcp.stride_w;
4123 assert(jcp.tr_iw % str_w == 0);
4124 const int src_stride_w_shift = jcp.tr_iw / str_w;
4125
4126 mov(reg_b_stride, 64);
4127 mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
4128
4129 for (int s = 0; s < str_w; s++) {
4130 for (int i_kw = s; i_kw < jcp.kw; i_kw += str_w) {
4131
4132 for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4133 for (int icb = 0; icb < nb_ic_blocking; icb++)
4134 tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4135 ptr[reg_kernel + reg_b_stride
4136 + get_full_kernel_offset(
4137 ocb, icb, 0, i_kw)]);
4138
4139 int src_offset_l = (i_kw * (jcp.dilate_w + 1)) / str_w
4140 + s * src_stride_w_shift;
4141
4142 for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
4143 dim_t ur_w_src_offset = ur_w_b
4144 * get_src_offset(0, filter_w_to_src(0, jcp.ur_w, 0));
4145 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
4146 for (int icb = 0; icb < nb_ic_blocking; icb++) {
4147 dim_t icb_offset = icb * jcp.tr_src_buf_size;
4148 tileloadd(Tmm(get_src_tensor(icb)),
4149 ptr[reg_src
4150 + jcp.typesize_in
4151 * (src_offset_l + icb_offset)
4152 + ur_w_src_offset + reg_a_stride]);
4153 }
4154 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4155 tileloadd(Tmm(get_ddst_tensor(ocb)),
4156 ptr[reg_ddst
4157 + jcp.typesize_in * ocb
4158 * jcp.tr_diff_dst_buf_size
4159 + ur_w_ddst_offset + reg_b_stride]);
4160 for (int icb = 0; icb < nb_ic_blocking; icb++)
4161 tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
4162 Tmm(get_src_tensor(icb)),
4163 Tmm(get_ddst_tensor(ocb)));
4164 }
4165 }
4166
4167 for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4168 for (int icb = 0; icb < nb_ic_blocking; icb++)
4169 tilestored(ptr[reg_kernel + reg_b_stride
4170 + get_full_kernel_offset(
4171 ocb, icb, 0, i_kw)],
4172 Tmm(get_wei_tensor(ocb, icb)));
4173 }
4174 }
4175 safe_add(reg_src, get_src_offset(ic_block, 0), reg_long_offt);
4176 add(reg_kernel, get_kernel_offset(ic_block, 0));
4177 }
4178
compute_diff_bias_init(int ocb)4179 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_init(int ocb) {
4180 auto reg_unit_val = reg_tmp.cvt16();
4181 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
4182 vpbroadcastw(vreg_bias_unit, reg_unit_val);
4183
4184 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4185 vmovups(vreg_bias_acc, ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block]);
4186 }
4187
compute_diff_bias_row(bool is_partial,int ocb)4188 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_row(
4189 bool is_partial, int ocb) {
4190 if (!jcp.with_bias) return;
4191 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4192 Label skip_label;
4193 test(reg_tmp, FLAG_IC_FIRST);
4194 jz(skip_label, T_NEAR);
4195
4196 if (is_partial) { compute_diff_bias_init(ocb); }
4197 auto compute_step = [&]() {
4198 vmovups(vreg_bias_ddst, ptr[reg_ddst]);
4199 vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
4200 };
4201
4202 Label ow_loop, ow_tail;
4203 int niters = jcp.tr_ow / 2;
4204 if (niters > 0) {
4205 mov(reg_tmp, jcp.tr_ow / 2);
4206 L(ow_loop);
4207 compute_step();
4208 add(reg_ddst, get_ddst_offset(2));
4209 sub(reg_tmp, 1);
4210 jnz(ow_loop, T_NEAR);
4211 }
4212 if (jcp.tr_ow % 2) compute_step();
4213
4214 if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
4215
4216 if (is_partial) {
4217 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4218 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4219 vreg_bias_acc);
4220 }
4221
4222 L(skip_label);
4223 }
4224
maybe_compute_diff_bias(int nb_oc_blocking)4225 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_compute_diff_bias(
4226 int nb_oc_blocking) {
4227 // In harness_3d_reduction case calculation of diff_bias is called
4228 // for every ow row separately to be aligned with od loop in
4229 // compute_od_loop_common()
4230 if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
4231 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4232
4233 Label skip_label;
4234 test(reg_tmp, FLAG_IC_FIRST);
4235 jz(skip_label, T_NEAR);
4236
4237 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4238 Label bias_loop, skip_label_local;
4239
4240 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4241 add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size);
4242
4243 switch (jcp.harness) {
4244 case harness_2d_reduction:
4245 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4246 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4247 break;
4248 case harness_mb_reduction:
4249 case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
4250 case harness_3d_reduction:
4251 default: assert(!"Invalid harness type");
4252 }
4253
4254 cmp(reg_oj, 0);
4255 jle(skip_label_local, T_NEAR); // nothing to do
4256
4257 compute_diff_bias_init(ocb);
4258 L(bias_loop);
4259 {
4260 compute_diff_bias_row(false, ocb);
4261 add(reg_ddst, get_ddst_offset(0, 1));
4262
4263 sub(reg_oj, 1);
4264 jnz(bias_loop, T_NEAR);
4265 }
4266
4267 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4268 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4269 vreg_bias_acc);
4270
4271 L(skip_label_local);
4272 }
4273 // restore reg_ddst value
4274 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4275
4276 L(skip_label);
4277 }
4278
compute_oh_step_common(int nb_ic_blocking,int nb_oc_blocking)4279 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_step_common(
4280 int nb_ic_blocking, int nb_oc_blocking) {
4281 Label kh_label, ic_block_label, ow_block_label, kd_label;
4282
4283 int ic_block = jcp.ic_block;
4284 int ic_tail = jcp.ic_tail;
4285
4286 auto ic_loop = [&](int nb_ic_blocking, int nb_oc_blocking) {
4287 Label ic_tail_label, ic_loop_done_label;
4288
4289 if (ic_tail) {
4290 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
4291 cmp(reg_icb, jcp.ic_tail);
4292 jne(ic_tail_label, T_NEAR);
4293
4294 compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4295 jmp(ic_loop_done_label, T_NEAR);
4296
4297 L(ic_tail_label);
4298 compute_ic_loop(ic_tail, nb_ic_blocking, nb_oc_blocking);
4299 add(reg_kernel, get_kernel_offset(jcp.ic_block - ic_tail, 0));
4300 safe_add(reg_src,
4301 get_src_offset(0, 0, filter_h_to_src(1))
4302 - get_src_offset(ic_tail, 0),
4303 reg_long_offt);
4304 L(ic_loop_done_label);
4305 } else {
4306 compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4307 }
4308 };
4309
4310 if (jcp.ndims == 5) {
4311 /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
4312 * 'movs' must be guaranteed. */
4313 mov(ki, reg_kd_count);
4314 mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
4315 mov(aux_reg_src, reg_src);
4316 mov(aux_reg_kernel, reg_kernel);
4317
4318 L(kd_label);
4319 mov(reg_src, aux_reg_src);
4320 mov(reg_kernel, aux_reg_kernel);
4321 }
4322
4323 mov(kj, reg_kh);
4324 L(kh_label);
4325 {
4326 ic_loop(nb_ic_blocking, nb_oc_blocking);
4327
4328 if (jcp.dilate_h > 0) {
4329 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
4330 }
4331 // substract pointer shift made within ic block loop
4332 // and move to next kh index
4333 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
4334 dec(kj);
4335 cmp(kj, 0);
4336 jg(kh_label, T_NEAR);
4337 }
4338 if (jcp.ndims == 5) {
4339 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
4340 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
4341 dec(ki);
4342 cmp(ki, 0);
4343 jg(kd_label, T_NEAR);
4344 }
4345 // In harness_3d_reduction case calculation of diff_bias is called
4346 // for every ow row separately to be aligned with od loop in
4347 // compute_od_loop_common()
4348 if (jcp.harness == harness_3d_reduction) {
4349 auto reg_save_ddst = reg_a_stride;
4350 mov(reg_save_ddst, reg_ddst);
4351 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4352 safe_add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size,
4353 reg_long_offt);
4354 compute_diff_bias_row(true, ocb);
4355 }
4356 mov(reg_ddst, reg_save_ddst);
4357 }
4358
4359 if (jcp.ndims == 5) {
4360 mov(reg_src, aux_reg_src);
4361 mov(reg_kernel, aux_reg_kernel);
4362 mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
4363 od_step_comeback_pointers();
4364 } else {
4365 oh_step_comeback_pointers();
4366 }
4367 }
4368
maybe_zero_kernel(int nb_ic_blocking,int nb_oc_blocking)4369 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_zero_kernel(
4370 int nb_ic_blocking, int nb_oc_blocking) {
4371 if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
4372 Label skip_zeroing, zeroing_loop;
4373
4374 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4375 cmp(reg_tmp, 0);
4376 jz(skip_zeroing, T_NEAR);
4377
4378 Zmm zero = Zmm(0);
4379 vpxord(zero, zero, zero);
4380 if (jcp.with_bias) {
4381 Label skip_bias_zeroing;
4382 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4383 test(reg_tmp, FLAG_IC_FIRST);
4384 jz(skip_bias_zeroing, T_NEAR);
4385 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4386 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4387 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], zero);
4388 }
4389 L(skip_bias_zeroing);
4390 if (jcp.harness == harness_compute_full_spatial)
4391 jmp(skip_zeroing, T_NEAR);
4392 }
4393
4394 mov(reg_b_stride, 64);
4395 tilezero(Tmm(get_wei_tensor(0, 0)));
4396 for (dim_t shift = 0;
4397 shift < get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
4398 shift += get_kernel_offset(jcp.ic_block, 0)) {
4399 for_(int icb = 0; icb < nb_ic_blocking; icb++)
4400 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4401 tilestored(
4402 ptr[reg_kernel + reg_b_stride
4403 + get_full_kernel_offset(ocb, icb, 0, 0) + shift],
4404 Tmm(get_wei_tensor(0, 0)));
4405 }
4406 }
4407 L(skip_zeroing);
4408 }
4409
compute_oh_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4410 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_loop_common(
4411 int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4412 int b_pad = jcp.b_pad;
4413 int t_pad = jcp.t_pad;
4414
4415 bool is_dilated = jcp.dilate_h != 0;
4416 int dilate_h = jcp.dilate_h + 1;
4417 int stride_h = jcp.stride_h;
4418 auto filter_step_size = get_kernel_offset(0, jcp.kw);
4419 auto src_step_size = get_src_offset(0, 0, 1);
4420 auto ddst_step_size = get_ddst_offset(0, 1);
4421 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
4422 oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
4423 oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
4424 oh_dilate_label_end, oh_dilate_setup_label_shift,
4425 oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
4426
4427 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4428 int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
4429 int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
4430 int oh_head_overflow_end = div_up(t_pad, stride_h);
4431 int oh_tail_end = jcp.oh;
4432
4433 int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
4434 int ih_body_end
4435 = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
4436
4437 if (is_partial)
4438 mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4439 else
4440 xor_(reg_oj, reg_oj);
4441
4442 /* Compute 'top' edge */
4443 if (t_pad > 0) {
4444 if (is_partial) {
4445 cmp(reg_oj, oh_head_overflow_end);
4446 jge(oh_tpad_tail_label_end, T_NEAR);
4447 }
4448 const int overflow
4449 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
4450 const int underflow = div_up(t_pad, dilate_h);
4451 const int initial_kh = jcp.kh - overflow - underflow;
4452
4453 // Setup reg_kh, reg_kernel, and reg_src
4454 mov(reg_kh, initial_kh);
4455 add(reg_kernel, filter_step_size * underflow);
4456 if (is_dilated) {
4457 const int tail = t_pad % dilate_h;
4458 const int shift = tail == 0 ? 0 : dilate_h - tail;
4459 mov(reg_ih_shift, shift);
4460 if (!is_partial) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4461 add(reg_src, src_step_size * shift);
4462 }
4463
4464 if (is_partial) {
4465 Label head_setup, head_setup_finish;
4466 cmp(reg_oj, 0);
4467 je(head_setup_finish, T_NEAR);
4468 mov(reg_oj_setup, reg_oj);
4469
4470 L(head_setup);
4471 if (is_dilated) {
4472 inc(reg_ih_shift);
4473 cmp(reg_ih_shift, dilate_h);
4474 jl(oh_dilate_setup_label_shift, T_NEAR);
4475 // unshift src as new kernel element enters
4476 sub(reg_src, src_step_size * (dilate_h - 1));
4477 xor_(reg_ih_shift, reg_ih_shift);
4478 }
4479 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4480 add(reg_kh, stride_h);
4481 sub(reg_kernel, filter_step_size * stride_h);
4482 if (is_dilated) {
4483 jmp(oh_dilate_setup_label_noshift, T_NEAR);
4484 L(oh_dilate_setup_label_shift);
4485 // shift src as old kernel element progresses
4486 add(reg_src, src_step_size * stride_h);
4487 L(oh_dilate_setup_label_noshift);
4488 }
4489 sub(reg_oj_setup, 1);
4490 jg(head_setup, T_NEAR);
4491 L(head_setup_finish);
4492
4493 if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4494 if (oh_head_end < oh_head_overflow_end) {
4495 cmp(reg_oj, oh_head_end);
4496 jge(oh_tpad_label_end, T_NEAR);
4497 }
4498 }
4499
4500 //Setup reg_kernel
4501 // If dilated, shift src ptr
4502 // Loop
4503 L(oh_tpad_label);
4504 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4505 add(reg_ddst, ddst_step_size);
4506 if (is_dilated) {
4507 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4508 inc(reg_ih_shift);
4509 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4510 cmp(reg_ih_shift, dilate_h);
4511 jl(oh_dilate_label_shift, T_NEAR);
4512 // unshift src as new kernel element enters
4513 sub(reg_src, src_step_size * (dilate_h - 1));
4514 xor_(reg_ih_shift, reg_ih_shift);
4515 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4516 }
4517 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4518 add(reg_kh, stride_h);
4519 sub(reg_kernel, filter_step_size * stride_h);
4520 if (is_dilated) {
4521 jmp(oh_dilate_label_noshift, T_NEAR);
4522 L(oh_dilate_label_shift);
4523 // shift src as old kernel element progresses
4524 add(reg_src, src_step_size * stride_h);
4525 L(oh_dilate_label_noshift);
4526 }
4527 inc(reg_oj);
4528
4529 if (is_partial) {
4530 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4531 jge(oh_bpad_label_end, T_NEAR);
4532 }
4533 cmp(reg_oj, oh_head_end);
4534 jl(oh_tpad_label, T_NEAR);
4535
4536 L(oh_tpad_label_end);
4537 // need second loop to process kernel if it is larger than the src
4538 // (does not apply to dilations as they must have unit stride)
4539 if (oh_head_end < oh_head_overflow_end) {
4540 assert(!is_dilated);
4541
4542 cmp(reg_oj, oh_head_overflow_end);
4543 jge(oh_tpad_tail_label_end, T_NEAR);
4544
4545 mov(reg_kh, jcp.ih);
4546 L(oh_tpad_tail_label);
4547 {
4548 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4549 add(reg_ddst, ddst_step_size);
4550 sub(reg_kernel, filter_step_size * stride_h);
4551
4552 inc(reg_oj);
4553
4554 if (is_partial) {
4555 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4556 jge(oh_bpad_label_end, T_NEAR);
4557 }
4558 cmp(reg_oj, oh_head_overflow_end);
4559 jl(oh_tpad_tail_label, T_NEAR);
4560 }
4561 }
4562 if (body_src_start_offset != 0) {
4563 add(reg_kernel, filter_step_size * body_src_start_offset);
4564 add(reg_src, src_step_size * body_src_start_offset);
4565 }
4566 L(oh_tpad_tail_label_end);
4567 }
4568
4569 if (is_partial) {
4570 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4571 jge(oh_bpad_label_end, T_NEAR);
4572 }
4573 cmp(reg_oj, oh_body_end);
4574 jge(oh_label_end, T_NEAR);
4575
4576 /* Compute middle block(s) */
4577 mov(reg_kh, jcp.kh);
4578 L(oh_label);
4579 {
4580 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4581 add(reg_src, src_step_size * stride_h);
4582 add(reg_ddst, ddst_step_size);
4583
4584 inc(reg_oj);
4585
4586 if (is_partial) {
4587 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4588 jge(oh_bpad_label_end, T_NEAR);
4589 }
4590
4591 cmp(reg_oj, oh_body_end);
4592 jl(oh_label, T_NEAR);
4593 }
4594 L(oh_label_end);
4595
4596 /* Compute bottom edge */
4597 if (b_pad > 0) {
4598 if (is_partial) {
4599 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4600 jge(oh_bpad_label_end, T_NEAR);
4601 }
4602 cmp(reg_oj, jcp.oh);
4603 jge(oh_bpad_label_end, T_NEAR);
4604
4605 if (is_dilated) {
4606 // Assumes unit stride for dilations
4607 mov(reg_kh, jcp.kh - 1);
4608 xor_(reg_ih_shift, reg_ih_shift);
4609 } else {
4610 assert(jcp.dilate_h == 0);
4611 mov(reg_kh, jcp.ih - ih_body_end);
4612 }
4613 if (is_partial) {
4614 lea(reg_oj_setup,
4615 ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
4616 if (stride_h == 1 && !is_dilated) {
4617 sub(reg_kh, reg_oj_setup);
4618 } else {
4619 Label body_setup, body_setup_finish, dilate_skip;
4620 cmp(reg_oj_setup, 0);
4621 je(body_setup_finish, T_NEAR);
4622
4623 L(body_setup);
4624 if (is_dilated) {
4625 inc(reg_ih_shift);
4626 cmp(reg_ih_shift, dilate_h);
4627 jl(dilate_skip, T_NEAR);
4628 xor_(reg_ih_shift, reg_ih_shift);
4629 }
4630 sub(reg_kh, stride_h);
4631 L(dilate_skip);
4632 sub(reg_oj_setup, 1);
4633 jg(body_setup, T_NEAR);
4634 L(body_setup_finish);
4635 }
4636 }
4637
4638 if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4639 L(oh_bpad_label);
4640 {
4641 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4642 add(reg_src, src_step_size * stride_h);
4643 add(reg_ddst, ddst_step_size);
4644
4645 if (is_dilated) {
4646 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4647 inc(reg_ih_shift);
4648 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4649 cmp(reg_ih_shift, dilate_h);
4650 jl(oh_dilate_label_end, T_NEAR);
4651 xor_(reg_ih_shift, reg_ih_shift);
4652 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4653 }
4654 sub(reg_kh, stride_h);
4655 L(oh_dilate_label_end);
4656 inc(reg_oj);
4657 if (is_partial) {
4658 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4659 jge(oh_bpad_label_end, T_NEAR);
4660 }
4661 cmp(reg_oj, oh_tail_end);
4662 jl(oh_bpad_label, T_NEAR);
4663 }
4664 }
4665 L(oh_bpad_label_end);
4666 }
4667
compute_od_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4668 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_od_loop_common(
4669 int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4670 assert(jcp.harness == harness_3d_reduction);
4671
4672 const int src_backpad_overlap
4673 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
4674
4675 const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
4676 const auto src_shift = get_src_offset(0, 0, jcp.ih);
4677 const auto ddst_shift = get_ddst_offset(0, jcp.oh);
4678
4679 const int kd_front_pad = nstl::max(0, jcp.f_pad);
4680 const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
4681
4682 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
4683 backpad_end_label, backpad_label;
4684
4685 /* initially offset 'kd' by f_pad */
4686 mov(reg_src_d, ptr[param + GET_OFF(src)]);
4687 mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
4688
4689 if (is_partial) {
4690 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
4691 mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
4692 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
4693 } else {
4694 const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
4695 const int kd_offset = get_kernel_offset(
4696 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
4697 add(reg_kernel, kd_offset);
4698 xor_(reg_d_index, reg_d_index);
4699 mov(reg_kd_count, kd_padding);
4700 }
4701
4702 cmp(reg_kd_count, 0);
4703 jle(loop_end_label, T_NEAR); // no iterations along kd
4704 if (is_partial)
4705 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4706 else
4707 cmp(reg_d_index, jcp.od);
4708 jge(loop_end_label, T_NEAR); // no iterations along depth dimension
4709
4710 L(d_loop_label);
4711
4712 mov(reg_src, reg_src_d);
4713 mov(reg_ddst, reg_ddst_d);
4714
4715 mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
4716 mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
4717 mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
4718
4719 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4720
4721 mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
4722 mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
4723 mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
4724
4725 /* Compute 'front' edge */
4726 if (jcp.f_pad > 0) {
4727 /* Check if within fpad region */
4728 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
4729 jge(fpad_end_label, T_NEAR);
4730
4731 /* Fpad steps */
4732 sub(reg_kernel, filter_shift * jcp.stride_d);
4733 add(reg_kd_count, jcp.stride_d);
4734
4735 /* Final number of kernel elements that overlap with src */
4736 const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
4737 cmp(reg_kd_count, src_ker_overlap);
4738 jle(common_block_label, T_NEAR);
4739
4740 /* Correct any excess shifts to kernel and src */
4741 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
4742 /* Filter has moved beyond padding (adjust for stride effects) */
4743 if (jcp.f_pad % jcp.stride_d != 0) {
4744 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
4745 add(reg_kernel, filter_shift * src_corr);
4746 add(reg_src_d, src_shift * src_corr);
4747 }
4748 } else {
4749 /* Filter still overlaps padding (complete reset) */
4750 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
4751 }
4752
4753 /* Apply correction */
4754 mov(reg_kd_count, src_ker_overlap);
4755 jmp(common_block_label);
4756
4757 L(fpad_end_label);
4758 }
4759
4760 /* Compute bottom edge */
4761 if (jcp.back_pad > 0) {
4762
4763 /* Check if within back_pad region */
4764 cmp(reg_d_index, src_backpad_overlap - 1);
4765 jl(backpad_end_label, T_NEAR);
4766 jg(backpad_label, T_NEAR);
4767
4768 /* Execute overlap correction between the filter and the initial
4769 * back_pad region. */
4770 mov(reg_kd_count,
4771 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
4772 jmp(backpad_end_label, T_NEAR);
4773
4774 L(backpad_label);
4775 sub(reg_kd_count, jcp.stride_d);
4776 cmp(reg_kd_count, 0);
4777 jle(loop_end_label, T_NEAR);
4778
4779 L(backpad_end_label);
4780 }
4781
4782 /* Compute middle block */
4783 add(reg_src_d, src_shift * jcp.stride_d);
4784
4785 /* Execute common block and loop */
4786 L(common_block_label);
4787 add(reg_ddst_d, ddst_shift);
4788 inc(reg_d_index);
4789 if (is_partial)
4790 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4791 else
4792 cmp(reg_d_index, jcp.od);
4793 jl(d_loop_label, T_NEAR);
4794
4795 L(loop_end_label);
4796 }
4797
compute_loop(int nb_ic_blocking,int nb_oc_blocking)4798 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_loop(
4799 int nb_ic_blocking, int nb_oc_blocking) {
4800 mov(reg_src, ptr[param + GET_OFF(src)]);
4801 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4802 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4803
4804 maybe_zero_kernel(nb_ic_blocking, nb_oc_blocking);
4805 maybe_compute_diff_bias(nb_oc_blocking);
4806
4807 switch (jcp.harness) {
4808 case harness_3d_reduction:
4809 compute_od_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4810 break;
4811 case harness_2d_reduction:
4812 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4813 break;
4814 case harness_mb_reduction:
4815 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4816 break;
4817 case harness_compute_full_spatial:
4818 compute_full_spat_loop(nb_ic_blocking, nb_oc_blocking);
4819 break;
4820 default: assert(!"Invalid harness type");
4821 }
4822 }
4823
setup_stack_space()4824 void jit_avx512_core_amx_bwd_weights_kernel_t::setup_stack_space() {
4825 kd_count_offset = ic_block_step_stack_size;
4826 src_d_offset = ic_block_step_stack_size + 8;
4827 ddst_d_offset = ic_block_step_stack_size + 16;
4828 d_index_offset = ic_block_step_stack_size + 24;
4829 ih_dilate_offset = ic_block_step_stack_size + 32;
4830 src_save_offset = ic_block_step_stack_size + 40;
4831 ddst_save_offset = ic_block_step_stack_size + 48;
4832 stack_space_needed = ic_block_step_stack_size + 56;
4833 }
4834
generate()4835 void jit_avx512_core_amx_bwd_weights_kernel_t::generate() {
4836 preamble();
4837
4838 setup_stack_space();
4839
4840 sub(rsp, stack_space_needed);
4841
4842 Label last_ic_block_label, last_blocks_done_label;
4843
4844 mov(reg_tmp, ptr[param + GET_OFF(last_ic_block)]);
4845 cmp(reg_tmp, 0);
4846 jne(last_ic_block_label, T_NEAR);
4847 { // full nb_ic_blocking
4848 Label last_oc_block_label;
4849 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4850 cmp(reg_tmp, 0);
4851 jne(last_oc_block_label, T_NEAR);
4852 { // full nb_oc_blocking
4853 compute_loop(jcp.nb_ic_blocking, jcp.nb_oc_blocking);
4854 jmp(last_blocks_done_label, T_NEAR);
4855 }
4856 L(last_oc_block_label);
4857 { // tail of nb_oc_blocking
4858 compute_loop(jcp.nb_ic_blocking, 1);
4859 jmp(last_blocks_done_label, T_NEAR);
4860 }
4861 }
4862 L(last_ic_block_label);
4863 { // tail nb_ic_blocking
4864 Label last_oc_block_label;
4865 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4866 cmp(reg_tmp, 0);
4867 jne(last_oc_block_label, T_NEAR);
4868 { // full nb_oc_blocking
4869 compute_loop(1, jcp.nb_oc_blocking);
4870 jmp(last_blocks_done_label, T_NEAR);
4871 }
4872 L(last_oc_block_label);
4873 { // tail of nb_oc_blocking
4874 compute_loop(1, 1);
4875 jmp(last_blocks_done_label, T_NEAR);
4876 }
4877 }
4878
4879 L(last_blocks_done_label);
4880 add(rsp, stack_space_needed);
4881
4882 postamble();
4883 }
4884
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_bias_md,memory_desc_t & diff_dst_md,int nthreads)4885 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(
4886 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4887 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4888 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4889 const memory_desc_wrapper src_d(&src_md);
4890 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4891 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4892 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4893
4894 jcp = zero<decltype(jcp)>();
4895
4896 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4897 int ndims = src_d.ndims();
4898
4899 if (!mayiuse(avx512_core_bf16_amx_bf16)) return status::unimplemented;
4900 jcp.isa = avx512_core_bf16_amx_bf16;
4901
4902 jcp.ver = ver_vnni; // Needed for transpose routines
4903 jcp.nthr = nthreads;
4904
4905 jcp.ndims = ndims;
4906 jcp.prop_kind = cd.prop_kind;
4907
4908 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4909 jcp.mb = src_d.dims()[0];
4910
4911 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4912 jcp.oc_without_padding = jcp.oc;
4913 jcp.ic = src_d.dims()[1] / jcp.ngroups;
4914
4915 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4916 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4917 jcp.iw = src_d.dims()[ndims - 1];
4918 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4919 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4920 jcp.ow = diff_dst_d.dims()[ndims - 1];
4921
4922 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4923 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4924 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4925
4926 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4927 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4928 jcp.l_pad = cd.padding[0][ndims - 3];
4929
4930 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4931 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4932 jcp.stride_w = cd.strides[ndims - 3];
4933
4934 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4935 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4936 jcp.dilate_w = cd.dilates[ndims - 3];
4937
4938 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4939 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4940 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4941
4942 bool ok = true
4943 // general condition to simplify dilations
4944 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4945 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4946 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4947 // special condition to simplify dilations in compute_oh_loop_common
4948 && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4949 if (!ok) return status::unimplemented;
4950
4951 ok = true && one_of(ndims, 3, 4, 5)
4952 && everyone_is(
4953 data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4954 && one_of(diff_weights_d.data_type(), data_type::f32,
4955 data_type::bf16);
4956 if (!ok) return status::unimplemented;
4957
4958 jcp.transform_to_vnni = diff_weights_d.data_type() == data_type::bf16;
4959
4960 jcp.r_pad = nstl::max(0,
4961 calculate_end_padding(
4962 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4963 jcp.b_pad = nstl::max(0,
4964 calculate_end_padding(
4965 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4966 jcp.back_pad = nstl::max(0,
4967 calculate_end_padding(
4968 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4969
4970 /* XXX: no support for padding when dilation_d > 0 */
4971 if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4972 return status::unimplemented;
4973
4974 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4975 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4976 jcp.ohp = jcp.oh;
4977 jcp.owp = jcp.ow;
4978
4979 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
4980 if (jcp.is_depthwise)
4981 return status::unimplemented; // TODO: add support of DW convolution
4982
4983 const int dat_format_tag = ndims - 3;
4984 format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc,
4985 format_tag::nhwc, format_tag::ndhwc);
4986 format_tag_t dat_tag_opt = dat_tag_nspc;
4987
4988 if (src_d.format_kind() == format_kind::any) {
4989 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
4990 jcp.src_tag = dat_tag_opt;
4991 } else
4992 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt);
4993 if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented;
4994 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
4995
4996 if (diff_dst_d.format_kind() == format_kind::any) {
4997 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
4998 jcp.dst_tag = jcp.src_tag;
4999 } else
5000 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
5001 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
5002
5003 if (!jcp.is_nspc) return status::unimplemented;
5004
5005 const int wei_format_tag = 2 * ndims - 6 + with_groups;
5006 format_tag_t wei_tag;
5007 if (jcp.transform_to_vnni)
5008 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i,
5009 format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i,
5010 format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i,
5011 format_tag::gOIdhw16i16o2i);
5012 else
5013 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o,
5014 format_tag::gOIw16i16o, format_tag::OIhw16i16o,
5015 format_tag::gOIhw16i16o, format_tag::OIdhw16i16o,
5016 format_tag::gOIdhw16i16o);
5017 if (diff_weights_md.format_kind == format_kind::any) {
5018 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
5019 jcp.wei_tag = wei_tag;
5020 } else {
5021 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
5022 if (jcp.wei_tag != wei_tag) return status::unimplemented;
5023 }
5024 jcp.wei_dt = diff_weights_d.data_type();
5025
5026 /* conditions on bias memory */
5027 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
5028 if (jcp.with_bias) {
5029 if (diff_bias_d.format_kind() == format_kind::any)
5030 CHECK(memory_desc_init_by_tag(diff_bias_md, format_tag::x));
5031 }
5032 jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
5033 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
5034
5035 /* kernel applicability check wrt boundaries
5036 * the conditions are quite general across the kernels we have,
5037 * but ideally the check should belong to a specific kernel... */
5038 const int max_pad_h = ext_kh / 2;
5039 const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
5040 && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
5041 && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd;
5042 if (!boundaries_ok) return status::unimplemented;
5043
5044 jcp.ic_block = 16;
5045 jcp.oc_block = 16;
5046
5047 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
5048 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
5049
5050 jcp.ic_tail = jcp.ic % jcp.ic_block;
5051 jcp.oc_tail = jcp.oc % jcp.oc_block;
5052
5053 jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1;
5054 jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1;
5055
5056 int max_palette = amx::get_max_palette();
5057 jcp.max_tiles = amx::get_max_tiles(max_palette);
5058 jcp.full_tile_width = amx::get_max_rows(max_palette);
5059
5060 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
5061 return status::unimplemented;
5062
5063 const bool is_2d = (ndims == 4);
5064 const bool is_3d = (ndims == 5);
5065 jcp.typesize_in = sizeof(bfloat16_t);
5066 jcp.typesize_out = sizeof(float);
5067
5068 // TODO: Find more shapes (especially 3D with large spatials) for which
5069 // local transposition will be beneficial. Furthermore, for TBB threads
5070 // more shapes can potentially benefit from spatial blocking
5071 int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
5072
5073 jcp.global_transpose = dnnl_thr_syncable();
5074 jcp.spatial_blk_size = optimal_blk_size;
5075
5076 const int tr_round = 32; // To load full tile register
5077 int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
5078 jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
5079 * jcp.stride_w;
5080
5081 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
5082 jcp.tr_ow = rnd_up(jcp.ow, 2);
5083
5084 if (jcp.tr_ow <= max_ur_w) {
5085 jcp.ur_w = jcp.tr_ow;
5086 jcp.ur_w_blocks = 1;
5087 } else {
5088 jcp.ur_w = 1;
5089 for (int i = max_ur_w; i >= 1; i -= 2) {
5090 if (jcp.tr_ow % i == 0) {
5091 jcp.ur_w = i;
5092 break;
5093 }
5094 }
5095 jcp.ur_w_blocks = jcp.tr_ow / jcp.ur_w;
5096 }
5097
5098 bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
5099 && jcp.oc <= diff_dst_d.padded_dims()[1]
5100 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
5101 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
5102 if (!args_ok) return status::unimplemented;
5103
5104 bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh
5105 && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w)
5106 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
5107 // TODO: Remove this constraint: only 3x3 kernel works now
5108 && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
5109 && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw
5110 && jcp.ih >= jcp.kh && jcp.iw >= jcp.kw;
5111
5112 jcp.harness = ndims == 5
5113 ? harness_3d_reduction
5114 : (use_full_spat_loop ? harness_compute_full_spatial
5115 : (ndims == 4) ? harness_2d_reduction
5116 : harness_mb_reduction);
5117 switch (jcp.harness) {
5118 case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
5119 case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
5120 case harness_compute_full_spatial:
5121 case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
5122 default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
5123 }
5124 { // balancing
5125 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
5126 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
5127 jcp.nthr = nthr;
5128 jcp.nthr_mb = nthr_mb;
5129 jcp.nthr_g = nthr_g;
5130 jcp.nthr_oc_b = nthr_oc_b;
5131 jcp.nthr_ic_b = nthr_ic_b;
5132
5133 // TODO: Optimize memory allocation when threaded on height and depth
5134 jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
5135 jcp.tr_src_buf_count = jcp.global_transpose
5136 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
5137 : jcp.nthr;
5138
5139 jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
5140 jcp.tr_diff_dst_buf_count = jcp.global_transpose
5141 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
5142 : jcp.nthr;
5143 }
5144
5145 return status::success;
5146 }
5147
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_dst_md)5148 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad(
5149 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
5150 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
5151 memory_desc_t &diff_dst_md) {
5152 const memory_desc_wrapper src_d(&src_md);
5153 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
5154 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
5155
5156 // XXX: See the comment about tr_iw and guarding elements in
5157 // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf()
5158 const size_t tr_src_size
5159 = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking)
5160 + jcp.tr_src_num_guard_elems;
5161 scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
5162
5163 /* prepare synchronization contexts */
5164 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
5165 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
5166 scratchpad.book<simple_barrier::ctx_t>(
5167 key_conv_tr_src_bctx, tr_src_bctx_size);
5168 }
5169
5170 const size_t tr_diff_dst_size = jcp.tr_diff_dst_buf_count
5171 * jcp.tr_diff_dst_buf_size * jcp.nb_oc_blocking;
5172
5173 const size_t min_align = 64;
5174 scratchpad.book(
5175 key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, min_align);
5176
5177 /* prepare synchronization contexts */
5178 if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
5179 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
5180 scratchpad.book<simple_barrier::ctx_t>(
5181 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
5182 }
5183
5184 if (IMPLICATION(jcp.nthr_mb == 1,
5185 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
5186 || jcp.wei_dt == data_type::bf16)) {
5187 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
5188 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
5189 const size_t bia_size
5190 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
5191
5192 const int num_wei_buffers
5193 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
5194 const int num_bia_buffers = jcp.with_bias
5195 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
5196 : jcp.nthr_mb - 1)
5197 : 0;
5198
5199 const size_t wei_bia_reduction_size
5200 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
5201
5202 scratchpad.book<float>(
5203 key_conv_wei_bia_reduction, wei_bia_reduction_size);
5204
5205 scratchpad.book<simple_barrier::ctx_t>(
5206 key_conv_wei_bia_reduction_bctx, 1);
5207 }
5208
5209 if (jcp.with_bias
5210 && ((jcp.oc_without_padding % jcp.oc_block != 0)
5211 && jcp.bia_dt == data_type::f32)) {
5212 scratchpad.book(key_conv_padded_bias,
5213 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
5214 }
5215 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
5216
5217 constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
5218 << 30; // 32Gb - TODO: may it's too large?
5219 const size_t scratchpad_limit_by_tensor_sizes = (size_t)32 * jcp.nthr
5220 * (src_d.size() + diff_weights_d.size() + diff_dst_d.size());
5221 const size_t scratchpad_limit
5222 = nstl::min(scratchpad_limit_by_absolute_value,
5223 scratchpad_limit_by_tensor_sizes);
5224 if (scratchpad.size() > scratchpad_limit)
5225 return status::unimplemented;
5226 else
5227 return status::success;
5228 }
5229
balance(const jit_conv_conf_t & j,int & nthr_,int & nthr_mb_,int & nthr_g_,int & nthr_oc_b_,int & nthr_ic_b_)5230 void jit_avx512_core_amx_bwd_weights_kernel_t::balance(const jit_conv_conf_t &j,
5231 int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_,
5232 int &nthr_ic_b_) {
5233 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
5234
5235 const int max_threads = dnnl_get_max_threads();
5236
5237 if (max_threads < j.ngroups) {
5238 /* simplification... fortunately it doesn't hurt much */
5239 nthr_ = nthr_g_ = max_threads;
5240 return;
5241 }
5242
5243 nthr_g_ = j.ngroups;
5244 const int nthr = max_threads / nthr_g_;
5245
5246 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
5247 /* calculate per thread memory cost (read/write). high level optimizer
5248 * tries to minimize memory consumption. few notes:
5249 * (n1) if weights tensor size is less than source and destination
5250 * tensors we apply the ratio of the source and destination
5251 * tensor sizes to weights one as compensation coefficient to
5252 * avoid parallelization across batch size only, othervise we
5253 * apply additional coefficient to source component based on
5254 * performance measurements
5255 * (n2) use scales based on output vs input channels ratio for source
5256 * and destination componets to imporve threading balance across
5257 * input and output channels */
5258
5259 const dim_t src_type_size = 2;
5260 const dim_t wei_type_size = 4;
5261
5262 dim_t src_size
5263 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
5264 dim_t dst_size
5265 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
5266 dim_t wei_size
5267 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
5268
5269 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
5270 float oi_channels_ratio = (float)(j.nb_oc / j.nb_oc_blocking)
5271 / (j.nb_ic / j.nb_ic_blocking);
5272 auto get_src_coef = [=]() {
5273 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
5274 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
5275
5276 return src_coef;
5277 };
5278
5279 auto get_dst_coef
5280 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
5281
5282 auto get_wei_coef
5283 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
5284
5285 const float src_coef = get_src_coef();
5286 const float dst_coef = get_dst_coef();
5287 const float wei_coef = get_wei_coef();
5288
5289 float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
5290 * div_up(j.ngroups, nthr_g_)
5291 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.mb
5292 * (j.ic_block * j.nb_ic_blocking) * j.id * j.ih * j.tr_iw
5293 / j.nthr_mb_work / j.stride_d / j.stride_h / j.stride_w;
5294 float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
5295 * div_up((j.nb_oc / j.nb_oc_blocking),
5296 (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5297 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.kh * j.kw
5298 * j.kd * (j.ic_block * j.nb_ic_blocking)
5299 * (j.oc_block * j.nb_oc_blocking);
5300 float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
5301 * div_up(j.ngroups, nthr_g_)
5302 * div_up((j.nb_oc / j.nb_oc_blocking),
5303 (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5304 * j.mb * (j.oc_block * j.nb_oc_blocking) * j.od * j.oh * j.tr_ow
5305 / j.nthr_mb_work;
5306
5307 return src_v + dst_v + wei_v;
5308 };
5309
5310 float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
5311
5312 /* find the best thread distribution with lowest memory cost */
5313
5314 const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
5315 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
5316 const int nthr_par = nthr / nthr_mb;
5317 const int nthr_oc_b_max = nstl::min(nthr_par,
5318 (j.nb_oc / j.nb_oc_blocking)); // Amount of nb_oc_blocks
5319 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
5320 int nthr_ic_b = nstl::min(
5321 nthr_par / nthr_oc_b, (j.nb_ic / j.nb_ic_blocking));
5322
5323 float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
5324 if (mem_cost <= best_mem_cost) {
5325 best_mem_cost = mem_cost;
5326 nthr_mb_ = nthr_mb;
5327 nthr_oc_b_ = nthr_oc_b;
5328 nthr_ic_b_ = nthr_ic_b;
5329 }
5330 }
5331 }
5332
5333 if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
5334 nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
5335 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5336
5337 assert(nthr_ <= max_threads);
5338 }
5339
5340 } // namespace x64
5341 } // namespace cpu
5342 } // namespace impl
5343 } // namespace dnnl
5344
5345 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
5346