1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/c_types_map.hpp"
18 #include "common/nstl.hpp"
19 #include "common/type_helpers.hpp"
20 #include "common/utils.hpp"
21 #include "cpu/x64/jit_generator.hpp"
22
23 #include "cpu/x64/jit_brgemm_transpose_utils.hpp"
24
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29
30 using namespace dnnl::impl::format_tag;
31 using namespace dnnl::impl::utils;
32 using namespace Xbyak;
33
34 #define GET_OFF(x) offsetof(ctx_t, x)
35
36 struct jit_brgemm_trans_m_k_f32_t : public jit_brgemm_trans_src_t,
37 public jit_generator {
38 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_f32_t)
39
jit_brgemm_trans_m_k_f32_tdnnl::impl::cpu::x64::jit_brgemm_trans_m_k_f32_t40 jit_brgemm_trans_m_k_f32_t(const jit_brgemm_primitive_conf_t *conf)
41 : jit_brgemm_trans_src_t(conf) {}
42
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_m_k_f32_t43 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_m_k_f32_t44 status_t create_kernel() override { return jit_generator::create_kernel(); }
45
46 private:
47 using reg64_t = const Xbyak::Reg64;
48 using reg32_t = const Xbyak::Reg32;
49 using opmask_t = const Xbyak::Opmask;
50
51 enum { typesize = sizeof(float), transpose_size = 16 };
52 dim_t src_stride = 0, tr_src_stride = 0;
53
54 opmask_t k3333 = k1;
55 opmask_t k5555 = k2;
56 opmask_t kAAAA = k3;
57 opmask_t kCCCC = k4;
58 opmask_t k0F0F = k5;
59 opmask_t kF0F0 = k6;
60 opmask_t kTail = k7;
61
62 reg64_t reg_src_base = rax;
63 reg64_t reg_tr_src_base = rbx;
64
65 reg64_t reg_src = r8;
66 reg64_t reg_tr_src = r9;
67 reg64_t reg_loop_K = r10;
68 reg64_t reg_loop_M = r11;
69 reg64_t reg_loop_batch = r12;
70 reg64_t reg_tr_src_tmp = r13;
71 reg32_t regw_tmp = r14d;
72
73 void transpose_16x16(int nrows, int ncolumns = transpose_size);
74 void generate() override;
75 };
76
transpose_16x16(int nrows,int ncolumns)77 void jit_brgemm_trans_m_k_f32_t::transpose_16x16(int nrows, int ncolumns) {
78 assert(nrows >= 0 && nrows <= transpose_size);
79 static_assert(transpose_size == 16, "Unsupported transpose size");
80 if (!nrows) return;
81
82 auto src_zmm = [=](int i) {
83 assert(i >= 0 && i < 16);
84 return Zmm(i);
85 };
86
87 auto tmp_zmm = [=](int i) {
88 assert(i >= 0 && i < 16);
89 return Zmm(16 + i);
90 };
91
92 auto kmovw = [=](Opmask k, unsigned w) {
93 mov(regw_tmp, w);
94 jit_generator::kmovw(k, regw_tmp);
95 };
96
97 auto load = [=](int i) {
98 auto src_load = src_zmm(i);
99 if (i >= nrows) {
100 vpxord(src_load, src_load, src_load);
101 return;
102 }
103
104 if (ncolumns < transpose_size) {
105 kmovw(kTail, (1 << ncolumns) - 1);
106 src_load = src_zmm(i) | kTail | T_z;
107 }
108 vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
109 };
110
111 auto store = [=](Zmm r, int i) {
112 mov(reg_tr_src_tmp, reg_tr_src);
113 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
114
115 // Xbyak does not allow k0 to be specified explicitly via the '|'
116 // operator, so we have to do this via a method call (implicitly
117 // EVEX encoding uses k0 to mean 'no mask')
118 bool partial_store = nrows < transpose_size;
119 auto k = partial_store ? kTail : k0;
120 auto base = reg_tr_src_tmp;
121 base.setOpmaskIdx(k.getIdx(), true);
122
123 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
124 vmovups(addr, r);
125 };
126
127 auto transpose16x8 = [=](int base_idx) {
128 assert(base_idx == 0 || base_idx == 8);
129
130 // swap 1
131 for (int i = 0; i < 4; i++) {
132 int src_idx0 = base_idx + i * 2;
133 int src_idx1 = src_idx0 + 1;
134
135 int next_src_idx0 = src_idx0 + 2;
136 int next_src_idx1 = src_idx1 + 2;
137 bool load_next = base_idx == 0 || i < 3;
138
139 if (base_idx == 0 && i == 0) {
140 load(src_idx0);
141 if (src_idx1 < nrows)
142 load(src_idx1);
143 else
144 vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
145 src_zmm(src_idx1));
146 }
147
148 auto tmp0 = tmp_zmm(src_idx0);
149 auto tmp1 = tmp_zmm(src_idx1);
150 auto src0 = src_zmm(src_idx0);
151 auto src1 = src_zmm(src_idx1);
152
153 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
154 valignd(tmp0, src0, src0, 0x1);
155
156 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
157 valignd(tmp1, src1, src1, 0xf);
158
159 vmovaps(src0 | kAAAA, tmp1);
160 vmovaps(src1 | k5555, tmp0);
161 }
162 // swap 2
163 for (int i = 0; i < 4; i++) {
164 int select_half = (i < 2) ? 0 : 2;
165 int src_idx0 = base_idx + i + select_half + 0;
166 int src_idx2 = src_idx0 + 2;
167
168 auto tmp0 = tmp_zmm(src_idx0);
169 auto tmp1 = tmp_zmm(src_idx2);
170 auto src0 = src_zmm(src_idx0);
171 auto src2 = src_zmm(src_idx2);
172
173 valignd(tmp0, src0, src0, 0x2);
174 valignd(tmp1, src2, src2, 0xe);
175 vmovaps(src2 | k3333, tmp0);
176 vmovaps(src0 | kCCCC, tmp1);
177 }
178
179 // swap 4
180 for (int i = 0; i < 4; i++) {
181 int src_idx0 = base_idx + i;
182 int src_idx4 = src_idx0 + 4;
183
184 auto tmp0 = tmp_zmm(src_idx0);
185 auto src0 = src_zmm(src_idx0);
186 auto src4 = src_zmm(src_idx4);
187
188 vmovaps(tmp0, src0);
189 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
190 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
191 }
192 };
193
194 auto fixup16x16 = [=]() {
195 // swap 8
196 for (int i = 0; i < 8; i++) {
197 auto tmp = tmp_zmm(i);
198 auto src0 = src_zmm(i);
199 auto src8 = src_zmm(8 + i);
200 vshuff64x2(tmp, src0, src8, 0x44);
201 store(tmp, i);
202 }
203
204 for (int i = 0; i < 8; i++) {
205 auto tmp = tmp_zmm(8 + i);
206 auto src0 = src_zmm(i);
207 auto src8 = src_zmm(8 + i);
208 vshuff64x2(tmp, src0, src8, 0xee);
209 store(tmp, 8 + i);
210 }
211 };
212
213 transpose16x8(0);
214 transpose16x8(8);
215 fixup16x16();
216 }
217
generate()218 void jit_brgemm_trans_m_k_f32_t::generate() {
219 preamble();
220 assert(conf_->ic_block % transpose_size == 0);
221 int os_block = conf_->os_block;
222 int last_os_block_tail = conf_->K_tail % transpose_size;
223 int ic_tail = conf_->ic % transpose_size;
224 src_stride = conf_->ic * typesize;
225 tr_src_stride = conf_->LDA * typesize;
226 dim_t m_src_shift = transpose_size * typesize;
227 dim_t m_tr_src_shift = tr_src_stride * transpose_size;
228
229 dim_t batch_src_shift = src_stride * os_block;
230 dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
231
232 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
233 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
234 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
235 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
236
237 auto kmovw = [=](Opmask k, unsigned w) {
238 mov(regw_tmp, w);
239 jit_generator::kmovw(k, regw_tmp);
240 };
241
242 kmovw(k3333, 0x3333); // 0011001100110011
243 kmovw(k5555, 0x5555); // 0101010101010101
244 kmovw(kAAAA, 0xaaaa); // 1010101010101010
245 kmovw(kCCCC, 0xcccc); // 1100110011001100
246 kmovw(k0F0F, 0x0f0f); // 0000111100001111
247 kmovw(kF0F0, 0xf0f0); // 1111000011110000
248
249 auto compute_M = [=](bool is_os_tail) {
250 auto nrows = is_os_tail ? last_os_block_tail : transpose_size;
251 mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
252 mov(reg_src, reg_src_base);
253 mov(reg_tr_src, reg_tr_src_base);
254 Label M_loop, M_tail_or_done, M_done;
255 if (ic_tail > 0) {
256 cmp(reg_loop_M, transpose_size);
257 jl(M_tail_or_done, T_NEAR);
258 }
259
260 L(M_loop);
261 transpose_16x16(nrows, transpose_size);
262 if (conf_->ic_block > transpose_size) {
263 add(reg_src, m_src_shift);
264 add(reg_tr_src, m_tr_src_shift);
265 sub(reg_loop_M, transpose_size);
266 cmp(reg_loop_M, transpose_size);
267 jge(M_loop, T_NEAR);
268 } else {
269 jmp(M_done, T_NEAR);
270 }
271
272 L(M_tail_or_done);
273 if (ic_tail > 0) {
274 cmp(reg_loop_M, 0);
275 jle(M_done, T_NEAR);
276
277 transpose_16x16(nrows, ic_tail);
278 }
279 L(M_done);
280 };
281
282 auto compute_batch = [=](bool is_os_tail) {
283 Label batch_loop;
284 L(batch_loop);
285
286 compute_M(is_os_tail);
287 add(reg_src_base, batch_src_shift);
288 add(reg_tr_src_base, batch_tr_src_shift);
289
290 sub(reg_loop_batch, 1);
291 jnz(batch_loop, T_NEAR);
292 };
293
294 Label K_tail;
295 if (last_os_block_tail > 0) {
296 cmp(reg_loop_K, transpose_size);
297 jl(K_tail, T_NEAR);
298 }
299
300 compute_batch(false);
301
302 if (last_os_block_tail > 0) {
303 Label K_done;
304 jmp(K_done, T_NEAR);
305
306 L(K_tail);
307 compute_batch(true);
308 L(K_done);
309 }
310
311 postamble();
312 }
313
314 struct jit_brgemm_trans_m_k_bf16_t : public jit_brgemm_trans_src_t,
315 public jit_generator {
316 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_bf16_t)
jit_brgemm_trans_m_k_bf16_tdnnl::impl::cpu::x64::jit_brgemm_trans_m_k_bf16_t317 jit_brgemm_trans_m_k_bf16_t(const jit_brgemm_primitive_conf_t *conf)
318 : jit_brgemm_trans_src_t(conf) {}
319
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_m_k_bf16_t320 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_m_k_bf16_t321 status_t create_kernel() override { return jit_generator::create_kernel(); }
322
323 private:
324 using reg64_t = const Xbyak::Reg64;
325 using reg32_t = const Xbyak::Reg32;
326 using opmask_t = const Xbyak::Opmask;
327
328 enum {
329 typesize = sizeof(int16_t),
330 transpose_size = 16,
331 };
332 dim_t src_stride = 0, tr_src_stride = 0;
333
334 opmask_t kFFFF = k1;
335 opmask_t k5555 = k2;
336 opmask_t kAAAA = k3;
337 opmask_t kAA = k4;
338 opmask_t k55 = k5;
339 opmask_t kCC = k6;
340 opmask_t k33 = k7;
341 opmask_t kTail = k1;
342
343 reg32_t regw_tmp = r15d;
344
345 reg64_t reg_k_src = r14;
346 reg64_t reg_k_tr_src = r13;
347
348 reg64_t reg_m_src = r12;
349 reg64_t reg_m_tr_src = r11;
350
351 reg64_t reg_batch_src = r10;
352 reg64_t reg_batch_tr_src = r9;
353
354 reg64_t reg_loop_batch = r8;
355 reg64_t reg_loop_K = rax;
356 reg64_t reg_loop_M = rbx;
357
358 reg64_t reg_tr_src_tmp = abi_not_param1; // lnx -> rcx
359 reg64_t imm_addr64 = rdx;
360
361 Xbyak::Zmm vidx1 = zmm31;
362 Xbyak::Zmm vidx2 = zmm30;
363 Xbyak::Zmm vidx3 = zmm29;
364 Xbyak::Zmm vidx4 = zmm28;
365 Xbyak::Zmm vidx5 = zmm27;
366 Xbyak::Zmm zmm_tmp = zmm26;
367
368 void transpose(
369 reg64_t dst, reg64_t src, int nrows, int ncolumns = transpose_size);
370 void generate() override;
371 };
372
transpose(reg64_t dst,reg64_t src,int nrows,int ncolumns)373 void jit_brgemm_trans_m_k_bf16_t::transpose(
374 reg64_t dst, reg64_t src, int nrows, int ncolumns) {
375 assert(nrows >= 0 && nrows <= transpose_size);
376 static_assert(transpose_size == 16, "Unsupported transpose size");
377 if (!nrows) return;
378
379 auto src_zmm = [=](int i) { return Zmm(i); };
380
381 auto src_ymm = [=](int i) {
382 assert(i >= 0 && i < 16);
383 return Ymm(i);
384 };
385
386 auto kmovw = [=](Opmask k, unsigned w) {
387 mov(regw_tmp, w);
388 jit_generator::kmovw(k, regw_tmp);
389 };
390
391 auto kmovd = [=](Opmask k, unsigned w) {
392 mov(regw_tmp, w);
393 jit_generator::kmovd(k, regw_tmp);
394 };
395
396 auto store = [=](Zmm r, int i) {
397 mov(reg_tr_src_tmp, dst);
398
399 auto k = kTail;
400 auto base = reg_tr_src_tmp;
401 base.setOpmaskIdx(k.getIdx(), true);
402
403 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
404 vmovups(addr, r);
405 };
406
407 const int ic_block = ncolumns;
408 kmovd(kFFFF, ic_block < transpose_size ? (1 << ic_block) - 1 : 0xffff);
409
410 for (int i = 0; i < nrows / 2; i++) {
411 auto zmm_src0 = src_zmm(2 * i);
412 auto zmm_src1 = src_zmm(2 * i + 1);
413 auto src1 = src_ymm(2 * i + 1);
414 vmovdqu16(zmm_src0 | kFFFF | T_z,
415 EVEX_compress_addr(src, 2 * i * src_stride));
416 vmovdqu16(zmm_src1 | kFFFF | T_z,
417 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
418 vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
419 vpermw(zmm_src0, vidx5, zmm_src0);
420 }
421
422 // for odd numbers we need to mix row with zeroes
423 if (nrows % 2) {
424 int i = nrows / 2;
425 auto zmm_src0 = src_zmm(2 * i);
426 vmovdqu16(zmm_src0 | kFFFF | T_z,
427 EVEX_compress_addr(src, 2 * i * src_stride));
428 vpermw(zmm_src0, vidx5, zmm_src0);
429 }
430
431 for (int i = rnd_up(nrows, 2); i < 16; i += 2) {
432 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
433 }
434
435 // swap 1
436 for (int i = 0; i < 4; i++) {
437 auto zmm0 = src_zmm(4 * i);
438 auto zmm1 = src_zmm(4 * i + 2);
439 auto tmp0 = src_zmm(4 * i + 1);
440 auto tmp1 = src_zmm(4 * i + 3);
441
442 vmovups(tmp0, zmm0);
443 vmovups(tmp1, zmm1);
444
445 vpermps(tmp0 | kAAAA, vidx3, zmm1);
446 vpermps(tmp1 | k5555, vidx3, zmm0);
447 }
448 // swap 2
449 int base_idx;
450 base_idx = 0;
451 for (int i = 0; i < 2; i++) {
452 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
453 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
454
455 auto tmp0 = src_zmm(base_idx + 2 * i);
456 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
457
458 vmovupd(tmp0, zmm0);
459 vmovupd(tmp1, zmm1);
460
461 vpermpd(tmp0 | kAA, vidx2, zmm1);
462 vpermpd(tmp1 | k55, vidx2, zmm0);
463 }
464 base_idx = 8;
465 for (int i = 0; i < 2; i++) {
466 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
467 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
468
469 auto tmp0 = src_zmm(base_idx + 2 * i);
470 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
471
472 vmovupd(tmp0, zmm0);
473 vmovupd(tmp1, zmm1);
474
475 vpermpd(tmp0 | kAA, vidx2, zmm1);
476 vpermpd(tmp1 | k55, vidx2, zmm0);
477 }
478
479 // swap 3
480 for (int i = 0; i < 4; i++) {
481 auto zmm0 = src_zmm(2 * i);
482 auto zmm1 = src_zmm(2 * i + 8);
483
484 auto tmp0 = src_zmm(2 * i + 1);
485 auto tmp1 = src_zmm(2 * i + 9);
486
487 vmovupd(tmp0, zmm0);
488 vmovupd(tmp1, zmm1);
489
490 vpermpd(tmp0 | kCC, vidx1, zmm1);
491 vpermpd(tmp1 | k33, vidx1, zmm0);
492 }
493
494 // all stores
495 for (int i = 0; i < 8; i++)
496 vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1);
497
498 auto get_vec_idx = [=](int ic_idx) {
499 assert(ic_idx < 16 && ic_idx >= 0);
500 switch (ic_idx) {
501 case 0: return 1;
502 case 1: return 0;
503 case 2: return 3;
504 case 3: return 2;
505 case 4: return 9;
506 case 5: return 8;
507 case 6: return 11;
508 case 7: return 10;
509 case 8: return 5;
510 case 9: return 4;
511 case 10: return 7;
512 case 11: return 6;
513 case 12: return 13;
514 case 13: return 12;
515 case 14: return 15;
516 default: return 14;
517 }
518 };
519
520 int store_tail = rnd_up(nrows, 2);
521 kmovw(kTail, (1 << store_tail / 2) - 1);
522
523 for (int ic = 0; ic < ic_block; ic++)
524 store(src_zmm(get_vec_idx(ic)), ic);
525 }
526
generate()527 void jit_brgemm_trans_m_k_bf16_t::generate() {
528 preamble();
529
530 alignas(64) static constexpr const int64_t idx1[8]
531 = {2, 3, 0, 1, 6, 7, 4, 5};
532 alignas(64) static constexpr const int64_t idx2[8]
533 = {1, 0, 3, 2, 5, 4, 7, 6};
534 alignas(64) static constexpr const int32_t idx3[16]
535 = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
536 alignas(64) static constexpr const int32_t idx4[16]
537 = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7};
538 alignas(64) static constexpr const uint16_t idx5[32]
539 = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17,
540 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
541
542 constexpr int amx_bf16_granularity = 2;
543 const bool last_row_padded = conf_->isa == avx512_core_bf16_amx_bf16
544 && conf_->os % amx_bf16_granularity != 0;
545 const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0);
546
547 const int os_block = conf_->os_block;
548 const int last_os_block_tail = eff_K_tail % transpose_size;
549 const int ic_tail = conf_->ic % transpose_size;
550 src_stride = conf_->ic * typesize;
551 tr_src_stride = conf_->LDA * typesize;
552
553 dim_t batch_src_shift = src_stride * os_block;
554 dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
555
556 dim_t M_src_shift = transpose_size * typesize;
557 dim_t M_tr_src_shift = transpose_size * conf_->LDA * typesize;
558
559 dim_t K_src_shift = transpose_size * conf_->ic * typesize;
560 dim_t K_tr_src_shift = transpose_size * typesize;
561
562 auto kmovw = [=](Opmask k, unsigned w) {
563 mov(regw_tmp, w);
564 jit_generator::kmovw(k, regw_tmp);
565 };
566
567 kmovw(kFFFF, 0xffff);
568 kmovw(k5555, 0x5555);
569 kmovw(kAAAA, 0xaaaa);
570 kmovw(kAA, 0xaa);
571 kmovw(k55, 0x55);
572 kmovw(kCC, 0xcc);
573 kmovw(k33, 0x33);
574
575 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
576 mov(imm_addr64, reinterpret_cast<size_t>(addr));
577 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
578 };
579
580 auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
581 mov(imm_addr64, reinterpret_cast<size_t>(addr));
582 jit_generator::vmovdqa32(z, ptr[imm_addr64]);
583 };
584
585 vmovdqa64(vidx1, idx1);
586 vmovdqa64(vidx2, idx2);
587 vmovdqa32(vidx3, idx3);
588 vmovdqa32(vidx4, idx4);
589 vmovdqa32(vidx5, (const int32_t *)idx5);
590
591 auto compute_m_loop = [&](reg64_t ®_base, reg64_t ®_tr_base,
592 bool is_os_tail) {
593 mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
594 mov(reg_m_src, reg_base);
595 mov(reg_m_tr_src, reg_tr_base);
596
597 Label M_loop_tail, M_loop;
598 if (ic_tail > 0) {
599 cmp(reg_loop_M, transpose_size);
600 jl(M_loop_tail, T_NEAR);
601 }
602 L(M_loop);
603 {
604 transpose(reg_m_tr_src, reg_m_src,
605 is_os_tail ? last_os_block_tail : transpose_size,
606 transpose_size);
607 add(reg_m_src, M_src_shift);
608 add(reg_m_tr_src, M_tr_src_shift);
609 }
610 sub(reg_loop_M, transpose_size);
611 cmp(reg_loop_M, transpose_size);
612 jge(M_loop, T_NEAR);
613
614 if (ic_tail > 0) {
615 Label M_loop_done;
616 L(M_loop_tail);
617 cmp(reg_loop_M, 0);
618 jle(M_loop_done, T_NEAR);
619
620 transpose(reg_m_tr_src, reg_m_src,
621 is_os_tail ? last_os_block_tail : transpose_size, ic_tail);
622 L(M_loop_done);
623 }
624 };
625
626 auto compute_k_loop = [&](reg64_t ®_base, reg64_t ®_tr_base) {
627 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
628 mov(reg_k_src, reg_base);
629 mov(reg_k_tr_src, reg_tr_base);
630
631 Label K_tail, K_loop, K_done;
632 if (last_os_block_tail > 0) {
633 cmp(reg_loop_K, transpose_size);
634 jl(K_tail, T_NEAR);
635 }
636 L(K_loop);
637 {
638 compute_m_loop(reg_k_src, reg_k_tr_src, false);
639 add(reg_k_src, K_src_shift);
640 add(reg_k_tr_src, K_tr_src_shift);
641 }
642 sub(reg_loop_K, transpose_size);
643 cmp(reg_loop_K, transpose_size);
644 jge(K_loop, T_NEAR);
645
646 cmp(reg_loop_K, 0);
647 je(K_done, T_NEAR);
648
649 if (last_os_block_tail > 0) {
650 L(K_tail);
651 compute_m_loop(reg_k_src, reg_k_tr_src, true);
652 }
653 L(K_done);
654 };
655
656 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
657 mov(reg_batch_src, ptr[param1 + GET_OFF(src)]);
658 mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]);
659
660 Label batch_loop;
661 L(batch_loop);
662 {
663 compute_k_loop(reg_batch_src, reg_batch_tr_src);
664
665 add(reg_batch_src, batch_src_shift);
666 add(reg_batch_tr_src, batch_tr_src_shift);
667 }
668 sub(reg_loop_batch, 1);
669 jnz(batch_loop, T_NEAR);
670
671 postamble();
672 }
673
copy_row_blk_loop(int copy_row_iters)674 void jit_brgemm_copy_to_coarse_t::copy_row_blk_loop(int copy_row_iters) {
675 int row_blks = div_up(copy_row_iters, row_loop_unroll);
676
677 for (int row_b = 0; row_b < row_blks; ++row_b) {
678 const int row_start = 0;
679 const int row_end = nstl::min(static_cast<int>(row_loop_unroll),
680 copy_row_iters - row_b * static_cast<int>(row_loop_unroll));
681
682 for (int row = row_start; row < row_end; ++row) {
683 const int row_idx = row_b * row_loop_unroll + row;
684 const auto offset = addr_offset(row_idx);
685
686 const auto zmm = get_zmm_copy(row);
687 const auto addr = EVEX_compress_addr(reg_data, offset);
688 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
689
690 vmovdqu8(zmm, addr);
691 vmovdqu8(addr_tr, zmm);
692 }
693 }
694 }
695
copy_row_tail(int row_offset)696 void jit_brgemm_copy_to_coarse_t::copy_row_tail(int row_offset) {
697 // Mask for row tail load and store are already set up
698 const auto zmm_data = zmm_tail | reg_m_row_tail_load | T_z;
699 const auto zmm_tr_data = zmm_tail | reg_m_row_tail_store;
700
701 const auto offset = addr_offset(row_offset);
702 const auto addr = EVEX_compress_addr(reg_data, offset);
703 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
704
705 vmovdqu8(zmm_data, addr);
706 vmovdqu8(addr_tr, zmm_tr_data);
707 }
708
copy_row_loop()709 void jit_brgemm_copy_to_coarse_t::copy_row_loop() {
710 Xbyak::Label label_row_tail, label_row_exit;
711
712 // Note: copying is done in chunks of size row_step_
713 const auto copy_row = [&](bool is_last_blk) {
714 const int row_blk
715 = is_last_blk ? (row_size_ % tr_row_size_) : tr_row_size_;
716 const int row_iters = row_blk / row_step_;
717 const int row_iters_tail = row_blk % row_step_;
718
719 copy_row_blk_loop(row_iters);
720 if (row_iters_tail != 0) copy_row_tail(/* row_offset = */ row_iters);
721 };
722
723 const bool only_row_tail = row_size_ < tr_row_size_;
724
725 if (!only_row_tail) {
726 cmp(reg_last_row_blk, 0);
727 jne(label_row_tail, T_NEAR);
728
729 copy_row(/* is_last_blk = */ false);
730 jmp(label_row_exit, T_NEAR);
731 }
732
733 L(label_row_tail);
734 copy_row(/* is_last_blk = */ true);
735
736 L(label_row_exit);
737 }
738
copy_os_loop()739 void jit_brgemm_copy_to_coarse_t::copy_os_loop() {
740
741 Label loop_os;
742 L(loop_os);
743
744 copy_row_loop();
745 add(reg_data, data_stride_);
746 add(reg_tr_data, tr_data_stride_);
747
748 dec(reg_os_work);
749 jnz(loop_os, T_NEAR);
750 }
751
set_tail_mask()752 void jit_brgemm_copy_to_coarse_t::set_tail_mask() {
753 const int row_tail = row_size_ % row_step_;
754 assert(row_tail > 0 && "kernel is meant to be used with tail processing");
755
756 // Set load mask
757 const size_t tail_mask_load
758 = (static_cast<size_t>(1) << (typesize_ * row_tail)) - 1;
759 mov(reg_tail_mask, tail_mask_load);
760 kmovq(reg_m_row_tail_load, reg_tail_mask);
761
762 // Caution: Since size of ZMM equals 64 bytes therefore we need
763 // different masks to store tails with smaller row_block_size_
764 constexpr auto full_mask = size_t {0xffffffffffffffff};
765 constexpr auto half_mask = size_t {0x00000000ffffffff};
766 constexpr auto quad_mask = size_t {0x000000000000ffff};
767
768 const auto num_bytes = [](size_t mask) -> int {
769 // Given by 1 + position of leftmost 1 bit
770 return 1 + math::ilog2q(mask);
771 };
772
773 const int row_tail_store_size
774 = utils::rnd_up(row_tail, row_block_size_) * typesize_;
775 if (row_tail_store_size >= num_bytes(full_mask))
776 mov(reg_tail_mask, full_mask);
777 else if (row_tail_store_size >= num_bytes(half_mask))
778 mov(reg_tail_mask, half_mask);
779 else {
780 assert(row_tail_store_size == num_bytes(quad_mask));
781 mov(reg_tail_mask, quad_mask);
782 }
783 kmovq(reg_m_row_tail_store, reg_tail_mask);
784 }
785
generate()786 void jit_brgemm_copy_to_coarse_t::generate() {
787 preamble();
788
789 set_tail_mask();
790 mov(reg_data, ptr[param1 + GET_OFF(data)]);
791 mov(reg_tr_data, ptr[param1 + GET_OFF(tr_data)]);
792 mov(reg_os_work, ptr[param1 + GET_OFF(os_work)]);
793 mov(reg_last_row_blk, ptr[param1 + GET_OFF(last_row_blk)]);
794
795 copy_os_loop();
796
797 postamble();
798 }
799
800 struct jit_trans_to_vnni_t : public jit_brgemm_trans_to_vnni_t,
801 public jit_generator {
802 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_to_vnni_t)
jit_trans_to_vnni_tdnnl::impl::cpu::x64::jit_trans_to_vnni_t803 jit_trans_to_vnni_t(const jit_brgemm_primitive_conf_t *conf,
804 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
805 matrix_to_transform)
806 : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) {}
807
operator ()dnnl::impl::cpu::x64::jit_trans_to_vnni_t808 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_trans_to_vnni_t809 status_t create_kernel() override { return jit_generator::create_kernel(); }
810
811 private:
812 using reg64_t = const Xbyak::Reg64;
813 using reg32_t = const Xbyak::Reg32;
814 using opmask_t = const Xbyak::Opmask;
815 using zmm = const Xbyak::Zmm;
816
817 enum {
818 typesize_data = sizeof(int16_t),
819 typesize_acc = sizeof(float),
820 transpose_size = 16,
821 };
822
823 int last_row_block_tail = 0, col_tail = 0;
824 dim_t src_stride = 0, tr_src_stride = 0;
825 dim_t src_col_shift = 0, tr_src_col_shift = 0;
826 dim_t src_row_shift = 0, tr_src_row_shift = 0;
827 dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
828
829 opmask_t kFFFF = k1;
830 opmask_t mask_tail = k2;
831
832 zmm vidx1 = zmm31;
833
834 reg32_t regw_tmp = r15d;
835
836 reg64_t reg_batch_src = r14;
837 reg64_t reg_batch_tr_src = r13;
838
839 reg64_t reg_row_src = r12;
840 reg64_t reg_row_tr_src = r11;
841
842 reg64_t reg_col_src = r10;
843 reg64_t reg_col_tr_src = r9;
844
845 reg64_t reg_loop_batch = r8;
846 reg64_t reg_loop_row = rax;
847 reg64_t reg_loop_col = rbx;
848
849 reg64_t imm_addr64 = abi_not_param1; // lnx -> rcx
850
851 void maybe_zero_pad_col(reg64_t dst);
852 void transpose(reg64_t dst, reg64_t src, int nrows,
853 int ncolumns = transpose_size, bool pad_by_zeroes = false);
854 void generate() override;
855 };
856
maybe_zero_pad_col(reg64_t dst)857 void jit_trans_to_vnni_t::maybe_zero_pad_col(reg64_t dst) {
858 auto zmm_zero = Xbyak::Zmm(0);
859 vpxord(zmm_zero, zmm_zero, zmm_zero);
860 const int oc_utilized = rnd_up(conf_->oc % conf_->oc_block, transpose_size);
861 const int iters = (conf_->oc_block - oc_utilized) / transpose_size;
862 for (int n = 0; n < iters; ++n) {
863 for (int i = 0; i < transpose_size; i += 2) {
864 auto addr = EVEX_compress_addr(dst, i * tr_src_stride);
865 vmovups(addr, zmm_zero);
866 }
867 add(reg_col_tr_src, tr_src_col_shift);
868 }
869 }
870
transpose(reg64_t dst,reg64_t src,int nrows,int ncolumns,bool pad_by_zeroes)871 void jit_trans_to_vnni_t::transpose(
872 reg64_t dst, reg64_t src, int nrows, int ncolumns, bool pad_by_zeroes) {
873 assert(nrows >= 0 && nrows <= transpose_size);
874 static_assert(transpose_size == 16, "Unsupported transpose size");
875 if (!nrows) return;
876
877 auto src_zmm = [=](int i) { return Zmm(i); };
878
879 auto src_ymm = [=](int i) {
880 assert(i >= 0 && i < 16);
881 return Ymm(i);
882 };
883
884 auto store = [=](Zmm r, int i) {
885 auto addr = EVEX_compress_addr(dst, i * tr_src_stride);
886 vmovups(addr, r);
887 };
888 auto mask = ncolumns == transpose_size ? kFFFF : mask_tail;
889
890 int i = 0;
891 for (; i < nrows / 2; i++) {
892 auto src1 = src_ymm(2 * i + 1);
893 auto zmm_src0 = src_zmm(2 * i);
894 auto zmm_src1 = src_zmm(2 * i + 1);
895 if (matrix_to_transform_ == matrix_B) {
896 vmovdqu16(zmm_src0 | mask | T_z,
897 EVEX_compress_addr(src, 2 * i * src_stride));
898 vmovdqu16(zmm_src1 | mask | T_z,
899 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
900 vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
901 } else {
902 vmovups(zmm_src0 | mask | T_z,
903 EVEX_compress_addr(src, 2 * i * src_stride));
904 vmovups(zmm_src1 | mask | T_z,
905 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
906 vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0);
907 }
908 vpermw(zmm_src0, vidx1, zmm_src0);
909 store(zmm_src0, 2 * i);
910 }
911
912 if (nrows % 2) {
913 auto zmm_src0 = src_zmm(2 * i);
914 if (matrix_to_transform_ == matrix_B) {
915 vmovdqu16(zmm_src0 | mask | T_z,
916 EVEX_compress_addr(src, 2 * i * src_stride));
917 } else {
918 auto zmm_zero = src_zmm(2 * i + 1);
919 vmovups(zmm_src0 | mask | T_z,
920 EVEX_compress_addr(src, 2 * i * src_stride));
921 vpxord(zmm_zero, zmm_zero, zmm_zero);
922 vcvtne2ps2bf16(zmm_src0, zmm_zero, zmm_src0);
923 }
924 vpermw(zmm_src0, vidx1, zmm_src0);
925 store(zmm_src0, 2 * i);
926 i++;
927 }
928
929 if (pad_by_zeroes && i < transpose_size / 2) {
930 auto zmm_zero = src_zmm(2 * i);
931 vpxord(zmm_zero, zmm_zero, zmm_zero);
932 for (; i < transpose_size / 2; i++)
933 store(zmm_zero, 2 * i);
934 }
935 }
936
generate()937 void jit_trans_to_vnni_t::generate() {
938 preamble();
939
940 alignas(64) static constexpr const int16_t idx1[32]
941 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
942 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
943
944 if (matrix_to_transform_ == matrix_B) {
945 int row_block = conf_->os_block;
946
947 constexpr int amx_bf16_granularity = 2;
948 const bool last_row_padded = conf_->isa == avx512_core_bf16_amx_bf16
949 && conf_->os % amx_bf16_granularity != 0;
950 const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0);
951
952 last_row_block_tail = eff_K_tail % transpose_size;
953 col_tail = conf_->oc % transpose_size;
954 src_stride = conf_->oc * typesize_data;
955 tr_src_stride = conf_->LDB * typesize_data;
956
957 src_batch_shift = src_stride * row_block;
958 tr_src_batch_shift = tr_src_stride * rnd_up(conf_->K, 2);
959
960 src_col_shift = transpose_size * typesize_data;
961 tr_src_col_shift = 2 * transpose_size * typesize_data;
962
963 src_row_shift = transpose_size * conf_->oc * typesize_data;
964 tr_src_row_shift = transpose_size * conf_->LDB * typesize_data;
965
966 } else { // matrix_to_transform_ == matrix_C
967 int row_block = conf_->ic_block;
968 last_row_block_tail = conf_->M_tail % transpose_size;
969 assert(row_block == transpose_size);
970 col_tail = conf_->oc % transpose_size;
971 src_stride = conf_->LDC * typesize_acc;
972 tr_src_stride = conf_->LDD * typesize_data;
973
974 src_batch_shift = src_stride * row_block;
975 tr_src_batch_shift = tr_src_stride * rnd_up(conf_->M, 2);
976
977 src_col_shift = transpose_size * typesize_acc;
978 tr_src_col_shift = 2 * transpose_size * typesize_data;
979 }
980
981 // mov(reg_src, ptr[param1 + GET_OFF(src)]);
982 // mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
983 // mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
984
985 auto kmovw = [=](Opmask k, unsigned w) {
986 mov(regw_tmp, w);
987 jit_generator::kmovw(k, regw_tmp);
988 };
989 auto kmovd = [=](Opmask k, unsigned w) {
990 mov(regw_tmp, w);
991 jit_generator::kmovd(k, regw_tmp);
992 };
993
994 kmovw(kFFFF, 0xffff); // 1111111111111111
995 kmovd(mask_tail, (1 << col_tail) - 1);
996
997 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
998 mov(imm_addr64, reinterpret_cast<size_t>(addr));
999 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1000 };
1001
1002 vmovdqa64(vidx1, (const int64_t *)idx1);
1003
1004 auto compute_col_loop = [&](reg64_t ®_base, reg64_t ®_tr_base,
1005 bool is_row_tail) {
1006 const bool pad_by_zeroes = matrix_to_transform_ == matrix_C;
1007 int nrows = is_row_tail ? last_row_block_tail : transpose_size;
1008
1009 mov(reg_col_src, reg_base);
1010 mov(reg_col_tr_src, reg_tr_base);
1011 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1012
1013 Label col_loop, col_loop_tail;
1014 cmp(reg_loop_col, transpose_size);
1015 jl(col_loop_tail, T_NEAR);
1016
1017 L(col_loop);
1018 {
1019 transpose(reg_col_tr_src, reg_col_src, nrows, transpose_size,
1020 pad_by_zeroes);
1021 add(reg_col_src, src_col_shift);
1022 add(reg_col_tr_src, tr_src_col_shift);
1023 }
1024 sub(reg_loop_col, transpose_size);
1025 cmp(reg_loop_col, transpose_size);
1026 jge(col_loop, T_NEAR);
1027
1028 L(col_loop_tail);
1029 if (col_tail > 0) {
1030 Label col_loop_done;
1031 cmp(reg_loop_col, 0);
1032 jle(col_loop_done, T_NEAR);
1033 transpose(reg_col_tr_src, reg_col_src, nrows, col_tail,
1034 pad_by_zeroes);
1035 L(col_loop_done);
1036 }
1037 const int oc_block_tail = conf_->oc % conf_->oc_block;
1038 const bool full_oc_block_utilized = oc_block_tail == 0
1039 || rnd_up(oc_block_tail, transpose_size) == conf_->oc_block;
1040 const bool col_pad_required = pad_by_zeroes && !full_oc_block_utilized;
1041
1042 if (col_pad_required) {
1043 Label col_pad_done;
1044 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1045 cmp(reg_loop_col, conf_->oc_block);
1046 je(col_pad_done, T_NEAR);
1047 if (col_tail > 0) add(reg_col_tr_src, tr_src_col_shift);
1048 maybe_zero_pad_col(reg_col_tr_src);
1049 L(col_pad_done);
1050 }
1051 };
1052
1053 auto compute_row_loop = [&](reg64_t ®_base, reg64_t ®_tr_base) {
1054 mov(reg_row_src, reg_base);
1055 mov(reg_row_tr_src, reg_tr_base);
1056 mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1057
1058 Label row_tail, row_loop, row_done;
1059 if (last_row_block_tail > 0) {
1060 cmp(reg_loop_row, transpose_size);
1061 jl(row_tail, T_NEAR);
1062 }
1063 L(row_loop);
1064 {
1065 compute_col_loop(reg_row_src, reg_row_tr_src, false);
1066
1067 add(reg_row_src, src_row_shift);
1068 add(reg_row_tr_src, tr_src_row_shift);
1069 }
1070 sub(reg_loop_row, transpose_size);
1071 cmp(reg_loop_row, transpose_size);
1072 jge(row_loop, T_NEAR);
1073
1074 cmp(reg_loop_row, 0);
1075 je(row_done, T_NEAR);
1076
1077 if (last_row_block_tail > 0) {
1078 L(row_tail);
1079 compute_col_loop(reg_row_src, reg_row_tr_src, true);
1080 }
1081 L(row_done);
1082 };
1083
1084 mov(reg_batch_src, ptr[param1 + GET_OFF(src)]);
1085 mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1086 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1087
1088 Label batch_loop;
1089 L(batch_loop);
1090 {
1091 compute_row_loop(reg_batch_src, reg_batch_tr_src);
1092
1093 add(reg_batch_src, src_batch_shift);
1094 add(reg_batch_tr_src, tr_src_batch_shift);
1095 }
1096 sub(reg_loop_batch, 1);
1097 jnz(batch_loop, T_NEAR);
1098
1099 postamble();
1100 }
1101
1102 struct jit_copy_f32_t : public jit_brgemm_trans_to_vnni_t,
1103 public jit_generator {
1104 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_copy_f32_t)
jit_copy_f32_tdnnl::impl::cpu::x64::jit_copy_f32_t1105 jit_copy_f32_t(const jit_brgemm_primitive_conf_t *conf,
1106 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
1107 matrix_to_transform)
1108 : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) {}
1109
operator ()dnnl::impl::cpu::x64::jit_copy_f32_t1110 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_copy_f32_t1111 status_t create_kernel() override { return jit_generator::create_kernel(); }
1112
1113 private:
1114 using reg64_t = const Xbyak::Reg64;
1115 using reg32_t = const Xbyak::Reg32;
1116 using opmask_t = const Xbyak::Opmask;
1117 using zmm = const Xbyak::Zmm;
1118
1119 enum {
1120 typesize_data = sizeof(float),
1121 column_step = 16,
1122 num_regs = 32,
1123 };
1124
1125 dim_t src_stride = 0, tr_src_stride = 0;
1126 dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
1127 dim_t col_shift = column_step * typesize_data;
1128
1129 opmask_t mask_tail = k2;
1130
1131 reg64_t reg_src = r8;
1132 reg64_t reg_tr_src = r9;
1133 reg64_t reg_loop_batch = r10;
1134 reg64_t reg_loop_row = r11;
1135 reg64_t reg_loop_col = r12;
1136 reg32_t regw_tmp = r14d;
1137 reg64_t reg_long_offt = r15;
1138
1139 void copy_block(int nrows, int ncolumns);
1140 void generate() override;
1141 };
1142
copy_block(int nrows,int ncolumns)1143 void jit_copy_f32_t::copy_block(int nrows, int ncolumns) {
1144
1145 auto kmovd = [=](Opmask k, unsigned w) {
1146 mov(regw_tmp, w);
1147 jit_generator::kmovd(k, regw_tmp);
1148 };
1149
1150 const int nc_tail = ncolumns % column_step;
1151 if (nc_tail > 0) kmovd(mask_tail, (1 << nc_tail) - 1);
1152
1153 auto get_zmm = [=](int i) { return Zmm(i % num_regs); };
1154
1155 auto load = [=](int r, int cb) {
1156 auto src_reg = get_zmm(r * cb);
1157 const bool is_tail
1158 = nc_tail > 0 && ncolumns - cb * column_step < column_step;
1159 auto src_load = is_tail ? src_reg | mask_tail | T_z : src_reg;
1160 const dim_t offset = r * src_stride + cb * col_shift;
1161 auto addr = EVEX_compress_addr_safe(reg_src, offset, reg_long_offt);
1162 vmovups(src_load, addr);
1163 };
1164
1165 auto store = [=](int r, int cb) {
1166 auto reg = get_zmm(r * cb);
1167 const dim_t offset = r * tr_src_stride + cb * col_shift;
1168 auto addr = EVEX_compress_addr_safe(reg_tr_src, offset, reg_long_offt);
1169 vmovups(addr, reg);
1170 };
1171
1172 for_(int r = 0; r < nrows; r++)
1173 for (int cb = 0; cb < div_up(ncolumns, column_step); cb++) {
1174 load(r, cb);
1175 store(r, cb);
1176 }
1177 }
1178
generate()1179 void jit_copy_f32_t::generate() {
1180 preamble();
1181
1182 const int row_block = conf_->os_block;
1183 const int row_tail = conf_->os % row_block;
1184 const int col_block = conf_->oc_block * conf_->nb_oc_blocking;
1185 const int col_tail = conf_->oc % col_block;
1186 src_stride = conf_->oc * typesize_data;
1187 tr_src_stride = conf_->LDB * typesize_data;
1188 src_batch_shift = src_stride * row_block;
1189 tr_src_batch_shift = tr_src_stride * row_block;
1190
1191 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1192 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1193 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1194 mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1195 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1196
1197 auto compute_batch = [=](int nrows, int ncolumns) {
1198 Label batch_loop;
1199 L(batch_loop);
1200
1201 copy_block(nrows, ncolumns);
1202 add(reg_src, src_batch_shift);
1203 add(reg_tr_src, tr_src_batch_shift);
1204
1205 sub(reg_loop_batch, 1);
1206 jnz(batch_loop, T_NEAR);
1207 };
1208
1209 auto compute_rows = [=](int ncolumns) {
1210 Label row_done;
1211 if (row_tail > 0) {
1212 Label row_common;
1213 cmp(reg_loop_row, row_block);
1214 je(row_common, T_NEAR);
1215
1216 compute_batch(row_tail, ncolumns);
1217 jmp(row_done, T_NEAR);
1218
1219 L(row_common);
1220 }
1221
1222 compute_batch(row_block, ncolumns);
1223 L(row_done);
1224 };
1225
1226 Label col_done;
1227 if (col_tail > 0) {
1228 Label col_common;
1229 cmp(reg_loop_col, col_block);
1230 je(col_common, T_NEAR);
1231
1232 compute_rows(col_tail);
1233 jmp(col_done, T_NEAR);
1234
1235 L(col_common);
1236 }
1237
1238 compute_rows(col_block);
1239 L(col_done);
1240
1241 postamble();
1242 }
1243
1244 struct jit_brgemm_trans_wei_f32_t : public jit_brgemm_trans_wei_t,
1245 public jit_generator {
1246 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_f32_t)
1247
jit_brgemm_trans_wei_f32_tdnnl::impl::cpu::x64::jit_brgemm_trans_wei_f32_t1248 jit_brgemm_trans_wei_f32_t(const jit_brgemm_primitive_conf_t *conf)
1249 : jit_brgemm_trans_wei_t(conf) {}
1250
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_wei_f32_t1251 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_wei_f32_t1252 status_t create_kernel() override { return jit_generator::create_kernel(); }
1253
1254 private:
1255 using reg64_t = const Xbyak::Reg64;
1256 using reg32_t = const Xbyak::Reg32;
1257 using opmask_t = const Xbyak::Opmask;
1258
1259 enum { typesize = sizeof(float), transpose_size = 16 };
1260 dim_t src_stride = 0, tr_src_stride = 0;
1261
1262 opmask_t k3333 = k1;
1263 opmask_t k5555 = k2;
1264 opmask_t kAAAA = k3;
1265 opmask_t kCCCC = k4;
1266 opmask_t k0F0F = k5;
1267 opmask_t kF0F0 = k6;
1268 opmask_t kTail = k7;
1269
1270 reg64_t reg_src_base = rax;
1271 reg64_t reg_tr_src_base = rbx;
1272
1273 reg64_t reg_src = r8;
1274 reg64_t reg_tr_src = r9;
1275 reg64_t reg_loop_N = r10;
1276 reg64_t reg_loop_K = r11;
1277 reg64_t reg_loop_batch = r12;
1278 reg64_t reg_tr_src_tmp = r13;
1279 reg32_t regw_tmp = r14d;
1280
1281 void transpose_16x16(int nrows, int ncolumns = transpose_size);
1282 void generate() override;
1283 };
1284
transpose_16x16(int nrows,int ncolumns)1285 void jit_brgemm_trans_wei_f32_t::transpose_16x16(int nrows, int ncolumns) {
1286 assert(nrows >= 0 && nrows <= transpose_size);
1287 static_assert(transpose_size == 16, "Unsupported transpose size");
1288 if (!nrows) return;
1289
1290 auto src_zmm = [=](int i) {
1291 assert(i >= 0 && i < 16);
1292 return Zmm(i);
1293 };
1294
1295 auto tmp_zmm = [=](int i) {
1296 assert(i >= 0 && i < 16);
1297 return Zmm(16 + i);
1298 };
1299
1300 auto kmovw = [=](Opmask k, unsigned w) {
1301 mov(regw_tmp, w);
1302 jit_generator::kmovw(k, regw_tmp);
1303 };
1304
1305 auto load = [=](int i) {
1306 auto src_load = src_zmm(i);
1307 if (ncolumns < transpose_size) {
1308 kmovw(kTail, (1 << ncolumns) - 1);
1309 src_load = src_zmm(i) | kTail | T_z;
1310 }
1311 vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1312 };
1313
1314 auto store = [=](Zmm r, int i) {
1315 mov(reg_tr_src_tmp, reg_tr_src);
1316 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
1317
1318 // Xbyak does not allow k0 to be specified explicitly via the '|'
1319 // operator, so we have to do this via a method call (implicitly
1320 // EVEX encoding uses k0 to mean 'no mask')
1321 bool partial_store = nrows < transpose_size;
1322 auto k = partial_store ? kTail : k0;
1323 auto base = reg_tr_src_tmp;
1324 base.setOpmaskIdx(k.getIdx(), true);
1325
1326 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
1327 vmovups(addr, r);
1328 };
1329
1330 auto transpose16x8 = [=](int base_idx) {
1331 assert(base_idx == 0 || base_idx == 8);
1332
1333 // swap 1
1334 for (int i = 0; i < 4; i++) {
1335 int src_idx0 = base_idx + i * 2;
1336 int src_idx1 = src_idx0 + 1;
1337
1338 int next_src_idx0 = src_idx0 + 2;
1339 int next_src_idx1 = src_idx1 + 2;
1340 bool load_next = base_idx == 0 || i < 3;
1341
1342 if (base_idx == 0 && i == 0) {
1343 load(src_idx0);
1344 if (src_idx1 < nrows)
1345 load(src_idx1);
1346 else
1347 vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
1348 src_zmm(src_idx1));
1349 }
1350
1351 auto tmp0 = tmp_zmm(src_idx0);
1352 auto tmp1 = tmp_zmm(src_idx1);
1353 auto src0 = src_zmm(src_idx0);
1354 auto src1 = src_zmm(src_idx1);
1355
1356 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
1357 valignd(tmp0, src0, src0, 0x1);
1358
1359 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
1360 valignd(tmp1, src1, src1, 0xf);
1361
1362 vmovaps(src0 | kAAAA, tmp1);
1363 vmovaps(src1 | k5555, tmp0);
1364 }
1365 // swap 2
1366 for (int i = 0; i < 4; i++) {
1367 int select_half = (i < 2) ? 0 : 2;
1368 int src_idx0 = base_idx + i + select_half + 0;
1369 int src_idx2 = src_idx0 + 2;
1370
1371 auto tmp0 = tmp_zmm(src_idx0);
1372 auto tmp1 = tmp_zmm(src_idx2);
1373 auto src0 = src_zmm(src_idx0);
1374 auto src2 = src_zmm(src_idx2);
1375
1376 valignd(tmp0, src0, src0, 0x2);
1377 valignd(tmp1, src2, src2, 0xe);
1378 vmovaps(src2 | k3333, tmp0);
1379 vmovaps(src0 | kCCCC, tmp1);
1380 }
1381
1382 // swap 4
1383 for (int i = 0; i < 4; i++) {
1384 int src_idx0 = base_idx + i;
1385 int src_idx4 = src_idx0 + 4;
1386
1387 auto tmp0 = tmp_zmm(src_idx0);
1388 auto src0 = src_zmm(src_idx0);
1389 auto src4 = src_zmm(src_idx4);
1390
1391 vmovaps(tmp0, src0);
1392 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
1393 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
1394 }
1395 };
1396
1397 auto fixup16x16 = [=]() {
1398 // swap 8
1399 for (int i = 0; i < 8; i++) {
1400 auto tmp = tmp_zmm(i);
1401 auto src0 = src_zmm(i);
1402 auto src8 = src_zmm(8 + i);
1403 vshuff64x2(tmp, src0, src8, 0x44);
1404 store(tmp, i);
1405 }
1406
1407 for (int i = 0; i < 8; i++) {
1408 auto tmp = tmp_zmm(8 + i);
1409 auto src0 = src_zmm(i);
1410 auto src8 = src_zmm(8 + i);
1411 vshuff64x2(tmp, src0, src8, 0xee);
1412 store(tmp, 8 + i);
1413 }
1414 };
1415
1416 transpose16x8(0);
1417 transpose16x8(8);
1418 fixup16x16();
1419 }
1420
generate()1421 void jit_brgemm_trans_wei_f32_t::generate() {
1422 preamble();
1423 assert(conf_->oc_block % transpose_size == 0);
1424 int fwd_ic_block = conf_->simd_w;
1425 int fwd_oc_block = 0;
1426 switch (conf_->wei_tag) {
1427 case OI16i64o:
1428 case OIw16i64o:
1429 case OIhw16i64o:
1430 case OIdhw16i64o:
1431 case OI8i64o2i:
1432 case OIw8i64o2i:
1433 case OIhw8i64o2i:
1434 case OIdhw8i64o2i:
1435 case OI16i64o2i:
1436 case OIw16i64o2i:
1437 case OIhw16i64o2i:
1438 case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
1439 case OI16i32o:
1440 case OIw16i32o:
1441 case OIhw16i32o:
1442 case OIdhw16i32o:
1443 case OI8i32o2i:
1444 case OIw8i32o2i:
1445 case OIhw8i32o2i:
1446 case OIdhw8i32o2i:
1447 case OI16i32o2i:
1448 case OIw16i32o2i:
1449 case OIhw16i32o2i:
1450 case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
1451 default: fwd_oc_block = conf_->simd_w;
1452 };
1453
1454 int oc_tail = conf_->K_tail % transpose_size;
1455 int ic_block = conf_->ic_block;
1456 int ic_tail = conf_->N_tail % transpose_size;
1457 src_stride = fwd_oc_block * typesize;
1458 tr_src_stride = ic_block * typesize;
1459 dim_t N_src_shift = conf_->kd * conf_->kh * conf_->kw * fwd_ic_block
1460 * fwd_oc_block * typesize;
1461 dim_t N_tr_src_shift = conf_->simd_w * typesize;
1462 dim_t K_src_shift = conf_->simd_w * typesize;
1463 dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize;
1464
1465 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
1466 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
1467 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1468 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
1469
1470 auto kmovw = [=](Opmask k, unsigned w) {
1471 mov(regw_tmp, w);
1472 jit_generator::kmovw(k, regw_tmp);
1473 };
1474
1475 kmovw(k3333, 0x3333); // 0011001100110011
1476 kmovw(k5555, 0x5555); // 0101010101010101
1477 kmovw(kAAAA, 0xaaaa); // 1010101010101010
1478 kmovw(kCCCC, 0xcccc); // 1100110011001100
1479 kmovw(k0F0F, 0x0f0f); // 0000111100001111
1480 kmovw(kF0F0, 0xf0f0); // 1111000011110000
1481
1482 auto compute_N = [=](bool is_oc_tail) {
1483 mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
1484 mov(reg_src, reg_src_base);
1485 mov(reg_tr_src, reg_tr_src_base);
1486 Label N_loop, N_loop_tail;
1487
1488 cmp(reg_loop_N, transpose_size);
1489 jl(N_loop_tail, T_NEAR);
1490
1491 L(N_loop);
1492
1493 transpose_16x16(transpose_size, is_oc_tail ? oc_tail : transpose_size);
1494 add(reg_src, N_src_shift);
1495 add(reg_tr_src, N_tr_src_shift);
1496
1497 sub(reg_loop_N, transpose_size);
1498 cmp(reg_loop_N, transpose_size);
1499 jge(N_loop, T_NEAR);
1500
1501 L(N_loop_tail);
1502 if (ic_tail > 0) {
1503 Label N_loop_done;
1504 cmp(reg_loop_N, 0);
1505 jle(N_loop_done, T_NEAR);
1506 transpose_16x16(ic_tail, is_oc_tail ? oc_tail : transpose_size);
1507 L(N_loop_done);
1508 }
1509 };
1510
1511 Label K_loop, K_tail;
1512 if (oc_tail > 0) {
1513 cmp(reg_loop_K, transpose_size);
1514 jl(K_tail, T_NEAR);
1515 }
1516
1517 L(K_loop);
1518 compute_N(false);
1519 add(reg_src_base, K_src_shift);
1520 add(reg_tr_src_base, K_tr_src_shift);
1521
1522 sub(reg_loop_K, transpose_size);
1523 cmp(reg_loop_K, transpose_size);
1524 jge(K_loop, T_NEAR);
1525
1526 L(K_tail);
1527 if (oc_tail > 0) {
1528 Label K_loop_done;
1529 cmp(reg_loop_K, 0);
1530 jle(K_loop_done, T_NEAR);
1531
1532 compute_N(true);
1533 L(K_loop_done);
1534 }
1535
1536 postamble();
1537 }
1538
1539 struct jit_brgemm_trans_wei_bf16_t : public jit_brgemm_trans_wei_t,
1540 public jit_generator {
1541 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_bf16_t)
1542
jit_brgemm_trans_wei_bf16_tdnnl::impl::cpu::x64::jit_brgemm_trans_wei_bf16_t1543 jit_brgemm_trans_wei_bf16_t(const jit_brgemm_primitive_conf_t *conf)
1544 : jit_brgemm_trans_wei_t(conf) {}
1545
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_wei_bf16_t1546 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_wei_bf16_t1547 status_t create_kernel() override { return jit_generator::create_kernel(); }
1548
1549 private:
1550 using reg64_t = const Xbyak::Reg64;
1551 using reg32_t = const Xbyak::Reg32;
1552 using opmask_t = const Xbyak::Opmask;
1553 using zmm = const Xbyak::Zmm;
1554
1555 enum { typesize = sizeof(int16_t), transpose_size = 16 };
1556 dim_t src_stride = 0, tr_src_stride = 0;
1557
1558 opmask_t kTail = k7;
1559
1560 reg64_t reg_src_base = rax;
1561 reg64_t reg_tr_src_base = rbx;
1562
1563 reg64_t reg_src = r8;
1564 reg64_t reg_tr_src = r9;
1565 reg64_t reg_loop_N = r10;
1566 reg64_t reg_loop_K = r11;
1567 reg64_t reg_loop_batch = r12;
1568 reg64_t reg_tr_src_tmp = r13;
1569 reg32_t regw_tmp = r14d;
1570 reg64_t imm_addr64 = r15;
1571
1572 zmm v_abcdefgh_to_abefcdgh = zmm31;
1573
1574 void transpose_16x16_vnni(int nrows, int ncolumns = transpose_size);
1575 void generate() override;
1576 };
1577
transpose_16x16_vnni(int nrows,int ncolumns)1578 void jit_brgemm_trans_wei_bf16_t::transpose_16x16_vnni(
1579 int nrows, int ncolumns) {
1580 assert(nrows >= 0 && nrows <= transpose_size);
1581 static_assert(transpose_size == 16, "Unsupported transpose size");
1582 if (!nrows) return;
1583
1584 auto src_zmm = [=](int i) {
1585 assert(i >= 0 && i < 8);
1586 return Zmm(i);
1587 };
1588
1589 auto tmp_zmm = [=](int i) {
1590 assert(i >= 0 && i < 8);
1591 return Zmm(8 + i);
1592 };
1593
1594 auto kmovw = [=](Opmask k, unsigned w) {
1595 mov(regw_tmp, w);
1596 jit_generator::kmovw(k, regw_tmp);
1597 };
1598
1599 auto load = [=](int i) {
1600 auto src_load = src_zmm(i);
1601 if (ncolumns < transpose_size) {
1602 kmovw(kTail, (1 << ncolumns) - 1);
1603 src_load = src_zmm(i) | kTail | T_z;
1604 }
1605 vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1606 };
1607
1608 auto store = [=](Zmm r, int i) {
1609 mov(reg_tr_src_tmp, reg_tr_src);
1610 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
1611
1612 // Xbyak does not allow k0 to be specified explicitly via the '|'
1613 // operator, so we have to do this via a method call (implicitly
1614 // EVEX encoding uses k0 to mean 'no mask')
1615 bool partial_store = nrows < transpose_size;
1616 auto k = partial_store ? kTail : k0;
1617 auto base = reg_tr_src_tmp;
1618 base.setOpmaskIdx(k.getIdx(), true);
1619
1620 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
1621 vmovups(addr, r);
1622 };
1623
1624 for (int i = 0; i < 8; i++)
1625 load(i);
1626
1627 for (int i = 0; i < 8; i++)
1628 vpshufb(src_zmm(i), src_zmm(i), v_abcdefgh_to_abefcdgh);
1629
1630 for (int i = 0; i < 2; i++) {
1631 vpunpcklqdq(tmp_zmm(2 * i + 0), src_zmm(2 * i), src_zmm(2 * i + 1));
1632 vpunpckhqdq(tmp_zmm(2 * i + 1), src_zmm(2 * i), src_zmm(2 * i + 1));
1633 }
1634
1635 for (int i = 0; i < 2; i++) {
1636 vpunpcklqdq(
1637 src_zmm(2 * i + 0), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1));
1638 vpunpckhqdq(
1639 src_zmm(2 * i + 1), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1));
1640 }
1641
1642 for (int i = 0; i < 2; i++) {
1643 vshufi32x4(src_zmm(4 + 0 + i), tmp_zmm(i), tmp_zmm(2 + i), 0x88);
1644 vshufi32x4(src_zmm(4 + 2 + i), tmp_zmm(i), tmp_zmm(2 + i), 0xdd);
1645 }
1646
1647 for (int i = 0; i < 2; i++) {
1648 vshufi32x4(tmp_zmm(0 + i), src_zmm(i), src_zmm(2 + i), 0x88);
1649 vshufi32x4(tmp_zmm(2 + i), src_zmm(i), src_zmm(2 + i), 0xdd);
1650 }
1651
1652 for (int i = 0; i < 4; i++)
1653 vshufi32x4(src_zmm(i), src_zmm(4 + i), tmp_zmm(i), 0x88);
1654
1655 for (int i = 0; i < 4; i++)
1656 vshufi32x4(src_zmm(4 + i), src_zmm(4 + i), tmp_zmm(i), 0xdd);
1657
1658 for (int i = 0; i < 8; i++)
1659 store(src_zmm(i), i);
1660 }
1661
generate()1662 void jit_brgemm_trans_wei_bf16_t::generate() {
1663 preamble();
1664 int fwd_oc_block = 0;
1665 switch (conf_->wei_tag) {
1666 case OI16i64o:
1667 case OIw16i64o:
1668 case OIhw16i64o:
1669 case OIdhw16i64o:
1670 case OI8i64o2i:
1671 case OIw8i64o2i:
1672 case OIhw8i64o2i:
1673 case OIdhw8i64o2i:
1674 case OI16i64o2i:
1675 case OIw16i64o2i:
1676 case OIhw16i64o2i:
1677 case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
1678 case OI16i32o:
1679 case OIw16i32o:
1680 case OIhw16i32o:
1681 case OIdhw16i32o:
1682 case OI8i32o2i:
1683 case OIw8i32o2i:
1684 case OIhw8i32o2i:
1685 case OIdhw8i32o2i:
1686 case OI16i32o2i:
1687 case OIw16i32o2i:
1688 case OIhw16i32o2i:
1689 case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
1690 default: fwd_oc_block = conf_->simd_w;
1691 };
1692
1693 int oc_tail = conf_->K_tail % transpose_size;
1694 int ic_block = conf_->ic_block;
1695 int ic_tail = conf_->N_tail % transpose_size;
1696 src_stride = 2 * fwd_oc_block * typesize;
1697 tr_src_stride = 2 * ic_block * typesize;
1698 dim_t N_src_shift = conf_->simd_w * fwd_oc_block * typesize;
1699 dim_t N_tr_src_shift = 2 * conf_->simd_w * typesize;
1700 dim_t K_src_shift = 2 * conf_->simd_w * typesize;
1701 dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize;
1702
1703 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
1704 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
1705 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1706 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
1707
1708 alignas(64) static constexpr const int32_t abcdefgh_to_abefcdgh[16]
1709 = {0x05040100, 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100,
1710 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302,
1711 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302, 0x0d0c0908,
1712 0x0f0e0b0a};
1713
1714 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
1715 mov(imm_addr64, reinterpret_cast<size_t>(addr));
1716 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1717 };
1718
1719 vmovdqa64(v_abcdefgh_to_abefcdgh, (const int64_t *)abcdefgh_to_abefcdgh);
1720 auto compute_N = [=](bool is_oc_tail) {
1721 mov(reg_src, reg_src_base);
1722 mov(reg_tr_src, reg_tr_src_base);
1723 mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
1724
1725 Label N_loop, N_loop_tail;
1726 cmp(reg_loop_N, transpose_size);
1727 jl(N_loop_tail, T_NEAR);
1728
1729 L(N_loop);
1730
1731 transpose_16x16_vnni(
1732 transpose_size, is_oc_tail ? oc_tail : transpose_size);
1733 add(reg_src, N_src_shift);
1734 add(reg_tr_src, N_tr_src_shift);
1735
1736 sub(reg_loop_N, transpose_size);
1737 cmp(reg_loop_N, transpose_size);
1738 jge(N_loop, T_NEAR);
1739
1740 L(N_loop_tail);
1741 if (ic_tail > 0) {
1742 Label N_loop_done;
1743 cmp(reg_loop_N, 0);
1744 jle(N_loop_done, T_NEAR);
1745 transpose_16x16_vnni(
1746 ic_tail, is_oc_tail ? oc_tail : transpose_size);
1747 L(N_loop_done);
1748 }
1749 };
1750
1751 Label K_loop, K_tail;
1752 if (oc_tail > 0) {
1753 cmp(reg_loop_K, transpose_size);
1754 jl(K_tail, T_NEAR);
1755 }
1756
1757 L(K_loop);
1758 compute_N(false);
1759 add(reg_src_base, K_src_shift);
1760 add(reg_tr_src_base, K_tr_src_shift);
1761
1762 sub(reg_loop_K, transpose_size);
1763 cmp(reg_loop_K, transpose_size);
1764 jge(K_loop, T_NEAR);
1765
1766 L(K_tail);
1767 if (oc_tail > 0) {
1768 Label K_loop_done;
1769 cmp(reg_loop_K, 0);
1770 jle(K_loop_done, T_NEAR);
1771 compute_N(true);
1772 L(K_loop_done);
1773 }
1774
1775 postamble();
1776 }
1777
1778 struct jit_amx_ip_trans_diff_wei_to_vnni_t : public jit_amx_ip_trans_diff_wei,
1779 public jit_generator {
1780 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_amx_ip_trans_diff_wei_to_vnni)
1781
jit_amx_ip_trans_diff_wei_to_vnni_tdnnl::impl::cpu::x64::jit_amx_ip_trans_diff_wei_to_vnni_t1782 jit_amx_ip_trans_diff_wei_to_vnni_t(const jit_brgemm_primitive_conf_t *jbgp,
1783 const int ext_ic_block, const int ext_oc_block)
1784 : jit_amx_ip_trans_diff_wei(jbgp, ext_ic_block, ext_oc_block) {}
1785
operator ()dnnl::impl::cpu::x64::jit_amx_ip_trans_diff_wei_to_vnni_t1786 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_amx_ip_trans_diff_wei_to_vnni_t1787 status_t create_kernel() override { return jit_generator::create_kernel(); }
1788
1789 private:
1790 void generate() override;
1791 };
1792
generate()1793 void jit_amx_ip_trans_diff_wei_to_vnni_t::generate() {
1794 const int typesize_out = 2;
1795 const int typesize_acc = 4;
1796 const int simd_w = 16;
1797
1798 using reg64_t = const Xbyak::Reg64;
1799 using reg32_t = const Xbyak::Reg32;
1800
1801 const reg64_t ®_output = r15;
1802 const reg64_t ®_input = r14;
1803 const reg64_t ®_prm_table = r13;
1804 const reg64_t ®_last_ic_block = r12;
1805 const reg64_t ®_last_oc_block = r11;
1806 const reg32_t ®w_tmp = r10d;
1807
1808 const Xbyak::Zmm &zmm_idx = Xbyak::Zmm(31);
1809 auto get_zmm_src = [&](int ic) { return Xbyak::Zmm(ic % 8); };
1810
1811 Xbyak::Label prm_table;
1812 Xbyak::Label skip_oc_tail, to_exit;
1813
1814 Xbyak::Opmask load_mask = k4;
1815
1816 int tail_mask = (jbgp_->N_tail % simd_w)
1817 ? (1 << (jbgp_->N_tail % simd_w)) - 1
1818 : 0xffff;
1819 auto kmovw = [=](Xbyak::Opmask k, unsigned w) {
1820 mov(regw_tmp, w);
1821 jit_generator::kmovw(k, regw_tmp);
1822 };
1823
1824 auto reorder_oc_block = [&](int icb, int ic_block, bool is_oc_tail) {
1825 // INP: [64i][No] : FP32
1826 // OUT: [OCB][ICB][16i][No][2i]: BF16
1827 if (ic_block <= 0) return;
1828
1829 dim_t inp_icb_offset = typesize_acc
1830 * (icb * ext_ic_block_ * jbgp_->oc_block); // Internal
1831 dim_t out_icb_offset = typesize_out
1832 * (icb * div_up(ext_ic_block_, 2) * ext_oc_block_
1833 * 2); // External
1834
1835 const int oc_padded = rnd_up(jbgp_->oc, jbgp_->oc_block);
1836 const int oc_padded_ext = rnd_up(jbgp_->oc, ext_oc_block_);
1837
1838 bool tailing_done = false;
1839 for (int oc = 0; oc < jbgp_->oc_block; oc += simd_w) {
1840 int ext_oc = oc % ext_oc_block_;
1841 int ext_ocb = oc / ext_oc_block_;
1842 dim_t ext_ocb_offset = typesize_out
1843 * (ext_ocb * div_up(jbgp_->ic, ext_ic_block_)
1844 * div_up(ext_ic_block_, 2) * ext_oc_block_ * 2);
1845 if (is_oc_tail && oc_padded != oc_padded_ext
1846 && oc + simd_w > ext_oc_block_)
1847 break;
1848 dim_t inp_offset = inp_icb_offset + typesize_acc * (oc); // Internal
1849 dim_t out_offset = out_icb_offset + typesize_out * (ext_oc * 2)
1850 + ext_ocb_offset; // External
1851 kmovw(load_mask, 0xffff);
1852 if (is_oc_tail) {
1853 if (jbgp_->N_tail && (oc + simd_w) >= jbgp_->N_tail) {
1854 if (tailing_done == false) {
1855 kmovw(load_mask, tail_mask);
1856 tailing_done = true;
1857 } else {
1858 auto zmm_src_0 = get_zmm_src(0);
1859 vpxord(zmm_src_0, zmm_src_0, zmm_src_0);
1860 for (int ic = 0; ic < ext_ic_block_ / 2; ic++) {
1861 vmovups(ptr[reg_output + out_offset
1862 + typesize_out
1863 * (ic * ext_oc_block_ * 2)],
1864 zmm_src_0);
1865 }
1866 continue;
1867 }
1868 }
1869 }
1870
1871 int ic = 0;
1872 for (; ic < ic_block / 2; ic++) {
1873 int ic1 = 2 * ic;
1874 int ic2 = 2 * ic + 1;
1875
1876 auto zmm_src_0 = get_zmm_src(ic1);
1877 auto zmm_src_1 = get_zmm_src(ic2);
1878
1879 vmovups(zmm_src_0 | load_mask | T_z,
1880 ptr[reg_input + inp_offset
1881 + typesize_acc * (ic1 * jbgp_->oc_block)]);
1882 vmovups(zmm_src_1 | load_mask | T_z,
1883 ptr[reg_input + inp_offset
1884 + typesize_acc * (ic2 * jbgp_->oc_block)]);
1885
1886 vcvtne2ps2bf16(zmm_src_0, zmm_src_1, zmm_src_0);
1887 vpermw(zmm_src_0, zmm_idx, zmm_src_0);
1888
1889 vmovups(ptr[reg_output + out_offset
1890 + typesize_out * (ic * ext_oc_block_ * 2)],
1891 zmm_src_0);
1892 }
1893 if (ic_block % 2) {
1894 int ic1 = 2 * ic;
1895 int ic2 = 2 * ic + 1;
1896
1897 auto zmm_src_0 = get_zmm_src(ic1);
1898 auto zmm_src_1 = get_zmm_src(ic2);
1899
1900 vmovups(zmm_src_0 | load_mask | T_z,
1901 ptr[reg_input + inp_offset
1902 + typesize_acc * (ic1 * jbgp_->oc_block)]);
1903 vpxord(zmm_src_1, zmm_src_1, zmm_src_1);
1904
1905 vcvtne2ps2bf16(zmm_src_0, zmm_src_1, zmm_src_0);
1906 vpermw(zmm_src_0, zmm_idx, zmm_src_0);
1907
1908 vmovups(ptr[reg_output + out_offset
1909 + typesize_out * (ic * ext_oc_block_ * 2)],
1910 zmm_src_0);
1911 ic++;
1912 }
1913 if (ic < ext_ic_block_ / 2) {
1914 auto zmm_src_0 = get_zmm_src(0);
1915 vpxord(zmm_src_0, zmm_src_0, zmm_src_0);
1916 for (; ic < ext_ic_block_ / 2; ic++) {
1917 vmovups(ptr[reg_output + out_offset
1918 + typesize_out * (ic * ext_oc_block_ * 2)],
1919 zmm_src_0);
1920 }
1921 }
1922 }
1923 };
1924
1925 auto reorder_ic_block = [&](bool is_oc_tail, bool is_ic_tail) {
1926 int nb_ic = div_up(jbgp_->ic_block, ext_ic_block_);
1927 for (int icb = 0; icb < nb_ic; icb++) {
1928 int ic_0 = icb * ext_ic_block_;
1929 int ic_1 = (icb + 1) * ext_ic_block_;
1930 if (is_ic_tail) {
1931 int ext_ic_tail = (jbgp_->ic % ext_ic_block_)
1932 ? (jbgp_->ic % ext_ic_block_)
1933 : ext_ic_block_;
1934 if (jbgp_->M_tail && ic_0 >= jbgp_->M_tail) break;
1935 if (jbgp_->M_tail && ic_0 <= jbgp_->M_tail
1936 && jbgp_->M_tail <= ic_1) {
1937 reorder_oc_block(icb, ext_ic_tail, is_oc_tail);
1938 } else {
1939 reorder_oc_block(icb, ext_ic_block_, is_oc_tail);
1940 }
1941 } else {
1942 reorder_oc_block(icb, ext_ic_block_, is_oc_tail);
1943 }
1944 }
1945 };
1946
1947 auto reorder = [&](bool is_oc_tail) {
1948 Xbyak::Label skip_ic_tail, to_exit_1;
1949
1950 cmp(reg_last_ic_block, 0);
1951 je(skip_ic_tail, T_NEAR);
1952
1953 reorder_ic_block(is_oc_tail, true);
1954 jmp(to_exit, T_NEAR);
1955
1956 L(skip_ic_tail);
1957 reorder_ic_block(is_oc_tail, false);
1958
1959 L(to_exit_1);
1960 };
1961
1962 preamble();
1963
1964 mov(reg_input, ptr[abi_param1 + GET_OFF(src)]);
1965 mov(reg_output, ptr[abi_param1 + GET_OFF(dst)]);
1966 mov(reg_last_ic_block, ptr[abi_param1 + GET_OFF(last_ic_block)]);
1967 mov(reg_last_oc_block, ptr[abi_param1 + GET_OFF(last_oc_block)]);
1968
1969 mov(reg_prm_table, prm_table);
1970 vmovups(zmm_idx, ptr[reg_prm_table]);
1971
1972 cmp(reg_last_oc_block, 0);
1973 je(skip_oc_tail, T_NEAR);
1974
1975 reorder(true);
1976 jmp(to_exit, T_NEAR);
1977
1978 L(skip_oc_tail);
1979 reorder(false);
1980
1981 L(to_exit);
1982 postamble();
1983
1984 align(64);
1985 L(prm_table);
1986 const uint16_t prm_array[32]
1987 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
1988 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
1989 for (size_t i = 0; i < 32; ++i)
1990 dw(prm_array[i]);
1991 }
1992
1993 #undef GET_OFF
1994
create_brgemm_trans_src(std::unique_ptr<jit_brgemm_trans_src_t> & trans_ker,const jit_brgemm_primitive_conf_t * conf)1995 status_t create_brgemm_trans_src(
1996 std::unique_ptr<jit_brgemm_trans_src_t> &trans_ker,
1997 const jit_brgemm_primitive_conf_t *conf) {
1998 if (conf->prop_kind == dnnl_backward_weights
1999 && conf->src_dt == data_type::f32)
2000 CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_m_k_f32_t(conf)));
2001 else if (conf->prop_kind == dnnl_backward_weights
2002 && conf->src_dt == data_type::bf16)
2003 CHECK(safe_ptr_assign(
2004 trans_ker, new jit_brgemm_trans_m_k_bf16_t(conf)));
2005 else
2006 return status::invalid_arguments;
2007
2008 return trans_ker->create_kernel();
2009 }
2010
create_brgemm_copy_to_coarse(std::unique_ptr<jit_brgemm_copy_to_coarse_t> & copy_ker,const jit_brgemm_primitive_conf_t * conf)2011 status_t create_brgemm_copy_to_coarse(
2012 std::unique_ptr<jit_brgemm_copy_to_coarse_t> ©_ker,
2013 const jit_brgemm_primitive_conf_t *conf) {
2014 if (conf->isa == avx512_core_bf16_amx_int8
2015 || conf->isa == avx512_core_bf16_amx_bf16)
2016 CHECK(safe_ptr_assign(copy_ker, new jit_brgemm_copy_to_coarse_t(conf)));
2017 else
2018 return status::invalid_arguments;
2019
2020 return copy_ker->create_kernel();
2021 }
2022
create_brgemm_trans_to_vnni(std::unique_ptr<jit_brgemm_trans_to_vnni_t> & trans_ker,const jit_brgemm_primitive_conf_t * conf,jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform)2023 status_t create_brgemm_trans_to_vnni(
2024 std::unique_ptr<jit_brgemm_trans_to_vnni_t> &trans_ker,
2025 const jit_brgemm_primitive_conf_t *conf,
2026 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform) {
2027 if (conf->prop_kind == dnnl_backward_weights
2028 && conf->dst_dt == data_type::bf16)
2029 CHECK(safe_ptr_assign(
2030 trans_ker, new jit_trans_to_vnni_t(conf, matrix_to_transform)));
2031 else if (conf->prop_kind == dnnl_backward_weights
2032 && conf->dst_dt == data_type::f32)
2033 CHECK(safe_ptr_assign(
2034 trans_ker, new jit_copy_f32_t(conf, matrix_to_transform)));
2035 else
2036 return status::invalid_arguments;
2037
2038 return trans_ker->create_kernel();
2039 }
2040
create_brgemm_trans_wei(std::unique_ptr<jit_brgemm_trans_wei_t> & trans_ker,const jit_brgemm_primitive_conf_t * conf)2041 status_t create_brgemm_trans_wei(
2042 std::unique_ptr<jit_brgemm_trans_wei_t> &trans_ker,
2043 const jit_brgemm_primitive_conf_t *conf) {
2044 if (conf->prop_kind == dnnl_backward_data && conf->wei_dt == data_type::f32)
2045 CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_wei_f32_t(conf)));
2046 else if (conf->prop_kind == dnnl_backward_data
2047 && conf->wei_dt == data_type::bf16)
2048 CHECK(safe_ptr_assign(
2049 trans_ker, new jit_brgemm_trans_wei_bf16_t(conf)));
2050 else
2051 return status::invalid_arguments;
2052
2053 return trans_ker->create_kernel();
2054 }
2055
create_brgemm_amx_ip_trans_wei(std::unique_ptr<jit_amx_ip_trans_diff_wei> & trans_ker,const jit_brgemm_primitive_conf_t * conf,const int ext_ic_block,const int ext_oc_block)2056 status_t create_brgemm_amx_ip_trans_wei(
2057 std::unique_ptr<jit_amx_ip_trans_diff_wei> &trans_ker,
2058 const jit_brgemm_primitive_conf_t *conf, const int ext_ic_block,
2059 const int ext_oc_block) {
2060 if (conf->prop_kind == dnnl_backward_weights
2061 && conf->wei_dt == data_type::bf16) {
2062 CHECK(safe_ptr_assign(trans_ker,
2063 new jit_amx_ip_trans_diff_wei_to_vnni_t(
2064 conf, ext_ic_block, ext_oc_block)));
2065 } else
2066 return status::invalid_arguments;
2067
2068 return trans_ker->create_kernel();
2069 }
2070
2071 } // namespace x64
2072 } // namespace cpu
2073 } // namespace impl
2074 } // namespace dnnl
2075