1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/c_types_map.hpp"
18 #include "common/memory_tracking.hpp"
19 #include "common/nstl.hpp"
20 #include "common/type_helpers.hpp"
21 #include "common/utils.hpp"
22
23 #include "cpu/platform.hpp"
24 #include "cpu/x64/cpu_barrier.hpp"
25 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
26 #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
27 #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp"
28
29 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
30
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace x64 {
35
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::data_type;
38 using namespace dnnl::impl::utils;
39 using namespace Xbyak;
40
prepare_output(int ur_w)41 void jit_avx512_core_amx_compute_zp_pbuff_t::prepare_output(int ur_w) {
42 for (int oc = 0; oc < jcp.nb_oc_blocking; oc++)
43 for (int ur = 0; ur < ur_w; ur++) {
44 const Zmm zmm = zmm_out(ur, oc);
45 vpxord(zmm, zmm, zmm);
46 }
47 }
48
store_output(int ur_w,bool last_oc_block_flag)49 void jit_avx512_core_amx_compute_zp_pbuff_t::store_output(
50 int ur_w, bool last_oc_block_flag) {
51 assert(jcp.is_nspc);
52
53 const int nb_oc_block = jcp.nb_oc_blocking;
54 const int oc_block = jcp.oc_block;
55
56 const auto src_zp_addr = EVEX_compress_addr(reg_src_zero_point, 0, true);
57
58 /* write out register to output_addr */
59 for (int oc = 0; oc < nb_oc_block; oc++) {
60 const bool mask_flag = last_oc_block_flag && oc == nb_oc_block - 1;
61 for (int ur = 0; ur < ur_w; ur++) {
62 const int output_offset = sizeof(int32_t)
63 * (oc * oc_block
64 + ur * jcp.oc_without_padding * jcp.ngroups);
65 const Zmm zmm_dst = zmm_out(ur, oc);
66 const Zmm m_zmm_dst = mask_flag ? zmm_dst | ktail_mask : zmm_dst;
67 // multiply dst by src_zero_point
68 vpmulld(m_zmm_dst, zmm_dst, src_zp_addr);
69 vmovups(EVEX_compress_addr(reg_zp_pbuff, output_offset), m_zmm_dst);
70 }
71 }
72 }
73
compute_ker(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool padded)74 void jit_avx512_core_amx_compute_zp_pbuff_t::compute_ker(int ur_w, int pad_l,
75 int pad_r, ic_block_t last_ic_block_flag, bool padded) {
76
77 const int kw = jcp.kw;
78 const int ic_block = jcp.ic_block_int_np;
79 const int oc_block = jcp.oc_block;
80 const int nb_oc_block = jcp.nb_oc_blocking;
81
82 const bool ic_tail
83 = (jcp.ic_without_padding % (jcp.ic_block / ic_inner_block)) > 0;
84 const bool masked_write = ic_tail && last_ic_block_flag == last_ic_block;
85
86 /* Skip the last loads of input
87 if (ic%16)/ic_sub_step < ic_block/ic_sub_step */
88 const int icb = (last_ic_block_flag == last_ic_block)
89 ? div_up(
90 (jcp.ic_without_padding % jcp.ic_block_int), ic_inner_block)
91 : ic_block / ic_inner_block;
92
93 auto get_filter_offset = [=](int ocb, int ic, int ki) {
94 size_t w_step = jcp.is_relo ? jcp.kh : 1;
95 size_t kw_offset = static_cast<size_t>(ki) * w_step
96 * jcp.ic_block_int_np * jcp.oc_block;
97 size_t oc_subblock_step = static_cast<size_t>(jcp.kd) * jcp.kh * jcp.kw
98 * jcp.ic_block_int_np * jcp.oc_block;
99 size_t offset = kw_offset
100 + static_cast<size_t>(ocb) * jcp.nb_ic_int * oc_subblock_step
101 + static_cast<size_t>(ic) * oc_block * ic_inner_block;
102 return sizeof(char) * offset;
103 };
104 auto compute_fma = [=](const Zmm zmm_accum, const int ic,
105 const Address addr) {
106 if (jcp.is_relo) {
107 vmovups(zmm_permb, ptr[reg_scratch]); // get permute index table
108 const Zmm r_zmm = masked_write && ic == icb - 1
109 ? zmm_permb | kmask_ic_block | T_z
110 : zmm_permb;
111 // only values from 'src2' are used to write dst
112 vpermi2b(r_zmm, zmm_permb, addr);
113 vpdpbusd(zmm_accum, zmm_one,
114 zmm_permb); // XXX - using the same register for all ur_w
115 } else {
116 vpdpbusd(zmm_accum, zmm_one, addr);
117 }
118 };
119
120 if (jcp.is_relo && last_ic_block_flag == last_ic_block && ic_tail) {
121 const Reg64 reg_tmp = reg_scratch;
122 mov(reg_tmp, ic_mask_label);
123 kmovq(kmask_ic_block, qword[reg_tmp]);
124 }
125 if (jcp.is_relo) mov(reg_scratch, permb_idx_label);
126
127 for (int ki = 0; ki < kw; ki++) {
128 const int ur_start = get_ow_start(ki, pad_l);
129 const int ur_end = get_ow_end(ur_w, ki, pad_r);
130 for (int ur = 0; ur < ur_w; ur++) {
131 // Calculate zero_point padding as:
132 // accum = is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0)
133 if (ur < ur_start || ur >= ur_end || padded) {
134 for (int oc = 0; oc < nb_oc_block; oc++) {
135 const Zmm zmm_accum = zmm_out(ur, oc);
136 for (int ic = 0; ic < icb; ic++) {
137 const auto addr_filt = EVEX_compress_addr(
138 aux_reg_filt, get_filter_offset(oc, ic, ki));
139 compute_fma(zmm_accum, ic, addr_filt);
140 }
141 }
142 }
143 }
144 }
145 }
146
kh_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)147 void jit_avx512_core_amx_compute_zp_pbuff_t::kh_loop(int ur_w, int pad_l,
148 int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
149
150 Label kh_label, skip_kh_loop;
151 const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
152 const size_t shift_wei_h_step = sizeof(char)
153 * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
154 * jcp.oc_block;
155
156 // Compute zero_point compensation for the padded region. Total compute
157 // area is 'overflow * kw' where 'overflow' indicates the overlap
158 // between the filter and either top_pad or bottom_pad region.
159 auto compute_kh_loop = [=](size_t param_overflow) {
160 Label overflow_label, no_overflow_label;
161
162 mov(reg_overflow, ptr[param1 + param_overflow]);
163 cmp(reg_overflow, 0);
164 je(no_overflow_label, T_NEAR);
165 L(overflow_label);
166 {
167 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
168 add(aux_reg_filt, shift_wei_h_step);
169 dec(reg_overflow);
170 jne(overflow_label, T_NEAR);
171 }
172 L(no_overflow_label);
173 };
174
175 if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(t_overflow));
176
177 // check for holes and skip computation due to dilation
178 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
179 if ((jcp.dilate_h >= jcp.ih)) {
180 cmp(reg_kj, 0);
181 je(skip_kh_loop, T_NEAR);
182 }
183
184 L(kh_label);
185 {
186 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
187
188 add(aux_reg_filt, shift_wei_h_step);
189 dec(reg_kj);
190 jne(kh_label, T_NEAR);
191 }
192
193 L(skip_kh_loop);
194
195 if (handle_h_pad && jcp.ndims > 3) compute_kh_loop(GET_OFF(b_overflow));
196 }
197
kd_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool handle_h_pad)198 void jit_avx512_core_amx_compute_zp_pbuff_t::kd_loop(int ur_w, int pad_l,
199 int pad_r, ic_block_t last_ic_block_flag, bool handle_h_pad) {
200
201 Label kd_label, skip_kd_loop;
202 const size_t wei_h_step = jcp.is_relo ? 1 : jcp.kw;
203 const size_t shift_wei_h_step = sizeof(char)
204 * static_cast<size_t>(wei_h_step) * jcp.ic_block_int_np
205 * jcp.oc_block;
206
207 // Compute zero_point compensation for the padded region. Total compute
208 // area is 'overflow * kh * kw' where 'overflow' indicates the overlap
209 // between the filter and either front_pad or back_pad region.
210 auto compute_kd_loop = [=](size_t param_overflow) {
211 Label kh_loop_label;
212 Label no_overflow_label, overflow_label;
213
214 mov(reg_ki, ptr[param1 + param_overflow]);
215 cmp(reg_ki, 0);
216 je(no_overflow_label, T_NEAR);
217 L(overflow_label);
218 {
219 mov(aux_reg_filt, aux_reg_filt_d);
220 mov(reg_kj, jcp.kh);
221 L(kh_loop_label);
222 {
223 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
224 add(aux_reg_filt, shift_wei_h_step);
225 dec(reg_kj);
226 jne(kh_loop_label, T_NEAR);
227 }
228 add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
229 dec(reg_ki);
230 jne(overflow_label, T_NEAR);
231 }
232 L(no_overflow_label);
233 };
234
235 const bool zp_d_padding
236 = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
237 if (zp_d_padding) {
238 mov(aux_reg_filt_d, reg_filt);
239 compute_kd_loop(GET_OFF(f_overflow));
240
241 // check for holes and skip computation due to dilation
242 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
243 if (jcp.dilate_d >= jcp.id) {
244 cmp(reg_ki, 0);
245 je(skip_kd_loop, T_NEAR);
246 }
247 L(kd_label);
248 mov(aux_reg_filt, aux_reg_filt_d);
249
250 } else {
251 mov(aux_reg_filt, reg_filt);
252 }
253
254 kh_loop(ur_w, pad_l, pad_r, last_ic_block_flag, handle_h_pad);
255
256 if (zp_d_padding) {
257 add(aux_reg_filt_d, shift_wei_h_step * jcp.kh);
258 dec(reg_ki);
259 jne(kd_label, T_NEAR);
260
261 L(skip_kd_loop);
262
263 compute_kd_loop(GET_OFF(back_overflow));
264 }
265 }
266
icb_loop(int ur_w,int pad_l,int pad_r,bool handle_h_pad)267 void jit_avx512_core_amx_compute_zp_pbuff_t::icb_loop(
268 int ur_w, int pad_l, int pad_r, bool handle_h_pad) {
269
270 Label icb_label;
271 const size_t nb_ic = jcp.nb_ic_int;
272 const bool do_icb_loop = nb_ic > 1;
273
274 /* Initialize zmm_one for weight accumulation */
275 xor_(reg_scratch, reg_scratch);
276 const Reg8 _t8 = reg_scratch.cvt8();
277 mov(_t8, 0x1);
278 vpbroadcastb(zmm_one, _t8);
279
280 prepare_output(ur_w);
281
282 mov(reg_icb, nb_ic);
283
284 L(icb_label);
285 if (jcp.ic_without_padding != jcp.ic) {
286 Label common_ker, end_ker;
287 if (do_icb_loop) {
288 cmp(reg_icb, 1); // The last ic block
289 jne(common_ker, T_NEAR);
290 }
291 kd_loop(ur_w, pad_l, pad_r, last_ic_block, handle_h_pad);
292 if (do_icb_loop) {
293 jmp(end_ker, T_NEAR);
294
295 L(common_ker);
296 kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
297
298 L(end_ker);
299 }
300 } else {
301 kd_loop(ur_w, pad_l, pad_r, no_last_block, handle_h_pad);
302 }
303 // End of IC Loop
304 if (do_icb_loop) {
305 const size_t shift_wei_icb_step = static_cast<size_t>(jcp.kd) * jcp.kh
306 * jcp.kw * jcp.oc_block * jcp.ic_block_int_np;
307 add(reg_filt, sizeof(char) * shift_wei_icb_step);
308
309 dec(reg_icb);
310 cmp(reg_icb, 0);
311 jg(icb_label, T_NEAR);
312
313 sub(reg_filt, sizeof(char) * shift_wei_icb_step * nb_ic);
314 }
315
316 if (jcp.oc_without_padding != jcp.oc) {
317 Label common_store, end_store;
318
319 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
320 jne(common_store, T_NEAR);
321
322 store_output(ur_w, true); // last oc block
323 jmp(end_store, T_NEAR);
324
325 L(common_store);
326 store_output(ur_w, false);
327
328 L(end_store);
329 } else {
330 store_output(ur_w, false);
331 }
332 }
333
unroll_width(const bool h_padding)334 void jit_avx512_core_amx_compute_zp_pbuff_t::unroll_width(
335 const bool h_padding) {
336
337 auto ur_w_shift = [&](const int ur_w) {
338 return sizeof(int32_t) * (ur_w * jcp.oc_without_padding * jcp.ngroups);
339 };
340
341 const int max_ur_w = jit_avx512_core_amx_compute_zp_pbuff_t::max_regs_ur
342 / (jcp.nb_oc_blocking);
343 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
344 int l_pad = jcp.l_pad;
345
346 const int l_pad_output = jcp.l_pad_output;
347 const int r_pad_output = jcp.r_pad_output;
348
349 // a single middle element (if required) containing only height padding
350 const int no_pad = nstl::max(0, jcp.ow - l_pad_output - r_pad_output);
351
352 const int ow_start = nstl::max(jcp.ow - r_pad_output, l_pad_output);
353 const int r_pad_start = nstl::min(jcp.ow_pad - l_pad_output, r_pad_output);
354
355 int ow = 0;
356 int cur_l_pad_output = l_pad_output;
357 while (cur_l_pad_output > 0) {
358 const int ur_w = nstl::min(cur_l_pad_output, max_ur_w);
359 ow += ur_w;
360 const int cur_r_pad = calculate_end_padding(
361 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
362 icb_loop(ur_w, l_pad, cur_r_pad, h_padding);
363 add(reg_zp_pbuff, ur_w_shift(ur_w));
364
365 l_pad = nstl::max(l_pad - ur_w * jcp.stride_w, 0);
366 cur_l_pad_output = nstl::max(cur_l_pad_output - ur_w, 0);
367 }
368
369 if (no_pad > 0) {
370 const int ur_w = 1;
371 if (h_padding) icb_loop(ur_w, 0, 0, true);
372 if (h_padding || jcp.ow_mid) add(reg_zp_pbuff, ur_w_shift(ur_w));
373 }
374 assert(ow + no_pad == ow_start);
375
376 ow = ow_start;
377 int cur_r_pad_output = r_pad_start;
378 while (cur_r_pad_output > 0 && ow < jcp.ow) {
379 const int ur_w = nstl::min(cur_r_pad_output, max_ur_w);
380 ow += ur_w;
381 const int cur_r_pad = calculate_end_padding(
382 jcp.l_pad, ow, jcp.iw, jcp.stride_w, ext_kw);
383 icb_loop(ur_w, 0, cur_r_pad, h_padding);
384 add(reg_zp_pbuff, ur_w_shift(ur_w));
385
386 cur_r_pad_output = nstl::max(cur_r_pad_output - ur_w, 0);
387 }
388 }
389
generate()390 void jit_avx512_core_amx_compute_zp_pbuff_t::generate() {
391 Label h_pad_label, end_label;
392
393 assert(jcp.req_zero_point_buffer);
394 assert(jcp.typesize_in == sizeof(char));
395
396 preamble();
397
398 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
399 mov(reg_zp_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
400 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
401
402 if (jcp.oc_without_padding != jcp.oc) {
403 const Reg32 reg_tmp = reg_scratch.cvt32();
404 const int tail_size = jcp.oc_without_padding % jcp.oc_block;
405 const int mask = (1 << tail_size) - 1;
406 mov(reg_tmp, mask);
407 kmovw(ktail_mask, reg_tmp);
408 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
409 }
410
411 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
412 cmp(reg_overflow, 0);
413 jne(h_pad_label, T_NEAR);
414 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
415 cmp(reg_overflow, 0);
416 jne(h_pad_label, T_NEAR);
417 if (jcp.ndims == 5 && (jcp.f_pad_output > 0 || jcp.back_pad_output > 0)) {
418 mov(reg_overflow, ptr[param1 + GET_OFF(kd_padding)]);
419 cmp(reg_overflow, jcp.kd);
420 jne(h_pad_label, T_NEAR);
421 }
422
423 // Handle width padding region
424 unroll_width(false);
425 jmp(end_label, T_NEAR);
426
427 // handle height padding region
428 L(h_pad_label);
429 unroll_width(true);
430
431 L(end_label);
432
433 postamble();
434
435 // reduced-lowering ('is_relo' == true) weights format is '..i16o', so
436 // permute elements through permb into the VNNI layout '...16i4i'.
437 if (jcp.is_relo) {
438 align(64);
439 L(permb_idx_label);
440 // permb: id-bit for table selection is bit[6]
441 const uint8_t select_src2_bit = 0x40;
442 // permb: bits [5:0] select the element within each input table
443 const uint8_t permb_idx_table[64] = {0, 16, 32, 48, 1, 17, 33, 49, 2,
444 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22,
445 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42,
446 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46,
447 62, 15, 31, 47, 63};
448 for (size_t i = 0; i < 64; ++i)
449 db(select_src2_bit | permb_idx_table[i]);
450
451 // write zero-mask (permb) for ic_tail in VNNI format '..16o4i'
452 const int ic_tail_size
453 = jcp.ic_without_padding % (jcp.ic_block / ic_inner_block);
454 if (jcp.ic_without_padding != jcp.ic && ic_tail_size > 0) {
455 align(64);
456 L(ic_mask_label);
457
458 assert(4 > ic_tail_size);
459 // mask is on a 4-bit basis from the 4 ic elements in a zmm
460 const int nibble = (1 << ic_tail_size) - 1;
461 for (int i = 0; i < 16; ++i) {
462 db(nibble | (nibble << 4));
463 }
464 }
465 }
466 }
467
generate()468 void jit_avx512_core_amx_copy_to_wbuffer_t::generate() {
469
470 const bool is_bf16 = jcp.src_dt == data_type::bf16;
471
472 // required for use of VPERMB instruction
473 assert(IMPLICATION(!is_bf16, cpu().has(Xbyak::util::Cpu::tAVX512_VBMI)));
474 assert(jcp.ic_block_int * jcp.typesize_in == 64);
475
476 preamble();
477
478 mov(reg_src, ptr[param1 + GET_OFF(src)]);
479 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
480
481 // load permute indices from data section
482 Label permute_index_table;
483 mov(reg_tmp, permute_index_table);
484 if (is_bf16)
485 vmovdqu16(zmm_idx, ptr[reg_tmp]);
486 else
487 vmovdqu8(zmm_idx, ptr[reg_tmp]);
488
489 const int vnni_width = is_bf16 ? 2 : 4;
490 const int r = jcp.kh * jcp.kw * jcp.ic_without_padding;
491 const int nb_r = div_up(r, vnni_width);
492 const int rtail = (r % vnni_width) * jcp.oc_block;
493 if (rtail > 0) {
494 uint64_t mask = (UINT64_C(1) << rtail) - 1;
495 mov(reg_tmp, mask);
496 kmovq(kmask_load, reg_tmp);
497 }
498 const int nb_z = rnd_up(nb_r, jcp.ic_block);
499 if (nb_r < nb_z) vpxord(zmm_zero, zmm_zero, zmm_zero);
500
501 const int tile_size = jcp.ic_block_int * jcp.oc_block * jcp.typesize_in;
502 const int ocb_src_step = r * jcp.oc_block * jcp.typesize_in;
503 const int ocb_dst_step = rnd_up(ocb_src_step, tile_size);
504
505 // reorder from ~Owhi16o -> ~OR16oVr with r := whi and V := vnni_width
506 for (int g = 0; g < jcp.ngroups; g++) {
507 for (int ocb = 0; ocb < jcp.nb_oc; ocb++) {
508 int offset = 0;
509 int rb = 0;
510 for (; rb < nb_r; offset += 64, rb++) {
511 auto zmm_src_tmp = (rtail > 0 && rb == nb_r - 1)
512 ? zmm_src | kmask_load | T_z
513 : zmm_src;
514 if (is_bf16) {
515 vmovdqu16(zmm_src_tmp, ptr[reg_src + offset]);
516 vpermw(zmm_dst, zmm_idx, zmm_src);
517 vmovdqu16(ptr[reg_dst + offset], zmm_dst);
518 } else {
519 vmovdqu8(zmm_src_tmp, ptr[reg_src + offset]);
520 vpermb(zmm_dst, zmm_idx, zmm_src);
521 vmovdqu8(ptr[reg_dst + offset], zmm_dst);
522 }
523 }
524 for (; rb < nb_z; offset += 64, rb++) {
525 if (is_bf16)
526 vmovdqu16(ptr[reg_dst + offset], zmm_zero);
527 else
528 vmovdqu8(ptr[reg_dst + offset], zmm_zero);
529 }
530 add(reg_src, ocb_src_step);
531 add(reg_dst, ocb_dst_step);
532 }
533 }
534
535 postamble();
536
537 align(64);
538 L(permute_index_table);
539 const uint8_t no = 16; // 16o
540 const uint8_t nr = is_bf16 ? 2 : 4; // 2r or 4r
541 for (uint8_t o = 0; o < no; ++o) {
542 for (uint8_t r = 0; r < nr; r++) {
543 const uint8_t index = o + r * no;
544 if (is_bf16)
545 dw(index);
546 else
547 db(index);
548 }
549 }
550 }
551
copy_row_body(int lpad,int iw_len,int icb)552 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_body(
553 int lpad, int iw_len, int icb) {
554
555 const bool is_bf16 = jcp.src_dt == data_type::bf16;
556 int iwp_idx = 0;
557 // there are min(gen_kw, jcp.stride_w) continuous sets of input
558 // data (for each stride idx), they are placed one by one
559 // without additional padding
560 const bool are_sets_interleaved
561 = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
562 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
563 const int num_sets = are_sets_interleaved ? jcp.n_stride_sets : jcp.kw;
564 for (int set_idx = 0; set_idx < num_sets; set_idx++) {
565 int set_width_padded = !jcp.is_pbuffer_strided
566 ? (jcp.ow_block - 1) * jcp.stride_w + gen_kw
567 : are_sets_interleaved ? jcp.ow_block - 1 + gen_kw / num_sets
568 + (set_idx < gen_kw % num_sets ? 1 : 0)
569 : jcp.ow_block;
570 for (int set_shift = 0; set_shift < set_width_padded;
571 set_shift++, iwp_idx++) {
572 int iw_idx = set_idx * (jcp.dilate_w + 1)
573 + set_shift * (jcp.is_pbuffer_strided ? jcp.stride_w : 1)
574 - lpad;
575 size_t out_base_offset
576 = (size_t)jcp.typesize_in * iwp_idx * jcp.ic_block_int_np;
577 if (iw_idx < 0 || iw_idx >= iw_len) {
578 // left or right padding
579 vmovups(ptr[reg_aux_out_ptr + out_base_offset], zmm_zero);
580 } else if (jcp.is_nspc) {
581 size_t inp_w_offset = (size_t)jcp.typesize_in * iw_idx
582 * jcp.ngroups * jcp.ic_without_padding;
583 int ic = icb * jcp.ic_block_int_np;
584 // TODO: use Xmm or Ymm moves for better small ic efficiency
585 auto zmm_tmp_mask
586 = ic + jcp.ic_block_int <= jcp.ic_without_padding
587 ? zmm_tmp
588 : zmm_tmp | ktail_mask | T_z;
589 if (is_bf16) {
590 vmovdqu16(
591 zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
592 vmovdqu16(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
593 } else {
594 vmovdqu8(zmm_tmp_mask, ptr[reg_aux_inp_ptr + inp_w_offset]);
595 vmovdqu8(ptr[reg_aux_out_ptr + out_base_offset], zmm_tmp);
596 }
597 } else {
598 assert(is_bf16);
599 size_t inp_w_offset
600 = (size_t)jcp.typesize_in * iw_idx * jcp.ic_block;
601 for (int j = 0; j < jcp.ic_block_int_np / jcp.ic_block; j++) {
602 int ic = icb * jcp.ic_block_int_np + j * jcp.ic_block;
603 size_t inp_c_w_offset = (size_t)jcp.typesize_in * j * jcp.ih
604 * jcp.iw * jcp.ic_block
605 + inp_w_offset;
606 if (ic + jcp.ic_block <= jcp.ic) {
607 vmovdqu16(
608 ymm_tmp, ptr[reg_aux_inp_ptr + inp_c_w_offset]);
609 } else {
610 vpxord(ymm_tmp, ymm_tmp, ymm_tmp);
611 }
612 size_t out_offset = out_base_offset
613 + (size_t)jcp.typesize_in * j * jcp.ic_block;
614 vmovdqu16(ptr[reg_aux_out_ptr + out_offset], ymm_tmp);
615 }
616 }
617 }
618 }
619 }
620
copy_row(int icb)621 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row(int icb) {
622 if (jcp.nb_ow == 1) {
623 copy_row_body(jcp.l_pad, jcp.iw, icb);
624 } else {
625 auto get_iw_len_required = [&](int cur_ow_block, int cur_lpad) {
626 return (cur_ow_block - 1) * jcp.stride_w
627 + (jcp.kw - 1) * (jcp.dilate_w + 1) + 1 - cur_lpad;
628 };
629
630 auto get_iw_len_limited = [&](int owb, int cur_ow_block, int cur_lpad) {
631 auto len_req = get_iw_len_required(cur_ow_block, cur_lpad);
632 if (owb < 0) return len_req;
633 int ow_block_start = nstl::max(
634 0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad);
635 return nstl::min(jcp.iw - ow_block_start, len_req);
636 };
637
638 int general_owb_cases = jcp.nb_ow;
639 Xbyak::Label copy_row_done_label;
640 bool special_first_block_case = jcp.l_pad > 0;
641 if (special_first_block_case) {
642 general_owb_cases--;
643 Xbyak::Label skip_first_block_case_label;
644 cmp(reg_owb, 0);
645 jne(skip_first_block_case_label, T_NEAR);
646 copy_row_body(jcp.l_pad,
647 get_iw_len_limited(0, jcp.ow_block, jcp.l_pad), icb);
648 jmp(copy_row_done_label, T_NEAR);
649 L(skip_first_block_case_label);
650 }
651 bool special_last_block_case = false
652 // has ow_block_tail
653 || jcp.ow % jcp.ow_block != 0
654 // there is no ow_block_tail but right padding exists
655 || get_iw_len_limited(jcp.nb_ow - 1, jcp.ow_block, 0)
656 != get_iw_len_required(jcp.ow_block, 0);
657 if (special_last_block_case) {
658 general_owb_cases--;
659 Xbyak::Label skip_last_block_case_label;
660 cmp(reg_owb, jcp.nb_ow - 1);
661 jne(skip_last_block_case_label, T_NEAR);
662 int ow_block_tail = jcp.ow % jcp.ow_block;
663 int cur_ow_block = ow_block_tail > 0 ? ow_block_tail : jcp.ow_block;
664 copy_row_body(
665 0, get_iw_len_limited(jcp.nb_ow - 1, cur_ow_block, 0), icb);
666 jmp(copy_row_done_label, T_NEAR);
667 L(skip_last_block_case_label);
668 }
669
670 bool special_penult_block_case = true
671 // if nb_ow = 2 and l_pad > 0 it's the same as
672 // special_first_block_case
673 && jcp.nb_ow >= (special_first_block_case ? 3 : 2)
674 // right padding exists in penult block
675 && get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0)
676 != get_iw_len_required(jcp.ow_block, 0);
677 if (special_penult_block_case) {
678 general_owb_cases--;
679 Xbyak::Label skip_penult_block_case_label;
680 cmp(reg_owb, jcp.nb_ow - 2);
681 jne(skip_penult_block_case_label, T_NEAR);
682 copy_row_body(
683 0, get_iw_len_limited(jcp.nb_ow - 2, jcp.ow_block, 0), icb);
684 jmp(copy_row_done_label, T_NEAR);
685 L(skip_penult_block_case_label);
686 }
687
688 if (general_owb_cases > 0) // general case
689 copy_row_body(0, get_iw_len_required(jcp.ow_block, 0), icb);
690
691 L(copy_row_done_label);
692 }
693 }
694
copy_row_reduced_lowering()695 void jit_avx512_core_amx_copy_to_pbuffer_t::copy_row_reduced_lowering() {
696 assert(jcp.nb_ic_int == 1);
697 assert(jcp.ic_block_int * jcp.typesize_in == 64);
698 assert(jcp.is_nspc);
699
700 auto load_mask = [=](int tail, Opmask kmask) {
701 uint64_t mask = (UINT64_C(1) << tail) - 1;
702 mov(reg_tmp, mask);
703 kmovq(kmask, reg_tmp);
704 };
705
706 const bool is_bf16 = jcp.src_dt == data_type::bf16;
707 const int inp_w_step
708 = jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in;
709 const int inp_h_step = jcp.iw * inp_w_step;
710 const int out_h_step = jcp.ic_without_padding * jcp.typesize_in;
711 const int out_w_step = jcp.kh * out_h_step;
712 const int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
713 if (tail_size > 0) load_mask(tail_size, ktail_mask);
714
715 auto zero_it = [=](reg64_t tmp_out_ptr) {
716 for (int ic = 0; ic < jcp.ic_without_padding; ic += jcp.ic_block_int) {
717 const int offset = ic * jcp.typesize_in;
718 const bool masked = ic + jcp.ic_block_int > jcp.ic_without_padding;
719 Zmm zmm = masked ? zmm_zero | ktail_mask : zmm_zero;
720 if (is_bf16)
721 vmovdqu16(ptr[tmp_out_ptr + offset], zmm);
722 else
723 vmovdqu8(ptr[tmp_out_ptr + offset], zmm);
724 }
725 };
726
727 // pointer to 1st needed element in src buffer
728 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
729 // pointer to 1st needed element in dst buffer
730 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
731
732 // total number of rows to copy
733 mov(reg_kht, ptr[param1 + GET_OFF(kh_offset)]);
734
735 // number of rows of src buffer to copy
736 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
737 // number of zero-padded rows above src buffer to copy
738 mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
739 // number of zero-padded rows below src buffer to copy
740 mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
741
742 // number of columns of src buffer to copy
743 mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
744 // number of zero-padded columns before src buffer to copy
745 mov(reg_lov, ptr[param1 + GET_OFF(f_overflow)]);
746 // number of zero-padded columns before src buffer to copy
747 mov(reg_rov, ptr[param1 + GET_OFF(back_overflow)]);
748
749 vpxord(zmm_zero, zmm_zero, zmm_zero);
750
751 { // Handle Left Overflow
752 Label label_lov, label_lov_skip;
753 test(reg_lov, reg_lov);
754 jz(label_lov_skip, T_NEAR);
755 L(label_lov); // handle left or right overflow
756 {
757 Label label_lov_inner;
758 mov(reg_aux_out_ptr, reg_out_ptr);
759 mov(reg_cnt, reg_kht);
760 L(label_lov_inner);
761 {
762 zero_it(reg_aux_out_ptr);
763 add(reg_aux_out_ptr, out_h_step);
764 dec(reg_cnt);
765 jnz(label_lov_inner, T_NEAR);
766 }
767 add(reg_out_ptr, out_w_step);
768 dec(reg_lov);
769 jnz(label_lov, T_NEAR);
770 }
771 L(label_lov_skip);
772 }
773
774 // save output pointer for later use
775 mov(reg_save_out_ptr, reg_out_ptr);
776
777 // just in case there is no meat...
778 Label label_kwp_end;
779 test(reg_kwp, reg_kwp);
780 jz(label_kwp_end, T_NEAR);
781
782 // Unroll over W-dimension in powers of 2
783 Label label_tov;
784 Label label_khp, label_no_khp;
785 Label label_bov;
786 test(reg_tov, reg_tov);
787 jnz(label_tov, T_NEAR);
788 test(reg_khp, reg_khp);
789 jnz(label_khp, T_NEAR);
790 test(reg_bov, reg_bov);
791 jnz(label_bov, T_NEAR);
792 jmp(label_kwp_end, T_NEAR); // safe exit in case of bad parameters
793
794 L(label_tov); // handle top overflow
795 {
796 Label label_tov_inner;
797 mov(reg_aux_out_ptr, reg_out_ptr);
798 mov(reg_cnt, reg_kwp);
799 L(label_tov_inner);
800 {
801 zero_it(reg_aux_out_ptr);
802 add(reg_aux_out_ptr, out_w_step);
803 dec(reg_cnt);
804 jnz(label_tov_inner, T_NEAR);
805 }
806 add(reg_out_ptr, out_h_step);
807 dec(reg_tov);
808 jnz(label_tov, T_NEAR);
809 }
810 test(reg_khp, reg_khp);
811 jz(label_no_khp, T_NEAR);
812 L(label_khp); // handle kh padding (not fully unrolled)
813 {
814 Label label_khp_inner;
815 mov(reg_aux_inp_ptr, reg_inp_ptr);
816 mov(reg_aux_out_ptr, reg_out_ptr);
817 mov(reg_cnt, reg_kwp);
818 L(label_khp_inner);
819 {
820 for (int ic = 0; ic < jcp.ic_without_padding;
821 ic += jcp.ic_block_int) {
822 const int offset = ic * jcp.typesize_in;
823 const bool masked
824 = ic + jcp.ic_block_int > jcp.ic_without_padding;
825 // zero masking is needed to avoid dependency on destination
826 Zmm zmm_load = masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
827 Zmm zmm_store = masked ? zmm_tmp | ktail_mask : zmm_tmp;
828 if (is_bf16) {
829 vmovdqu16(zmm_load, ptr[reg_aux_inp_ptr + offset]);
830 vmovdqu16(ptr[reg_aux_out_ptr + offset], zmm_store);
831 } else {
832 vmovdqu8(zmm_load, ptr[reg_aux_inp_ptr + offset]);
833 vmovdqu8(ptr[reg_aux_out_ptr + offset], zmm_store);
834 }
835 }
836 add(reg_aux_inp_ptr, inp_w_step);
837 add(reg_aux_out_ptr, out_w_step);
838 dec(reg_cnt);
839 jnz(label_khp_inner, T_NEAR);
840 }
841 add(reg_inp_ptr, inp_h_step);
842 add(reg_out_ptr, out_h_step);
843 dec(reg_khp);
844 jnz(label_khp, T_NEAR);
845 }
846 L(label_no_khp);
847 test(reg_bov, reg_bov);
848 jz(label_kwp_end, T_NEAR);
849 L(label_bov); // handle bottom overflow
850 {
851 Label label_bov_inner;
852 mov(reg_aux_out_ptr, reg_out_ptr);
853 mov(reg_cnt, reg_kwp);
854 L(label_bov_inner);
855 {
856 zero_it(reg_aux_out_ptr);
857 add(reg_aux_out_ptr, out_w_step);
858 dec(reg_cnt);
859 jnz(label_bov_inner, T_NEAR);
860 }
861 add(reg_out_ptr, out_h_step);
862 dec(reg_bov);
863 jnz(label_bov, T_NEAR);
864 }
865 L(label_kwp_end);
866
867 { // Handle Right Overflow
868 Label label_rov, label_rov_skip;
869 // retrieve output pointer
870 mov(reg_out_ptr, reg_save_out_ptr);
871 // calculate the shift
872 imul(reg_tmp, reg_kwp, out_w_step);
873 // shift past the body
874 add(reg_out_ptr, reg_tmp);
875 // skip if no right overflow
876 test(reg_rov, reg_rov);
877 jz(label_rov_skip, T_NEAR);
878
879 L(label_rov); // handle left or right overflow
880 {
881 Label label_rov_inner;
882 mov(reg_aux_out_ptr, reg_out_ptr);
883 mov(reg_cnt, reg_kht);
884 L(label_rov_inner);
885 {
886 zero_it(reg_aux_out_ptr);
887 add(reg_aux_out_ptr, out_h_step);
888 dec(reg_cnt);
889 jnz(label_rov_inner, T_NEAR);
890 }
891 add(reg_out_ptr, out_w_step);
892 dec(reg_rov);
893 jnz(label_rov, T_NEAR);
894 }
895 L(label_rov_skip);
896 }
897
898 // For bf16, zero-pad an extra cacheline to avoid NaNs
899 // For int8, it is sufficient to zero-pad the weights only
900 if (is_bf16) {
901 // shift forward to align h index to end of needed buffer
902 imul(reg_tmp, reg_kht, out_h_step);
903 add(reg_out_ptr, reg_tmp);
904 // shift backward to align w index to end of needed buffer
905 sub(reg_out_ptr, out_w_step);
906 vmovdqu16(ptr[reg_out_ptr], zmm_zero);
907 }
908 }
909
generate()910 void jit_avx512_core_amx_copy_to_pbuffer_t::generate() {
911
912 // Special copy kernel for reduced lowering
913 if (jcp.is_relo) {
914 assert(jcp.nb_ic_int == 1);
915 preamble();
916 copy_row_reduced_lowering();
917 postamble();
918 return;
919 }
920
921 preamble();
922
923 const bool is_3d = jcp.ndims == 5;
924 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
925 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
926 if (is_3d) mov(reg_kdp, ptr[param1 + GET_OFF(kd_padding)]);
927 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
928 mov(reg_tover, ptr[param1 + GET_OFF(t_overflow)]);
929 mov(reg_bover, ptr[param1 + GET_OFF(b_overflow)]);
930 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
931
932 vpxord(zmm_zero, zmm_zero, zmm_zero);
933
934 if (jcp.is_nspc && jcp.ic_without_padding % jcp.ic_block_int) {
935 int tail_size = jcp.ic_without_padding % jcp.ic_block_int;
936 uint64_t mask = (UINT64_C(1) << tail_size) - 1;
937 mov(reg_tmp, mask);
938 kmovq(ktail_mask, reg_tmp);
939 }
940
941 for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
942 Xbyak::Label kd_label, no_kd_label;
943 Xbyak::Label kh_label, no_kh_label, icb_label;
944 Xbyak::Label kh_tover_label, kh_bover_label;
945 Xbyak::Label no_kh_tover_label, no_kh_bover_label;
946
947 mov(reg_aux_inp_ptr, reg_inp_ptr);
948 mov(reg_aux_out_ptr, reg_out_ptr);
949 if (is_3d) {
950 cmp(reg_kdp, 0);
951 jle(no_kd_label, T_NEAR);
952 mov(reg_kdc, reg_kdp);
953 L(kd_label);
954 push(reg_aux_inp_ptr);
955 push(reg_aux_out_ptr);
956 }
957 cmp(reg_khp, 0);
958 jle(no_kh_bover_label, T_NEAR); // nothing to do
959 mov(reg_khc, reg_khp);
960
961 cmp(reg_tover, 0);
962 jle(no_kh_tover_label, T_NEAR);
963
964 mov(reg_kh_over, reg_tover);
965 L(kh_tover_label);
966 {
967 // TODO: adjust step to improve zeroing efficiency for small ic
968 for (int iw = 0; iw < jcp.iwp; iw++)
969 vmovups(ptr[reg_aux_out_ptr
970 + jcp.typesize_in * iw * jcp.ic_block_int_np],
971 zmm_zero);
972 int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
973 add(reg_aux_out_ptr, out_h_offset);
974
975 dec(reg_kh_over);
976 jnz(kh_tover_label, T_NEAR);
977 }
978 sub(reg_khc, reg_tover);
979 L(no_kh_tover_label);
980
981 cmp(reg_khc, reg_bover);
982 jle(no_kh_label, T_NEAR);
983
984 L(kh_label);
985 {
986 copy_row(icb);
987 size_t inp_h_offset = !jcp.is_nspc
988 ? (size_t)jcp.typesize_in * jcp.iw * jcp.ic_block
989 : (size_t)jcp.typesize_in * jcp.iw * jcp.ngroups
990 * jcp.ic_without_padding;
991 size_t out_h_offset
992 = (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
993
994 add(reg_aux_inp_ptr, inp_h_offset);
995 add(reg_aux_out_ptr, out_h_offset);
996
997 dec(reg_khc);
998 cmp(reg_khc, reg_bover);
999 jg(kh_label, T_NEAR);
1000 }
1001 L(no_kh_label);
1002
1003 cmp(reg_khc, 0);
1004 jle(no_kh_bover_label, T_NEAR);
1005
1006 L(kh_bover_label);
1007 {
1008 // TODO: adjust step to improve zeroing efficiency for small ic
1009 for (int iw = 0; iw < jcp.iwp; iw++)
1010 vmovups(ptr[reg_aux_out_ptr
1011 + jcp.typesize_in * iw * jcp.ic_block_int_np],
1012 zmm_zero);
1013 int out_h_offset = jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np;
1014 add(reg_aux_out_ptr, out_h_offset);
1015
1016 dec(reg_khc);
1017 jnz(kh_bover_label, T_NEAR);
1018 }
1019 size_t out_d_offset = (size_t)jcp.typesize_in
1020 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1021 L(no_kh_bover_label);
1022 if (is_3d) {
1023 size_t inp_d_offset = !jcp.is_nspc
1024 ? (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block
1025 * (jcp.dilate_d + 1)
1026 : (size_t)jcp.typesize_in * jcp.ih * jcp.iw * jcp.ngroups
1027 * jcp.ic_without_padding * (jcp.dilate_d + 1);
1028 pop(reg_aux_out_ptr);
1029 pop(reg_aux_inp_ptr);
1030 add(reg_aux_inp_ptr, inp_d_offset);
1031 add(reg_aux_out_ptr, out_d_offset);
1032 dec(reg_kdc);
1033 jnz(kd_label, T_NEAR);
1034 L(no_kd_label);
1035 }
1036 // End IC Loop
1037 size_t inp_cb_offset = !jcp.is_nspc
1038 ? (size_t)jcp.typesize_in * (jcp.ic_block_int_np / jcp.ic_block)
1039 * jcp.id * jcp.ih * jcp.iw * jcp.ic_block
1040 : (size_t)jcp.typesize_in * jcp.ic_block_int_np;
1041 size_t out_cb_offset = (size_t)jcp.kd * out_d_offset;
1042
1043 add(reg_inp_ptr, inp_cb_offset);
1044 add(reg_out_ptr, out_cb_offset);
1045 }
1046
1047 postamble();
1048 }
1049
jit_avx512_core_amx_fwd_kernel_t(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_t & dst_md)1050 jit_avx512_core_amx_fwd_kernel_t::jit_avx512_core_amx_fwd_kernel_t(
1051 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
1052 const memory_desc_t &dst_md)
1053 : jit_generator(nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
1054 , jcp(ajcp)
1055 , attr_(attr) {
1056 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
1057 using namespace binary_injector;
1058 const auto &rhs_addr_reg = bin_injector_helper_reg_1;
1059 const auto &rhs_helper_reg = bin_injector_helper_reg_2;
1060 static constexpr bool preserve_gpr = false;
1061 static constexpr bool preserve_vmm = false;
1062 const size_t tail_size = jcp.oc_without_padding % isa_simd_width_;
1063 static constexpr bool use_exact_tail_scalar_bcast = true;
1064
1065 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
1066 31, rhs_addr_reg, rhs_helper_reg, preserve_gpr, preserve_vmm,
1067 GET_OFF(post_ops_binary_rhs_arg_vec),
1068 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
1069 use_exact_tail_scalar_bcast};
1070 const binary_injector::static_params_t static_params {
1071 this->param1, rhs_arg_static_params};
1072
1073 postops_injector_ = utils::make_unique<
1074 injector::jit_uni_postops_injector_t<avx512_core>>(
1075 this, jcp.post_ops, static_params);
1076 }
1077 copy_to_pbuffer_
1078 = utils::make_unique<jit_avx512_core_amx_copy_to_pbuffer_t>(jcp);
1079 if (jcp.is_relo)
1080 copy_to_wbuffer_
1081 = utils::make_unique<jit_avx512_core_amx_copy_to_wbuffer_t>(
1082 jcp);
1083 }
1084
create_kernel()1085 status_t jit_avx512_core_amx_fwd_kernel_t::create_kernel() {
1086 CHECK(jit_generator::create_kernel());
1087 CHECK(copy_to_pbuffer_->create_kernel());
1088 if (jcp.is_relo) CHECK(copy_to_wbuffer_->create_kernel());
1089 if (jcp.req_zero_point_buffer) {
1090 zp_pbuff_kernel_
1091 = utils::make_unique<jit_avx512_core_amx_compute_zp_pbuff_t>(
1092 jcp);
1093 if (zp_pbuff_kernel_ == nullptr) return status::out_of_memory;
1094 CHECK(zp_pbuff_kernel_->create_kernel());
1095 }
1096 return status::success;
1097 }
1098
1099 // Tile register decomposition
1100 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i,bool is_h_tail) const1101 int jit_avx512_core_amx_fwd_kernel_t::get_out_tensor(
1102 int h, int i, bool is_h_tail) const {
1103 const int C_BASE = 0;
1104 const int C_LAST = 4;
1105 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
1106 MAYBE_UNUSED(C_LAST);
1107 const int tile = C_BASE
1108 + (jcp.nb_oh_blocking > 1
1109 ? h * jcp.nb_oh_blocking + i
1110 : (int)is_h_tail * jcp.nb_oc_blocking + i);
1111 assert(C_BASE <= tile && tile < C_LAST);
1112 return tile;
1113 }
get_inp_tensor(int h,bool is_h_tail) const1114 int jit_avx512_core_amx_fwd_kernel_t::get_inp_tensor(
1115 int h, bool is_h_tail) const {
1116 const int I_BASE = 4;
1117 const int I_LAST = 6;
1118 assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
1119 MAYBE_UNUSED(I_LAST);
1120 const int tile = I_BASE + (jcp.nb_oh_blocking > 1 ? h : (int)is_h_tail);
1121 assert(I_BASE <= tile && tile < I_LAST);
1122 return tile;
1123 }
get_wei_tensor(int i) const1124 int jit_avx512_core_amx_fwd_kernel_t::get_wei_tensor(int i) const {
1125 const int W_BASE = 6;
1126 const int W_LAST = 8;
1127 assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
1128 MAYBE_UNUSED(W_LAST);
1129 const int tile = W_BASE + i;
1130 assert(W_BASE <= tile && tile < W_LAST);
1131 return tile;
1132 }
1133
1134 // Shifts and offsets
get_inp_icb_step() const1135 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_icb_step() const {
1136 return (size_t)jcp.kd * get_inp_d_step();
1137 }
get_wei_icb_step() const1138 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_icb_step() const {
1139 return (size_t)jcp.typesize_in * jcp.kd * jcp.kh * jcp.kw
1140 * jcp.ic_block_int_np * jcp.oc_block;
1141 }
get_inp_d_step() const1142 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_d_step() const {
1143 return (size_t)jcp.typesize_in
1144 * (jcp.ihp * jcp.iwp * jcp.ic_block_int_np + jcp.ic_block_int);
1145 }
get_inp_h_step() const1146 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_h_step() const {
1147 return (size_t)jcp.typesize_in * jcp.iwp * jcp.ic_block_int_np
1148 * (jcp.dilate_h + 1);
1149 }
get_wei_d_step() const1150 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_d_step() const {
1151 return (size_t)jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block_int_np
1152 * jcp.oc_block;
1153 }
get_wei_h_step() const1154 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_h_step() const {
1155 return (size_t)jcp.typesize_in * jcp.kw * jcp.ic_block_int_np
1156 * jcp.oc_block;
1157 }
get_out_ocb_offset(int ohb,int ocb,size_t typesize) const1158 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_ocb_offset(
1159 int ohb, int ocb, size_t typesize) const {
1160 size_t el_offset = jcp.is_nspc
1161 ? (size_t)ocb * jcp.oc_block
1162 + (size_t)ohb * jcp.ow * jcp.ngroups
1163 * jcp.oc_without_padding
1164 : (size_t)ocb * jcp.oh * jcp.ow * jcp.oc_block
1165 + (size_t)ohb * jcp.ow * jcp.oc_block;
1166 return (size_t)typesize * el_offset;
1167 }
get_out_row_offset(int ohb,int ocb,int j,size_t typesize) const1168 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_row_offset(
1169 int ohb, int ocb, int j, size_t typesize) const {
1170 size_t offset_w = jcp.is_nspc
1171 ? (size_t)typesize * j * jcp.ngroups * jcp.oc_without_padding
1172 : (size_t)typesize * j * jcp.oc_block;
1173 return get_out_ocb_offset(ohb, ocb, typesize) + offset_w;
1174 }
get_out_shift(int width,size_t typesize) const1175 size_t jit_avx512_core_amx_fwd_kernel_t::get_out_shift(
1176 int width, size_t typesize) const {
1177 return jcp.is_nspc
1178 ? (size_t)typesize * width * jcp.ngroups * jcp.oc_without_padding
1179 : (size_t)typesize * width * jcp.oc_block;
1180 }
get_wsp_ocb_offset(int ohb,int ocb) const1181 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_ocb_offset(
1182 int ohb, int ocb) const {
1183 size_t el_offset = (size_t)ocb * prv_width_ * jcp.oc_block
1184 + (size_t)ohb * jcp.nb_oc_blocking * jcp.full_tile_width
1185 * jcp.oc_block;
1186 return jcp.typesize_acc * el_offset;
1187 }
get_wsp_row_offset(int ohb,int ocb,int j) const1188 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_row_offset(
1189 int ohb, int ocb, int j) const {
1190 return get_wsp_ocb_offset(ohb, ocb)
1191 + (size_t)jcp.typesize_acc * j * jcp.oc_block;
1192 }
get_wsp_shift() const1193 size_t jit_avx512_core_amx_fwd_kernel_t::get_wsp_shift() const {
1194 return (size_t)jcp.typesize_acc * jcp.nb_oh_blocking * jcp.full_tile_width
1195 * jcp.oc_block * jcp.nb_oc_blocking;
1196 }
get_wei_offset(int ocb,int kw) const1197 size_t jit_avx512_core_amx_fwd_kernel_t::get_wei_offset(int ocb, int kw) const {
1198 size_t el_offset = (size_t)kw * jcp.ic_block_int_np * jcp.oc_block;
1199 size_t raw_oc_subblock_step
1200 = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
1201 size_t oc_subblock_step = jcp.is_relo
1202 ? rnd_up(raw_oc_subblock_step, jcp.ic_block_int * jcp.oc_block)
1203 : raw_oc_subblock_step;
1204 el_offset += (size_t)ocb * jcp.nb_ic_int * oc_subblock_step;
1205 return jcp.typesize_in * el_offset;
1206 }
get_inp_shift() const1207 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_shift() const {
1208 size_t w_step = (jcp.is_relo ? jcp.stride_w * jcp.kh
1209 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w)
1210 * jcp.ic_block_int_np;
1211 return (size_t)jcp.typesize_in * jcp.tile_width * w_step;
1212 }
get_inp_offset(int ohb,int kw) const1213 size_t jit_avx512_core_amx_fwd_kernel_t::get_inp_offset(int ohb, int kw) const {
1214 if (jcp.is_relo)
1215 return ohb * jcp.iwp * jcp.kh * jcp.ic_block_int_np * jcp.typesize_in;
1216 // calculate offset by height dimension
1217 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
1218 const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh);
1219 size_t el_offset = (size_t)ohb * jcp.oh_per_tile * gen_stride_h * jcp.iwp
1220 * jcp.ic_block_int_np;
1221
1222 // add offset by width dimension
1223 if (IMPLICATION(jcp.is_pbuffer_strided, jcp.stride_w == 1)) {
1224 el_offset += (size_t)kw * (jcp.dilate_w + 1) * jcp.ic_block_int_np;
1225 } else if (jcp.dilate_w > 0) {
1226 el_offset += (size_t)kw * jcp.ow_block * jcp.ic_block_int_np;
1227 } else {
1228 // dilate_w == 0 && stride_w > 1
1229 // there are min(jcp.kw, jcp.stride_w) continuous sets of input data
1230 // (foreach stride idx), they are placed one by one without additional
1231 // padding
1232
1233 // calculate set idx for current kw value
1234 int set_idx = kw % jcp.stride_w;
1235 // calculate shift within set for current kw value
1236 int set_shift = kw / jcp.stride_w;
1237
1238 // calculate the beginning of the current set along width, each set
1239 // with index set_i contains number of elements along width equal to
1240 // jcp.ow - 1 + jcp.kw / jcp.stride_w
1241 // + (set_i < jcp.kw % jcp.stride_w)
1242 size_t set_start = (jcp.ow_block - 1 + jcp.kw / jcp.stride_w) * set_idx
1243 + nstl::min(set_idx, jcp.kw % jcp.stride_w);
1244 el_offset += (set_start + set_shift) * jcp.ic_block_int_np;
1245 }
1246 return jcp.typesize_in * el_offset;
1247 }
1248
get_zp_comp_offset(int ocb,int zp_h,int zp_w) const1249 size_t jit_avx512_core_amx_fwd_kernel_t::get_zp_comp_offset(
1250 int ocb, int zp_h, int zp_w) const {
1251 const size_t ocb_offset = (size_t)ocb * jcp.oc_block;
1252 const size_t sp_offset = (size_t)(zp_h * jcp.ow_pad + zp_w) * jcp.ngroups
1253 * jcp.oc_without_padding;
1254 return (ocb_offset + sp_offset) * sizeof(int32_t);
1255 }
1256
get_zp_index_offset(int index,int mid,int s_pad_output,int e_pad_output)1257 int jit_avx512_core_amx_fwd_kernel_t::get_zp_index_offset(
1258 int index, int mid, int s_pad_output, int e_pad_output) {
1259 using namespace nstl;
1260 const int mid_end = e_pad_output - 1;
1261 int zp_mid = min(mid, max(0, index - mid_end));
1262 int zp_pad_offset
1263 = accum_with_upper_bound(index, s_pad_output, e_pad_output);
1264 return zp_pad_offset + zp_mid;
1265 }
1266
1267 // Code generation
prepare_output(int tail)1268 void jit_avx512_core_amx_fwd_kernel_t::prepare_output(int tail) {
1269 for (int h = 0; h < jcp.nb_oh_blocking; h++)
1270 for (int i = 0; i < jcp.nb_oc_blocking; i++)
1271 tilezero(Tmm(get_out_tensor(h, i, tail)));
1272 }
1273
init_runtime_counters(bool start_with_last_tile_block)1274 void jit_avx512_core_amx_fwd_kernel_t::init_runtime_counters(
1275 bool start_with_last_tile_block) {
1276 prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
1277 ? jcp.tile_tail
1278 : jcp.tile_width;
1279
1280 row_count_ = 0;
1281 is_store_done_ = false;
1282 is_buffer_empty_ = true;
1283 }
1284
reduce_to_block(const int block_size,const int pad_output)1285 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_block(
1286 const int block_size, const int pad_output) {
1287 return (size_t)(pad_output >= block_size ? block_size : 0)
1288 + (pad_output % block_size);
1289 }
1290
reduce_to_blocked_dims(const int dim_size,const int block_size,const int s_pad_output,const int e_pad_output)1291 size_t jit_avx512_core_amx_fwd_kernel_t::reduce_to_blocked_dims(
1292 const int dim_size, const int block_size, const int s_pad_output,
1293 const int e_pad_output) {
1294 using namespace nstl;
1295
1296 // start padding (s_pad)
1297 int s_pad_limit = reduce_to_block(block_size, s_pad_output);
1298 int s_pad_area_blk = rnd_up(s_pad_limit, block_size);
1299
1300 // middle (no padding)
1301 int no_pad_area = max(
1302 0, dim_size - rnd_up(s_pad_output, block_size) - e_pad_output);
1303 int no_pad_limit = (no_pad_area >= block_size ? block_size : 0);
1304
1305 // end padding (e_pad)
1306 int no_pad_area_shift = no_pad_area % block_size;
1307 int e_pad_area_overlap
1308 = no_pad_area_shift == 0 ? 0 : block_size - no_pad_area_shift;
1309 // middle and end padding shift
1310 int e_pad_shift_limit
1311 = no_pad_area_shift + min(e_pad_output, e_pad_area_overlap);
1312 int e_pad_area_blk = max(0, e_pad_output - e_pad_area_overlap);
1313 // full end padding block
1314 int e_pad_limit = reduce_to_block(block_size, e_pad_area_blk);
1315
1316 // calculate reduced size of s_pad, middle and e_pad blocks.
1317 return min((size_t)dim_size,
1318 (size_t)s_pad_area_blk + no_pad_limit + e_pad_shift_limit
1319 + e_pad_limit);
1320 }
1321
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)1322 Ymm jit_avx512_core_amx_fwd_kernel_t::ymm_mask(
1323 const Ymm &ymm_in, bool mask_flag, bool store) {
1324 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
1325 : ymm_in;
1326 }
1327
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)1328 Zmm jit_avx512_core_amx_fwd_kernel_t::zmm_mask(
1329 const Zmm &zmm_in, bool mask_flag, bool store) {
1330 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
1331 : zmm_in;
1332 }
1333
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)1334 void jit_avx512_core_amx_fwd_kernel_t::cvt2ps(data_type_t type_in,
1335 const Zmm &zmm_in, const Operand &op, bool mask_flag) {
1336 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
1337 switch (type_in) {
1338 case data_type::f32:
1339 case data_type::s32: vmovups(zmm, op); break;
1340 case data_type::s8: vpmovsxbd(zmm, op); break;
1341 case data_type::u8: vpmovzxbd(zmm, op); break;
1342 default: assert(!"unsupported data type");
1343 }
1344 if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
1345 }
1346
apply_sum(const Zmm & zmm_out,const float * p_sum_scale,const Xbyak::Address & addr,const bool mask_flag)1347 void jit_avx512_core_amx_fwd_kernel_t::apply_sum(const Zmm &zmm_out,
1348 const float *p_sum_scale, const Xbyak::Address &addr,
1349 const bool mask_flag) {
1350 if (p_sum_scale) {
1351 const float p_sum_scale_val = *p_sum_scale;
1352 const auto sum_injector = [&, p_sum_scale_val, mask_flag]() {
1353 cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
1354 if (p_sum_scale_val == 1.f)
1355 vaddps(zmm_out, zmm_prev_dst);
1356 else
1357 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
1358 };
1359 postops_injector_->set_lambda_injector(
1360 primitive_kind::sum, sum_injector);
1361 }
1362 }
1363
apply_postops(const Zmm & zmm_out,const float * p_sum_scale,const Xbyak::Address & addr,const bool mask_flag,const size_t off,const int ocb)1364 void jit_avx512_core_amx_fwd_kernel_t::apply_postops(const Zmm &zmm_out,
1365 const float *p_sum_scale, const Xbyak::Address &addr,
1366 const bool mask_flag, const size_t off, const int ocb) {
1367 if (jcp.with_eltwise || jcp.with_binary
1368 || (jcp.with_sum && p_sum_scale != nullptr)) {
1369 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1370
1371 apply_sum(zmm_out, p_sum_scale, addr, mask_flag);
1372
1373 const auto vmm_idx = zmm_out.getIdx();
1374 if (jcp.with_binary) {
1375 const int oc_l_offset = ocb * jcp.oc_block;
1376 rhs_arg_params.vmm_idx_to_oc_elem_off_addr.emplace(
1377 vmm_idx, ptr[param1 + GET_OFF(oc_l_off)]);
1378 rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(
1379 vmm_idx, oc_l_offset);
1380 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1381 vmm_idx, static_cast<int>(off));
1382 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1383 }
1384
1385 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
1386 }
1387 }
1388
store_output_vector_bf16(const Zmm & zmm_out,int ocb,int h,int w)1389 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_bf16(
1390 const Zmm &zmm_out, int ocb, int h, int w) {
1391 const bool mask_flag = jcp.is_nspc && jcp.oc_without_padding != jcp.oc
1392 && ocb == (jcp.nb_oc_blocking - 1);
1393
1394 const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1395 auto addr = EVEX_compress_addr(reg_out_ptr, off);
1396
1397 const auto &p = attr_.post_ops_;
1398
1399 const int sum_idx = p.find(primitive_kind::sum);
1400 if (sum_idx != -1) {
1401 if (jcp.dst_dt == data_type::bf16) {
1402 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
1403 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
1404 vaddps(zmm_out, zmm_prev_dst);
1405 } else {
1406 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
1407 vaddps(zmm_out, zmm_prev_dst);
1408 }
1409 }
1410 if (jcp.with_bias) {
1411 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
1412 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1413 if (jcp.bia_dt == data_type::bf16) {
1414 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
1415 vpslld(zmm_bias, zmm_bias, 16);
1416 vaddps(zmm_out, zmm_bias);
1417 } else
1418 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
1419 }
1420
1421 static constexpr auto skip_sum_injection = nullptr;
1422 apply_postops(zmm_out, skip_sum_injection, addr, mask_flag, off, ocb);
1423
1424 if (jcp.dst_dt == data_type::bf16) {
1425 Ymm ymm_out = Ymm(zmm_out.getIdx());
1426 vcvtneps2bf16(ymm_out, zmm_out);
1427 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
1428 } else {
1429 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
1430 }
1431 }
1432
store_output_vector_int8(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1433 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector_int8(
1434 const Zmm &zmm_out, int ocb, int h, int w, const bool compute_zp,
1435 const int zp_h, const int zp_w) {
1436 const int nb_oc_block = jcp.nb_oc_blocking;
1437 const int oc_block = jcp.oc_block;
1438 const bool mask_flag = true && jcp.oc_without_padding != jcp.oc
1439 && ocb == (nb_oc_block - 1);
1440
1441 const auto off = get_out_row_offset(h, ocb, w, jcp.typesize_out);
1442 auto addr = EVEX_compress_addr(reg_out_ptr, off);
1443
1444 const auto &p = attr_.post_ops_;
1445 const int sum_idx = p.find(primitive_kind::sum);
1446 const float *p_sum_scale = nullptr;
1447 if (sum_idx != -1) {
1448 const auto &p_entry = p.entry_[sum_idx];
1449 p_sum_scale = &p_entry.sum.scale;
1450 }
1451
1452 if (p_sum_scale && *p_sum_scale != 1.f)
1453 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
1454
1455 int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * oc_block);
1456 if (jcp.with_bias) {
1457 int bias_offset = jcp.typesize_bia * ocb * oc_block;
1458 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
1459 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
1460 }
1461 if (compute_zp) {
1462 assert(jcp.req_zero_point_buffer);
1463 // add zero-point padding compensation when accum data is S32
1464 const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1465 vmovups(m_zmm_zp,
1466 EVEX_compress_addr(reg_zero_point_pbuff,
1467 get_zp_comp_offset(ocb, zp_h, zp_w)));
1468 const Zmm m_zmm_out = zmm_mask(zmm_out, mask_flag);
1469 vpaddd(m_zmm_out, zmm_out, zmm_zp);
1470 }
1471 if (jcp.src_zero_point) {
1472 // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
1473 int zp_offset = sizeof(int32_t) * ocb * oc_block;
1474 const Zmm m_zmm_zp = zmm_mask(zmm_zp, mask_flag);
1475 vpmulld(m_zmm_zp, zmm_src_zp,
1476 EVEX_compress_addr(reg_zp_compensation, zp_offset));
1477 vpaddd(zmm_out, zmm_out, zmm_zp);
1478 }
1479
1480 /* add bias and zero-point to zmm_accum */
1481 vcvtdq2ps(zmm_out, zmm_out);
1482 if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
1483 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
1484 vmulps(zmm_out_msk, zmm_out,
1485 EVEX_compress_addr(reg_ptr_scales, scale_offset));
1486
1487 apply_postops(zmm_out, p_sum_scale, addr, mask_flag, off, ocb);
1488
1489 if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); }
1490
1491 // Properly saturate the accumulators for integer datatypes
1492 if (one_of(jcp.dst_dt, u8, s8, s32)) {
1493 init_saturate_f32(
1494 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dst_dt);
1495 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt);
1496 vcvtps2dq(zmm_out, zmm_out);
1497 }
1498
1499 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
1500
1501 switch (jcp.dst_dt) {
1502 case data_type::f32:
1503 case data_type::s32: vmovups(addr, zmm_out_store); break;
1504 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
1505 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
1506 default: assert(!"unknown dst_dt");
1507 }
1508 }
1509
store_output_vector(const Zmm & zmm_out,int ocb,int h,int w,const bool compute_zp,const int zp_h,const int zp_w)1510 void jit_avx512_core_amx_fwd_kernel_t::store_output_vector(const Zmm &zmm_out,
1511 int ocb, int h, int w, const bool compute_zp, const int zp_h,
1512 const int zp_w) {
1513 /*
1514 Output:
1515 jcp.is_nspc !jcp.is_nspc
1516 --------------------- ---------------------
1517 INT8: [N][H][W][NBOC][16OC]
1518 BF16: [N][H][W][NBOC][16OC] or [N][NBOC][H][W][16OC]
1519 */
1520 if (jcp.src_dt == data_type::bf16) {
1521 store_output_vector_bf16(zmm_out, ocb, h, w);
1522 } else {
1523 store_output_vector_int8(zmm_out, ocb, h, w, compute_zp, zp_h, zp_w);
1524 }
1525 }
1526
store_output(int width,int tail,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool is_last_oh_block,const bool zp_3d_pad)1527 void jit_avx512_core_amx_fwd_kernel_t::store_output(int width, int tail,
1528 bool do_store, const bool handle_h_blk, const int t_pad_output,
1529 const int b_pad_output, const int l_pad_output, const int r_pad_output,
1530 const bool is_last_oh_block, const bool zp_3d_pad) {
1531 auto store_output_block = [=](int width, int tail, bool do_store,
1532 bool is_last_h = false) {
1533 // Calculate the number of oh blocks; it may differ on last call
1534 const int last_h_blks
1535 = div_up(jcp.oh, jcp.oh_per_tile) % jcp.nb_oh_blocking;
1536 const int h_blks = is_last_h && last_h_blks != 0 ? last_h_blks
1537 : jcp.nb_oh_blocking;
1538 // Calculate the number of oh rows per tile; it may differ on last call
1539 const int h_tail = is_last_h && jcp.oh % jcp.oh_per_tile != 0
1540 ? (h_blks - 1) * jcp.oh_per_tile + jcp.oh % jcp.oh_per_tile
1541 : h_blks * jcp.oh_per_tile;
1542 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
1543 const int owp = gen_kw + jcp.ow - 1;
1544
1545 if (jcp.src_zero_point) {
1546 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1547 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1548 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1549 }
1550 if (jcp.dst_zero_point) {
1551 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1552 vcvtdq2ps(zmm_dst_zp,
1553 EVEX_compress_addr(reg_dst_zero_point, 0, true));
1554 }
1555
1556 for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
1557 for (int ohb = 0; ohb < h_blks; ohb++) {
1558 /* Formats: Workspace: [NBOC][W][16OC] */
1559 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
1560 + get_wsp_ocb_offset(ohb, ocb)],
1561 Tmm(get_out_tensor(ohb, ocb, tail)));
1562 is_buffer_empty_ = false;
1563 is_store_done_ = false;
1564
1565 // preserve registers used by binary post_ops injector
1566 const injector_utils::conditional_register_preserve_guard_t
1567 cond_register_guard(jcp.with_binary, this,
1568 {bin_injector_helper_reg_1,
1569 bin_injector_helper_reg_2});
1570
1571 for (int tw = 0; tw < width && do_store; tw++) {
1572 // height
1573 const int oh_index = ohb * jcp.oh_per_tile + tw / owp;
1574 const bool zp_h_pad
1575 = oh_index < t_pad_output || oh_index >= b_pad_output;
1576 const int zp_h = get_zp_index_offset(
1577 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1578 // width
1579 const int ow_index = tw % owp;
1580 const bool zp_w_pad
1581 = ow_index < l_pad_output || ow_index >= r_pad_output;
1582 const int zp_w = get_zp_index_offset(
1583 ow_index, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1584
1585 const bool compute_zp = jcp.req_zero_point_buffer
1586 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1587
1588 assert(IMPLICATION(jcp.oh_per_tile == 1,
1589 ohb == oh_index && tw == ow_index));
1590 if (oh_index < h_tail && ow_index < jcp.ow) {
1591 Zmm zmm_r = zmm_out(tw);
1592 vmovups(zmm_r,
1593 ptr[reg_wsp_ptr
1594 + get_wsp_row_offset(ohb, ocb, tw)]);
1595 store_output_vector(zmm_r, ocb, oh_index, ow_index,
1596 compute_zp, zp_h, zp_w);
1597 }
1598 }
1599 }
1600 };
1601
1602 // adjustment in case interleave store is turned off
1603 do_store = do_store || jcp.per_one_pstore == 0;
1604 if (!do_store) { w_padding.emplace(l_pad_output, r_pad_output); }
1605 if (!handle_h_blk) {
1606 store_output_block(width, tail, do_store, is_last_oh_block);
1607 } else {
1608 if (jcp.oh % (jcp.oh_per_tile * jcp.nb_oh_blocking) == 0) {
1609 store_output_block(width, tail, do_store);
1610 } else {
1611 Label label_oh_oc_store, label_done;
1612 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
1613 cmp(reg_last_h, 0);
1614 jne(label_oh_oc_store, T_NEAR);
1615 store_output_block(width, tail, do_store, true); // last h
1616 jmp(label_done, T_NEAR);
1617 L(label_oh_oc_store);
1618 store_output_block(width, tail, do_store, false);
1619 L(label_done);
1620 }
1621 }
1622 if (do_store) {
1623 add(reg_out_ptr, get_out_shift(width, jcp.typesize_out));
1624 if (jcp.req_zero_point_buffer) {
1625 const size_t sp_shift
1626 = accum_with_upper_bound(width, l_pad_output, r_pad_output);
1627 add(reg_zero_point_pbuff, get_out_shift(sp_shift, sizeof(int32_t)));
1628 }
1629 }
1630 }
1631
interleave_store(int width,int const t_pad_output,int const b_pad_output,const bool zp_3d_pad)1632 void jit_avx512_core_amx_fwd_kernel_t::interleave_store(int width,
1633 int const t_pad_output, int const b_pad_output, const bool zp_3d_pad) {
1634 for (int c = 0;
1635 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
1636 c++) {
1637 // row_count = ohb * OCB * TW + ocb * TW + tw
1638 int tw = row_count_ % prv_width_;
1639 int ocb = (row_count_ / prv_width_) % jcp.nb_oc_blocking;
1640 int ohb = (row_count_ / prv_width_) / jcp.nb_oc_blocking;
1641
1642 // preserve registers used by binary post_ops injector
1643 const injector_utils::conditional_register_preserve_guard_t
1644 cond_register_guard(jcp.with_binary, this,
1645 {bin_injector_helper_reg_1, bin_injector_helper_reg_2});
1646
1647 // height
1648 const int oh_index = ohb;
1649 const bool zp_h_pad
1650 = oh_index < t_pad_output || oh_index >= b_pad_output;
1651 const int zp_h = get_zp_index_offset(
1652 oh_index, (int)jcp.oh_mid, t_pad_output, b_pad_output);
1653 // width
1654 const int l_pad_output
1655 = w_padding.empty() ? 0 : w_padding.front().l_pad_output;
1656 const int r_pad_output
1657 = w_padding.empty() ? jcp.ow : w_padding.front().r_pad_output;
1658
1659 const bool zp_w_pad = tw < l_pad_output || tw >= r_pad_output;
1660 const int zp_w = get_zp_index_offset(
1661 tw, (int)jcp.ow_mid, l_pad_output, r_pad_output);
1662
1663 const bool compute_zp = jcp.req_zero_point_buffer
1664 && (zp_3d_pad || zp_w_pad || zp_h_pad);
1665
1666 Zmm zmm_r = zmm_out(tw);
1667 vmovups(zmm_r, ptr[reg_wsp_ptr + get_wsp_row_offset(ohb, ocb, tw)]);
1668 store_output_vector(zmm_r, ocb, ohb, tw, compute_zp, zp_h, zp_w);
1669 row_count_++;
1670
1671 if (row_count_
1672 == prv_width_ * jcp.nb_oc_blocking * jcp.nb_oh_blocking) {
1673 add(reg_out_ptr, get_out_shift(prv_width_, jcp.typesize_out));
1674 if (jcp.req_zero_point_buffer) {
1675 const size_t sp_shift = accum_with_upper_bound(
1676 prv_width_, l_pad_output, r_pad_output);
1677 add(reg_zero_point_pbuff,
1678 get_out_shift(sp_shift, sizeof(int32_t)));
1679 if (!w_padding.empty()) w_padding.pop();
1680 }
1681 row_count_ = 0;
1682 is_store_done_ = true;
1683 prv_width_ = width;
1684 }
1685 }
1686 }
1687
compute_icb_loop(int width,bool do_store,const bool handle_h_blk,const int t_pad_output,const int b_pad_output,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad,const bool is_last_oh_block)1688 void jit_avx512_core_amx_fwd_kernel_t::compute_icb_loop(int width,
1689 bool do_store, const bool handle_h_blk, const int t_pad_output,
1690 const int b_pad_output, const int l_pad_output, const int r_pad_output,
1691 const bool zp_3d_pad, const bool is_last_oh_block) {
1692 const bool tail = width == jcp.tile_tail;
1693
1694 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
1695 if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) {
1696 tdpbf16ps(x1, x2, x3);
1697 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) {
1698 tdpbuud(x1, x2, x3);
1699 } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) {
1700 tdpbusd(x1, x2, x3);
1701 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) {
1702 tdpbsud(x1, x2, x3);
1703 } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) {
1704 tdpbssd(x1, x2, x3);
1705 } else {
1706 assert(!"unsupported combination");
1707 }
1708 };
1709
1710 prepare_output(tail);
1711
1712 // prepare registers for when 'interleave_store()' is computed
1713 if (jcp.src_zero_point) {
1714 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
1715 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
1716 vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0));
1717 }
1718 if (jcp.dst_zero_point) {
1719 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
1720 vcvtdq2ps(zmm_dst_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
1721 }
1722
1723 // reduced lowering path
1724 if (jcp.is_relo) {
1725 const int nreduce = jcp.nreduce;
1726 const int stride = jcp.ic_block_int; // ie 64 (32) for int8 (bf16)
1727
1728 push(reg_inp_ptr);
1729 push(reg_wei_ptr);
1730
1731 for (int ireduce = 0; ireduce < nreduce; ireduce += stride) {
1732 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1733 tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1734 ptr[reg_inp_ptr + get_inp_offset(ohb, 0)
1735 + reg_inp_stride]);
1736 }
1737 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1738 tileloadd(Tmm(get_wei_tensor(ocb)),
1739 ptr[reg_wei_ptr + get_wei_offset(ocb, 0)
1740 + reg_wei_stride]);
1741 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1742 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1743 Tmm(get_inp_tensor(ohb, tail)),
1744 Tmm(get_wei_tensor(ocb)));
1745 interleave_store(width, t_pad_output, b_pad_output);
1746 }
1747 }
1748 if (ireduce + stride < nreduce) {
1749 add(reg_inp_ptr, stride * jcp.typesize_in);
1750 add(reg_wei_ptr, stride * jcp.oc_block * jcp.typesize_in);
1751 }
1752 }
1753 pop(reg_wei_ptr);
1754 pop(reg_inp_ptr);
1755
1756 store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1757 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block);
1758
1759 add(reg_inp_ptr, get_inp_shift());
1760
1761 return;
1762 }
1763
1764 auto wei_offset = [&](int icb, int ocb, int kd, int kh, int kw) {
1765 return (size_t)icb * get_wei_icb_step() + kd * get_wei_d_step()
1766 + kh * get_wei_h_step() + get_wei_offset(ocb, kw);
1767 };
1768
1769 auto inp_offset = [&](int icb, int ohb, int kd, int kh, int kw) {
1770 return (size_t)icb * get_inp_icb_step() + kd * get_inp_d_step()
1771 + kh * get_inp_h_step() + get_inp_offset(ohb, kw);
1772 };
1773
1774 auto safe_tileloadd
1775 = [=](const Tmm &t1, const Xbyak::Reg64 ®_ptr, size_t offset,
1776 const Xbyak::Reg64 ®_stride) {
1777 if (offset <= INT32_MAX) {
1778 tileloadd(t1, ptr[reg_ptr + offset + reg_stride]);
1779 } else {
1780 safe_add(reg_ptr, offset, reg_tmp);
1781 tileloadd(t1, ptr[reg_ptr + reg_stride]);
1782 safe_sub(reg_ptr, offset, reg_tmp);
1783 }
1784 };
1785
1786 // normal and k-remainders path
1787 const bool check_kd_padding
1788 = jcp.ndims == 5 && (jcp.f_pad > 0 || jcp.back_pad > 0);
1789 for (int icb = 0; icb < jcp.nb_ic_int; icb++) {
1790 Label kd_skip_compute;
1791 if (check_kd_padding) mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1792
1793 for (int kd = 0; kd < jcp.kd; kd++) {
1794 if (check_kd_padding) {
1795 dec(reg_kd);
1796 jl(kd_skip_compute, T_NEAR);
1797 push(reg_kd);
1798 }
1799 for (int kh = 0; kh < jcp.kh; kh++) {
1800 for (int set_idx = 0; set_idx < jcp.n_stride_sets;
1801 set_idx++) { // used to optimize input memory reuse in L1$
1802 for (int kw = set_idx; kw < jcp.kw; kw += jcp.kw_step) {
1803 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1804 const size_t inp_off
1805 = inp_offset(icb, ohb, kd, kh, kw);
1806 safe_tileloadd(Tmm(get_inp_tensor(ohb, tail)),
1807 reg_inp_ptr, inp_off, reg_inp_stride);
1808 }
1809 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1810 const size_t wei_off
1811 = wei_offset(icb, ocb, kd, kh, kw);
1812 safe_tileloadd(Tmm(get_wei_tensor(ocb)),
1813 reg_wei_ptr, wei_off, reg_wei_stride);
1814 for (int ohb = 0; ohb < jcp.nb_oh_blocking; ohb++) {
1815 tdpbxxd(Tmm(get_out_tensor(ohb, ocb, tail)),
1816 Tmm(get_inp_tensor(ohb, tail)),
1817 Tmm(get_wei_tensor(ocb)));
1818 interleave_store(width, t_pad_output,
1819 b_pad_output, zp_3d_pad);
1820 }
1821 }
1822 }
1823 }
1824 }
1825 if (check_kd_padding) pop(reg_kd);
1826 }
1827 L(kd_skip_compute);
1828 }
1829
1830 store_output(width, tail, do_store, handle_h_blk, t_pad_output,
1831 b_pad_output, l_pad_output, r_pad_output, is_last_oh_block,
1832 zp_3d_pad);
1833
1834 add(reg_inp_ptr, get_inp_shift());
1835 }
1836
dispatch_icb_loop(int width,bool do_store,const int l_pad_output,const int r_pad_output,const bool zp_3d_pad)1837 void jit_avx512_core_amx_fwd_kernel_t::dispatch_icb_loop(int width,
1838 bool do_store, const int l_pad_output, const int r_pad_output,
1839 const bool zp_3d_pad) {
1840 if (jcp.req_zero_point_buffer
1841 && (jcp.t_pad_output > 0 || jcp.b_pad_output > 0)) {
1842 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
1843 const size_t height_limit = reduce_to_blocked_dims(
1844 jcp.oh, oh_step_size, jcp.t_pad_output, jcp.b_pad_output);
1845 const int ur_h = div_up(height_limit, oh_step_size);
1846 assert(6 >= ur_h);
1847
1848 // Use a jump-table to execute the corresponding block
1849 Label h_blk_label[6], h_blk_end_label, jmp_table_label;
1850 mov(reg_jmp_blk, ptr[param1 + GET_OFF(ohb)]);
1851 mov(reg_tmp, jmp_table_label);
1852 jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1853 jmp(h_blk_end_label, T_NEAR); // error, shouldn't happen
1854
1855 align(8);
1856 L(jmp_table_label);
1857 for (int u = 0; u < ur_h; ++u) {
1858 putL(h_blk_label[u]);
1859 }
1860
1861 // Save value of global variables for the next 'h_blk' iteration
1862 const int local_prv_width = prv_width_;
1863 const int local_row_count = row_count_;
1864 const bool local_is_store_done = is_store_done_;
1865 const bool local_is_buffer_empty = is_buffer_empty_;
1866
1867 // Unroll ow_block with regards to l_pad_output and r_pad_output
1868 int cur_t_pad = reduce_to_block(oh_step_size, jcp.t_pad_output);
1869 int cur_b_pad = height_limit
1870 - reduce_to_block(oh_step_size, jcp.b_pad_output);
1871 for (int u = 0; u < ur_h; u++) {
1872 bool last = u == ur_h - 1;
1873 L(h_blk_label[u]);
1874
1875 // restore to previous 'h_blk' state of variables
1876 prv_width_ = local_prv_width;
1877 row_count_ = local_row_count;
1878 is_store_done_ = local_is_store_done;
1879 is_buffer_empty_ = local_is_buffer_empty;
1880 compute_icb_loop(width, do_store, false, cur_t_pad, cur_b_pad,
1881 l_pad_output, r_pad_output, zp_3d_pad, last);
1882 cur_t_pad = nstl::max(0, cur_t_pad - oh_step_size);
1883 cur_b_pad = nstl::max(0, cur_b_pad - oh_step_size);
1884 if (!last) jmp(h_blk_end_label, T_NEAR);
1885 }
1886 L(h_blk_end_label);
1887 } else {
1888 compute_icb_loop(width, do_store, true, 0, jcp.oh, l_pad_output,
1889 r_pad_output, zp_3d_pad);
1890 }
1891 }
1892
dispatch_zp_3d_compute(int width,bool do_store,const int l_pad_output,const int r_pad_output)1893 void jit_avx512_core_amx_fwd_kernel_t::dispatch_zp_3d_compute(int width,
1894 bool do_store, const int l_pad_output, const int r_pad_output) {
1895 if (jcp.req_zero_point_buffer && (jcp.f_pad > 0 || jcp.back_pad > 0)) {
1896 Label compute_3d_zp_label, zp_d_end_label;
1897 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
1898 cmp(reg_kd, jcp.kd);
1899 jne(compute_3d_zp_label, T_NEAR);
1900
1901 // Save value of global variables for next 'dispatch_icb_loop'
1902 const int local_prv_width = prv_width_;
1903 const int local_row_count = row_count_;
1904 const bool local_is_store_done = is_store_done_;
1905 const bool local_is_buffer_empty = is_buffer_empty_;
1906 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1907
1908 jmp(zp_d_end_label, T_NEAR);
1909 L(compute_3d_zp_label);
1910
1911 prv_width_ = local_prv_width;
1912 row_count_ = local_row_count;
1913 is_store_done_ = local_is_store_done;
1914 is_buffer_empty_ = local_is_buffer_empty;
1915 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, true);
1916
1917 L(zp_d_end_label);
1918 } else
1919 dispatch_icb_loop(width, do_store, l_pad_output, r_pad_output, false);
1920 }
1921
compute_ow_loop()1922 void jit_avx512_core_amx_fwd_kernel_t::compute_ow_loop() {
1923 auto compute_ow_loop_body = [=](bool last_owb, int num_tile_blocks,
1924 const int l_pad_output,
1925 const int r_pad_output) {
1926 int cur_l_pad_output = l_pad_output;
1927 int cur_r_pad_output = r_pad_output;
1928 int gen_tile_tail = last_owb && jcp.tile_tail > 0 ? jcp.tile_tail
1929 : jcp.tile_width;
1930 init_runtime_counters(last_owb && num_tile_blocks == 1);
1931 for (int owb = 0; owb < num_tile_blocks - 1; owb++) {
1932 dispatch_zp_3d_compute(
1933 jcp.tile_width, false, cur_l_pad_output, cur_r_pad_output);
1934 cur_l_pad_output = nstl::max(0, cur_l_pad_output - jcp.tile_width);
1935 cur_r_pad_output = nstl::max(0, cur_r_pad_output - jcp.tile_width);
1936 }
1937 dispatch_zp_3d_compute(
1938 gen_tile_tail, true, cur_l_pad_output, cur_r_pad_output);
1939 };
1940
1941 assert(jcp.nb_ow > 0);
1942 if (jcp.nb_ow == 1) {
1943 const int ow_r_pad_start
1944 = nstl::max(jcp.ow - jcp.r_pad_output, jcp.l_pad_output);
1945 compute_ow_loop_body(
1946 true, jcp.ow_blocks, jcp.l_pad_output, ow_r_pad_start);
1947 } else if (jcp.req_zero_point_buffer
1948 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0)) {
1949
1950 const size_t zp_addr_shift
1951 = jcp.ngroups * jcp.oc_without_padding * sizeof(int32_t);
1952 const int ow_step_size = jcp.ow_block;
1953 const int ow_blocks_per_call = div_up(ow_step_size, jcp.tile_width);
1954 const int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call == 0
1955 ? ow_blocks_per_call
1956 : jcp.ow_blocks % ow_blocks_per_call;
1957 const int width_limit = reduce_to_blocked_dims(
1958 jcp.ow, ow_step_size, jcp.l_pad_output, jcp.r_pad_output);
1959 const int ur_w = div_up(width_limit, ow_step_size);
1960 assert(6 >= ur_w);
1961 // Use a jump-table to execute the corresponding block
1962 Label w_blk_label[6], w_blk_end_label, jmp_table_label;
1963 mov(reg_jmp_blk, ptr[param1 + GET_OFF(owb)]);
1964 mov(reg_tmp, jmp_table_label);
1965 jmp(ptr[reg_tmp + reg_jmp_blk * sizeof(void *)]);
1966 jmp(w_blk_end_label, T_NEAR); // error, shouldn't happen
1967
1968 align(8);
1969 L(jmp_table_label);
1970 for (int u = 0; u < ur_w; ++u) {
1971 putL(w_blk_label[u]);
1972 }
1973
1974 // Unroll ow_block with regards to l_pad_output and r_pad_output
1975 int cur_l_pad = reduce_to_block(ow_step_size, jcp.l_pad_output);
1976 int cur_r_pad
1977 = width_limit - reduce_to_block(ow_step_size, jcp.r_pad_output);
1978 int zp_offset = 0;
1979 for (int u = 0; u < ur_w; u++) {
1980 const bool last = u == ur_w - 1;
1981 L(w_blk_label[u]);
1982 if (u > 0) add(reg_zero_point_pbuff, zp_offset * zp_addr_shift);
1983 compute_ow_loop_body(last,
1984 last ? last_owb_tile_blocks : ow_blocks_per_call, cur_l_pad,
1985 cur_r_pad);
1986 zp_offset += accum_with_upper_bound(
1987 ow_step_size, cur_l_pad, cur_r_pad);
1988 cur_l_pad = nstl::max(0, cur_l_pad - ow_step_size);
1989 cur_r_pad = nstl::max(0, cur_r_pad - ow_step_size);
1990 if (!last) jmp(w_blk_end_label, T_NEAR);
1991 }
1992 L(w_blk_end_label);
1993
1994 } else {
1995 assert(jcp.oh_per_tile == 1);
1996 Label label_done;
1997 int ow_blocks_per_call = utils::div_up(jcp.ow_block, jcp.tile_width);
1998 int last_owb_tile_blocks = jcp.ow_blocks % ow_blocks_per_call;
1999 if (last_owb_tile_blocks == 0 && jcp.tile_tail > 0)
2000 last_owb_tile_blocks = ow_blocks_per_call;
2001 if (last_owb_tile_blocks > 0) {
2002 Label label_not_last_owb;
2003 mov(reg_tmp, ptr[param1 + GET_OFF(owb)]);
2004 cmp(reg_tmp, jcp.nb_ow - 1);
2005 jne(label_not_last_owb, T_NEAR);
2006
2007 compute_ow_loop_body(true, last_owb_tile_blocks, 0, jcp.ow);
2008
2009 jmp(label_done, T_NEAR);
2010
2011 L(label_not_last_owb);
2012 }
2013 compute_ow_loop_body(false, ow_blocks_per_call, 0, jcp.ow);
2014
2015 L(label_done);
2016 }
2017 }
2018
generate()2019 void jit_avx512_core_amx_fwd_kernel_t::generate() {
2020 preamble();
2021
2022 mov(reg_inp_ptr, ptr[param1 + GET_OFF(src)]);
2023 mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]);
2024 mov(reg_out_ptr, ptr[param1 + GET_OFF(dst)]);
2025 mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
2026 if (jcp.req_zero_point_buffer)
2027 mov(reg_zero_point_pbuff, ptr[param1 + GET_OFF(zero_point_pbuff)]);
2028
2029 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
2030 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
2031
2032 const int fac = jcp.is_relo ? jcp.stride_w * jcp.kh
2033 : jcp.is_pbuffer_strided ? 1 : jcp.stride_w;
2034 const int inp_stride = fac * jcp.ic_block_int_np * jcp.typesize_in;
2035 const int wei_stride = jcp.oc_block * jcp.typesize_acc;
2036 mov(reg_inp_stride, inp_stride);
2037 mov(reg_wei_stride, wei_stride);
2038
2039 if (jcp.is_nspc && jcp.oc_without_padding != jcp.oc) {
2040 // Use mask 0xF by default for all output data and post-ops
2041 // loads / stores with block index
2042 // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1)
2043 // TODO: use masked loads / stores for the last occ only
2044 int current_block_size = jcp.oc_block;
2045 int mask = (1 << current_block_size) - 1;
2046 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
2047 mov(regw_tmp, mask);
2048 kmovw(ktail_mask, regw_tmp);
2049 Xbyak::Label mask_is_set;
2050 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
2051 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
2052 jne(mask_is_set, T_NEAR);
2053 // Reset the mask
2054 current_block_size = jcp.oc_without_padding % jcp.oc_block;
2055 mask = (1 << current_block_size) - 1;
2056 mov(regw_tmp, mask);
2057 kmovw(ktail_mask, regw_tmp);
2058
2059 L(mask_is_set);
2060 }
2061 compute_ow_loop();
2062
2063 postamble();
2064
2065 if (jcp.with_eltwise) postops_injector_->prepare_table();
2066 }
2067
tile_configure(char * tcfg_buff)2068 void jit_avx512_core_amx_fwd_kernel_t::tile_configure(char *tcfg_buff) {
2069 const int vnni_width = jcp.src_dt == data_type::bf16 ? 2 : 4;
2070 // Input tile dimensions
2071 const int a_col = jcp.is_relo ? jcp.ic_block_int
2072 : jcp.ic_block_int_np * jcp.kw_per_tile;
2073 // Weights tile dimensions
2074 const int b_col = jcp.oc_block * vnni_width;
2075 const int b_row = a_col / vnni_width;
2076 // Accumulator tile dimensions
2077 const int c_col = 16;
2078
2079 for (size_t i = 0; i < 64; i++)
2080 tcfg_buff[i] = 0;
2081
2082 // Weights (W_BASE) Tensor Tiles
2083 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2084 tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
2085 b_row, b_col * jcp.typesize_in);
2086
2087 // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
2088 for (int h = 0; h < jcp.nb_oh_blocking; h++) {
2089 tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
2090 jcp.tile_width, a_col * jcp.typesize_in);
2091 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2092 tc_configure_tile((palette_config_t *)tcfg_buff,
2093 get_out_tensor(h, i), jcp.tile_width,
2094 c_col * jcp.typesize_acc);
2095 }
2096 if (jcp.tile_tail != 0) {
2097 assert(jcp.nb_oh_blocking == 1);
2098 assert(jcp.oh_per_tile == 1);
2099 assert(jcp.ow > jcp.tile_width);
2100 tc_configure_tile((palette_config_t *)tcfg_buff,
2101 get_inp_tensor(0, true), jcp.tile_tail,
2102 a_col * jcp.typesize_in);
2103 for (int i = 0; i < jcp.nb_oc_blocking; i++)
2104 tc_configure_tile((palette_config_t *)tcfg_buff,
2105 get_out_tensor(0, i, true), jcp.tile_tail,
2106 c_col * jcp.typesize_acc);
2107 }
2108
2109 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
2110 }
2111
set_oh_blk_limits(jit_conv_conf_t & jcp)2112 void jit_avx512_core_amx_fwd_kernel_t::set_oh_blk_limits(jit_conv_conf_t &jcp) {
2113
2114 constexpr int size = sizeof(jcp.h_blk_limits) / sizeof(jcp.h_blk_limits[0]);
2115 // set default values
2116 for (int i = 0; i < size; i++)
2117 jcp.h_blk_limits[i] = jcp.oh;
2118
2119 const bool calculate_oh_limits
2120 = jcp.t_pad_output > 0 || jcp.b_pad_output > 0;
2121 if (jcp.req_zero_point_buffer && calculate_oh_limits) {
2122
2123 int limit_idx = 0;
2124 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2125
2126 // full t_pad output block
2127 const int t_pad_blk_end = rnd_dn(jcp.t_pad_output, oh_step_size);
2128 if (jcp.t_pad_output >= oh_step_size) {
2129 jcp.h_blk_limits[limit_idx++] = t_pad_blk_end;
2130 }
2131 // t_pad output overlap with no padding
2132 const int t_pad_shift = jcp.t_pad_output % oh_step_size;
2133 if (t_pad_shift != 0) {
2134 jcp.h_blk_limits[limit_idx++] = t_pad_blk_end + t_pad_shift;
2135 }
2136 const int t_pad_next_blk = rnd_up(jcp.t_pad_output, oh_step_size);
2137 const int oh_blk_tail = jcp.oh % oh_step_size;
2138 const int b_pad_no_tail = nstl::max(0, jcp.b_pad_output - oh_blk_tail);
2139 const int b_pad_start
2140 = nstl::max(jcp.t_pad_output, jcp.oh - jcp.b_pad_output);
2141 const int b_pad_blk_start = rnd_dn(b_pad_start, oh_step_size);
2142 // middle block without padding
2143 const int mid_blk = nstl::max(0, b_pad_blk_start - t_pad_next_blk);
2144 if (mid_blk >= oh_step_size) {
2145 jcp.h_blk_limits[limit_idx++] = b_pad_blk_start;
2146 }
2147 // no padding with b_pad overlap
2148 const int b_pad_shift = b_pad_no_tail % oh_step_size;
2149 if (b_pad_shift != 0) {
2150 jcp.h_blk_limits[limit_idx++] = rnd_up(b_pad_start, oh_step_size);
2151 }
2152 // full b_pad output block
2153 if (b_pad_no_tail >= oh_step_size) {
2154 jcp.h_blk_limits[limit_idx++] = jcp.oh - oh_blk_tail;
2155 }
2156 // b_pad tail block does not require a limit
2157 }
2158 }
2159
set_ow_blk_limits(jit_conv_conf_t & jcp)2160 void jit_avx512_core_amx_fwd_kernel_t::set_ow_blk_limits(jit_conv_conf_t &jcp) {
2161
2162 jcp.l_pad_blk = 0;
2163 jcp.no_pad_w_blk = 0;
2164 jcp.r_pad_blk = 0;
2165
2166 const bool calculate_ow_limits
2167 = jcp.nb_ow > 1 && (jcp.l_pad_output > 0 || jcp.r_pad_output > 0);
2168 if (jcp.req_zero_point_buffer && calculate_ow_limits) {
2169 const int ow_step_size = jcp.ow_block;
2170
2171 // l_pad
2172 const int l_pad_limit
2173 = (jcp.l_pad_output >= ow_step_size ? ow_step_size : 0)
2174 + (jcp.l_pad_output % ow_step_size);
2175 const int l_pad_area_blk = rnd_up(l_pad_limit, ow_step_size);
2176 jcp.l_pad_blk = div_up(l_pad_limit, ow_step_size);
2177
2178 // middle (area without padding)
2179 const int no_pad_area
2180 = nstl::max(0, jcp.ow - l_pad_area_blk - jcp.r_pad_output);
2181 jcp.no_pad_w_blk = no_pad_area >= ow_step_size ? 1 : 0;
2182
2183 // r_pad
2184 const int no_pad_area_shift = no_pad_area % ow_step_size;
2185 const int r_pad_area_overlap
2186 = no_pad_area_shift == 0 ? 0 : ow_step_size - no_pad_area_shift;
2187 const int r_pad_area
2188 = nstl::max(0, jcp.r_pad_output - r_pad_area_overlap);
2189 const int r_pad_limit = (r_pad_area >= ow_step_size ? ow_step_size : 0)
2190 + (r_pad_area % ow_step_size);
2191 jcp.r_pad_blk = (r_pad_area_overlap > 0 ? 1 : 0)
2192 + div_up(r_pad_limit, ow_step_size);
2193 }
2194 }
2195
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr,int nthreads)2196 status_t jit_avx512_core_amx_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp,
2197 const convolution_desc_t &cd, memory_desc_t &src_md,
2198 memory_desc_t &weights_md, memory_desc_t &dst_md,
2199 memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) {
2200 using namespace prop_kind;
2201
2202 const memory_desc_wrapper src_d(&src_md);
2203 const memory_desc_wrapper weights_d(&weights_md);
2204 const memory_desc_wrapper dst_d(&dst_md);
2205 const memory_desc_wrapper bias_d(&bias_md);
2206
2207 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
2208 int ndims = src_d.ndims();
2209 bool is_1d = ndims == 3;
2210 bool is_3d = ndims == 5;
2211
2212 const bool is_bf16_convolution
2213 = everyone_is(true, src_d.data_type() == data_type::bf16,
2214 weights_d.data_type() == data_type::bf16,
2215 one_of(dst_d.data_type(), data_type::bf16, data_type::f32));
2216 const bool is_int8_convolution = everyone_is(true,
2217 (src_d.data_type() == data_type::u8
2218 || src_d.data_type() == data_type::s8),
2219 weights_d.data_type() == data_type::s8,
2220 one_of(dst_d.data_type(), data_type::f32, data_type::s32,
2221 data_type::s8, data_type::u8));
2222
2223 bool supported = false
2224 || (is_bf16_convolution && mayiuse(avx512_core_bf16_amx_bf16))
2225 || (is_int8_convolution && mayiuse(avx512_core_bf16_amx_int8));
2226 if (!supported) return status::unimplemented;
2227
2228 jcp = zero<decltype(jcp)>();
2229 jcp.isa = is_bf16_convolution ? avx512_core_bf16_amx_bf16
2230 : avx512_core_bf16_amx_int8;
2231 jcp.ndims = ndims;
2232 jcp.prop_kind = cd.prop_kind;
2233 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2234
2235 jcp.mb = src_d.dims()[0];
2236 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
2237 jcp.oc_without_padding = jcp.oc;
2238 jcp.ic = src_d.dims()[1] / jcp.ngroups;
2239 jcp.ic_without_padding = jcp.ic;
2240 jcp.id = is_3d ? src_d.dims()[2] : 1;
2241 jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1;
2242 jcp.iw = src_d.dims()[ndims - 1];
2243 jcp.od = is_3d ? dst_d.dims()[2] : 1;
2244 jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1;
2245 jcp.ow = dst_d.dims()[ndims - 1];
2246 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
2247 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
2248 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2249 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
2250 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
2251 jcp.l_pad = cd.padding[0][ndims - 3];
2252 jcp.stride_d = is_3d ? cd.strides[0] : 1;
2253 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
2254 jcp.stride_w = cd.strides[ndims - 3];
2255 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
2256
2257 jcp.dilate_d = is_3d ? cd.dilates[ndims - 5] : 0;
2258 jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
2259 jcp.dilate_w = cd.dilates[ndims - 3];
2260
2261 const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
2262 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
2263 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
2264 jcp.back_pad = calculate_end_padding(
2265 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, gen_kd);
2266 jcp.b_pad = calculate_end_padding(
2267 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
2268 jcp.r_pad = calculate_end_padding(
2269 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
2270 if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
2271 || jcp.b_pad >= gen_kh || jcp.f_pad >= gen_kd
2272 || jcp.back_pad >= gen_kd)
2273 return status::unimplemented;
2274
2275 const int max_pad = 28; // akin to maximum jcp.ur_w value in other jits
2276 if (jcp.l_pad > max_pad || jcp.r_pad > max_pad)
2277 return status::unimplemented; // TODO: relax this restriction
2278
2279 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
2280 jcp.dst_dt = cd.dst_desc.data_type;
2281 jcp.src_dt = cd.src_desc.data_type;
2282 jcp.wei_dt = cd.weights_desc.data_type;
2283
2284 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
2285
2286 if (jcp.is_depthwise)
2287 return status::unimplemented; // TODO: add support of DW convolution
2288
2289 const auto zp = attr.zero_points_;
2290 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
2291 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
2292 jcp.zp_src_is_common = zp.common(
2293 DNNL_ARG_SRC); // otherwise, it's per-channel (not supported)
2294
2295 if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)
2296 || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point,
2297 is_int8_convolution))
2298 return status::unimplemented;
2299
2300 // Calculate zero-point padding values outside of the main JIT-kernel
2301 // and store the results in an auxiliary buffer.
2302 jcp.req_zero_point_buffer = jcp.src_zero_point
2303 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0 || jcp.t_pad > 0
2304 || jcp.f_pad > 0 || jcp.back_pad > 0);
2305
2306 format_tag_t dat_tag_ncsp = utils::pick(ndims - 3, format_tag::nCw16c,
2307 format_tag::nChw16c, format_tag::nCdhw16c);
2308 format_tag_t dat_tag_nspc = utils::pick(
2309 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
2310 // To toggle the default data layout for BF16 between nChw16c and nhwc,
2311 // swap the following two variable definitions. Current choice: nhwc.
2312
2313 // Clang-tidy change - if it was intentional please revert it and
2314 // put `NOLINTNEXTLINE` to suppress the warning.
2315 format_tag_t dat_tag_opt = dat_tag_nspc;
2316 format_tag_t dat_tag_alt
2317 = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
2318
2319 if (src_d.format_kind() == format_kind::any) {
2320 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
2321 jcp.src_tag = dat_tag_opt;
2322 } else
2323 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
2324
2325 if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
2326 return status::unimplemented;
2327
2328 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
2329 assert(IMPLICATION(is_int8_convolution, jcp.is_nspc));
2330
2331 // TODO: remove all support for nChw16c from this implementation
2332 if (!jcp.is_nspc) return status::unimplemented;
2333
2334 if (dst_d.format_kind() == format_kind::any) {
2335 CHECK(memory_desc_init_by_tag(dst_md, jcp.src_tag));
2336 jcp.dst_tag = jcp.src_tag;
2337 } else
2338 jcp.dst_tag = dst_d.matches_one_of_tag(jcp.src_tag);
2339
2340 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
2341
2342 if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
2343 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
2344
2345 jcp.nthr = nthreads;
2346
2347 jcp.ic_block = 16;
2348 jcp.oc_block = 16;
2349
2350 if (jcp.ngroups == 1) {
2351 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2352 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2353 }
2354 bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
2355 if (!args_ok) return status::unimplemented;
2356
2357 const int vnni_width = is_bf16_convolution ? 2 : 4;
2358 jcp.ic_block_int = jcp.ic_block * vnni_width; // 32 for bf16, 64 for int8
2359
2360 // fallback to non-amx impl when accumulation is too small
2361 const dim_t total_k = jcp.ic_without_padding * jcp.kd * jcp.kh * jcp.kw;
2362 const bool is_tiny_k = total_k < jcp.ic_block_int / 2;
2363 if (is_tiny_k) return status::unimplemented;
2364
2365 // small-ic parameters
2366 jcp.ic_block_int_np = jcp.is_nspc
2367 ? nstl::min(jcp.ic_block_int, jcp.ic_without_padding)
2368 : jcp.ic_block_int;
2369 bool is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2370
2371 // reduced lowering
2372 jcp.is_relo = (!is_3d)
2373 && is_small_ic
2374 // no trivial cases
2375 && 1 < jcp.kh * jcp.kw
2376 // required for use of VPERMB instruction in weights copy kernel
2377 && IMPLICATION(is_int8_convolution,
2378 cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))
2379 // no dilation or excessive stride along w-direction
2380 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
2381 // no dilation or excessive stride along h-direction
2382 && jcp.stride_h <= jcp.kh && jcp.stride_w <= jcp.kw;
2383 jcp.nreduce = jcp.kh * jcp.kw * jcp.ic_block_int_np;
2384
2385 if (!jcp.is_relo) {
2386 jcp.ic_block_int_np = is_bf16_convolution
2387 ? jcp.ic_block_int
2388 : rnd_up(jcp.ic_block_int_np, vnni_width);
2389 is_small_ic = jcp.ic_block_int_np < jcp.ic_block_int;
2390 }
2391
2392 // k-remainders
2393 jcp.kw_per_tile = is_small_ic && !jcp.is_relo && jcp.dilate_w == 0
2394 && jcp.stride_w <= jcp.kw // TODO: relax this restriction
2395 && jcp.kw * jcp.ic_block_int_np <= jcp.ic_block_int
2396 ? jcp.kw
2397 : 1;
2398 jcp.is_pbuffer_strided = (1 == jcp.kw_per_tile);
2399 jcp.n_stride_sets
2400 = jcp.is_pbuffer_strided ? nstl::min(jcp.stride_w, jcp.kw) : 1;
2401 jcp.kw_step = jcp.is_pbuffer_strided ? jcp.stride_w : jcp.kw_per_tile;
2402
2403 const auto &p = attr.post_ops_;
2404
2405 const int sum_ind = p.find(primitive_kind::sum);
2406 jcp.with_sum = sum_ind != -1;
2407 const int eltwise_ind = p.find(primitive_kind::eltwise);
2408 jcp.with_eltwise = eltwise_ind != -1;
2409 const int binary_ind = p.find(primitive_kind::binary);
2410 jcp.with_binary = binary_ind != -1;
2411
2412 jcp.post_ops = p;
2413
2414 using namespace injector;
2415 const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16);
2416 const bool sum_requires_scale_one = sum_at_pos_0_only;
2417 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
2418 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
2419 {broadcasting_strategy_t::scalar,
2420 broadcasting_strategy_t::per_oc}});
2421 if (!post_ops_ok_) return status::unimplemented;
2422
2423 auto set_or_check_wei_format = [&]() {
2424 using namespace format_tag;
2425 using namespace memory_extra_flags;
2426 format_tag_t wei_tag;
2427 wei_tag = jcp.is_relo ? pick(with_groups + 2 * (ndims - 3), Owi16o,
2428 gOwi16o, Owhi16o, gOwhi16o) // no 3d support
2429 : is_bf16_convolution
2430 ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i,
2431 gOIw16i16o2i, OIhw16i16o2i, gOIhw16i16o2i,
2432 OIdhw16i16o2i, gOIdhw16i16o2i)
2433 : is_small_ic ? pick(with_groups + 2 * (ndims - 3),
2434 OwI16o4i, gOwI16o4i, OhwI16o4i, gOhwI16o4i,
2435 OdhwI16o4i, gOdhwI16o4i)
2436 : pick(with_groups + 2 * (ndims - 3),
2437 OIw16i16o4i, gOIw16i16o4i,
2438 OIhw16i16o4i, gOIhw16i16o4i,
2439 OIdhw16i16o4i, gOIdhw16i16o4i);
2440
2441 memory_desc_t want_wei_md = weights_md;
2442 memory_desc_init_by_tag(want_wei_md, wei_tag);
2443
2444 if (jcp.src_zero_point) {
2445 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
2446 want_wei_md.extra.asymm_compensation_mask = (1 << 0)
2447 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
2448 }
2449 if (weights_md.format_kind == format_kind::any) {
2450 weights_md = want_wei_md;
2451 return true;
2452 }
2453 return weights_md == want_wei_md;
2454 };
2455
2456 if (!set_or_check_wei_format()) return status::unimplemented;
2457
2458 jcp.typesize_in = types::data_type_size(src_d.data_type());
2459 jcp.typesize_out = types::data_type_size(dst_d.data_type());
2460 jcp.typesize_bia
2461 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
2462 jcp.typesize_acc = sizeof(int32_t);
2463
2464 jcp.nb_ic = jcp.ic / jcp.ic_block;
2465 jcp.nb_oc = jcp.oc / jcp.oc_block;
2466 jcp.nb_ic_int = div_up(jcp.ic, jcp.ic_block_int);
2467
2468 jcp.nb_oc_blocking_thr_chunk = 1;
2469
2470 const int max_palette = amx::get_max_palette();
2471 jcp.max_tiles = amx::get_max_tiles(max_palette);
2472 jcp.full_tile_width = amx::get_max_rows(max_palette);
2473 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
2474 return status::unimplemented;
2475
2476 // Pack n rows per tile, such that:
2477 // ow + (ow + gen_kw - 1) * (n - 1) <= jcp.full_tile_width
2478 auto calculate_tile_width = [&](int n) {
2479 assert(n > 0);
2480 return jcp.ow + (gen_kw + jcp.ow - 1) * (n - 1);
2481 };
2482 const bool ok_to_pack_tile = !jcp.is_relo
2483 && (utils::everyone_is(1, jcp.kh, jcp.kw)
2484 || utils::everyone_is(1, jcp.stride_h, jcp.stride_w));
2485 const int max_oh_per_tile
2486 = 1 + (jcp.full_tile_width - jcp.ow) / (jcp.ow + gen_kw - 1);
2487 jcp.oh_per_tile = ok_to_pack_tile
2488 ? nstl::min(jcp.oh, nstl::max(1, max_oh_per_tile))
2489 : 1;
2490 jcp.tile_width = nstl::min<int>(
2491 jcp.full_tile_width, calculate_tile_width(jcp.oh_per_tile));
2492 jcp.ow_blocks = utils::div_up(jcp.ow, jcp.tile_width);
2493
2494 // Prefer to use a single tile width when possible
2495 // (eg ow28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
2496 if (jcp.oh_per_tile == 1 && jcp.ow % jcp.ow_blocks == 0)
2497 jcp.tile_width = jcp.ow / jcp.ow_blocks;
2498 jcp.tile_tail = jcp.oh_per_tile == 1 ? jcp.ow % jcp.tile_width : 0;
2499
2500 jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1;
2501 jcp.nb_ic_blocking = 1;
2502 jcp.nb_oh_blocking
2503 = utils::everyone_is(true, jcp.tile_tail == 0,
2504 // requirement for interleave stores
2505 IMPLICATION(jcp.ow_blocks > 1, jcp.oh % 2 == 0),
2506 // requirement for small spatial
2507 utils::div_up(jcp.oh, jcp.oh_per_tile) > 1,
2508 // choose maximal pbuffer overlap for reduced lowering
2509 !jcp.is_relo)
2510 ? 2
2511 : 1;
2512
2513 // TODO: tune oh blocking
2514 const int oh_blk_size_param = jcp.is_relo ? 1 : 10;
2515 const int oh_step_size = jcp.nb_oh_blocking * jcp.oh_per_tile;
2516 const int oh_blk_size = rnd_up(oh_blk_size_param, oh_step_size);
2517 jcp.oh_blk_size = rnd_up(nstl::min(jcp.oh, oh_blk_size), oh_step_size);
2518 // Here ihp means the input buffer height including padding (ie the number
2519 // of input rows required for computation of jcp.oh_blk_size output rows.
2520 // If an input row doesn't participate in the computation of any output row,
2521 // it isn't copied to the buffer at all (eg jcp.stride_h > gen_kh).
2522 jcp.ihp = jcp.is_relo
2523 ? jcp.oh_blk_size
2524 : (jcp.oh_blk_size - 1) * nstl::min(jcp.stride_h, gen_kh) + gen_kh;
2525
2526 // TODO: tune ow blocking
2527 const int ow_blocks_per_call = jcp.is_relo ? 10 : 2;
2528 jcp.ow_block = nstl::min(jcp.ow, jcp.tile_width * ow_blocks_per_call);
2529 jcp.nb_ow = utils::div_up(jcp.ow, jcp.ow_block);
2530 // iwp includes all width elements that are really used in calculation
2531 // including left and right zero padding
2532 const bool are_sets_interleaved
2533 = IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1);
2534 jcp.iwp = are_sets_interleaved
2535 ? (jcp.ow_block - 1) * nstl::min(jcp.stride_w, jcp.kw) + gen_kw
2536 : jcp.ow_block * jcp.kw;
2537
2538 // Number of ops per tile store
2539 int ops_tile_store = jcp.tile_width;
2540 // Number of ops per accumulation tile
2541 int avaliable_ops = jcp.is_relo
2542 ? utils::div_up(jcp.nreduce, jcp.ic_block_int)
2543 : jcp.nb_ic_int * jcp.kh * (jcp.kw / jcp.kw_per_tile);
2544 // Number of vectors to store per tile operation
2545 // NOTE: set to zero to turn off interleave store (mostly for debugging)
2546 jcp.per_one_pstore = utils::div_up(ops_tile_store, avaliable_ops);
2547
2548 if (jcp.is_relo) {
2549 jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.ihp * jcp.iwp * jcp.kh
2550 * jcp.ic_block_int_np
2551 // pbuffer pointer shifts each oh step for reduced-lowering
2552 + (jcp.oh - 1) * jcp.stride_h * jcp.ic_block_int_np
2553 // extra $line due to pbuffer writing full Zmm
2554 + jcp.ic_block_int;
2555 } else {
2556 jcp.inp_buffer_size = (size_t)jcp.nb_ic_int * jcp.kd
2557 * ((size_t)jcp.ihp * jcp.iwp * jcp.ic_block_int_np
2558 // extra $line due to pbuffer writing full Zmm
2559 + jcp.ic_block_int);
2560 }
2561 jcp.wei_buffer_size = (size_t)jcp.ngroups * jcp.nb_oc
2562 * rnd_up(jcp.kh * jcp.kw * jcp.ic * jcp.oc_block, 1024);
2563 jcp.wsp_buffer_size = (size_t)jcp.nb_oh_blocking * jcp.nb_oc_blocking
2564 * jcp.full_tile_width * jcp.oc_block;
2565
2566 const auto &oscales = attr.output_scales_;
2567 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
2568
2569 // Note: currently unsupported, results in seg-fault
2570 const int l_pad_output = nstl::min(jcp.ow, div_up(jcp.l_pad, jcp.stride_w));
2571 if (!jcp.is_relo && (l_pad_output > jcp.ow_block))
2572 return status::unimplemented;
2573
2574 // Relevant to 'zero_point padding buffer' (pbuff) jit kernel
2575 if (jcp.req_zero_point_buffer) {
2576 auto calculate_output_padding_dims = [=](int o_dim, int s_pad,
2577 int e_pad,
2578 int &s_pad_output,
2579 int &e_pad_output,
2580 bool &o_mid, int &o_pad,
2581 int stride,
2582 bool req_mid_area) {
2583 s_pad_output = nstl::min(o_dim, div_up(s_pad, stride));
2584 e_pad_output = nstl::min(o_dim, div_up(e_pad, stride));
2585 o_mid = (o_dim - s_pad_output - e_pad_output > 0) && req_mid_area;
2586 o_pad = nstl::min(o_dim,
2587 nstl::max(1, s_pad_output + e_pad_output + (int)o_mid));
2588 };
2589
2590 const bool mid_w_area = (jcp.l_pad > 0 || jcp.r_pad > 0)
2591 && (jcp.t_pad > 0 || jcp.b_pad > 0 || jcp.f_pad > 0
2592 || jcp.back_pad > 0);
2593 const bool mid_h_area = (jcp.t_pad > 0 || jcp.b_pad > 0)
2594 && (jcp.l_pad > 0 || jcp.r_pad > 0 || jcp.f_pad > 0
2595 || jcp.back_pad > 0);
2596 const bool mid_d_area = (jcp.f_pad > 0 || jcp.back_pad > 0)
2597 && (jcp.r_pad > 0 || jcp.l_pad > 0 || jcp.b_pad > 0
2598 || jcp.t_pad > 0);
2599 calculate_output_padding_dims(jcp.ow, jcp.l_pad, jcp.r_pad,
2600 jcp.l_pad_output, jcp.r_pad_output, jcp.ow_mid, jcp.ow_pad,
2601 jcp.stride_w, mid_w_area);
2602 calculate_output_padding_dims(jcp.oh, jcp.t_pad, jcp.b_pad,
2603 jcp.t_pad_output, jcp.b_pad_output, jcp.oh_mid, jcp.oh_pad,
2604 jcp.stride_h, mid_h_area);
2605 calculate_output_padding_dims(jcp.od, jcp.f_pad, jcp.back_pad,
2606 jcp.f_pad_output, jcp.back_pad_output, jcp.od_mid, jcp.od_pad,
2607 jcp.stride_d, mid_d_area);
2608 jcp.zp_pbuff_size
2609 = jcp.od_pad * jcp.oh_pad * jcp.ow_pad * jcp.oc * jcp.ngroups;
2610
2611 // compute zero-point padding kernel outside of the main parallel
2612 // region when threads are more likely to parallelize work across mb
2613 // within the convolution compute block.
2614 jcp.zp_pbuff_outer_compute = jcp.mb > 1 || is_3d;
2615
2616 const bool params_ok = ((jcp.ow_pad - (int)jcp.ow_mid) <= max_pad * 2);
2617 if (!params_ok) { return status::unimplemented; }
2618 }
2619
2620 // Set default parameters for driver code, but mostly required for
2621 // 'zero_point padding buffer' (pbuff) accumulation over output tensor
2622 set_oh_blk_limits(jcp);
2623 set_ow_blk_limits(jcp);
2624
2625 return status::success;
2626 }
2627
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)2628 status_t jit_avx512_core_amx_fwd_kernel_t::init_scratchpad(
2629 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
2630 const primitive_attr_t &attr) {
2631
2632 size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
2633 scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
2634 if (jcp.is_relo) {
2635 scratchpad.book(
2636 key_conv_amx_wei_buffer, jcp.wei_buffer_size, jcp.typesize_in);
2637 }
2638
2639 size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
2640 scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
2641 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
2642 assert(jcp.ngroups == 1);
2643 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
2644 }
2645 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
2646 if (jcp.req_zero_point_buffer) {
2647 const int nthr = jcp.zp_pbuff_outer_compute ? 1 : jcp.nthr;
2648 scratchpad.book(key_conv_zero_point_pad,
2649 (size_t)nthr * jcp.zp_pbuff_size, sizeof(int32_t));
2650 if (!jcp.zp_pbuff_outer_compute) {
2651 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
2652 scratchpad.book<bool>(key_conv_zero_point_flag,
2653 (size_t)jcp.nthr * oc_chunks * jcp.ngroups);
2654 }
2655 }
2656
2657 // Keep scratchpad memory footprint under control
2658 const size_t L2_size_per_core = platform::get_per_core_cache_size(2);
2659 const size_t L3_size_per_core = platform::get_per_core_cache_size(3);
2660 const size_t max_scratchpad_size
2661 = jcp.nthr * (L2_size_per_core + L3_size_per_core);
2662 // TODO: tune this relationship as needed
2663 if (scratchpad.size() > max_scratchpad_size) return status::unimplemented;
2664 return status::success;
2665 }
2666
copy_row(const bool is_masked)2667 void jit_avx512_core_amx_bwd_data_copy_kernel_t::copy_row(
2668 const bool is_masked) {
2669 assert(jcp.is_nspc && "no support for nChw16c in this copy kernel");
2670
2671 const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
2672 const int inp_w_step
2673 = jcp.ngroups * jcp.oc_without_padding * jcp.typesize_in;
2674 const int inp_h_step = jcp.ow * inp_w_step;
2675 const int out_w_step = jcp.oc_block_int * jcp.typesize_in;
2676 const int out_h_step = jcp.owp * out_w_step;
2677
2678 auto zero_it = [=](reg64_t tmp_out_ptr, int offset) {
2679 // no mask as output is a padded buffer
2680 if (is_bf16)
2681 vmovdqu16(ptr[tmp_out_ptr + offset], zmm_zero);
2682 else
2683 vmovdqu8(ptr[tmp_out_ptr + offset], zmm_zero);
2684 };
2685
2686 auto copy_it = [=](reg64_t tmp_inp_ptr, int inp_off, reg64_t tmp_out_ptr,
2687 int out_off) {
2688 Zmm zmm_load = is_masked ? zmm_tmp | ktail_mask | T_z : zmm_tmp;
2689 Zmm zmm_stor = zmm_tmp; // no mask as output is padded buffer
2690 if (is_bf16) {
2691 vmovdqu16(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2692 vmovdqu16(ptr[tmp_out_ptr + out_off], zmm_stor);
2693 } else {
2694 vmovdqu8(zmm_load, ptr[tmp_inp_ptr + inp_off]);
2695 vmovdqu8(ptr[tmp_out_ptr + out_off], zmm_stor);
2696 }
2697 };
2698
2699 mov(reg_ptr_aux_out, reg_ptr_out);
2700
2701 { // Handle Top Overflow
2702 Label label_tov_loop, label_tov_skip;
2703 test(reg_tov, reg_tov);
2704 jz(label_tov_skip, T_NEAR);
2705 mov(reg_cnt_tmp, reg_tov);
2706 L(label_tov_loop);
2707 {
2708 for (int ow = 0; ow < jcp.owp; ow++) {
2709 const int offset = ow * out_w_step;
2710 zero_it(reg_ptr_aux_out, offset);
2711 }
2712 add(reg_ptr_aux_out, out_h_step);
2713 dec(reg_cnt_tmp);
2714 jnz(label_tov_loop, T_NEAR);
2715 }
2716 L(label_tov_skip);
2717 }
2718
2719 mov(reg_ptr_aux_inp_h, reg_ptr_inp);
2720
2721 // Handle Middle Loop
2722 Label label_khp_loop, label_khp_skip;
2723 test(reg_khp, reg_khp);
2724 jz(label_khp_skip, T_NEAR);
2725 mov(reg_cnt_khp, reg_khp);
2726 L(label_khp_loop);
2727 {
2728 Label label_lov, label_lov_skip;
2729 Label label_kwp, label_kwp_skip;
2730 Label label_rov, label_rov_skip;
2731 test(reg_lov, reg_lov);
2732 jnz(label_lov, T_NEAR);
2733 test(reg_kwp, reg_kwp);
2734 jnz(label_kwp, T_NEAR);
2735 test(reg_rov, reg_rov);
2736 jnz(label_rov, T_NEAR);
2737
2738 test(reg_lov, reg_lov);
2739 jz(label_lov_skip, T_NEAR); // not really needed, but just to be safe
2740 L(label_lov); // Handle Left Overflow
2741 {
2742 Label label_lov_loop;
2743 mov(reg_cnt_tmp, reg_lov);
2744 L(label_lov_loop);
2745 {
2746 zero_it(reg_ptr_aux_out, 0);
2747 add(reg_ptr_aux_out, out_w_step);
2748 dec(reg_cnt_tmp);
2749 jnz(label_lov_loop, T_NEAR);
2750 }
2751 }
2752 L(label_lov_skip);
2753
2754 test(reg_kwp, reg_kwp);
2755 jz(label_kwp_skip, T_NEAR);
2756 L(label_kwp); // Handle Center Loop
2757 {
2758 Label label_kwp_loop;
2759 mov(reg_ptr_aux_inp_w, reg_ptr_aux_inp_h);
2760 mov(reg_cnt_tmp, reg_kwp);
2761 L(label_kwp_loop);
2762 {
2763 copy_it(reg_ptr_aux_inp_w, 0, reg_ptr_aux_out, 0);
2764 add(reg_ptr_aux_out, out_w_step);
2765 add(reg_ptr_aux_inp_w, inp_w_step);
2766 dec(reg_cnt_tmp);
2767
2768 if (jcp.stride_w > 1) {
2769 jz(label_kwp_skip, T_NEAR);
2770 // Handle Dilation-by-Stride
2771 for (int sw = 0; sw < jcp.stride_w - 1; sw++) {
2772 const int offset = sw * out_w_step;
2773 zero_it(reg_ptr_aux_out, offset);
2774 }
2775 add(reg_ptr_aux_out, (jcp.stride_w - 1) * out_w_step);
2776 if (jcp.stride_w == 2)
2777 dec(reg_cnt_tmp);
2778 else
2779 sub(reg_cnt_tmp, jcp.stride_w - 1);
2780 jmp(label_kwp_loop, T_NEAR);
2781 } else {
2782 jnz(label_kwp_loop, T_NEAR);
2783 }
2784 }
2785 }
2786 L(label_kwp_skip);
2787
2788 test(reg_rov, reg_rov);
2789 jz(label_rov_skip, T_NEAR);
2790 L(label_rov); // Handle Right Overflow
2791 {
2792 Label label_rov_loop;
2793 mov(reg_cnt_tmp, reg_rov);
2794 L(label_rov_loop);
2795 {
2796 zero_it(reg_ptr_aux_out, 0);
2797 add(reg_ptr_aux_out, out_w_step);
2798 dec(reg_cnt_tmp);
2799 jnz(label_rov_loop, T_NEAR);
2800 }
2801 }
2802 L(label_rov_skip);
2803
2804 add(reg_ptr_aux_inp_h, inp_h_step);
2805 dec(reg_cnt_khp);
2806
2807 if (jcp.stride_h > 1) {
2808 jz(label_khp_skip, T_NEAR);
2809 // Handle Dilation-by-Stride
2810 for (int sh = 0; sh < jcp.stride_h - 1; sh++) {
2811 for (int ow = 0; ow < jcp.owp; ow++) {
2812 const int offset = sh * out_h_step + ow * out_w_step;
2813 zero_it(reg_ptr_aux_out, offset);
2814 }
2815 }
2816 add(reg_ptr_aux_out, (jcp.stride_h - 1) * out_h_step);
2817 if (jcp.stride_h == 2)
2818 dec(reg_cnt_khp);
2819 else
2820 sub(reg_cnt_khp, jcp.stride_h - 1);
2821 jmp(label_khp_loop, T_NEAR);
2822 } else {
2823 jnz(label_khp_loop, T_NEAR);
2824 }
2825 }
2826 L(label_khp_skip);
2827
2828 { // Handle Bottom Overflow
2829 Label label_bov_loop, label_bov_skip;
2830 test(reg_bov, reg_bov);
2831 jz(label_bov_skip, T_NEAR);
2832 mov(reg_cnt_tmp, reg_bov);
2833 L(label_bov_loop);
2834 {
2835 for (int ow = 0; ow < jcp.owp; ow++) {
2836 const int offset = ow * out_w_step;
2837 zero_it(reg_ptr_aux_out, offset);
2838 }
2839 add(reg_ptr_aux_out, out_h_step);
2840 dec(reg_cnt_tmp);
2841 jnz(label_bov_loop, T_NEAR);
2842 }
2843 L(label_bov_skip);
2844 }
2845 }
2846
generate()2847 void jit_avx512_core_amx_bwd_data_copy_kernel_t::generate() {
2848
2849 const int inp_c_step = jcp.oc_block_int * jcp.typesize_in;
2850 const int out_c_step = jcp.ohp * jcp.owp * inp_c_step;
2851 const int nb_oc_int_no_tail = jcp.oc_without_padding / jcp.oc_block_int;
2852 const int oc_block_int_tail = jcp.oc_without_padding % jcp.oc_block_int;
2853
2854 preamble();
2855
2856 // pointer to 1st needed element in src buffer
2857 mov(reg_ptr_inp, ptr[param1 + GET_OFF(src)]);
2858 // pointer to 1st needed element in dst buffer
2859 mov(reg_ptr_out, ptr[param1 + GET_OFF(dst)]);
2860
2861 // number of rows of src buffer to copy
2862 mov(reg_khp, ptr[param1 + GET_OFF(kh_padding)]);
2863 // number of zero-padded rows above src buffer to copy
2864 mov(reg_tov, ptr[param1 + GET_OFF(t_overflow)]);
2865 // number of zero-padded rows below src buffer to copy
2866 mov(reg_bov, ptr[param1 + GET_OFF(b_overflow)]);
2867
2868 // number of columns of src buffer to copy
2869 mov(reg_kwp, ptr[param1 + GET_OFF(kw_padding)]);
2870 // number of zero-padded columns before src buffer to copy
2871 mov(reg_lov, ptr[param1 + GET_OFF(l_overflow)]);
2872 // number of zero-padded columns before src buffer to copy
2873 mov(reg_rov, ptr[param1 + GET_OFF(r_overflow)]);
2874
2875 vpxord(zmm_zero, zmm_zero, zmm_zero);
2876
2877 if (oc_block_int_tail > 0) {
2878 uint64_t mask = (UINT64_C(1) << oc_block_int_tail) - 1;
2879 mov(reg_tmp, mask);
2880 kmovq(ktail_mask, reg_tmp);
2881 }
2882
2883 if (nb_oc_int_no_tail == 0) {
2884 copy_row(true); // masked
2885 } else if (nb_oc_int_no_tail == 1) {
2886 copy_row(false); // unmasked!
2887 if (oc_block_int_tail > 0) {
2888 add(reg_ptr_inp, inp_c_step);
2889 add(reg_ptr_out, out_c_step);
2890 copy_row(true); // masked
2891 }
2892 } else if (nb_oc_int_no_tail > 1) {
2893 mov(reg_cnt_ocb, nb_oc_int_no_tail);
2894 Label label_ocb_loop;
2895 L(label_ocb_loop);
2896 {
2897 copy_row(false); // unmasked!
2898 add(reg_ptr_inp, inp_c_step);
2899 add(reg_ptr_out, out_c_step);
2900 dec(reg_cnt_ocb);
2901 jnz(label_ocb_loop);
2902 }
2903 if (oc_block_int_tail > 0) copy_row(true); // masked
2904 }
2905
2906 postamble();
2907 }
2908
2909 // Tile register decomposition
2910 // { C_BASE = 0, I_BASE = 4, W_BASE = 6, }
get_out_tensor(int h,int i) const2911 int jit_avx512_core_amx_bwd_data_kernel_t::get_out_tensor(int h, int i) const {
2912 const int C_BASE = 0;
2913 const int C_LAST = 4;
2914 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
2915 MAYBE_UNUSED(C_LAST);
2916 const int tile = C_BASE + h * jcp.nb_ih_blocking + i;
2917 assert(C_BASE <= tile && tile < C_LAST);
2918 return tile;
2919 }
get_inp_tensor(int h) const2920 int jit_avx512_core_amx_bwd_data_kernel_t::get_inp_tensor(int h) const {
2921 const int I_BASE = 4;
2922 const int I_LAST = 6;
2923 assert(0 <= I_BASE && I_BASE < I_LAST && I_LAST <= jcp.max_tiles);
2924 MAYBE_UNUSED(I_LAST);
2925 const int tile = I_BASE + h;
2926 assert(I_BASE <= tile && tile < I_LAST);
2927 return tile;
2928 }
get_wei_tensor(int i) const2929 int jit_avx512_core_amx_bwd_data_kernel_t::get_wei_tensor(int i) const {
2930 const int W_BASE = 6;
2931 const int W_LAST = 8;
2932 assert(0 <= W_BASE && W_BASE < W_LAST && W_LAST <= jcp.max_tiles);
2933 MAYBE_UNUSED(W_LAST);
2934 const int tile = W_BASE + i;
2935 assert(W_BASE <= tile && tile < W_LAST);
2936 return tile;
2937 }
2938
2939 // Strides, shifts and offsets
2940 // - inp is a padded buffer ~ [nb_oc_int][ohp][owp]{32c,64c}
2941 // - weights is user buffer ~ OIhw16o16i{2o,4o}
2942 // - output is tiled buffer ~ [NBIH][NBIC][tile_width][16c]
get_inp_kh_step() const2943 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_kh_step() const {
2944 return (size_t)jcp.typesize_in * (jcp.dilate_h + 1) * jcp.owp
2945 * jcp.oc_block_int;
2946 }
get_inp_ocb_step() const2947 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_ocb_step() const {
2948 return (size_t)jcp.typesize_in * jcp.ohp * jcp.owp * jcp.oc_block_int;
2949 }
get_inp_shift() const2950 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_shift() const {
2951 return (size_t)jcp.typesize_in * jcp.tile_width * jcp.oc_block_int;
2952 }
get_inp_offset(int ihb,int kh,int kw) const2953 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_inp_offset(
2954 int ihb, int kh, int kw) const {
2955 // calculate offset by src height dimension
2956 size_t sp_offset = (size_t)ihb * jcp.owp;
2957 // add offset by kernel height dimension
2958 sp_offset += (size_t)(jcp.kh - 1 - kh) * (jcp.dilate_h + 1) * jcp.owp;
2959 // add offset by kernel width dimension
2960 sp_offset += (size_t)(jcp.kw - 1 - kw) * (jcp.dilate_w + 1);
2961 return jcp.typesize_in * sp_offset * jcp.oc_block_int;
2962 }
get_wei_kh_step() const2963 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_kh_step() const {
2964 return (size_t)jcp.typesize_in * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2965 }
get_wei_ocb_step() const2966 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_ocb_step() const {
2967 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2968 return (size_t)jcp.typesize_in * (is_deconv ? 1 : jcp.nb_ic) * jcp.kh
2969 * jcp.kw * jcp.oc_block_int * jcp.ic_block;
2970 }
get_wei_offset(int icb,int kh,int kw) const2971 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wei_offset(
2972 int icb, int kh, int kw) const {
2973 const bool is_deconv = jcp.prop_kind != prop_kind::backward_data;
2974 const size_t wei_kw_stride = jcp.oc_block_int * jcp.ic_block;
2975 const size_t wei_kh_stride = jcp.kw * wei_kw_stride;
2976 const size_t wei_icb_stride
2977 = (is_deconv ? jcp.nb_oc_int : 1) * jcp.kh * wei_kh_stride;
2978 return jcp.typesize_in
2979 * (icb * wei_icb_stride + kh * wei_kh_stride + kw * wei_kw_stride);
2980 }
get_out_icb_offset(int ihb,int icb) const2981 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_icb_offset(
2982 int ihb, int icb) const {
2983 size_t el_offset = jcp.is_nspc
2984 ? (size_t)icb * jcp.ic_block
2985 + (size_t)ihb * jcp.iw * jcp.ngroups
2986 * jcp.ic_without_padding
2987 : (size_t)icb * jcp.ih * jcp.iw * jcp.ic_block
2988 + (size_t)ihb * jcp.iw * jcp.ic_block;
2989 return (size_t)jcp.typesize_out * el_offset;
2990 }
get_out_row_offset(int ihb,int icb,int j) const2991 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_row_offset(
2992 int ihb, int icb, int j) const {
2993 size_t offset_w = jcp.is_nspc ? (size_t)jcp.typesize_out * j * jcp.ngroups
2994 * jcp.ic_without_padding
2995 : (size_t)jcp.typesize_out * j * jcp.ic_block;
2996 return get_out_icb_offset(ihb, icb) + offset_w;
2997 }
get_out_shift(int width) const2998 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_out_shift(int width) const {
2999 return jcp.is_nspc ? (size_t)jcp.typesize_out * width * jcp.ngroups
3000 * jcp.ic_without_padding
3001 : (size_t)jcp.typesize_out * width * jcp.ic_block;
3002 }
get_wsp_icb_offset(int ihb,int icb) const3003 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_icb_offset(
3004 int ihb, int icb) const {
3005 size_t el_offset = (size_t)icb * prv_width_ * jcp.ic_block
3006 + (size_t)ihb * jcp.nb_ic_blocking * jcp.full_tile_width
3007 * jcp.ic_block;
3008 return jcp.typesize_acc * el_offset;
3009 }
get_wsp_row_offset(int ihb,int icb,int j) const3010 size_t jit_avx512_core_amx_bwd_data_kernel_t::get_wsp_row_offset(
3011 int ihb, int icb, int j) const {
3012 return get_wsp_icb_offset(ihb, icb)
3013 + (size_t)jcp.typesize_acc * j * jcp.ic_block;
3014 }
3015
3016 // Code generation
prepare_output()3017 void jit_avx512_core_amx_bwd_data_kernel_t::prepare_output() {
3018 for (int h = 0; h < jcp.nb_ih_blocking; h++)
3019 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3020 tilezero(Tmm(get_out_tensor(h, i)));
3021 }
3022
init_runtime_counters(bool start_with_last_tile_block)3023 void jit_avx512_core_amx_bwd_data_kernel_t::init_runtime_counters(
3024 bool start_with_last_tile_block) {
3025 prv_width_ = start_with_last_tile_block && jcp.tile_tail > 0
3026 ? jcp.tile_tail
3027 : jcp.tile_width;
3028
3029 row_count_ = 0;
3030 is_store_done_ = false;
3031 is_buffer_empty_ = true;
3032 }
3033
maybe_eltwise(int position)3034 bool jit_avx512_core_amx_bwd_data_kernel_t::maybe_eltwise(int position) {
3035 using namespace primitive_kind;
3036 const auto &p = attr_.post_ops_;
3037
3038 if (position == 0) {
3039 /* eltwise before sum */
3040 return p.contain(eltwise, 0);
3041 } else if (position == 1) {
3042 /* eltwise after sum */
3043 return p.contain(sum, 0) && p.contain(eltwise, 1);
3044 }
3045
3046 return false;
3047 }
3048
ymm_mask(const Ymm & ymm_in,bool mask_flag,bool store)3049 Ymm jit_avx512_core_amx_bwd_data_kernel_t::ymm_mask(
3050 const Ymm &ymm_in, bool mask_flag, bool store) {
3051 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
3052 : ymm_in;
3053 }
3054
zmm_mask(const Zmm & zmm_in,bool mask_flag,bool store)3055 Zmm jit_avx512_core_amx_bwd_data_kernel_t::zmm_mask(
3056 const Zmm &zmm_in, bool mask_flag, bool store) {
3057 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
3058 : zmm_in;
3059 }
3060
cvt2ps(data_type_t type_in,const Zmm & zmm_in,const Operand & op,bool mask_flag)3061 void jit_avx512_core_amx_bwd_data_kernel_t::cvt2ps(data_type_t type_in,
3062 const Zmm &zmm_in, const Operand &op, bool mask_flag) {
3063 const Zmm zmm = zmm_mask(zmm_in, mask_flag);
3064 switch (type_in) {
3065 case data_type::f32:
3066 case data_type::s32: vmovups(zmm, op); break;
3067 case data_type::s8: vpmovsxbd(zmm, op); break;
3068 case data_type::u8: vpmovzxbd(zmm, op); break;
3069 default: assert(!"unsupported data type");
3070 }
3071 if (type_in != data_type::f32) vcvtdq2ps(zmm_in, zmm_in);
3072 }
3073
store_output_vector_bf16(const Zmm & zmm_out,int icb,int h,int w)3074 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_bf16(
3075 const Zmm &zmm_out, int icb, int h, int w) {
3076 const bool mask_flag = jcp.is_nspc && jcp.ic_without_padding != jcp.ic
3077 && icb == (jcp.nb_ic_blocking - 1);
3078
3079 auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3080
3081 const auto &p = attr_.post_ops_;
3082
3083 const int sum_idx = p.find(primitive_kind::sum);
3084 if (sum_idx != -1) {
3085 if (jcp.dsrc_dt == data_type::bf16) {
3086 vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr);
3087 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
3088 vaddps(zmm_out, zmm_prev_dst);
3089 } else {
3090 vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr);
3091 vaddps(zmm_out, zmm_prev_dst);
3092 }
3093 }
3094 if (jcp.with_bias) {
3095 int bias_offset = jcp.typesize_bia * icb * jcp.ic_block;
3096 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3097 if (jcp.bia_dt == data_type::bf16) {
3098 vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr);
3099 vpslld(zmm_bias, zmm_bias, 16);
3100 vaddps(zmm_out, zmm_bias);
3101 } else
3102 vaddps(zmm_mask(zmm_out, mask_flag), bias_addr);
3103 }
3104
3105 const int eltwise_ind = p.find(primitive_kind::eltwise);
3106 if (eltwise_ind != -1) eltwise_injector_->compute_vector(zmm_out.getIdx());
3107
3108 if (jcp.dsrc_dt == data_type::bf16) {
3109 Ymm ymm_out = Ymm(zmm_out.getIdx());
3110 vcvtneps2bf16(ymm_out, zmm_out);
3111 vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true));
3112 } else {
3113 vmovups(addr, zmm_mask(zmm_out, mask_flag, true));
3114 }
3115 }
3116
store_output_vector_int8(const Zmm & zmm_out,int icb,int h,int w)3117 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector_int8(
3118 const Zmm &zmm_out, int icb, int h, int w) {
3119 const int nb_ic_block = jcp.nb_ic_blocking;
3120 const int ic_block = jcp.ic_block;
3121 const bool mask_flag = true && jcp.ic_without_padding != jcp.ic
3122 && icb == (nb_ic_block - 1);
3123
3124 auto addr = EVEX_compress_addr(reg_out_ptr, get_out_row_offset(h, icb, w));
3125
3126 const auto &p = attr_.post_ops_;
3127 const int sum_idx = p.find(primitive_kind::sum);
3128 const float *p_sum_scale = nullptr;
3129 if (sum_idx != -1) {
3130 const auto &p_entry = p.entry_[sum_idx];
3131 p_sum_scale = &p_entry.sum.scale;
3132 }
3133
3134 if (p_sum_scale && *p_sum_scale != 1.f)
3135 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
3136
3137 int scale_offset = jcp.is_ic_scale * (sizeof(float) * icb * ic_block);
3138 if (jcp.with_bias) {
3139 int bias_offset = jcp.typesize_bia * icb * ic_block;
3140 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
3141 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
3142 }
3143 /* add bias to zmm_accum */
3144 vcvtdq2ps(zmm_out, zmm_out);
3145 if (jcp.with_bias) vaddps(zmm_out, zmm_out, zmm_bias);
3146 const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag);
3147 vmulps(zmm_out_msk, zmm_out,
3148 EVEX_compress_addr(reg_ptr_scales, scale_offset));
3149
3150 /* Do post-ops */
3151 if (maybe_eltwise(0)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3152 if (p_sum_scale) { // post_op: sum
3153 cvt2ps(jcp.dsrc_dt, zmm_prev_dst, addr, mask_flag);
3154 if (*p_sum_scale == 1.f)
3155 vaddps(zmm_out, zmm_prev_dst);
3156 else
3157 vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
3158 }
3159 if (maybe_eltwise(1)) eltwise_injector_->compute_vector(zmm_out.getIdx());
3160
3161 // Properly saturate the accumulators for integer datatypes
3162 if (one_of(jcp.dsrc_dt, u8, s8, s32)) {
3163 init_saturate_f32(
3164 zmm_zero, zmm_saturation, reg_aux_saturation, f32, jcp.dsrc_dt);
3165 saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dsrc_dt);
3166 vcvtps2dq(zmm_out, zmm_out);
3167 }
3168
3169 const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true);
3170
3171 switch (jcp.dsrc_dt) {
3172 case data_type::f32:
3173 case data_type::s32: vmovups(addr, zmm_out_store); break;
3174 case data_type::s8: vpmovsdb(addr, zmm_out_store); break;
3175 case data_type::u8: vpmovusdb(addr, zmm_out_store); break;
3176 default: assert(!"unknown dst_dt");
3177 }
3178 }
3179
store_output_vector(const Zmm & zmm_out,int icb,int h,int w)3180 void jit_avx512_core_amx_bwd_data_kernel_t::store_output_vector(
3181 const Zmm &zmm_out, int icb, int h, int w) {
3182 /*
3183 Output:
3184 jcp.is_nspc !jcp.is_nspc
3185 --------------------- ---------------------
3186 INT8: [N][H][W][NBIC][16IC]
3187 BF16: [N][H][W][NBIC][16IC] or [N][NBIC][H][W][16IC]
3188 */
3189 if (jcp.ddst_dt == data_type::bf16) {
3190 store_output_vector_bf16(zmm_out, icb, h, w);
3191 } else {
3192 store_output_vector_int8(zmm_out, icb, h, w);
3193 }
3194 }
3195
store_output(int width,bool do_store)3196 void jit_avx512_core_amx_bwd_data_kernel_t::store_output(
3197 int width, bool do_store) {
3198 auto store_output_block = [=](int width, bool do_store,
3199 bool is_last_ih_blks) {
3200 // Calculate the number of ih blocks; it may differ on last call
3201 const int n_ih_blks = is_last_ih_blks ? jcp.ih % jcp.nb_ih_blocking
3202 : jcp.nb_ih_blocking;
3203 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3204 for (int ihb = 0; ihb < n_ih_blks; ihb++) {
3205 /* Formats: Workspace: [NBIH][NBIC][W][16OC] */
3206 tilestored(ptr[reg_wsp_ptr + reg_wei_stride
3207 + get_wsp_icb_offset(ihb, icb)],
3208 Tmm(get_out_tensor(ihb, icb)));
3209 is_buffer_empty_ = false;
3210 is_store_done_ = false;
3211 for (int tw = 0; tw < width && do_store; tw++) {
3212 Zmm zmm_out = Zmm(tw);
3213 vmovups(zmm_out,
3214 ptr[reg_wsp_ptr
3215 + get_wsp_row_offset(ihb, icb, tw)]);
3216 store_output_vector(zmm_out, icb, ihb, tw);
3217 }
3218 }
3219 }
3220 };
3221
3222 // adjustment in case interleave store is turned off
3223 do_store = do_store || jcp.per_one_pstore == 0;
3224 if (jcp.ih % jcp.nb_ih_blocking == 0) {
3225 store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3226 } else {
3227 Label label_full_store, label_done;
3228 cmp(reg_last_h, 0);
3229 jne(label_full_store, T_NEAR);
3230 store_output_block(width, do_store, /* is_last_ih_blks = */ true);
3231 jmp(label_done, T_NEAR);
3232 L(label_full_store);
3233 store_output_block(width, do_store, /* is_last_ih_blks = */ false);
3234 L(label_done);
3235 }
3236 if (do_store) add(reg_out_ptr, get_out_shift(width));
3237 }
3238
interleave_store(int width)3239 void jit_avx512_core_amx_bwd_data_kernel_t::interleave_store(int width) {
3240 for (int c = 0;
3241 c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_;
3242 c++) {
3243 // row_count = ihb * ICB * TW + icb * TW + tw
3244 int tw = row_count_ % prv_width_;
3245 int icb = (row_count_ / prv_width_) % jcp.nb_ic_blocking;
3246 int ihb = (row_count_ / prv_width_) / jcp.nb_ic_blocking;
3247
3248 Zmm zmm_out = Zmm(tw);
3249 vmovups(zmm_out, ptr[reg_wsp_ptr + get_wsp_row_offset(ihb, icb, tw)]);
3250 store_output_vector(zmm_out, icb, ihb, tw);
3251 row_count_++;
3252
3253 if (row_count_
3254 == prv_width_ * jcp.nb_ic_blocking * jcp.nb_ih_blocking) {
3255 add(reg_out_ptr, get_out_shift(prv_width_));
3256 row_count_ = 0;
3257 is_store_done_ = true;
3258 prv_width_ = width;
3259 }
3260 }
3261 }
3262
compute_ocb_loop(int width,bool do_store)3263 void jit_avx512_core_amx_bwd_data_kernel_t::compute_ocb_loop(
3264 int width, bool do_store) {
3265
3266 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
3267 switch (jcp.ddst_dt) {
3268 using namespace data_type;
3269 case bf16: tdpbf16ps(x1, x2, x3); break;
3270 case s8: tdpbssd(x1, x2, x3); break;
3271 case u8: tdpbusd(x1, x2, x3); break;
3272 default: assert(!"unsupported data type");
3273 }
3274 };
3275
3276 prepare_output();
3277
3278 for (int ocb = 0; ocb < jcp.nb_oc_int; ocb++) {
3279 // reverse order through spatial components of weights so that
3280 // input buffer is accessed in a monotonically increasing fashion
3281 for (int kh = jcp.kh - 1; kh >= 0; kh--) {
3282 for (int kw = jcp.kw - 1; kw >= 0; kw--) {
3283 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3284 tileloadd(Tmm(get_inp_tensor(ihb)),
3285 ptr[reg_inp_ptr + get_inp_offset(ihb, kh, kw)
3286 + reg_inp_stride]);
3287 }
3288 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
3289 tileloadd(Tmm(get_wei_tensor(icb)),
3290 ptr[reg_wei_ptr + get_wei_offset(icb, kh, kw)
3291 + reg_wei_stride]);
3292 for (int ihb = 0; ihb < jcp.nb_ih_blocking; ihb++) {
3293 tdpbxxd(Tmm(get_out_tensor(ihb, icb)),
3294 Tmm(get_inp_tensor(ihb)),
3295 Tmm(get_wei_tensor(icb)));
3296 interleave_store(width);
3297 }
3298 }
3299 }
3300 }
3301 add(reg_inp_ptr, get_inp_ocb_step());
3302 add(reg_wei_ptr, get_wei_ocb_step());
3303 }
3304 sub(reg_inp_ptr, get_inp_ocb_step() * jcp.nb_oc_int);
3305 sub(reg_wei_ptr, get_wei_ocb_step() * jcp.nb_oc_int);
3306
3307 store_output(width, do_store);
3308
3309 add(reg_inp_ptr, get_inp_shift());
3310 }
3311
compute_iw_loop()3312 void jit_avx512_core_amx_bwd_data_kernel_t::compute_iw_loop() {
3313 auto compute_iw_loop_body = [=](bool last_iwb, int num_tile_blocks) {
3314 int gen_tile_tail = last_iwb && jcp.tile_tail > 0 ? jcp.tile_tail
3315 : jcp.tile_width;
3316 init_runtime_counters(last_iwb && num_tile_blocks == 1);
3317 for (int iwb = 0; iwb < num_tile_blocks - 1; iwb++)
3318 compute_ocb_loop(jcp.tile_width, false);
3319 compute_ocb_loop(gen_tile_tail, true);
3320 };
3321
3322 if (jcp.nb_iw == 1) {
3323 compute_iw_loop_body(true, jcp.iw_blocks);
3324 } else {
3325 Label label_done;
3326 int iw_blocks_per_call = div_up(jcp.iw_block, jcp.tile_width);
3327 int last_iwb_tile_blocks = jcp.iw_blocks % iw_blocks_per_call;
3328 if (last_iwb_tile_blocks == 0 && jcp.tile_tail > 0)
3329 last_iwb_tile_blocks = iw_blocks_per_call;
3330 if (last_iwb_tile_blocks > 0) {
3331 Label label_not_last_iwb;
3332 mov(reg_tmp, ptr[param1 + GET_OFF(iwb)]);
3333 cmp(reg_tmp, jcp.nb_iw - 1);
3334 jne(label_not_last_iwb, T_NEAR);
3335
3336 compute_iw_loop_body(true, last_iwb_tile_blocks);
3337
3338 jmp(label_done, T_NEAR);
3339
3340 L(label_not_last_iwb);
3341 }
3342 compute_iw_loop_body(false, iw_blocks_per_call);
3343
3344 L(label_done);
3345 }
3346 }
3347
generate()3348 void jit_avx512_core_amx_bwd_data_kernel_t::generate() {
3349 preamble();
3350
3351 mov(reg_inp_ptr, ptr[param1 + GET_OFF(dst)]); // padded buffer of diff_dst
3352 mov(reg_wei_ptr, ptr[param1 + GET_OFF(filt)]); // weights
3353 mov(reg_out_ptr, ptr[param1 + GET_OFF(src)]); // diff_src
3354 mov(reg_wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]);
3355
3356 if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
3357
3358 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
3359
3360 mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]);
3361
3362 const int inp_stride = jcp.oc_block_int * jcp.typesize_in;
3363 const int wei_stride = jcp.ic_block * jcp.typesize_acc;
3364 mov(reg_inp_stride, inp_stride);
3365 mov(reg_wei_stride, wei_stride);
3366
3367 if (jcp.is_nspc && jcp.ic_without_padding != jcp.ic) {
3368 // Use mask 0xF by default for all output data and post-ops
3369 // loads / stores with block index
3370 // icb = icc * jcp.nb_ic_blocking + (jcp.nb_ic_blocking - 1)
3371 // TODO: use masked loads / stores for the last icc only
3372 int current_block_size = jcp.ic_block;
3373 int mask = (1 << current_block_size) - 1;
3374 Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
3375 mov(regw_tmp, mask);
3376 kmovw(ktail_mask, regw_tmp);
3377 Xbyak::Label mask_is_set;
3378 mov(reg_ic_blocks, ptr[param1 + GET_OFF(ic_blocks)]);
3379 cmp(reg_ic_blocks, jcp.nb_ic - jcp.nb_ic_blocking);
3380 jne(mask_is_set, T_NEAR);
3381 // Reset the mask
3382 current_block_size = jcp.ic_without_padding % jcp.ic_block;
3383 mask = (1 << current_block_size) - 1;
3384 mov(regw_tmp, mask);
3385 kmovw(ktail_mask, regw_tmp);
3386
3387 L(mask_is_set);
3388 }
3389 compute_iw_loop();
3390
3391 postamble();
3392
3393 if (jcp.with_eltwise) eltwise_injector_->prepare_table();
3394 }
3395
post_ops_ok(const jit_conv_conf_t & jcp,const primitive_attr_t & attr)3396 bool jit_avx512_core_amx_bwd_data_kernel_t::post_ops_ok(
3397 const jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
3398 using namespace primitive_kind;
3399 const auto &p = attr.post_ops_;
3400 const bool is_bf16 = jcp.ddst_dt == data_type::bf16;
3401
3402 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
3403
3404 auto is_sum = [&](int idx) {
3405 if (is_bf16)
3406 return p.entry_[idx].is_sum();
3407 else
3408 return p.contain(sum, idx);
3409 };
3410
3411 switch (p.len()) {
3412 case 0: return true;
3413 case 1: return is_eltwise(0) || is_sum(0);
3414 case 2:
3415 return (is_sum(0) && is_eltwise(1))
3416 || (!is_bf16 && is_sum(1) && is_eltwise(0));
3417 default: return false;
3418 }
3419
3420 return false;
3421 }
3422
tile_configure(char * tcfg_buff)3423 void jit_avx512_core_amx_bwd_data_kernel_t::tile_configure(char *tcfg_buff) {
3424 const int vnni_width = jcp.ddst_dt == data_type::bf16 ? 2 : 4;
3425 // Input tile dimensions
3426 const int a_col = jcp.oc_block_int;
3427 const int a_row = jcp.tile_width;
3428 // Weights tile dimensions
3429 const int b_col = jcp.ic_block * vnni_width;
3430 const int b_row = a_col / vnni_width;
3431 // Accumulator tile dimensions
3432 const int c_col = jcp.ic_block;
3433 const int c_row = a_row;
3434
3435 for (size_t i = 0; i < 64; i++)
3436 tcfg_buff[i] = 0;
3437
3438 // Weights (W_BASE) Tensor Tiles
3439 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3440 tc_configure_tile((palette_config_t *)tcfg_buff, get_wei_tensor(i),
3441 b_row, b_col * jcp.typesize_in);
3442
3443 // Input (I_BASE) and Accumulator (C_BASE) Tensor Tiles
3444 for (int h = 0; h < jcp.nb_ih_blocking; h++) {
3445 tc_configure_tile((palette_config_t *)tcfg_buff, get_inp_tensor(h),
3446 a_row, a_col * jcp.typesize_in);
3447 for (int i = 0; i < jcp.nb_ic_blocking; i++)
3448 tc_configure_tile((palette_config_t *)tcfg_buff,
3449 get_out_tensor(h, i), c_row, c_col * jcp.typesize_acc);
3450 }
3451
3452 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3453 }
3454
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & diff_src_md,memory_desc_t & weights_md,memory_desc_t & diff_dst_md,memory_desc_t * bias_md,const primitive_attr_t & attr,int nthreads)3455 status_t jit_avx512_core_amx_bwd_data_kernel_t::init_conf(jit_conv_conf_t &jcp,
3456 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
3457 memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
3458 memory_desc_t *bias_md, const primitive_attr_t &attr, int nthreads) {
3459 using namespace prop_kind;
3460
3461 const memory_desc_wrapper diff_src_d(&diff_src_md);
3462 const memory_desc_wrapper weights_d(&weights_md);
3463 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
3464 const memory_desc_wrapper bias_d(bias_md);
3465
3466 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
3467 int ndims = diff_src_d.ndims();
3468 bool is_1d = ndims == 3;
3469 bool is_3d = ndims == 5;
3470
3471 if (is_3d) return status::unimplemented;
3472
3473 using namespace data_type;
3474 const bool is_deconv = cd.prop_kind != prop_kind::backward_data;
3475 const bool is_bf16_convolution = !is_deconv
3476 && everyone_is(true, diff_dst_d.data_type() == bf16,
3477 weights_d.data_type() == bf16,
3478 one_of(diff_src_d.data_type(), bf16, f32));
3479 const bool is_int8_deconvolution = is_deconv
3480 && everyone_is(true, one_of(diff_dst_d.data_type(), s8, u8),
3481 weights_d.data_type() == s8,
3482 one_of(diff_src_d.data_type(), f32, s32, s8, u8));
3483
3484 bool supported = false
3485 || (is_bf16_convolution && mayiuse(avx512_core_bf16_amx_bf16))
3486 || (is_int8_deconvolution && mayiuse(avx512_core_bf16_amx_int8));
3487 if (!supported) return status::unimplemented;
3488
3489 jcp = zero<decltype(jcp)>();
3490 jcp.isa = is_bf16_convolution ? avx512_core_bf16_amx_bf16
3491 : avx512_core_bf16_amx_int8;
3492 jcp.ndims = ndims;
3493 jcp.prop_kind = cd.prop_kind;
3494 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
3495
3496 jcp.mb = diff_src_d.dims()[0];
3497 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
3498 jcp.oc_without_padding = jcp.oc;
3499 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
3500 jcp.ic_without_padding = jcp.ic;
3501 jcp.ih = !is_1d ? diff_src_d.dims()[ndims - 2] : 1;
3502 jcp.iw = diff_src_d.dims()[ndims - 1];
3503 jcp.oh = !is_1d ? diff_dst_d.dims()[ndims - 2] : 1;
3504 jcp.ow = diff_dst_d.dims()[ndims - 1];
3505 jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1;
3506 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
3507 jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0;
3508 jcp.l_pad = cd.padding[0][ndims - 3];
3509 jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1;
3510 jcp.stride_w = cd.strides[ndims - 3];
3511
3512 // No bias for bf16 case to simplify integration with ref_deconvolution
3513 jcp.with_bias = bias_md && !is_bf16_convolution
3514 && cd.bias_desc.format_kind != format_kind::undef;
3515
3516 jcp.dilate_h = !is_1d ? cd.dilates[ndims - 4] : 0;
3517 jcp.dilate_w = cd.dilates[ndims - 3];
3518
3519 const int gen_kh = (jcp.kh - 1) * (jcp.dilate_h + 1) + 1;
3520 const int gen_kw = (jcp.kw - 1) * (jcp.dilate_w + 1) + 1;
3521 jcp.b_pad = calculate_end_padding(
3522 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, gen_kh);
3523 jcp.r_pad = calculate_end_padding(
3524 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, gen_kw);
3525 if (jcp.l_pad >= gen_kw || jcp.r_pad >= gen_kw || jcp.t_pad >= gen_kh
3526 || jcp.b_pad >= gen_kh)
3527 return status::unimplemented;
3528
3529 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
3530 if (is_deconv) {
3531 jcp.ddst_dt = cd.src_desc.data_type;
3532 jcp.dsrc_dt = cd.dst_desc.data_type;
3533 } else {
3534 jcp.ddst_dt = cd.diff_dst_desc.data_type;
3535 jcp.dsrc_dt = cd.diff_src_desc.data_type;
3536 }
3537 jcp.wei_dt = cd.weights_desc.data_type;
3538
3539 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
3540
3541 if (jcp.is_depthwise)
3542 return status::unimplemented; // TODO: add support of DW convolution
3543
3544 format_tag_t dat_tag_ncsp
3545 = pick(ndims - 3, format_tag::nCw16c, format_tag::nChw16c);
3546 format_tag_t dat_tag_nspc
3547 = pick(ndims - 3, format_tag::nwc, format_tag::nhwc);
3548 // To toggle the default data layout for BF16 between nChw16c and nhwc,
3549 // swap the following two variable definitions. Current choice: nhwc.
3550 format_tag_t dat_tag_opt = dat_tag_nspc;
3551 format_tag_t dat_tag_alt
3552 = is_bf16_convolution ? dat_tag_ncsp : dat_tag_nspc;
3553
3554 if (diff_src_d.format_kind() == format_kind::any) {
3555 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag_opt));
3556 jcp.src_tag = dat_tag_opt;
3557 } else
3558 jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag_alt, dat_tag_opt);
3559
3560 if (!one_of(jcp.src_tag, dat_tag_alt, dat_tag_opt))
3561 return status::unimplemented;
3562
3563 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
3564 assert(IMPLICATION(is_int8_deconvolution, jcp.is_nspc));
3565
3566 // TODO: remove all support for nChw16c from this implementation
3567 if (!jcp.is_nspc) return status::unimplemented;
3568
3569 if (diff_dst_d.format_kind() == format_kind::any) {
3570 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
3571 jcp.dst_tag = jcp.src_tag;
3572 } else
3573 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
3574
3575 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
3576
3577 if (jcp.with_bias && bias_d.format_kind() == format_kind::any)
3578 CHECK(memory_desc_init_by_tag(*bias_md, format_tag::x));
3579
3580 jcp.nthr = nthreads;
3581
3582 jcp.ic_block = 16;
3583 jcp.oc_block = 16;
3584
3585 if (jcp.ngroups == 1) {
3586 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
3587 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
3588 }
3589 bool args_ok = jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0;
3590 if (!args_ok) return status::unimplemented;
3591
3592 const int vnni_width = is_bf16_convolution ? 2 : 4;
3593 jcp.oc_block_int = jcp.oc_block * vnni_width; // 32 for bf16, 64 for int8
3594
3595 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
3596
3597 const auto &p = attr.post_ops_;
3598 const int eltwise_ind = p.find(primitive_kind::eltwise);
3599 jcp.with_eltwise = eltwise_ind != -1;
3600 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
3601
3602 auto set_or_check_wei_format = [&]() {
3603 using namespace format_tag;
3604 format_tag_t wei_tag;
3605 if (is_bf16_convolution)
3606 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16o16i2o,
3607 gOIw16o16i2o, OIhw16o16i2o, gOIhw16o16i2o);
3608 else if (is_int8_deconvolution)
3609 wei_tag = pick(with_groups + 2 * (ndims - 3), OIw16i16o4i,
3610 gOIw16i16o4i, OIhw16i16o4i, gOIhw16i16o4i);
3611 else {
3612 assert(!"unsupported combination");
3613 return false;
3614 }
3615
3616 memory_desc_t want_wei_md = weights_md;
3617 memory_desc_init_by_tag(want_wei_md, wei_tag);
3618
3619 if (weights_md.format_kind == format_kind::any) {
3620 weights_md = want_wei_md;
3621 return true;
3622 }
3623 return weights_md == want_wei_md;
3624 };
3625
3626 if (!set_or_check_wei_format()) return status::unimplemented;
3627
3628 jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
3629 jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
3630 jcp.typesize_bia
3631 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
3632 jcp.typesize_acc = sizeof(int32_t);
3633
3634 jcp.nb_ic = jcp.ic / jcp.ic_block;
3635 jcp.nb_oc = jcp.oc / jcp.oc_block;
3636 jcp.nb_oc_int = div_up(jcp.oc, jcp.oc_block_int);
3637
3638 const int max_palette = amx::get_max_palette();
3639 jcp.max_tiles = amx::get_max_tiles(max_palette);
3640 jcp.full_tile_width = amx::get_max_rows(max_palette);
3641 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
3642 return status::unimplemented;
3643
3644 jcp.tile_width = nstl::min(jcp.full_tile_width, jcp.iw);
3645 jcp.iw_blocks = div_up(jcp.iw, jcp.tile_width);
3646
3647 // Prefer to use a single tile width when possible
3648 // (eg iw28 => 2 tiles of 14 vs 1 of 16 and 1 of 12)
3649 if (jcp.iw % jcp.iw_blocks == 0) jcp.tile_width = jcp.iw / jcp.iw_blocks;
3650 jcp.tile_tail = jcp.iw % jcp.tile_width;
3651
3652 jcp.nb_ic_blocking = (jcp.nb_ic % 2 == 0) ? 2 : 1;
3653 jcp.nb_ih_blocking
3654 = everyone_is(true, jcp.ih > 1,
3655 // requirement for interleave stores
3656 IMPLICATION(jcp.iw_blocks > 1, jcp.ih % 2 == 0))
3657 ? 2
3658 : 1;
3659
3660 // TODO: tune ih blocking
3661 const int ih_blk_size_tmp = 10;
3662 const int ih_step = jcp.nb_ih_blocking;
3663 jcp.ih_blk_size = rnd_up(nstl::min(jcp.ih, ih_blk_size_tmp), ih_step);
3664 // ohp includes all elements that are really used in calculation,
3665 // including zero-padded "dilate-by-strides" and top and bottom overflow
3666 jcp.ohp = jcp.ih_blk_size + gen_kh - 1;
3667
3668 // TODO: tune iw blocking
3669 const int iw_blocks_per_call = 2;
3670 jcp.iw_block = jcp.tile_width * iw_blocks_per_call;
3671 jcp.nb_iw = div_up(jcp.iw, jcp.iw_block);
3672 // owp includes all elements that are really used in calculation,
3673 // including zero-padded "dilate-by-strides" and left and right overflow
3674 jcp.owp = jcp.iw_block + gen_kw - 1;
3675
3676 // Number of ops per tile store
3677 int ops_tile_store = jcp.tile_width;
3678 // Number of ops per accumulation tile
3679 int avaliable_ops = jcp.nb_oc_int * jcp.kh * jcp.kw;
3680 // Number of vectors to store per tile operation
3681 // NOTE: set to zero to turn off interleave store (mostly for debugging)
3682 jcp.per_one_pstore = div_up(ops_tile_store, avaliable_ops);
3683
3684 jcp.inp_buffer_size
3685 = (size_t)jcp.nb_oc_int * jcp.ohp * jcp.owp * jcp.oc_block_int;
3686 jcp.wsp_buffer_size = (size_t)jcp.nb_ih_blocking * jcp.nb_ic_blocking
3687 * jcp.full_tile_width * jcp.ic_block;
3688
3689 const auto &oscales = attr.output_scales_;
3690 jcp.is_ic_scale = oscales.mask_ == 1 << 1;
3691
3692 return status::success;
3693 }
3694
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)3695 void jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad(
3696 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
3697 const primitive_attr_t &attr) {
3698
3699 size_t inp_buffer_size = jcp.nthr * jcp.inp_buffer_size;
3700 scratchpad.book(key_conv_amx_inp_buffer, inp_buffer_size, jcp.typesize_in);
3701 size_t wsp_size = jcp.nthr * jcp.wsp_buffer_size;
3702 scratchpad.book(key_conv_amx_wsp_buffer, wsp_size, jcp.typesize_acc);
3703 if (jcp.with_bias && jcp.ic != jcp.ic_without_padding) {
3704 assert(jcp.ngroups == 1);
3705 scratchpad.book(key_conv_padded_bias, jcp.ic, jcp.typesize_bia);
3706 }
3707 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
3708 }
3709
3710 const int jit_avx512_core_amx_bwd_weights_kernel_t::max_ur_w = 32;
3711
3712 // Tile register decomposition
3713 // { C_BASE = 0, A_BASE = 4, B_BASE = 6, }
get_wei_tensor(int ocb,int icb) const3714 int jit_avx512_core_amx_bwd_weights_kernel_t::get_wei_tensor(
3715 int ocb, int icb) const {
3716 const int C_BASE = 0;
3717 const int C_LAST = 4;
3718 assert(0 <= C_BASE && C_BASE < C_LAST && C_LAST <= jcp.max_tiles);
3719 MAYBE_UNUSED(C_LAST);
3720 const int tile = C_BASE + ocb * jcp.nb_oc_blocking + icb;
3721 assert(C_BASE <= tile && tile < C_LAST);
3722 return tile;
3723 }
get_src_tensor(int icb) const3724 int jit_avx512_core_amx_bwd_weights_kernel_t::get_src_tensor(int icb) const {
3725 const int A_BASE = 4;
3726 const int A_LAST = 6;
3727 assert(0 <= A_BASE && A_BASE < A_LAST && A_LAST <= jcp.max_tiles);
3728 MAYBE_UNUSED(A_LAST);
3729 const int tile = A_BASE + icb;
3730 assert(A_BASE <= tile && tile < A_LAST);
3731 return tile;
3732 }
get_ddst_tensor(int ocb) const3733 int jit_avx512_core_amx_bwd_weights_kernel_t::get_ddst_tensor(int ocb) const {
3734 const int B_BASE = 6;
3735 const int B_LAST = 8;
3736 assert(0 <= B_BASE && B_BASE < B_LAST && B_LAST <= jcp.max_tiles);
3737 MAYBE_UNUSED(B_LAST);
3738 const int tile = B_BASE + ocb;
3739 assert(B_BASE <= tile && tile < B_LAST);
3740 return tile;
3741 }
3742
tile_configure(char * tcfg_buff)3743 void jit_avx512_core_amx_bwd_weights_kernel_t::tile_configure(char *tcfg_buff) {
3744 // Input tile dimensions
3745 const int a_col = jcp.ur_w;
3746 const int a_row = jcp.ic_block;
3747 // Weights tile dimensions
3748 const int b_col = jcp.oc_block * 2;
3749 const int b_row = a_col / 2;
3750 // Accumulator tile dimensions
3751 const int c_col = jcp.oc_block;
3752 const int c_row = a_row;
3753
3754 for (size_t i = 0; i < 64; i++)
3755 tcfg_buff[i] = 0;
3756
3757 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3758 tc_configure_tile((palette_config_t *)tcfg_buff, get_src_tensor(icb),
3759 a_row, a_col * jcp.typesize_in);
3760
3761 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3762 tc_configure_tile((palette_config_t *)tcfg_buff, get_ddst_tensor(ocb),
3763 b_row, b_col * jcp.typesize_in);
3764
3765 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
3766 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++)
3767 tc_configure_tile((palette_config_t *)tcfg_buff,
3768 get_wei_tensor(ocb, icb), c_row, c_col * jcp.typesize_out);
3769
3770 ((palette_config_t *)tcfg_buff)->palette_id = amx::get_max_palette();
3771 }
3772
od_step_comeback_pointers()3773 void jit_avx512_core_amx_bwd_weights_kernel_t::od_step_comeback_pointers() {
3774 Label kd_comeback_label;
3775 mov(kj, reg_kd_count);
3776 L(kd_comeback_label);
3777 {
3778 sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3779 sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3780 dec(kj);
3781 jnz(kd_comeback_label, T_NEAR);
3782 }
3783 }
3784
oh_step_comeback_pointers()3785 void jit_avx512_core_amx_bwd_weights_kernel_t::oh_step_comeback_pointers() {
3786 Label kh_comeback_label;
3787 mov(kj, reg_kh);
3788 L(kh_comeback_label);
3789 {
3790 sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3791 sub(reg_kernel, get_kernel_offset(0, jcp.kw));
3792 dec(kj);
3793 jnz(kh_comeback_label, T_NEAR);
3794 }
3795 }
3796
compute_full_spat_loop(int nb_ic_blocking,int nb_oc_blocking)3797 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_full_spat_loop(
3798 int nb_ic_blocking, int nb_oc_blocking) {
3799 // General code layout:
3800 //
3801 // Blocking over OH -- top level
3802 // (Reduces L2 pressure; not very useful right now)
3803 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
3804 // Loop over OH block -- emit_h_loop()
3805 // Loop over OW blocks -- emit_fma_block()
3806 // (Supports both fully unrolled and partially unrolled
3807 // versions to reduce code size)
3808 // Loop over OW block -- emit_fma_step()
3809
3810 auto src_row_size = get_src_offset(0, 0, 1);
3811 auto ddst_row_size = get_ddst_offset(0, 1);
3812 auto row_size = src_row_size + ddst_row_size;
3813
3814 int h_block_size = jcp.oh;
3815 int h_last_block_size = h_block_size;
3816 int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3817 auto working_set_size = row_size * h_block_size;
3818
3819 if (working_set_size > full_spat_max_working_set_size) {
3820 assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3821
3822 while (working_set_size > full_spat_opt_working_set_size
3823 && h_block_size >= min_h_block_size) {
3824 for (int i = 2; i <= h_block_size; i++)
3825 if (i == h_block_size)
3826 h_block_size = h_block_size / 2;
3827 else if (h_block_size % i == 0) {
3828 h_block_size = h_block_size / i;
3829 break;
3830 }
3831 working_set_size = row_size * h_block_size;
3832 }
3833 h_block_size = nstl::max(min_h_block_size, h_block_size);
3834 h_last_block_size = jcp.oh % h_block_size;
3835 if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3836 }
3837
3838 Opmask reg_h_block = k1;
3839 Reg64 reg_kh = rax;
3840 Reg64 reg_kw = rbx;
3841 Reg64 reg_tmp = abi_not_param1;
3842 Reg32 reg_tmp_w = reg_tmp.cvt32();
3843 Reg64 reg_ohs = rdx;
3844 Reg64 reg_ihs = rsi;
3845 Reg64 reg_h = r8;
3846 Reg64 reg_j = r10;
3847
3848 Reg64 reg_src = r13;
3849 Reg64 reg_ddst = r14;
3850 Reg64 reg_ker = r15;
3851
3852 Reg64 reg_dense_stride = abi_param1;
3853 Reg64 reg_a_stride = reg_tmp;
3854
3855 auto emit_block = [&]() {
3856 mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
3857 for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
3858 dim_t ur_w_src_offset = ur_w_b * get_src_offset(0, jcp.ur_w);
3859 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
3860
3861 for (int icb = 0; icb < nb_ic_blocking; icb++) {
3862 dim_t icb_offset = jcp.typesize_in * icb * jcp.tr_src_buf_size;
3863 tileloadd(Tmm(get_src_tensor(icb)),
3864 ptr[reg_src + reg_a_stride + icb_offset
3865 + ur_w_src_offset]);
3866 }
3867 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
3868 tileloadd(Tmm(get_ddst_tensor(ocb)),
3869 ptr[reg_ddst + reg_dense_stride
3870 + jcp.typesize_in * ocb
3871 * jcp.tr_diff_dst_buf_size
3872 + ur_w_ddst_offset]);
3873 for (int icb = 0; icb < nb_ic_blocking; icb++)
3874 tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
3875 Tmm(get_src_tensor(icb)),
3876 Tmm(get_ddst_tensor(ocb)));
3877 }
3878 }
3879 };
3880
3881 auto emit_h_loop = [&]() {
3882 Label h_loop, skip_h_loop;
3883 mov(reg_j, 1);
3884 cmp(reg_j, reg_h);
3885 je(skip_h_loop, T_NEAR);
3886 L(h_loop);
3887 {
3888 emit_block();
3889
3890 add(reg_src, get_src_offset(0, 0, 1));
3891 add(reg_ddst, get_ddst_offset(0, 1));
3892 add(reg_j, 1);
3893 cmp(reg_j, reg_h);
3894 jb(h_loop);
3895 }
3896 L(skip_h_loop);
3897
3898 emit_block();
3899 };
3900
3901 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3902 xor_(reg_kh, reg_kh);
3903 Label kh_loop, kh_loop_end;
3904
3905 int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3906 // NB: this is correct because we only support t_pad = kh / 2 and thus
3907 // ih == oh
3908 int ih_block_size = oh_block_size
3909 + (!is_first_block + !is_last_block) * jcp.t_pad;
3910
3911 L(kh_loop);
3912 {
3913 if (is_first_block) {
3914 xor_(reg_tmp, reg_tmp);
3915 mov(reg_ohs, jcp.t_pad);
3916 sub(reg_ohs, reg_kh);
3917 cmovb(reg_ohs, reg_tmp);
3918
3919 mov(reg_ihs, reg_ohs);
3920 sub(reg_ihs, jcp.t_pad);
3921 add(reg_ihs, reg_kh);
3922 } else {
3923 xor_(reg_ohs, reg_ohs);
3924 mov(reg_ihs, reg_kh);
3925 }
3926
3927 mov(reg_tmp, oh_block_size);
3928 sub(reg_tmp, reg_ohs);
3929 mov(reg_h, ih_block_size);
3930 sub(reg_h, reg_ihs);
3931 cmp(reg_tmp, reg_h);
3932 cmovb(reg_h, reg_tmp);
3933
3934 Label kh_loop_work;
3935 cmp(reg_h, 0);
3936 jg(kh_loop_work, T_NEAR);
3937
3938 // empty h loop for this jcp.kh:
3939 // - set the ddst to 0 if necessary
3940 // - move ker pt
3941 // - jump to the end
3942 sub(reg_h, 1);
3943 Label skip_ker_zeroing;
3944
3945 // The reg_ker ptr has highest bit set if the ddst needs to be
3946 // zeroed. Those who have byte-aligned their data will suffer the
3947 // consequences :(
3948 // TODO: move the flag to a mask register? (Roma)
3949 test(reg_ker, 1);
3950 jz(skip_ker_zeroing, T_NEAR);
3951
3952 Label zeroing_loop;
3953 vpxord(zmm0, zmm0, zmm0);
3954 and_(reg_ker, ~1); // temporarily clear the zeroing flag
3955
3956 mov(reg_dense_stride, 64);
3957 tilezero(Tmm(get_wei_tensor(0, 0)));
3958 for (int kw = 0; kw < jcp.kw; kw++) {
3959 // dim_t kw_offset = kw * get_kernel_offset(jcp.ic_block, 0);
3960 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
3961 for (int icb = 0; icb < nb_ic_blocking; icb++)
3962 tilestored(
3963 ptr[reg_ker + reg_dense_stride
3964 + get_full_kernel_offset(ocb, icb, 0, kw)],
3965 Tmm(get_wei_tensor(0, 0)));
3966 }
3967 // restore the zeroing flag (it will be cleared after the end of
3968 // emit_kh_kw_loop, but we may need it until then)
3969 or_(reg_ker, 1);
3970 jmp(kh_loop_end, T_NEAR);
3971
3972 L(skip_ker_zeroing);
3973 add(reg_ker, get_kernel_offset(0, jcp.kw));
3974 jmp(kh_loop_end, T_NEAR);
3975
3976 L(kh_loop_work);
3977
3978 mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
3979 mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
3980
3981 add(reg_src, reg_ihs);
3982 add(reg_ddst, reg_ohs);
3983
3984 Label kw_loop;
3985 xor_(reg_kw, reg_kw);
3986
3987 mov(reg_dense_stride, 64);
3988 L(kw_loop);
3989 {
3990 Label do_zero, ker_init_done;
3991 test(reg_ker, 1);
3992 jnz(do_zero, T_NEAR);
3993
3994 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
3995 for (int icb = 0; icb < nb_ic_blocking; icb++)
3996 tileloadd(Tmm(get_wei_tensor(ocb, icb)),
3997 ptr[reg_ker + reg_dense_stride
3998 + get_full_kernel_offset(ocb, icb, 0, 0)]);
3999 jmp(ker_init_done);
4000 L(do_zero);
4001 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4002 for (int icb = 0; icb < nb_ic_blocking; icb++)
4003 tilezero(Tmm(get_wei_tensor(ocb, icb)));
4004
4005 L(ker_init_done);
4006
4007 mov(ptr[rsp + ddst_save_offset], reg_ddst);
4008 mov(ptr[rsp + src_save_offset], reg_src);
4009
4010 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
4011 emit_h_loop();
4012
4013 mov(reg_ddst, ptr[rsp + ddst_save_offset]);
4014 mov(reg_src, ptr[rsp + src_save_offset]);
4015
4016 // The reg_ker ptr has highest bit set if the ddst needs to
4017 // be zeroed. Those who have byte-aligned their data will
4018 // suffer the consiquences :(
4019 mov(reg_tmp, reg_ker);
4020 and_(reg_ker, ~1);
4021
4022 for_(int ocb = 0; ocb < nb_oc_blocking; ocb++)
4023 for (int icb = 0; icb < nb_ic_blocking; icb++)
4024 tilestored(
4025 ptr[reg_ker + reg_dense_stride
4026 + get_full_kernel_offset(ocb, icb, 0, 0)],
4027 Tmm(get_wei_tensor(ocb, icb)));
4028
4029 mov(reg_ker, reg_tmp);
4030 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
4031 add(reg_kw, 1);
4032 cmp(reg_kw, jcp.kw);
4033 jl(kw_loop);
4034 }
4035
4036 sub(reg_src, reg_ihs);
4037 sub(reg_ddst, reg_ohs);
4038
4039 L(kh_loop_end);
4040 add(reg_kh, 1);
4041 cmp(reg_kh, jcp.kh);
4042 jl(kh_loop);
4043 }
4044 };
4045
4046 mov(reg_src, ptr[param + GET_OFF(src)]);
4047 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4048 mov(reg_ker, ptr[param + GET_OFF(filt)]);
4049 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4050 or_(reg_ker, reg_tmp);
4051
4052 bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
4053
4054 auto src_row_step = get_src_offset(0, 0, 1);
4055 auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
4056 auto ddst_block_step = get_ddst_offset(0, h_block_size);
4057
4058 emit_kh_kw_loop(true, single_kh_kw_loop);
4059
4060 if (!single_kh_kw_loop) {
4061 auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
4062 sub(reg_ker, ker_reset_offset);
4063 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4064
4065 add(reg_src, first_src_block_step);
4066 add(reg_ddst, ddst_block_step);
4067
4068 int num_innermost_iters
4069 = (jcp.oh - h_last_block_size) / h_block_size - 1;
4070 if (num_innermost_iters > 0) {
4071 Label h_block_loop;
4072
4073 mov(reg_tmp_w, num_innermost_iters);
4074 kmovw(reg_h_block, reg_tmp_w);
4075 L(h_block_loop);
4076 {
4077 emit_kh_kw_loop(false, false);
4078 sub(reg_ker, ker_reset_offset);
4079 add(reg_src, src_row_step * h_block_size);
4080 add(reg_ddst, ddst_block_step);
4081
4082 kmovw(reg_tmp_w, reg_h_block);
4083 sub(reg_tmp_w, 1);
4084 kmovw(reg_h_block, reg_tmp_w);
4085 jnz(h_block_loop);
4086 }
4087 }
4088
4089 emit_kh_kw_loop(false, true);
4090 }
4091 }
4092
compute_ic_loop(int ic_block,int nb_ic_blocking,int nb_oc_blocking)4093 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_ic_loop(
4094 int ic_block, int nb_ic_blocking, int nb_oc_blocking) {
4095 assert(jcp.ur_w % 2 == 0);
4096 const int str_w = jcp.stride_w;
4097 assert(jcp.tr_iw % str_w == 0);
4098 const int src_stride_w_shift = jcp.tr_iw / str_w;
4099
4100 mov(reg_b_stride, 64);
4101 mov(reg_a_stride, jcp.tr_iw * jcp.typesize_in);
4102
4103 for (int s = 0; s < str_w; s++) {
4104 for (int i_kw = s; i_kw < jcp.kw; i_kw += str_w) {
4105
4106 for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4107 for (int icb = 0; icb < nb_ic_blocking; icb++)
4108 tileloadd(Tmm(get_wei_tensor(ocb, icb)),
4109 ptr[reg_kernel + reg_b_stride
4110 + get_full_kernel_offset(
4111 ocb, icb, 0, i_kw)]);
4112
4113 int src_offset_l = (i_kw * (jcp.dilate_w + 1)) / str_w
4114 + s * src_stride_w_shift;
4115
4116 for (int ur_w_b = 0; ur_w_b < jcp.ur_w_blocks; ur_w_b++) {
4117 dim_t ur_w_src_offset = ur_w_b
4118 * get_src_offset(0, filter_w_to_src(0, jcp.ur_w, 0));
4119 dim_t ur_w_ddst_offset = ur_w_b * get_ddst_offset(jcp.ur_w);
4120 for (int icb = 0; icb < nb_ic_blocking; icb++) {
4121 dim_t icb_offset = icb * jcp.tr_src_buf_size;
4122 tileloadd(Tmm(get_src_tensor(icb)),
4123 ptr[reg_src
4124 + jcp.typesize_in
4125 * (src_offset_l + icb_offset)
4126 + ur_w_src_offset + reg_a_stride]);
4127 }
4128 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4129 tileloadd(Tmm(get_ddst_tensor(ocb)),
4130 ptr[reg_ddst
4131 + jcp.typesize_in * ocb
4132 * jcp.tr_diff_dst_buf_size
4133 + ur_w_ddst_offset + reg_b_stride]);
4134 for (int icb = 0; icb < nb_ic_blocking; icb++)
4135 tdpbf16ps(Tmm(get_wei_tensor(ocb, icb)),
4136 Tmm(get_src_tensor(icb)),
4137 Tmm(get_ddst_tensor(ocb)));
4138 }
4139 }
4140
4141 for (int ocb = 0; ocb < nb_oc_blocking; ocb++)
4142 for (int icb = 0; icb < nb_ic_blocking; icb++)
4143 tilestored(ptr[reg_kernel + reg_b_stride
4144 + get_full_kernel_offset(
4145 ocb, icb, 0, i_kw)],
4146 Tmm(get_wei_tensor(ocb, icb)));
4147 }
4148 }
4149 safe_add(reg_src, get_src_offset(ic_block, 0), reg_long_offt);
4150 add(reg_kernel, get_kernel_offset(ic_block, 0));
4151 }
4152
compute_diff_bias_init(int ocb)4153 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_init(int ocb) {
4154 auto reg_unit_val = reg_tmp.cvt16();
4155 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
4156 vpbroadcastw(vreg_bias_unit, reg_unit_val);
4157
4158 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4159 vmovups(vreg_bias_acc, ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block]);
4160 }
4161
compute_diff_bias_row(bool is_partial,int ocb)4162 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_diff_bias_row(
4163 bool is_partial, int ocb) {
4164 if (!jcp.with_bias) return;
4165 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4166 Label skip_label;
4167 test(reg_tmp, FLAG_IC_FIRST);
4168 jz(skip_label, T_NEAR);
4169
4170 if (is_partial) { compute_diff_bias_init(ocb); }
4171 auto compute_step = [&]() {
4172 vmovups(vreg_bias_ddst, ptr[reg_ddst]);
4173 vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
4174 };
4175
4176 Label ow_loop, ow_tail;
4177 int niters = jcp.tr_ow / 2;
4178 if (niters > 0) {
4179 mov(reg_tmp, jcp.tr_ow / 2);
4180 L(ow_loop);
4181 compute_step();
4182 add(reg_ddst, get_ddst_offset(2));
4183 sub(reg_tmp, 1);
4184 jnz(ow_loop, T_NEAR);
4185 }
4186 if (jcp.tr_ow % 2) compute_step();
4187
4188 if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
4189
4190 if (is_partial) {
4191 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4192 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4193 vreg_bias_acc);
4194 }
4195
4196 L(skip_label);
4197 }
4198
maybe_compute_diff_bias(int nb_oc_blocking)4199 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_compute_diff_bias(
4200 int nb_oc_blocking) {
4201 // In harness_3d_reduction case calculation of diff_bias is called
4202 // for every ow row separately to be aligned with od loop in
4203 // compute_od_loop_common()
4204 if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
4205 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4206
4207 Label skip_label;
4208 test(reg_tmp, FLAG_IC_FIRST);
4209 jz(skip_label, T_NEAR);
4210
4211 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4212 Label bias_loop, skip_label_local;
4213
4214 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4215 add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size);
4216
4217 switch (jcp.harness) {
4218 case harness_2d_reduction:
4219 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4220 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4221 break;
4222 case harness_mb_reduction:
4223 case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
4224 case harness_3d_reduction:
4225 default: assert(!"Invalid harness type");
4226 }
4227
4228 cmp(reg_oj, 0);
4229 jle(skip_label_local, T_NEAR); // nothing to do
4230
4231 compute_diff_bias_init(ocb);
4232 L(bias_loop);
4233 {
4234 compute_diff_bias_row(false, ocb);
4235 add(reg_ddst, get_ddst_offset(0, 1));
4236
4237 sub(reg_oj, 1);
4238 jnz(bias_loop, T_NEAR);
4239 }
4240
4241 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4242 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block],
4243 vreg_bias_acc);
4244
4245 L(skip_label_local);
4246 }
4247 // restore reg_ddst value
4248 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4249
4250 L(skip_label);
4251 }
4252
compute_oh_step_common(int nb_ic_blocking,int nb_oc_blocking)4253 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_step_common(
4254 int nb_ic_blocking, int nb_oc_blocking) {
4255 Label kh_label, ic_block_label, ow_block_label, kd_label;
4256
4257 int ic_block = jcp.ic_block;
4258 int ic_tail = jcp.ic_tail;
4259
4260 auto ic_loop = [&](int nb_ic_blocking, int nb_oc_blocking) {
4261 Label ic_tail_label, ic_loop_done_label;
4262
4263 if (ic_tail) {
4264 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
4265 cmp(reg_icb, jcp.ic_tail);
4266 jne(ic_tail_label, T_NEAR);
4267
4268 compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4269 jmp(ic_loop_done_label, T_NEAR);
4270
4271 L(ic_tail_label);
4272 compute_ic_loop(ic_tail, nb_ic_blocking, nb_oc_blocking);
4273 add(reg_kernel, get_kernel_offset(jcp.ic_block - ic_tail, 0));
4274 safe_add(reg_src,
4275 get_src_offset(0, 0, filter_h_to_src(1))
4276 - get_src_offset(ic_tail, 0),
4277 reg_long_offt);
4278 L(ic_loop_done_label);
4279 } else {
4280 compute_ic_loop(ic_block, nb_ic_blocking, nb_oc_blocking);
4281 }
4282 };
4283
4284 if (jcp.ndims == 5) {
4285 /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
4286 * 'movs' must be guaranteed. */
4287 mov(ki, reg_kd_count);
4288 mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
4289 mov(aux_reg_src, reg_src);
4290 mov(aux_reg_kernel, reg_kernel);
4291
4292 L(kd_label);
4293 mov(reg_src, aux_reg_src);
4294 mov(reg_kernel, aux_reg_kernel);
4295 }
4296
4297 mov(kj, reg_kh);
4298 L(kh_label);
4299 {
4300 ic_loop(nb_ic_blocking, nb_oc_blocking);
4301
4302 if (jcp.dilate_h > 0) {
4303 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
4304 }
4305 // substract pointer shift made within ic block loop
4306 // and move to next kh index
4307 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
4308 dec(kj);
4309 cmp(kj, 0);
4310 jg(kh_label, T_NEAR);
4311 }
4312 if (jcp.ndims == 5) {
4313 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
4314 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
4315 dec(ki);
4316 cmp(ki, 0);
4317 jg(kd_label, T_NEAR);
4318 }
4319 // In harness_3d_reduction case calculation of diff_bias is called
4320 // for every ow row separately to be aligned with od loop in
4321 // compute_od_loop_common()
4322 if (jcp.harness == harness_3d_reduction) {
4323 auto reg_save_ddst = reg_a_stride;
4324 mov(reg_save_ddst, reg_ddst);
4325 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4326 safe_add(reg_ddst, jcp.typesize_in * ocb * jcp.tr_diff_dst_buf_size,
4327 reg_long_offt);
4328 compute_diff_bias_row(true, ocb);
4329 }
4330 mov(reg_ddst, reg_save_ddst);
4331 }
4332
4333 if (jcp.ndims == 5) {
4334 mov(reg_src, aux_reg_src);
4335 mov(reg_kernel, aux_reg_kernel);
4336 mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
4337 od_step_comeback_pointers();
4338 } else {
4339 oh_step_comeback_pointers();
4340 }
4341 }
4342
maybe_zero_kernel(int nb_ic_blocking,int nb_oc_blocking)4343 void jit_avx512_core_amx_bwd_weights_kernel_t::maybe_zero_kernel(
4344 int nb_ic_blocking, int nb_oc_blocking) {
4345 if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
4346 Label skip_zeroing, zeroing_loop;
4347
4348 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4349 cmp(reg_tmp, 0);
4350 jz(skip_zeroing, T_NEAR);
4351
4352 Zmm zero = Zmm(0);
4353 vpxord(zero, zero, zero);
4354 if (jcp.with_bias) {
4355 Label skip_bias_zeroing;
4356 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
4357 test(reg_tmp, FLAG_IC_FIRST);
4358 jz(skip_bias_zeroing, T_NEAR);
4359 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4360 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
4361 vmovups(ptr[reg_tmp + sizeof(float) * ocb * jcp.oc_block], zero);
4362 }
4363 L(skip_bias_zeroing);
4364 if (jcp.harness == harness_compute_full_spatial)
4365 jmp(skip_zeroing, T_NEAR);
4366 }
4367
4368 mov(reg_b_stride, 64);
4369 tilezero(Tmm(get_wei_tensor(0, 0)));
4370 for (dim_t shift = 0;
4371 shift < get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
4372 shift += get_kernel_offset(jcp.ic_block, 0)) {
4373 for_(int icb = 0; icb < nb_ic_blocking; icb++)
4374 for (int ocb = 0; ocb < nb_oc_blocking; ocb++) {
4375 tilestored(
4376 ptr[reg_kernel + reg_b_stride
4377 + get_full_kernel_offset(ocb, icb, 0, 0) + shift],
4378 Tmm(get_wei_tensor(0, 0)));
4379 }
4380 }
4381 L(skip_zeroing);
4382 }
4383
compute_oh_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4384 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_oh_loop_common(
4385 int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4386 int b_pad = jcp.b_pad;
4387 int t_pad = jcp.t_pad;
4388
4389 bool is_dilated = jcp.dilate_h != 0;
4390 int dilate_h = jcp.dilate_h + 1;
4391 int stride_h = jcp.stride_h;
4392 auto filter_step_size = get_kernel_offset(0, jcp.kw);
4393 auto src_step_size = get_src_offset(0, 0, 1);
4394 auto ddst_step_size = get_ddst_offset(0, 1);
4395 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
4396 oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
4397 oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
4398 oh_dilate_label_end, oh_dilate_setup_label_shift,
4399 oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
4400
4401 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4402 int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
4403 int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
4404 int oh_head_overflow_end = div_up(t_pad, stride_h);
4405 int oh_tail_end = jcp.oh;
4406
4407 int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
4408 int ih_body_end
4409 = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
4410
4411 if (is_partial)
4412 mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
4413 else
4414 xor_(reg_oj, reg_oj);
4415
4416 /* Compute 'top' edge */
4417 if (t_pad > 0) {
4418 if (is_partial) {
4419 cmp(reg_oj, oh_head_overflow_end);
4420 jge(oh_tpad_tail_label_end, T_NEAR);
4421 }
4422 const int overflow
4423 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
4424 const int underflow = div_up(t_pad, dilate_h);
4425 const int initial_kh = jcp.kh - overflow - underflow;
4426
4427 // Setup reg_kh, reg_kernel, and reg_src
4428 mov(reg_kh, initial_kh);
4429 add(reg_kernel, filter_step_size * underflow);
4430 if (is_dilated) {
4431 const int tail = t_pad % dilate_h;
4432 const int shift = tail == 0 ? 0 : dilate_h - tail;
4433 mov(reg_ih_shift, shift);
4434 if (!is_partial) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4435 add(reg_src, src_step_size * shift);
4436 }
4437
4438 if (is_partial) {
4439 Label head_setup, head_setup_finish;
4440 cmp(reg_oj, 0);
4441 je(head_setup_finish, T_NEAR);
4442 mov(reg_oj_setup, reg_oj);
4443
4444 L(head_setup);
4445 if (is_dilated) {
4446 inc(reg_ih_shift);
4447 cmp(reg_ih_shift, dilate_h);
4448 jl(oh_dilate_setup_label_shift, T_NEAR);
4449 // unshift src as new kernel element enters
4450 sub(reg_src, src_step_size * (dilate_h - 1));
4451 xor_(reg_ih_shift, reg_ih_shift);
4452 }
4453 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4454 add(reg_kh, stride_h);
4455 sub(reg_kernel, filter_step_size * stride_h);
4456 if (is_dilated) {
4457 jmp(oh_dilate_setup_label_noshift, T_NEAR);
4458 L(oh_dilate_setup_label_shift);
4459 // shift src as old kernel element progresses
4460 add(reg_src, src_step_size * stride_h);
4461 L(oh_dilate_setup_label_noshift);
4462 }
4463 sub(reg_oj_setup, 1);
4464 jg(head_setup, T_NEAR);
4465 L(head_setup_finish);
4466
4467 if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4468 if (oh_head_end < oh_head_overflow_end) {
4469 cmp(reg_oj, oh_head_end);
4470 jge(oh_tpad_label_end, T_NEAR);
4471 }
4472 }
4473
4474 //Setup reg_kernel
4475 // If dilated, shift src ptr
4476 // Loop
4477 L(oh_tpad_label);
4478 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4479 add(reg_ddst, ddst_step_size);
4480 if (is_dilated) {
4481 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4482 inc(reg_ih_shift);
4483 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4484 cmp(reg_ih_shift, dilate_h);
4485 jl(oh_dilate_label_shift, T_NEAR);
4486 // unshift src as new kernel element enters
4487 sub(reg_src, src_step_size * (dilate_h - 1));
4488 xor_(reg_ih_shift, reg_ih_shift);
4489 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4490 }
4491 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
4492 add(reg_kh, stride_h);
4493 sub(reg_kernel, filter_step_size * stride_h);
4494 if (is_dilated) {
4495 jmp(oh_dilate_label_noshift, T_NEAR);
4496 L(oh_dilate_label_shift);
4497 // shift src as old kernel element progresses
4498 add(reg_src, src_step_size * stride_h);
4499 L(oh_dilate_label_noshift);
4500 }
4501 inc(reg_oj);
4502
4503 if (is_partial) {
4504 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4505 jge(oh_bpad_label_end, T_NEAR);
4506 }
4507 cmp(reg_oj, oh_head_end);
4508 jl(oh_tpad_label, T_NEAR);
4509
4510 L(oh_tpad_label_end);
4511 // need second loop to process kernel if it is larger than the src
4512 // (does not apply to dilations as they must have unit stride)
4513 if (oh_head_end < oh_head_overflow_end) {
4514 assert(!is_dilated);
4515
4516 cmp(reg_oj, oh_head_overflow_end);
4517 jge(oh_tpad_tail_label_end, T_NEAR);
4518
4519 mov(reg_kh, jcp.ih);
4520 L(oh_tpad_tail_label);
4521 {
4522 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4523 add(reg_ddst, ddst_step_size);
4524 sub(reg_kernel, filter_step_size * stride_h);
4525
4526 inc(reg_oj);
4527
4528 if (is_partial) {
4529 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4530 jge(oh_bpad_label_end, T_NEAR);
4531 }
4532 cmp(reg_oj, oh_head_overflow_end);
4533 jl(oh_tpad_tail_label, T_NEAR);
4534 }
4535 }
4536 if (body_src_start_offset != 0) {
4537 add(reg_kernel, filter_step_size * body_src_start_offset);
4538 add(reg_src, src_step_size * body_src_start_offset);
4539 }
4540 L(oh_tpad_tail_label_end);
4541 }
4542
4543 if (is_partial) {
4544 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4545 jge(oh_bpad_label_end, T_NEAR);
4546 }
4547 cmp(reg_oj, oh_body_end);
4548 jge(oh_label_end, T_NEAR);
4549
4550 /* Compute middle block(s) */
4551 mov(reg_kh, jcp.kh);
4552 L(oh_label);
4553 {
4554 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4555 add(reg_src, src_step_size * stride_h);
4556 add(reg_ddst, ddst_step_size);
4557
4558 inc(reg_oj);
4559
4560 if (is_partial) {
4561 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4562 jge(oh_bpad_label_end, T_NEAR);
4563 }
4564
4565 cmp(reg_oj, oh_body_end);
4566 jl(oh_label, T_NEAR);
4567 }
4568 L(oh_label_end);
4569
4570 /* Compute bottom edge */
4571 if (b_pad > 0) {
4572 if (is_partial) {
4573 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4574 jge(oh_bpad_label_end, T_NEAR);
4575 }
4576 cmp(reg_oj, jcp.oh);
4577 jge(oh_bpad_label_end, T_NEAR);
4578
4579 if (is_dilated) {
4580 // Assumes unit stride for dilations
4581 mov(reg_kh, jcp.kh - 1);
4582 xor_(reg_ih_shift, reg_ih_shift);
4583 } else {
4584 assert(jcp.dilate_h == 0);
4585 mov(reg_kh, jcp.ih - ih_body_end);
4586 }
4587 if (is_partial) {
4588 lea(reg_oj_setup,
4589 ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
4590 if (stride_h == 1 && !is_dilated) {
4591 sub(reg_kh, reg_oj_setup);
4592 } else {
4593 Label body_setup, body_setup_finish, dilate_skip;
4594 cmp(reg_oj_setup, 0);
4595 je(body_setup_finish, T_NEAR);
4596
4597 L(body_setup);
4598 if (is_dilated) {
4599 inc(reg_ih_shift);
4600 cmp(reg_ih_shift, dilate_h);
4601 jl(dilate_skip, T_NEAR);
4602 xor_(reg_ih_shift, reg_ih_shift);
4603 }
4604 sub(reg_kh, stride_h);
4605 L(dilate_skip);
4606 sub(reg_oj_setup, 1);
4607 jg(body_setup, T_NEAR);
4608 L(body_setup_finish);
4609 }
4610 }
4611
4612 if (is_dilated) mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4613 L(oh_bpad_label);
4614 {
4615 compute_oh_step_common(nb_ic_blocking, nb_oc_blocking);
4616 add(reg_src, src_step_size * stride_h);
4617 add(reg_ddst, ddst_step_size);
4618
4619 if (is_dilated) {
4620 mov(reg_ih_shift, ptr[rsp + ih_dilate_offset]);
4621 inc(reg_ih_shift);
4622 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4623 cmp(reg_ih_shift, dilate_h);
4624 jl(oh_dilate_label_end, T_NEAR);
4625 xor_(reg_ih_shift, reg_ih_shift);
4626 mov(ptr[rsp + ih_dilate_offset], reg_ih_shift);
4627 }
4628 sub(reg_kh, stride_h);
4629 L(oh_dilate_label_end);
4630 inc(reg_oj);
4631 if (is_partial) {
4632 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
4633 jge(oh_bpad_label_end, T_NEAR);
4634 }
4635 cmp(reg_oj, oh_tail_end);
4636 jl(oh_bpad_label, T_NEAR);
4637 }
4638 }
4639 L(oh_bpad_label_end);
4640 }
4641
compute_od_loop_common(int nb_ic_blocking,int nb_oc_blocking,bool is_partial)4642 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_od_loop_common(
4643 int nb_ic_blocking, int nb_oc_blocking, bool is_partial) {
4644 assert(jcp.harness == harness_3d_reduction);
4645
4646 const int src_backpad_overlap
4647 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
4648
4649 const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
4650 const auto src_shift = get_src_offset(0, 0, jcp.ih);
4651 const auto ddst_shift = get_ddst_offset(0, jcp.oh);
4652
4653 const int kd_front_pad = nstl::max(0, jcp.f_pad);
4654 const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
4655
4656 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
4657 backpad_end_label, backpad_label;
4658
4659 /* initially offset 'kd' by f_pad */
4660 mov(reg_src_d, ptr[param + GET_OFF(src)]);
4661 mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
4662
4663 if (is_partial) {
4664 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
4665 mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
4666 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
4667 } else {
4668 const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
4669 const int kd_offset = get_kernel_offset(
4670 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
4671 add(reg_kernel, kd_offset);
4672 xor_(reg_d_index, reg_d_index);
4673 mov(reg_kd_count, kd_padding);
4674 }
4675
4676 cmp(reg_kd_count, 0);
4677 jle(loop_end_label, T_NEAR); // no iterations along kd
4678 if (is_partial)
4679 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4680 else
4681 cmp(reg_d_index, jcp.od);
4682 jge(loop_end_label, T_NEAR); // no iterations along depth dimension
4683
4684 L(d_loop_label);
4685
4686 mov(reg_src, reg_src_d);
4687 mov(reg_ddst, reg_ddst_d);
4688
4689 mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
4690 mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
4691 mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
4692
4693 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4694
4695 mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
4696 mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
4697 mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
4698
4699 /* Compute 'front' edge */
4700 if (jcp.f_pad > 0) {
4701 /* Check if within fpad region */
4702 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
4703 jge(fpad_end_label, T_NEAR);
4704
4705 /* Fpad steps */
4706 sub(reg_kernel, filter_shift * jcp.stride_d);
4707 add(reg_kd_count, jcp.stride_d);
4708
4709 /* Final number of kernel elements that overlap with src */
4710 const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
4711 cmp(reg_kd_count, src_ker_overlap);
4712 jle(common_block_label, T_NEAR);
4713
4714 /* Correct any excess shifts to kernel and src */
4715 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
4716 /* Filter has moved beyond padding (adjust for stride effects) */
4717 if (jcp.f_pad % jcp.stride_d != 0) {
4718 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
4719 add(reg_kernel, filter_shift * src_corr);
4720 add(reg_src_d, src_shift * src_corr);
4721 }
4722 } else {
4723 /* Filter still overlaps padding (complete reset) */
4724 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
4725 }
4726
4727 /* Apply correction */
4728 mov(reg_kd_count, src_ker_overlap);
4729 jmp(common_block_label);
4730
4731 L(fpad_end_label);
4732 }
4733
4734 /* Compute bottom edge */
4735 if (jcp.back_pad > 0) {
4736
4737 /* Check if within back_pad region */
4738 cmp(reg_d_index, src_backpad_overlap - 1);
4739 jl(backpad_end_label, T_NEAR);
4740 jg(backpad_label, T_NEAR);
4741
4742 /* Execute overlap correction between the filter and the initial
4743 * back_pad region. */
4744 mov(reg_kd_count,
4745 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
4746 jmp(backpad_end_label, T_NEAR);
4747
4748 L(backpad_label);
4749 sub(reg_kd_count, jcp.stride_d);
4750 cmp(reg_kd_count, 0);
4751 jle(loop_end_label, T_NEAR);
4752
4753 L(backpad_end_label);
4754 }
4755
4756 /* Compute middle block */
4757 add(reg_src_d, src_shift * jcp.stride_d);
4758
4759 /* Execute common block and loop */
4760 L(common_block_label);
4761 add(reg_ddst_d, ddst_shift);
4762 inc(reg_d_index);
4763 if (is_partial)
4764 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
4765 else
4766 cmp(reg_d_index, jcp.od);
4767 jl(d_loop_label, T_NEAR);
4768
4769 L(loop_end_label);
4770 }
4771
compute_loop(int nb_ic_blocking,int nb_oc_blocking)4772 void jit_avx512_core_amx_bwd_weights_kernel_t::compute_loop(
4773 int nb_ic_blocking, int nb_oc_blocking) {
4774 mov(reg_src, ptr[param + GET_OFF(src)]);
4775 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4776 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4777
4778 maybe_zero_kernel(nb_ic_blocking, nb_oc_blocking);
4779 maybe_compute_diff_bias(nb_oc_blocking);
4780
4781 switch (jcp.harness) {
4782 case harness_3d_reduction:
4783 compute_od_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4784 break;
4785 case harness_2d_reduction:
4786 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking, true);
4787 break;
4788 case harness_mb_reduction:
4789 compute_oh_loop_common(nb_ic_blocking, nb_oc_blocking);
4790 break;
4791 case harness_compute_full_spatial:
4792 compute_full_spat_loop(nb_ic_blocking, nb_oc_blocking);
4793 break;
4794 default: assert(!"Invalid harness type");
4795 }
4796 }
4797
setup_stack_space()4798 void jit_avx512_core_amx_bwd_weights_kernel_t::setup_stack_space() {
4799 kd_count_offset = ic_block_step_stack_size;
4800 src_d_offset = ic_block_step_stack_size + 8;
4801 ddst_d_offset = ic_block_step_stack_size + 16;
4802 d_index_offset = ic_block_step_stack_size + 24;
4803 ih_dilate_offset = ic_block_step_stack_size + 32;
4804 src_save_offset = ic_block_step_stack_size + 40;
4805 ddst_save_offset = ic_block_step_stack_size + 48;
4806 stack_space_needed = ic_block_step_stack_size + 56;
4807 }
4808
generate()4809 void jit_avx512_core_amx_bwd_weights_kernel_t::generate() {
4810 preamble();
4811
4812 setup_stack_space();
4813
4814 sub(rsp, stack_space_needed);
4815
4816 Label last_ic_block_label, last_blocks_done_label;
4817
4818 mov(reg_tmp, ptr[param + GET_OFF(last_ic_block)]);
4819 cmp(reg_tmp, 0);
4820 jne(last_ic_block_label, T_NEAR);
4821 { // full nb_ic_blocking
4822 Label last_oc_block_label;
4823 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4824 cmp(reg_tmp, 0);
4825 jne(last_oc_block_label, T_NEAR);
4826 { // full nb_oc_blocking
4827 compute_loop(jcp.nb_ic_blocking, jcp.nb_oc_blocking);
4828 jmp(last_blocks_done_label, T_NEAR);
4829 }
4830 L(last_oc_block_label);
4831 { // tail of nb_oc_blocking
4832 compute_loop(jcp.nb_ic_blocking, 1);
4833 jmp(last_blocks_done_label, T_NEAR);
4834 }
4835 }
4836 L(last_ic_block_label);
4837 { // tail nb_ic_blocking
4838 Label last_oc_block_label;
4839 mov(reg_tmp, ptr[param + GET_OFF(last_oc_block)]);
4840 cmp(reg_tmp, 0);
4841 jne(last_oc_block_label, T_NEAR);
4842 { // full nb_oc_blocking
4843 compute_loop(1, jcp.nb_oc_blocking);
4844 jmp(last_blocks_done_label, T_NEAR);
4845 }
4846 L(last_oc_block_label);
4847 { // tail of nb_oc_blocking
4848 compute_loop(1, 1);
4849 jmp(last_blocks_done_label, T_NEAR);
4850 }
4851 }
4852
4853 L(last_blocks_done_label);
4854 add(rsp, stack_space_needed);
4855
4856 postamble();
4857 }
4858
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_bias_md,memory_desc_t & diff_dst_md,int nthreads)4859 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(
4860 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4861 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4862 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4863 const memory_desc_wrapper src_d(&src_md);
4864 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4865 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4866 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4867
4868 jcp = zero<decltype(jcp)>();
4869
4870 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4871 int ndims = src_d.ndims();
4872
4873 if (!mayiuse(avx512_core_bf16_amx_bf16)) return status::unimplemented;
4874 jcp.isa = avx512_core_bf16_amx_bf16;
4875
4876 jcp.ver = ver_vnni; // Needed for transpose routines
4877 jcp.nthr = nthreads;
4878
4879 jcp.ndims = ndims;
4880 jcp.prop_kind = cd.prop_kind;
4881
4882 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4883 jcp.mb = src_d.dims()[0];
4884
4885 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4886 jcp.oc_without_padding = jcp.oc;
4887 jcp.ic = src_d.dims()[1] / jcp.ngroups;
4888
4889 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4890 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4891 jcp.iw = src_d.dims()[ndims - 1];
4892 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4893 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4894 jcp.ow = diff_dst_d.dims()[ndims - 1];
4895
4896 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4897 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4898 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4899
4900 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4901 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4902 jcp.l_pad = cd.padding[0][ndims - 3];
4903
4904 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4905 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4906 jcp.stride_w = cd.strides[ndims - 3];
4907
4908 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4909 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4910 jcp.dilate_w = cd.dilates[ndims - 3];
4911
4912 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4913 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4914 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4915
4916 bool ok = true
4917 // general condition to simplify dilations
4918 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4919 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4920 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4921 // special condition to simplify dilations in compute_oh_loop_common
4922 && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4923 if (!ok) return status::unimplemented;
4924
4925 ok = true && one_of(ndims, 3, 4, 5)
4926 && everyone_is(
4927 data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4928 && one_of(diff_weights_d.data_type(), data_type::f32,
4929 data_type::bf16);
4930 if (!ok) return status::unimplemented;
4931
4932 jcp.transform_to_vnni = diff_weights_d.data_type() == data_type::bf16;
4933
4934 jcp.r_pad = nstl::max(0,
4935 calculate_end_padding(
4936 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4937 jcp.b_pad = nstl::max(0,
4938 calculate_end_padding(
4939 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4940 jcp.back_pad = nstl::max(0,
4941 calculate_end_padding(
4942 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4943
4944 /* XXX: no support for padding when dilation_d > 0 */
4945 if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4946 return status::unimplemented;
4947
4948 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4949 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4950 jcp.ohp = jcp.oh;
4951 jcp.owp = jcp.ow;
4952
4953 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
4954 if (jcp.is_depthwise)
4955 return status::unimplemented; // TODO: add support of DW convolution
4956
4957 const int dat_format_tag = ndims - 3;
4958 format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc,
4959 format_tag::nhwc, format_tag::ndhwc);
4960 format_tag_t dat_tag_opt = dat_tag_nspc;
4961
4962 if (src_d.format_kind() == format_kind::any) {
4963 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
4964 jcp.src_tag = dat_tag_opt;
4965 } else
4966 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt);
4967 if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented;
4968 jcp.is_nspc = jcp.src_tag == dat_tag_nspc;
4969
4970 if (diff_dst_d.format_kind() == format_kind::any) {
4971 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
4972 jcp.dst_tag = jcp.src_tag;
4973 } else
4974 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
4975 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
4976
4977 if (!jcp.is_nspc) return status::unimplemented;
4978
4979 const int wei_format_tag = 2 * ndims - 6 + with_groups;
4980 format_tag_t wei_tag;
4981 if (jcp.transform_to_vnni)
4982 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i,
4983 format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i,
4984 format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i,
4985 format_tag::gOIdhw16i16o2i);
4986 else
4987 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o,
4988 format_tag::gOIw16i16o, format_tag::OIhw16i16o,
4989 format_tag::gOIhw16i16o, format_tag::OIdhw16i16o,
4990 format_tag::gOIdhw16i16o);
4991 if (diff_weights_md.format_kind == format_kind::any) {
4992 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4993 jcp.wei_tag = wei_tag;
4994 } else {
4995 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4996 if (jcp.wei_tag != wei_tag) return status::unimplemented;
4997 }
4998 jcp.wei_dt = diff_weights_d.data_type();
4999
5000 /* conditions on bias memory */
5001 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
5002 if (jcp.with_bias) {
5003 if (diff_bias_d.format_kind() == format_kind::any)
5004 CHECK(memory_desc_init_by_tag(diff_bias_md, format_tag::x));
5005 }
5006 jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
5007 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
5008
5009 /* kernel applicability check wrt boundaries
5010 * the conditions are quite general across the kernels we have,
5011 * but ideally the check should belong to a specific kernel... */
5012 const int max_pad_h = ext_kh / 2;
5013 const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
5014 && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
5015 && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd;
5016 if (!boundaries_ok) return status::unimplemented;
5017
5018 jcp.ic_block = 16;
5019 jcp.oc_block = 16;
5020
5021 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
5022 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
5023
5024 jcp.ic_tail = jcp.ic % jcp.ic_block;
5025 jcp.oc_tail = jcp.oc % jcp.oc_block;
5026
5027 jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1;
5028 jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1;
5029
5030 int max_palette = amx::get_max_palette();
5031 jcp.max_tiles = amx::get_max_tiles(max_palette);
5032 jcp.full_tile_width = amx::get_max_rows(max_palette);
5033
5034 if (jcp.max_tiles != 8 || jcp.full_tile_width != 16)
5035 return status::unimplemented;
5036
5037 const bool is_2d = (ndims == 4);
5038 const bool is_3d = (ndims == 5);
5039 jcp.typesize_in = sizeof(bfloat16_t);
5040 jcp.typesize_out = sizeof(float);
5041
5042 // TODO: Find more shapes (especially 3D with large spatials) for which
5043 // local transposition will be beneficial. Furthermore, for TBB threads
5044 // more shapes can potentially benefit from spatial blocking
5045 int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
5046
5047 jcp.global_transpose = dnnl_thr_syncable();
5048 jcp.spatial_blk_size = optimal_blk_size;
5049
5050 const int tr_round = 32; // To load full tile register
5051 int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
5052 jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
5053 * jcp.stride_w;
5054
5055 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
5056 jcp.tr_ow = rnd_up(jcp.ow, 2);
5057
5058 if (jcp.tr_ow <= max_ur_w) {
5059 jcp.ur_w = jcp.tr_ow;
5060 jcp.ur_w_blocks = 1;
5061 } else {
5062 jcp.ur_w = 1;
5063 for (int i = max_ur_w; i >= 1; i -= 2) {
5064 if (jcp.tr_ow % i == 0) {
5065 jcp.ur_w = i;
5066 break;
5067 }
5068 }
5069 jcp.ur_w_blocks = jcp.tr_ow / jcp.ur_w;
5070 }
5071
5072 bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
5073 && jcp.oc <= diff_dst_d.padded_dims()[1]
5074 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
5075 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
5076 if (!args_ok) return status::unimplemented;
5077
5078 bool use_full_spat_loop = jcp.ndims < 5 && jcp.ih == jcp.oh
5079 && jcp.iw == jcp.ow && everyone_is(1, jcp.stride_h, jcp.stride_w)
5080 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
5081 // TODO: Remove this constraint: only 3x3 kernel works now
5082 && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
5083 && one_of(1, jcp.l_pad, jcp.r_pad) && jcp.kh == jcp.kw
5084 && jcp.ih >= jcp.kh && jcp.iw >= jcp.kw;
5085
5086 jcp.harness = ndims == 5
5087 ? harness_3d_reduction
5088 : (use_full_spat_loop ? harness_compute_full_spatial
5089 : (ndims == 4) ? harness_2d_reduction
5090 : harness_mb_reduction);
5091 switch (jcp.harness) {
5092 case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
5093 case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
5094 case harness_compute_full_spatial:
5095 case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
5096 default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
5097 }
5098 { // balancing
5099 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
5100 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
5101 jcp.nthr = nthr;
5102 jcp.nthr_mb = nthr_mb;
5103 jcp.nthr_g = nthr_g;
5104 jcp.nthr_oc_b = nthr_oc_b;
5105 jcp.nthr_ic_b = nthr_ic_b;
5106
5107 // TODO: Optimize memory allocation when threaded on height and depth
5108 jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
5109 jcp.tr_src_buf_count = jcp.global_transpose
5110 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
5111 : jcp.nthr;
5112
5113 jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
5114 jcp.tr_diff_dst_buf_count = jcp.global_transpose
5115 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
5116 : jcp.nthr;
5117 }
5118
5119 return status::success;
5120 }
5121
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,memory_desc_t & src_md,memory_desc_t & diff_weights_md,memory_desc_t & diff_dst_md)5122 status_t jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad(
5123 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
5124 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
5125 memory_desc_t &diff_dst_md) {
5126 const memory_desc_wrapper src_d(&src_md);
5127 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
5128 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
5129
5130 // XXX: See the comment about tr_iw and guarding elements in
5131 // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf()
5132 const size_t tr_src_size
5133 = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking)
5134 + jcp.tr_src_num_guard_elems;
5135 scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
5136
5137 /* prepare synchronization contexts */
5138 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
5139 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
5140 scratchpad.book<simple_barrier::ctx_t>(
5141 key_conv_tr_src_bctx, tr_src_bctx_size);
5142 }
5143
5144 const size_t tr_diff_dst_size = jcp.tr_diff_dst_buf_count
5145 * jcp.tr_diff_dst_buf_size * jcp.nb_oc_blocking;
5146
5147 const size_t min_align = 64;
5148 scratchpad.book(
5149 key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in, min_align);
5150
5151 /* prepare synchronization contexts */
5152 if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
5153 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
5154 scratchpad.book<simple_barrier::ctx_t>(
5155 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
5156 }
5157
5158 if (IMPLICATION(jcp.nthr_mb == 1,
5159 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
5160 || jcp.wei_dt == data_type::bf16)) {
5161 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
5162 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
5163 const size_t bia_size
5164 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
5165
5166 const int num_wei_buffers
5167 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
5168 const int num_bia_buffers = jcp.with_bias
5169 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
5170 : jcp.nthr_mb - 1)
5171 : 0;
5172
5173 const size_t wei_bia_reduction_size
5174 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
5175
5176 scratchpad.book<float>(
5177 key_conv_wei_bia_reduction, wei_bia_reduction_size);
5178
5179 scratchpad.book<simple_barrier::ctx_t>(
5180 key_conv_wei_bia_reduction_bctx, 1);
5181 }
5182
5183 if (jcp.with_bias
5184 && ((jcp.oc_without_padding % jcp.oc_block != 0)
5185 && jcp.bia_dt == data_type::f32)) {
5186 scratchpad.book(key_conv_padded_bias,
5187 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
5188 }
5189 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
5190
5191 constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
5192 << 30; // 32Gb - TODO: may it's too large?
5193 const size_t scratchpad_limit_by_tensor_sizes = (size_t)32 * jcp.nthr
5194 * (src_d.size() + diff_weights_d.size() + diff_dst_d.size());
5195 const size_t scratchpad_limit
5196 = nstl::min(scratchpad_limit_by_absolute_value,
5197 scratchpad_limit_by_tensor_sizes);
5198 if (scratchpad.size() > scratchpad_limit)
5199 return status::unimplemented;
5200 else
5201 return status::success;
5202 }
5203
balance(const jit_conv_conf_t & j,int & nthr_,int & nthr_mb_,int & nthr_g_,int & nthr_oc_b_,int & nthr_ic_b_)5204 void jit_avx512_core_amx_bwd_weights_kernel_t::balance(const jit_conv_conf_t &j,
5205 int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_,
5206 int &nthr_ic_b_) {
5207 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
5208
5209 const int max_threads = dnnl_get_max_threads();
5210
5211 if (max_threads < j.ngroups) {
5212 /* simplification... fortunately it doesn't hurt much */
5213 nthr_ = nthr_g_ = max_threads;
5214 return;
5215 }
5216
5217 nthr_g_ = j.ngroups;
5218 const int nthr = max_threads / nthr_g_;
5219
5220 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
5221 /* calculate per thread memory cost (read/write). high level optimizer
5222 * tries to minimize memory consumption. few notes:
5223 * (n1) if weights tensor size is less than source and destination
5224 * tensors we apply the ratio of the source and destination
5225 * tensor sizes to weights one as compensation coefficient to
5226 * avoid parallelization across batch size only, othervise we
5227 * apply additional coefficient to source component based on
5228 * performance measurements
5229 * (n2) use scales based on output vs input channels ratio for source
5230 * and destination componets to imporve threading balance across
5231 * input and output channels */
5232
5233 const dim_t src_type_size = 2;
5234 const dim_t wei_type_size = 4;
5235
5236 dim_t src_size
5237 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
5238 dim_t dst_size
5239 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
5240 dim_t wei_size
5241 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
5242
5243 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
5244 float oi_channels_ratio = (float)(j.nb_oc / j.nb_oc_blocking)
5245 / (j.nb_ic / j.nb_ic_blocking);
5246 auto get_src_coef = [=]() {
5247 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
5248 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
5249
5250 return src_coef;
5251 };
5252
5253 auto get_dst_coef
5254 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
5255
5256 auto get_wei_coef
5257 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
5258
5259 const float src_coef = get_src_coef();
5260 const float dst_coef = get_dst_coef();
5261 const float wei_coef = get_wei_coef();
5262
5263 float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
5264 * div_up(j.ngroups, nthr_g_)
5265 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.mb
5266 * (j.ic_block * j.nb_ic_blocking) * j.id * j.ih * j.tr_iw
5267 / j.nthr_mb_work / j.stride_d / j.stride_h / j.stride_w;
5268 float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
5269 * div_up((j.nb_oc / j.nb_oc_blocking),
5270 (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5271 * div_up((j.nb_ic / j.nb_ic_blocking), nthr_ic_b) * j.kh * j.kw
5272 * j.kd * (j.ic_block * j.nb_ic_blocking)
5273 * (j.oc_block * j.nb_oc_blocking);
5274 float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
5275 * div_up(j.ngroups, nthr_g_)
5276 * div_up((j.nb_oc / j.nb_oc_blocking),
5277 (j.oc_block * j.nb_oc_blocking) * nthr_oc_b)
5278 * j.mb * (j.oc_block * j.nb_oc_blocking) * j.od * j.oh * j.tr_ow
5279 / j.nthr_mb_work;
5280
5281 return src_v + dst_v + wei_v;
5282 };
5283
5284 float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
5285
5286 /* find the best thread distribution with lowest memory cost */
5287
5288 const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
5289 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
5290 const int nthr_par = nthr / nthr_mb;
5291 const int nthr_oc_b_max = nstl::min(nthr_par,
5292 (j.nb_oc / j.nb_oc_blocking)); // Amount of nb_oc_blocks
5293 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
5294 int nthr_ic_b = nstl::min(
5295 nthr_par / nthr_oc_b, (j.nb_ic / j.nb_ic_blocking));
5296
5297 float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
5298 if (mem_cost <= best_mem_cost) {
5299 best_mem_cost = mem_cost;
5300 nthr_mb_ = nthr_mb;
5301 nthr_oc_b_ = nthr_oc_b;
5302 nthr_ic_b_ = nthr_ic_b;
5303 }
5304 }
5305 }
5306
5307 if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
5308 nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
5309 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5310
5311 assert(nthr_ <= max_threads);
5312 }
5313
5314 } // namespace x64
5315 } // namespace cpu
5316 } // namespace impl
5317 } // namespace dnnl
5318
5319 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
5320