1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/nstl.hpp"
19 #include "common/type_helpers.hpp"
20 #include "common/utils.hpp"
21 #include "cpu/x64/jit_generator.hpp"
22 
23 #include "cpu/x64/jit_brgemm_transpose_utils.hpp"
24 
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29 
30 using namespace dnnl::impl::format_tag;
31 using namespace dnnl::impl::utils;
32 using namespace Xbyak;
33 
34 #define GET_OFF(x) offsetof(ctx_t, x)
35 
36 struct jit_brgemm_trans_m_k_f32_t : public jit_brgemm_trans_src_t,
37                                     public jit_generator {
38     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_f32_t)
39 
jit_brgemm_trans_m_k_f32_tdnnl::impl::cpu::x64::jit_brgemm_trans_m_k_f32_t40     jit_brgemm_trans_m_k_f32_t(const jit_brgemm_primitive_conf_t *conf)
41         : jit_brgemm_trans_src_t(conf) {}
42 
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_m_k_f32_t43     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_m_k_f32_t44     status_t create_kernel() override { return jit_generator::create_kernel(); }
45 
46 private:
47     using reg64_t = const Xbyak::Reg64;
48     using reg32_t = const Xbyak::Reg32;
49     using opmask_t = const Xbyak::Opmask;
50 
51     enum { typesize = sizeof(float), transpose_size = 16 };
52     dim_t src_stride = 0, tr_src_stride = 0;
53 
54     opmask_t k3333 = k1;
55     opmask_t k5555 = k2;
56     opmask_t kAAAA = k3;
57     opmask_t kCCCC = k4;
58     opmask_t k0F0F = k5;
59     opmask_t kF0F0 = k6;
60     opmask_t kTail = k7;
61 
62     reg64_t reg_src_base = rax;
63     reg64_t reg_tr_src_base = rbx;
64 
65     reg64_t reg_src = r8;
66     reg64_t reg_tr_src = r9;
67     reg64_t reg_loop_K = r10;
68     reg64_t reg_loop_M = r11;
69     reg64_t reg_loop_batch = r12;
70     reg64_t reg_tr_src_tmp = r13;
71     reg32_t regw_tmp = r14d;
72 
73     void transpose_16x16(int nrows, int ncolumns = transpose_size);
74     void generate() override;
75 };
76 
transpose_16x16(int nrows,int ncolumns)77 void jit_brgemm_trans_m_k_f32_t::transpose_16x16(int nrows, int ncolumns) {
78     assert(nrows >= 0 && nrows <= transpose_size);
79     static_assert(transpose_size == 16, "Unsupported transpose size");
80     if (!nrows) return;
81 
82     auto src_zmm = [=](int i) {
83         assert(i >= 0 && i < 16);
84         return Zmm(i);
85     };
86 
87     auto tmp_zmm = [=](int i) {
88         assert(i >= 0 && i < 16);
89         return Zmm(16 + i);
90     };
91 
92     auto kmovw = [=](Opmask k, unsigned w) {
93         mov(regw_tmp, w);
94         jit_generator::kmovw(k, regw_tmp);
95     };
96 
97     auto load = [=](int i) {
98         auto src_load = src_zmm(i);
99         if (i >= nrows) {
100             vpxord(src_load, src_load, src_load);
101             return;
102         }
103 
104         if (ncolumns < transpose_size) {
105             kmovw(kTail, (1 << ncolumns) - 1);
106             src_load = src_zmm(i) | kTail | T_z;
107         }
108         vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
109     };
110 
111     auto store = [=](Zmm r, int i) {
112         mov(reg_tr_src_tmp, reg_tr_src);
113         if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
114 
115         // Xbyak does not allow k0 to be specified explicitly via the '|'
116         // operator, so we have to do this via a method call (implicitly
117         // EVEX encoding uses k0 to mean 'no mask')
118         bool partial_store = nrows < transpose_size;
119         auto k = partial_store ? kTail : k0;
120         auto base = reg_tr_src_tmp;
121         base.setOpmaskIdx(k.getIdx(), true);
122 
123         auto addr = EVEX_compress_addr(base, i * tr_src_stride);
124         vmovups(addr, r);
125     };
126 
127     auto transpose16x8 = [=](int base_idx) {
128         assert(base_idx == 0 || base_idx == 8);
129 
130         // swap 1
131         for (int i = 0; i < 4; i++) {
132             int src_idx0 = base_idx + i * 2;
133             int src_idx1 = src_idx0 + 1;
134 
135             int next_src_idx0 = src_idx0 + 2;
136             int next_src_idx1 = src_idx1 + 2;
137             bool load_next = base_idx == 0 || i < 3;
138 
139             if (base_idx == 0 && i == 0) {
140                 load(src_idx0);
141                 if (src_idx1 < nrows)
142                     load(src_idx1);
143                 else
144                     vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
145                             src_zmm(src_idx1));
146             }
147 
148             auto tmp0 = tmp_zmm(src_idx0);
149             auto tmp1 = tmp_zmm(src_idx1);
150             auto src0 = src_zmm(src_idx0);
151             auto src1 = src_zmm(src_idx1);
152 
153             if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
154             valignd(tmp0, src0, src0, 0x1);
155 
156             if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
157             valignd(tmp1, src1, src1, 0xf);
158 
159             vmovaps(src0 | kAAAA, tmp1);
160             vmovaps(src1 | k5555, tmp0);
161         }
162         // swap 2
163         for (int i = 0; i < 4; i++) {
164             int select_half = (i < 2) ? 0 : 2;
165             int src_idx0 = base_idx + i + select_half + 0;
166             int src_idx2 = src_idx0 + 2;
167 
168             auto tmp0 = tmp_zmm(src_idx0);
169             auto tmp1 = tmp_zmm(src_idx2);
170             auto src0 = src_zmm(src_idx0);
171             auto src2 = src_zmm(src_idx2);
172 
173             valignd(tmp0, src0, src0, 0x2);
174             valignd(tmp1, src2, src2, 0xe);
175             vmovaps(src2 | k3333, tmp0);
176             vmovaps(src0 | kCCCC, tmp1);
177         }
178 
179         // swap 4
180         for (int i = 0; i < 4; i++) {
181             int src_idx0 = base_idx + i;
182             int src_idx4 = src_idx0 + 4;
183 
184             auto tmp0 = tmp_zmm(src_idx0);
185             auto src0 = src_zmm(src_idx0);
186             auto src4 = src_zmm(src_idx4);
187 
188             vmovaps(tmp0, src0);
189             vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
190             vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
191         }
192     };
193 
194     auto fixup16x16 = [=]() {
195         // swap 8
196         for (int i = 0; i < 8; i++) {
197             auto tmp = tmp_zmm(i);
198             auto src0 = src_zmm(i);
199             auto src8 = src_zmm(8 + i);
200             vshuff64x2(tmp, src0, src8, 0x44);
201             store(tmp, i);
202         }
203 
204         for (int i = 0; i < 8; i++) {
205             auto tmp = tmp_zmm(8 + i);
206             auto src0 = src_zmm(i);
207             auto src8 = src_zmm(8 + i);
208             vshuff64x2(tmp, src0, src8, 0xee);
209             store(tmp, 8 + i);
210         }
211     };
212 
213     transpose16x8(0);
214     transpose16x8(8);
215     fixup16x16();
216 }
217 
generate()218 void jit_brgemm_trans_m_k_f32_t::generate() {
219     preamble();
220     assert(conf_->ic_block % transpose_size == 0);
221     int os_block = conf_->os_block;
222     int last_os_block_tail = conf_->K_tail % transpose_size;
223     int ic_tail = conf_->ic % transpose_size;
224     src_stride = conf_->ic * typesize;
225     tr_src_stride = conf_->LDA * typesize;
226     dim_t m_src_shift = transpose_size * typesize;
227     dim_t m_tr_src_shift = tr_src_stride * transpose_size;
228 
229     dim_t batch_src_shift = src_stride * os_block;
230     dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
231 
232     mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
233     mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
234     mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
235     mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
236 
237     auto kmovw = [=](Opmask k, unsigned w) {
238         mov(regw_tmp, w);
239         jit_generator::kmovw(k, regw_tmp);
240     };
241 
242     kmovw(k3333, 0x3333); // 0011001100110011
243     kmovw(k5555, 0x5555); // 0101010101010101
244     kmovw(kAAAA, 0xaaaa); // 1010101010101010
245     kmovw(kCCCC, 0xcccc); // 1100110011001100
246     kmovw(k0F0F, 0x0f0f); // 0000111100001111
247     kmovw(kF0F0, 0xf0f0); // 1111000011110000
248 
249     auto compute_M = [=](bool is_os_tail) {
250         auto nrows = is_os_tail ? last_os_block_tail : transpose_size;
251         mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
252         mov(reg_src, reg_src_base);
253         mov(reg_tr_src, reg_tr_src_base);
254         Label M_loop, M_tail_or_done, M_done;
255         if (ic_tail > 0) {
256             cmp(reg_loop_M, transpose_size);
257             jl(M_tail_or_done, T_NEAR);
258         }
259 
260         L(M_loop);
261         transpose_16x16(nrows, transpose_size);
262         if (conf_->ic_block > transpose_size) {
263             add(reg_src, m_src_shift);
264             add(reg_tr_src, m_tr_src_shift);
265             sub(reg_loop_M, transpose_size);
266             cmp(reg_loop_M, transpose_size);
267             jge(M_loop, T_NEAR);
268         } else {
269             jmp(M_done, T_NEAR);
270         }
271 
272         L(M_tail_or_done);
273         if (ic_tail > 0) {
274             cmp(reg_loop_M, 0);
275             jle(M_done, T_NEAR);
276 
277             transpose_16x16(nrows, ic_tail);
278         }
279         L(M_done);
280     };
281 
282     auto compute_batch = [=](bool is_os_tail) {
283         Label batch_loop;
284         L(batch_loop);
285 
286         compute_M(is_os_tail);
287         add(reg_src_base, batch_src_shift);
288         add(reg_tr_src_base, batch_tr_src_shift);
289 
290         sub(reg_loop_batch, 1);
291         jnz(batch_loop, T_NEAR);
292     };
293 
294     Label K_tail;
295     if (last_os_block_tail > 0) {
296         cmp(reg_loop_K, transpose_size);
297         jl(K_tail, T_NEAR);
298     }
299 
300     compute_batch(false);
301 
302     if (last_os_block_tail > 0) {
303         Label K_done;
304         jmp(K_done, T_NEAR);
305 
306         L(K_tail);
307         compute_batch(true);
308         L(K_done);
309     }
310 
311     postamble();
312 }
313 
314 struct jit_brgemm_trans_m_k_bf16_t : public jit_brgemm_trans_src_t,
315                                      public jit_generator {
316     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_bf16_t)
jit_brgemm_trans_m_k_bf16_tdnnl::impl::cpu::x64::jit_brgemm_trans_m_k_bf16_t317     jit_brgemm_trans_m_k_bf16_t(const jit_brgemm_primitive_conf_t *conf)
318         : jit_brgemm_trans_src_t(conf) {}
319 
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_m_k_bf16_t320     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_m_k_bf16_t321     status_t create_kernel() override { return jit_generator::create_kernel(); }
322 
323 private:
324     using reg64_t = const Xbyak::Reg64;
325     using reg32_t = const Xbyak::Reg32;
326     using opmask_t = const Xbyak::Opmask;
327 
328     enum {
329         typesize = sizeof(int16_t),
330         transpose_size = 16,
331     };
332     dim_t src_stride = 0, tr_src_stride = 0;
333 
334     opmask_t kFFFF = k1;
335     opmask_t k5555 = k2;
336     opmask_t kAAAA = k3;
337     opmask_t kAA = k4;
338     opmask_t k55 = k5;
339     opmask_t kCC = k6;
340     opmask_t k33 = k7;
341     opmask_t kTail = k1;
342 
343     reg32_t regw_tmp = r15d;
344 
345     reg64_t reg_k_src = r14;
346     reg64_t reg_k_tr_src = r13;
347 
348     reg64_t reg_m_src = r12;
349     reg64_t reg_m_tr_src = r11;
350 
351     reg64_t reg_batch_src = r10;
352     reg64_t reg_batch_tr_src = r9;
353 
354     reg64_t reg_loop_batch = r8;
355     reg64_t reg_loop_K = rax;
356     reg64_t reg_loop_M = rbx;
357 
358     reg64_t reg_tr_src_tmp = abi_not_param1; // lnx -> rcx
359     reg64_t imm_addr64 = rdx;
360 
361     Xbyak::Zmm vidx1 = zmm31;
362     Xbyak::Zmm vidx2 = zmm30;
363     Xbyak::Zmm vidx3 = zmm29;
364     Xbyak::Zmm vidx4 = zmm28;
365     Xbyak::Zmm vidx5 = zmm27;
366     Xbyak::Zmm zmm_tmp = zmm26;
367 
368     void transpose(
369             reg64_t dst, reg64_t src, int nrows, int ncolumns = transpose_size);
370     void generate() override;
371 };
372 
transpose(reg64_t dst,reg64_t src,int nrows,int ncolumns)373 void jit_brgemm_trans_m_k_bf16_t::transpose(
374         reg64_t dst, reg64_t src, int nrows, int ncolumns) {
375     assert(nrows >= 0 && nrows <= transpose_size);
376     static_assert(transpose_size == 16, "Unsupported transpose size");
377     if (!nrows) return;
378 
379     auto src_zmm = [=](int i) { return Zmm(i); };
380 
381     auto src_ymm = [=](int i) {
382         assert(i >= 0 && i < 16);
383         return Ymm(i);
384     };
385 
386     auto kmovw = [=](Opmask k, unsigned w) {
387         mov(regw_tmp, w);
388         jit_generator::kmovw(k, regw_tmp);
389     };
390 
391     auto kmovd = [=](Opmask k, unsigned w) {
392         mov(regw_tmp, w);
393         jit_generator::kmovd(k, regw_tmp);
394     };
395 
396     auto store = [=](Zmm r, int i) {
397         mov(reg_tr_src_tmp, dst);
398 
399         auto k = kTail;
400         auto base = reg_tr_src_tmp;
401         base.setOpmaskIdx(k.getIdx(), true);
402 
403         auto addr = EVEX_compress_addr(base, i * tr_src_stride);
404         vmovups(addr, r);
405     };
406 
407     const int ic_block = ncolumns;
408     kmovd(kFFFF, ic_block < transpose_size ? (1 << ic_block) - 1 : 0xffff);
409 
410     for (int i = 0; i < nrows / 2; i++) {
411         auto zmm_src0 = src_zmm(2 * i);
412         auto zmm_src1 = src_zmm(2 * i + 1);
413         auto src1 = src_ymm(2 * i + 1);
414         vmovdqu16(zmm_src0 | kFFFF | T_z,
415                 EVEX_compress_addr(src, 2 * i * src_stride));
416         vmovdqu16(zmm_src1 | kFFFF | T_z,
417                 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
418         vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
419         vpermw(zmm_src0, vidx5, zmm_src0);
420     }
421 
422     // for odd numbers we need to mix row with zeroes
423     if (nrows % 2) {
424         int i = nrows / 2;
425         auto zmm_src0 = src_zmm(2 * i);
426         vmovdqu16(zmm_src0 | kFFFF | T_z,
427                 EVEX_compress_addr(src, 2 * i * src_stride));
428         vpermw(zmm_src0, vidx5, zmm_src0);
429     }
430 
431     for (int i = rnd_up(nrows, 2); i < 16; i += 2) {
432         vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
433     }
434 
435     // swap 1
436     for (int i = 0; i < 4; i++) {
437         auto zmm0 = src_zmm(4 * i);
438         auto zmm1 = src_zmm(4 * i + 2);
439         auto tmp0 = src_zmm(4 * i + 1);
440         auto tmp1 = src_zmm(4 * i + 3);
441 
442         vmovups(tmp0, zmm0);
443         vmovups(tmp1, zmm1);
444 
445         vpermps(tmp0 | kAAAA, vidx3, zmm1);
446         vpermps(tmp1 | k5555, vidx3, zmm0);
447     }
448     // swap 2
449     int base_idx;
450     base_idx = 0;
451     for (int i = 0; i < 2; i++) {
452         auto zmm0 = src_zmm(base_idx + 2 * i + 1);
453         auto zmm1 = src_zmm(base_idx + 2 * i + 5);
454 
455         auto tmp0 = src_zmm(base_idx + 2 * i);
456         auto tmp1 = src_zmm(base_idx + 2 * i + 4);
457 
458         vmovupd(tmp0, zmm0);
459         vmovupd(tmp1, zmm1);
460 
461         vpermpd(tmp0 | kAA, vidx2, zmm1);
462         vpermpd(tmp1 | k55, vidx2, zmm0);
463     }
464     base_idx = 8;
465     for (int i = 0; i < 2; i++) {
466         auto zmm0 = src_zmm(base_idx + 2 * i + 1);
467         auto zmm1 = src_zmm(base_idx + 2 * i + 5);
468 
469         auto tmp0 = src_zmm(base_idx + 2 * i);
470         auto tmp1 = src_zmm(base_idx + 2 * i + 4);
471 
472         vmovupd(tmp0, zmm0);
473         vmovupd(tmp1, zmm1);
474 
475         vpermpd(tmp0 | kAA, vidx2, zmm1);
476         vpermpd(tmp1 | k55, vidx2, zmm0);
477     }
478 
479     // swap 3
480     for (int i = 0; i < 4; i++) {
481         auto zmm0 = src_zmm(2 * i);
482         auto zmm1 = src_zmm(2 * i + 8);
483 
484         auto tmp0 = src_zmm(2 * i + 1);
485         auto tmp1 = src_zmm(2 * i + 9);
486 
487         vmovupd(tmp0, zmm0);
488         vmovupd(tmp1, zmm1);
489 
490         vpermpd(tmp0 | kCC, vidx1, zmm1);
491         vpermpd(tmp1 | k33, vidx1, zmm0);
492     }
493 
494     // all stores
495     for (int i = 0; i < 8; i++)
496         vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1);
497 
498     auto get_vec_idx = [=](int ic_idx) {
499         assert(ic_idx < 16 && ic_idx >= 0);
500         switch (ic_idx) {
501             case 0: return 1;
502             case 1: return 0;
503             case 2: return 3;
504             case 3: return 2;
505             case 4: return 9;
506             case 5: return 8;
507             case 6: return 11;
508             case 7: return 10;
509             case 8: return 5;
510             case 9: return 4;
511             case 10: return 7;
512             case 11: return 6;
513             case 12: return 13;
514             case 13: return 12;
515             case 14: return 15;
516             default: return 14;
517         }
518     };
519 
520     int store_tail = rnd_up(nrows, 2);
521     kmovw(kTail, (1 << store_tail / 2) - 1);
522 
523     for (int ic = 0; ic < ic_block; ic++)
524         store(src_zmm(get_vec_idx(ic)), ic);
525 }
526 
generate()527 void jit_brgemm_trans_m_k_bf16_t::generate() {
528     preamble();
529 
530     alignas(64) static constexpr const int64_t idx1[8]
531             = {2, 3, 0, 1, 6, 7, 4, 5};
532     alignas(64) static constexpr const int64_t idx2[8]
533             = {1, 0, 3, 2, 5, 4, 7, 6};
534     alignas(64) static constexpr const int32_t idx3[16]
535             = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
536     alignas(64) static constexpr const int32_t idx4[16]
537             = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7};
538     alignas(64) static constexpr const uint16_t idx5[32]
539             = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17,
540                     3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
541 
542     constexpr int amx_bf16_granularity = 2;
543     const bool last_row_padded = conf_->isa == avx512_core_bf16_amx_bf16
544             && conf_->os % amx_bf16_granularity != 0;
545     const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0);
546 
547     const int os_block = conf_->os_block;
548     const int last_os_block_tail = eff_K_tail % transpose_size;
549     const int ic_tail = conf_->ic % transpose_size;
550     src_stride = conf_->ic * typesize;
551     tr_src_stride = conf_->LDA * typesize;
552 
553     dim_t batch_src_shift = src_stride * os_block;
554     dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
555 
556     dim_t M_src_shift = transpose_size * typesize;
557     dim_t M_tr_src_shift = transpose_size * conf_->LDA * typesize;
558 
559     dim_t K_src_shift = transpose_size * conf_->ic * typesize;
560     dim_t K_tr_src_shift = transpose_size * typesize;
561 
562     auto kmovw = [=](Opmask k, unsigned w) {
563         mov(regw_tmp, w);
564         jit_generator::kmovw(k, regw_tmp);
565     };
566 
567     kmovw(kFFFF, 0xffff);
568     kmovw(k5555, 0x5555);
569     kmovw(kAAAA, 0xaaaa);
570     kmovw(kAA, 0xaa);
571     kmovw(k55, 0x55);
572     kmovw(kCC, 0xcc);
573     kmovw(k33, 0x33);
574 
575     auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
576         mov(imm_addr64, reinterpret_cast<size_t>(addr));
577         jit_generator::vmovdqa64(z, ptr[imm_addr64]);
578     };
579 
580     auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
581         mov(imm_addr64, reinterpret_cast<size_t>(addr));
582         jit_generator::vmovdqa32(z, ptr[imm_addr64]);
583     };
584 
585     vmovdqa64(vidx1, idx1);
586     vmovdqa64(vidx2, idx2);
587     vmovdqa32(vidx3, idx3);
588     vmovdqa32(vidx4, idx4);
589     vmovdqa32(vidx5, (const int32_t *)idx5);
590 
591     auto compute_m_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base,
592                                   bool is_os_tail) {
593         mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
594         mov(reg_m_src, reg_base);
595         mov(reg_m_tr_src, reg_tr_base);
596 
597         Label M_loop_tail, M_loop;
598         if (ic_tail > 0) {
599             cmp(reg_loop_M, transpose_size);
600             jl(M_loop_tail, T_NEAR);
601         }
602         L(M_loop);
603         {
604             transpose(reg_m_tr_src, reg_m_src,
605                     is_os_tail ? last_os_block_tail : transpose_size,
606                     transpose_size);
607             add(reg_m_src, M_src_shift);
608             add(reg_m_tr_src, M_tr_src_shift);
609         }
610         sub(reg_loop_M, transpose_size);
611         cmp(reg_loop_M, transpose_size);
612         jge(M_loop, T_NEAR);
613 
614         if (ic_tail > 0) {
615             Label M_loop_done;
616             L(M_loop_tail);
617             cmp(reg_loop_M, 0);
618             jle(M_loop_done, T_NEAR);
619 
620             transpose(reg_m_tr_src, reg_m_src,
621                     is_os_tail ? last_os_block_tail : transpose_size, ic_tail);
622             L(M_loop_done);
623         }
624     };
625 
626     auto compute_k_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base) {
627         mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
628         mov(reg_k_src, reg_base);
629         mov(reg_k_tr_src, reg_tr_base);
630 
631         Label K_tail, K_loop, K_done;
632         if (last_os_block_tail > 0) {
633             cmp(reg_loop_K, transpose_size);
634             jl(K_tail, T_NEAR);
635         }
636         L(K_loop);
637         {
638             compute_m_loop(reg_k_src, reg_k_tr_src, false);
639             add(reg_k_src, K_src_shift);
640             add(reg_k_tr_src, K_tr_src_shift);
641         }
642         sub(reg_loop_K, transpose_size);
643         cmp(reg_loop_K, transpose_size);
644         jge(K_loop, T_NEAR);
645 
646         cmp(reg_loop_K, 0);
647         je(K_done, T_NEAR);
648 
649         if (last_os_block_tail > 0) {
650             L(K_tail);
651             compute_m_loop(reg_k_src, reg_k_tr_src, true);
652         }
653         L(K_done);
654     };
655 
656     mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
657     mov(reg_batch_src, ptr[param1 + GET_OFF(src)]);
658     mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]);
659 
660     Label batch_loop;
661     L(batch_loop);
662     {
663         compute_k_loop(reg_batch_src, reg_batch_tr_src);
664 
665         add(reg_batch_src, batch_src_shift);
666         add(reg_batch_tr_src, batch_tr_src_shift);
667     }
668     sub(reg_loop_batch, 1);
669     jnz(batch_loop, T_NEAR);
670 
671     postamble();
672 }
673 
copy_row_blk_loop(int copy_row_iters)674 void jit_brgemm_copy_to_coarse_t::copy_row_blk_loop(int copy_row_iters) {
675     int row_blks = div_up(copy_row_iters, row_loop_unroll);
676 
677     for (int row_b = 0; row_b < row_blks; ++row_b) {
678         const int row_start = 0;
679         const int row_end = nstl::min(static_cast<int>(row_loop_unroll),
680                 copy_row_iters - row_b * static_cast<int>(row_loop_unroll));
681 
682         for (int row = row_start; row < row_end; ++row) {
683             const int row_idx = row_b * row_loop_unroll + row;
684             const auto offset = addr_offset(row_idx);
685 
686             const auto zmm = get_zmm_copy(row);
687             const auto addr = EVEX_compress_addr(reg_data, offset);
688             const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
689 
690             vmovdqu8(zmm, addr);
691             vmovdqu8(addr_tr, zmm);
692         }
693     }
694 }
695 
copy_row_tail(int row_offset)696 void jit_brgemm_copy_to_coarse_t::copy_row_tail(int row_offset) {
697     // Mask for row tail load and store are already set up
698     const auto zmm_data = zmm_tail | reg_m_row_tail_load | T_z;
699     const auto zmm_tr_data = zmm_tail | reg_m_row_tail_store;
700 
701     const auto offset = addr_offset(row_offset);
702     const auto addr = EVEX_compress_addr(reg_data, offset);
703     const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
704 
705     vmovdqu8(zmm_data, addr);
706     vmovdqu8(addr_tr, zmm_tr_data);
707 }
708 
copy_row_loop()709 void jit_brgemm_copy_to_coarse_t::copy_row_loop() {
710     Xbyak::Label label_row_tail, label_row_exit;
711 
712     // Note: copying is done in chunks of size row_step_
713     const auto copy_row = [&](bool is_last_blk) {
714         const int row_blk
715                 = is_last_blk ? (row_size_ % tr_row_size_) : tr_row_size_;
716         const int row_iters = row_blk / row_step_;
717         const int row_iters_tail = row_blk % row_step_;
718 
719         copy_row_blk_loop(row_iters);
720         if (row_iters_tail != 0) copy_row_tail(/* row_offset = */ row_iters);
721     };
722 
723     const bool only_row_tail = row_size_ < tr_row_size_;
724 
725     if (!only_row_tail) {
726         cmp(reg_last_row_blk, 0);
727         jne(label_row_tail, T_NEAR);
728 
729         copy_row(/* is_last_blk = */ false);
730         jmp(label_row_exit, T_NEAR);
731     }
732 
733     L(label_row_tail);
734     copy_row(/* is_last_blk = */ true);
735 
736     L(label_row_exit);
737 }
738 
copy_os_loop()739 void jit_brgemm_copy_to_coarse_t::copy_os_loop() {
740 
741     Label loop_os;
742     L(loop_os);
743 
744     copy_row_loop();
745     add(reg_data, data_stride_);
746     add(reg_tr_data, tr_data_stride_);
747 
748     dec(reg_os_work);
749     jnz(loop_os, T_NEAR);
750 }
751 
set_tail_mask()752 void jit_brgemm_copy_to_coarse_t::set_tail_mask() {
753     const int row_tail = row_size_ % row_step_;
754     assert(row_tail > 0 && "kernel is meant to be used with tail processing");
755 
756     // Set load mask
757     const size_t tail_mask_load
758             = (static_cast<size_t>(1) << (typesize_ * row_tail)) - 1;
759     mov(reg_tail_mask, tail_mask_load);
760     kmovq(reg_m_row_tail_load, reg_tail_mask);
761 
762     // Caution: Since size of ZMM equals 64 bytes therefore we need
763     // different masks to store tails with smaller row_block_size_
764     constexpr auto full_mask = size_t {0xffffffffffffffff};
765     constexpr auto half_mask = size_t {0x00000000ffffffff};
766     constexpr auto quad_mask = size_t {0x000000000000ffff};
767 
768     const auto num_bytes = [](size_t mask) -> int {
769         // Given by 1 + position of leftmost 1 bit
770         return 1 + math::ilog2q(mask);
771     };
772 
773     const int row_tail_store_size
774             = utils::rnd_up(row_tail, row_block_size_) * typesize_;
775     if (row_tail_store_size >= num_bytes(full_mask))
776         mov(reg_tail_mask, full_mask);
777     else if (row_tail_store_size >= num_bytes(half_mask))
778         mov(reg_tail_mask, half_mask);
779     else {
780         assert(row_tail_store_size == num_bytes(quad_mask));
781         mov(reg_tail_mask, quad_mask);
782     }
783     kmovq(reg_m_row_tail_store, reg_tail_mask);
784 }
785 
generate()786 void jit_brgemm_copy_to_coarse_t::generate() {
787     preamble();
788 
789     set_tail_mask();
790     mov(reg_data, ptr[param1 + GET_OFF(data)]);
791     mov(reg_tr_data, ptr[param1 + GET_OFF(tr_data)]);
792     mov(reg_os_work, ptr[param1 + GET_OFF(os_work)]);
793     mov(reg_last_row_blk, ptr[param1 + GET_OFF(last_row_blk)]);
794 
795     copy_os_loop();
796 
797     postamble();
798 }
799 
800 struct jit_trans_to_vnni_t : public jit_brgemm_trans_to_vnni_t,
801                              public jit_generator {
802     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_to_vnni_t)
jit_trans_to_vnni_tdnnl::impl::cpu::x64::jit_trans_to_vnni_t803     jit_trans_to_vnni_t(const jit_brgemm_primitive_conf_t *conf,
804             jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
805                     matrix_to_transform)
806         : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) {}
807 
operator ()dnnl::impl::cpu::x64::jit_trans_to_vnni_t808     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_trans_to_vnni_t809     status_t create_kernel() override { return jit_generator::create_kernel(); }
810 
811 private:
812     using reg64_t = const Xbyak::Reg64;
813     using reg32_t = const Xbyak::Reg32;
814     using opmask_t = const Xbyak::Opmask;
815     using zmm = const Xbyak::Zmm;
816 
817     enum {
818         typesize_data = sizeof(int16_t),
819         typesize_acc = sizeof(float),
820         transpose_size = 16,
821     };
822 
823     int last_row_block_tail = 0, col_tail = 0;
824     dim_t src_stride = 0, tr_src_stride = 0;
825     dim_t src_col_shift = 0, tr_src_col_shift = 0;
826     dim_t src_row_shift = 0, tr_src_row_shift = 0;
827     dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
828 
829     opmask_t kFFFF = k1;
830     opmask_t mask_tail = k2;
831 
832     zmm vidx1 = zmm31;
833 
834     reg32_t regw_tmp = r15d;
835 
836     reg64_t reg_batch_src = r14;
837     reg64_t reg_batch_tr_src = r13;
838 
839     reg64_t reg_row_src = r12;
840     reg64_t reg_row_tr_src = r11;
841 
842     reg64_t reg_col_src = r10;
843     reg64_t reg_col_tr_src = r9;
844 
845     reg64_t reg_loop_batch = r8;
846     reg64_t reg_loop_row = rax;
847     reg64_t reg_loop_col = rbx;
848 
849     reg64_t imm_addr64 = abi_not_param1; // lnx -> rcx
850 
851     void maybe_zero_pad_col(reg64_t dst);
852     void transpose(reg64_t dst, reg64_t src, int nrows,
853             int ncolumns = transpose_size, bool pad_by_zeroes = false);
854     void generate() override;
855 };
856 
maybe_zero_pad_col(reg64_t dst)857 void jit_trans_to_vnni_t::maybe_zero_pad_col(reg64_t dst) {
858     auto zmm_zero = Xbyak::Zmm(0);
859     vpxord(zmm_zero, zmm_zero, zmm_zero);
860     const int oc_utilized = rnd_up(conf_->oc % conf_->oc_block, transpose_size);
861     const int iters = (conf_->oc_block - oc_utilized) / transpose_size;
862     for (int n = 0; n < iters; ++n) {
863         for (int i = 0; i < transpose_size; i += 2) {
864             auto addr = EVEX_compress_addr(dst, i * tr_src_stride);
865             vmovups(addr, zmm_zero);
866         }
867         add(reg_col_tr_src, tr_src_col_shift);
868     }
869 }
870 
transpose(reg64_t dst,reg64_t src,int nrows,int ncolumns,bool pad_by_zeroes)871 void jit_trans_to_vnni_t::transpose(
872         reg64_t dst, reg64_t src, int nrows, int ncolumns, bool pad_by_zeroes) {
873     assert(nrows >= 0 && nrows <= transpose_size);
874     static_assert(transpose_size == 16, "Unsupported transpose size");
875     if (!nrows) return;
876 
877     auto src_zmm = [=](int i) { return Zmm(i); };
878 
879     auto src_ymm = [=](int i) {
880         assert(i >= 0 && i < 16);
881         return Ymm(i);
882     };
883 
884     auto store = [=](Zmm r, int i) {
885         auto addr = EVEX_compress_addr(dst, i * tr_src_stride);
886         vmovups(addr, r);
887     };
888     auto mask = ncolumns == transpose_size ? kFFFF : mask_tail;
889 
890     int i = 0;
891     for (; i < nrows / 2; i++) {
892         auto src1 = src_ymm(2 * i + 1);
893         auto zmm_src0 = src_zmm(2 * i);
894         auto zmm_src1 = src_zmm(2 * i + 1);
895         if (matrix_to_transform_ == matrix_B) {
896             vmovdqu16(zmm_src0 | mask | T_z,
897                     EVEX_compress_addr(src, 2 * i * src_stride));
898             vmovdqu16(zmm_src1 | mask | T_z,
899                     EVEX_compress_addr(src, (2 * i + 1) * src_stride));
900             vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
901         } else {
902             vmovups(zmm_src0 | mask | T_z,
903                     EVEX_compress_addr(src, 2 * i * src_stride));
904             vmovups(zmm_src1 | mask | T_z,
905                     EVEX_compress_addr(src, (2 * i + 1) * src_stride));
906             vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0);
907         }
908         vpermw(zmm_src0, vidx1, zmm_src0);
909         store(zmm_src0, 2 * i);
910     }
911 
912     if (nrows % 2) {
913         auto zmm_src0 = src_zmm(2 * i);
914         if (matrix_to_transform_ == matrix_B) {
915             vmovdqu16(zmm_src0 | mask | T_z,
916                     EVEX_compress_addr(src, 2 * i * src_stride));
917         } else {
918             auto zmm_zero = src_zmm(2 * i + 1);
919             vmovups(zmm_src0 | mask | T_z,
920                     EVEX_compress_addr(src, 2 * i * src_stride));
921             vpxord(zmm_zero, zmm_zero, zmm_zero);
922             vcvtne2ps2bf16(zmm_src0, zmm_zero, zmm_src0);
923         }
924         vpermw(zmm_src0, vidx1, zmm_src0);
925         store(zmm_src0, 2 * i);
926         i++;
927     }
928 
929     if (pad_by_zeroes && i < transpose_size / 2) {
930         auto zmm_zero = src_zmm(2 * i);
931         vpxord(zmm_zero, zmm_zero, zmm_zero);
932         for (; i < transpose_size / 2; i++)
933             store(zmm_zero, 2 * i);
934     }
935 }
936 
generate()937 void jit_trans_to_vnni_t::generate() {
938     preamble();
939 
940     alignas(64) static constexpr const int16_t idx1[32]
941             = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
942                     25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
943 
944     if (matrix_to_transform_ == matrix_B) {
945         int row_block = conf_->os_block;
946 
947         constexpr int amx_bf16_granularity = 2;
948         const bool last_row_padded = conf_->isa == avx512_core_bf16_amx_bf16
949                 && conf_->os % amx_bf16_granularity != 0;
950         const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0);
951 
952         last_row_block_tail = eff_K_tail % transpose_size;
953         col_tail = conf_->oc % transpose_size;
954         src_stride = conf_->oc * typesize_data;
955         tr_src_stride = conf_->LDB * typesize_data;
956 
957         src_batch_shift = src_stride * row_block;
958         tr_src_batch_shift = tr_src_stride * rnd_up(conf_->K, 2);
959 
960         src_col_shift = transpose_size * typesize_data;
961         tr_src_col_shift = 2 * transpose_size * typesize_data;
962 
963         src_row_shift = transpose_size * conf_->oc * typesize_data;
964         tr_src_row_shift = transpose_size * conf_->LDB * typesize_data;
965 
966     } else { // matrix_to_transform_ == matrix_C
967         int row_block = conf_->ic_block;
968         last_row_block_tail = conf_->M_tail % transpose_size;
969         assert(row_block == transpose_size);
970         col_tail = conf_->oc % transpose_size;
971         src_stride = conf_->LDC * typesize_acc;
972         tr_src_stride = conf_->LDD * typesize_data;
973 
974         src_batch_shift = src_stride * row_block;
975         tr_src_batch_shift = tr_src_stride * rnd_up(conf_->M, 2);
976 
977         src_col_shift = transpose_size * typesize_acc;
978         tr_src_col_shift = 2 * transpose_size * typesize_data;
979     }
980 
981     //    mov(reg_src, ptr[param1 + GET_OFF(src)]);
982     //    mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
983     //    mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
984 
985     auto kmovw = [=](Opmask k, unsigned w) {
986         mov(regw_tmp, w);
987         jit_generator::kmovw(k, regw_tmp);
988     };
989     auto kmovd = [=](Opmask k, unsigned w) {
990         mov(regw_tmp, w);
991         jit_generator::kmovd(k, regw_tmp);
992     };
993 
994     kmovw(kFFFF, 0xffff); // 1111111111111111
995     kmovd(mask_tail, (1 << col_tail) - 1);
996 
997     auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
998         mov(imm_addr64, reinterpret_cast<size_t>(addr));
999         jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1000     };
1001 
1002     vmovdqa64(vidx1, (const int64_t *)idx1);
1003 
1004     auto compute_col_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base,
1005                                     bool is_row_tail) {
1006         const bool pad_by_zeroes = matrix_to_transform_ == matrix_C;
1007         int nrows = is_row_tail ? last_row_block_tail : transpose_size;
1008 
1009         mov(reg_col_src, reg_base);
1010         mov(reg_col_tr_src, reg_tr_base);
1011         mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1012 
1013         Label col_loop, col_loop_tail;
1014         cmp(reg_loop_col, transpose_size);
1015         jl(col_loop_tail, T_NEAR);
1016 
1017         L(col_loop);
1018         {
1019             transpose(reg_col_tr_src, reg_col_src, nrows, transpose_size,
1020                     pad_by_zeroes);
1021             add(reg_col_src, src_col_shift);
1022             add(reg_col_tr_src, tr_src_col_shift);
1023         }
1024         sub(reg_loop_col, transpose_size);
1025         cmp(reg_loop_col, transpose_size);
1026         jge(col_loop, T_NEAR);
1027 
1028         L(col_loop_tail);
1029         if (col_tail > 0) {
1030             Label col_loop_done;
1031             cmp(reg_loop_col, 0);
1032             jle(col_loop_done, T_NEAR);
1033             transpose(reg_col_tr_src, reg_col_src, nrows, col_tail,
1034                     pad_by_zeroes);
1035             L(col_loop_done);
1036         }
1037         const int oc_block_tail = conf_->oc % conf_->oc_block;
1038         const bool full_oc_block_utilized = oc_block_tail == 0
1039                 || rnd_up(oc_block_tail, transpose_size) == conf_->oc_block;
1040         const bool col_pad_required = pad_by_zeroes && !full_oc_block_utilized;
1041 
1042         if (col_pad_required) {
1043             Label col_pad_done;
1044             mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1045             cmp(reg_loop_col, conf_->oc_block);
1046             je(col_pad_done, T_NEAR);
1047             if (col_tail > 0) add(reg_col_tr_src, tr_src_col_shift);
1048             maybe_zero_pad_col(reg_col_tr_src);
1049             L(col_pad_done);
1050         }
1051     };
1052 
1053     auto compute_row_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base) {
1054         mov(reg_row_src, reg_base);
1055         mov(reg_row_tr_src, reg_tr_base);
1056         mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1057 
1058         Label row_tail, row_loop, row_done;
1059         if (last_row_block_tail > 0) {
1060             cmp(reg_loop_row, transpose_size);
1061             jl(row_tail, T_NEAR);
1062         }
1063         L(row_loop);
1064         {
1065             compute_col_loop(reg_row_src, reg_row_tr_src, false);
1066 
1067             add(reg_row_src, src_row_shift);
1068             add(reg_row_tr_src, tr_src_row_shift);
1069         }
1070         sub(reg_loop_row, transpose_size);
1071         cmp(reg_loop_row, transpose_size);
1072         jge(row_loop, T_NEAR);
1073 
1074         cmp(reg_loop_row, 0);
1075         je(row_done, T_NEAR);
1076 
1077         if (last_row_block_tail > 0) {
1078             L(row_tail);
1079             compute_col_loop(reg_row_src, reg_row_tr_src, true);
1080         }
1081         L(row_done);
1082     };
1083 
1084     mov(reg_batch_src, ptr[param1 + GET_OFF(src)]);
1085     mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1086     mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1087 
1088     Label batch_loop;
1089     L(batch_loop);
1090     {
1091         compute_row_loop(reg_batch_src, reg_batch_tr_src);
1092 
1093         add(reg_batch_src, src_batch_shift);
1094         add(reg_batch_tr_src, tr_src_batch_shift);
1095     }
1096     sub(reg_loop_batch, 1);
1097     jnz(batch_loop, T_NEAR);
1098 
1099     postamble();
1100 }
1101 
1102 struct jit_copy_f32_t : public jit_brgemm_trans_to_vnni_t,
1103                         public jit_generator {
1104     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_copy_f32_t)
jit_copy_f32_tdnnl::impl::cpu::x64::jit_copy_f32_t1105     jit_copy_f32_t(const jit_brgemm_primitive_conf_t *conf,
1106             jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
1107                     matrix_to_transform)
1108         : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) {}
1109 
operator ()dnnl::impl::cpu::x64::jit_copy_f32_t1110     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_copy_f32_t1111     status_t create_kernel() override { return jit_generator::create_kernel(); }
1112 
1113 private:
1114     using reg64_t = const Xbyak::Reg64;
1115     using reg32_t = const Xbyak::Reg32;
1116     using opmask_t = const Xbyak::Opmask;
1117     using zmm = const Xbyak::Zmm;
1118 
1119     enum {
1120         typesize_data = sizeof(float),
1121         column_step = 16,
1122         num_regs = 32,
1123     };
1124 
1125     dim_t src_stride = 0, tr_src_stride = 0;
1126     dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
1127     dim_t col_shift = column_step * typesize_data;
1128 
1129     opmask_t mask_tail = k2;
1130 
1131     reg64_t reg_src = r8;
1132     reg64_t reg_tr_src = r9;
1133     reg64_t reg_loop_batch = r10;
1134     reg64_t reg_loop_row = r11;
1135     reg64_t reg_loop_col = r12;
1136     reg32_t regw_tmp = r14d;
1137     reg64_t reg_long_offt = r15;
1138 
1139     void copy_block(int nrows, int ncolumns);
1140     void generate() override;
1141 };
1142 
copy_block(int nrows,int ncolumns)1143 void jit_copy_f32_t::copy_block(int nrows, int ncolumns) {
1144 
1145     auto kmovd = [=](Opmask k, unsigned w) {
1146         mov(regw_tmp, w);
1147         jit_generator::kmovd(k, regw_tmp);
1148     };
1149 
1150     const int nc_tail = ncolumns % column_step;
1151     if (nc_tail > 0) kmovd(mask_tail, (1 << nc_tail) - 1);
1152 
1153     auto get_zmm = [=](int i) { return Zmm(i % num_regs); };
1154 
1155     auto load = [=](int r, int cb) {
1156         auto src_reg = get_zmm(r * cb);
1157         const bool is_tail
1158                 = nc_tail > 0 && ncolumns - cb * column_step < column_step;
1159         auto src_load = is_tail ? src_reg | mask_tail | T_z : src_reg;
1160         const dim_t offset = r * src_stride + cb * col_shift;
1161         auto addr = EVEX_compress_addr_safe(reg_src, offset, reg_long_offt);
1162         vmovups(src_load, addr);
1163     };
1164 
1165     auto store = [=](int r, int cb) {
1166         auto reg = get_zmm(r * cb);
1167         const dim_t offset = r * tr_src_stride + cb * col_shift;
1168         auto addr = EVEX_compress_addr_safe(reg_tr_src, offset, reg_long_offt);
1169         vmovups(addr, reg);
1170     };
1171 
1172     for_(int r = 0; r < nrows; r++)
1173     for (int cb = 0; cb < div_up(ncolumns, column_step); cb++) {
1174         load(r, cb);
1175         store(r, cb);
1176     }
1177 }
1178 
generate()1179 void jit_copy_f32_t::generate() {
1180     preamble();
1181 
1182     const int row_block = conf_->os_block;
1183     const int row_tail = conf_->os % row_block;
1184     const int col_block = conf_->oc_block * conf_->nb_oc_blocking;
1185     const int col_tail = conf_->oc % col_block;
1186     src_stride = conf_->oc * typesize_data;
1187     tr_src_stride = conf_->LDB * typesize_data;
1188     src_batch_shift = src_stride * row_block;
1189     tr_src_batch_shift = tr_src_stride * row_block;
1190 
1191     mov(reg_src, ptr[param1 + GET_OFF(src)]);
1192     mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1193     mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1194     mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1195     mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1196 
1197     auto compute_batch = [=](int nrows, int ncolumns) {
1198         Label batch_loop;
1199         L(batch_loop);
1200 
1201         copy_block(nrows, ncolumns);
1202         add(reg_src, src_batch_shift);
1203         add(reg_tr_src, tr_src_batch_shift);
1204 
1205         sub(reg_loop_batch, 1);
1206         jnz(batch_loop, T_NEAR);
1207     };
1208 
1209     auto compute_rows = [=](int ncolumns) {
1210         Label row_done;
1211         if (row_tail > 0) {
1212             Label row_common;
1213             cmp(reg_loop_row, row_block);
1214             je(row_common, T_NEAR);
1215 
1216             compute_batch(row_tail, ncolumns);
1217             jmp(row_done, T_NEAR);
1218 
1219             L(row_common);
1220         }
1221 
1222         compute_batch(row_block, ncolumns);
1223         L(row_done);
1224     };
1225 
1226     Label col_done;
1227     if (col_tail > 0) {
1228         Label col_common;
1229         cmp(reg_loop_col, col_block);
1230         je(col_common, T_NEAR);
1231 
1232         compute_rows(col_tail);
1233         jmp(col_done, T_NEAR);
1234 
1235         L(col_common);
1236     }
1237 
1238     compute_rows(col_block);
1239     L(col_done);
1240 
1241     postamble();
1242 }
1243 
1244 struct jit_brgemm_trans_wei_f32_t : public jit_brgemm_trans_wei_t,
1245                                     public jit_generator {
1246     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_f32_t)
1247 
jit_brgemm_trans_wei_f32_tdnnl::impl::cpu::x64::jit_brgemm_trans_wei_f32_t1248     jit_brgemm_trans_wei_f32_t(const jit_brgemm_primitive_conf_t *conf)
1249         : jit_brgemm_trans_wei_t(conf) {}
1250 
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_wei_f32_t1251     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_wei_f32_t1252     status_t create_kernel() override { return jit_generator::create_kernel(); }
1253 
1254 private:
1255     using reg64_t = const Xbyak::Reg64;
1256     using reg32_t = const Xbyak::Reg32;
1257     using opmask_t = const Xbyak::Opmask;
1258 
1259     enum { typesize = sizeof(float), transpose_size = 16 };
1260     dim_t src_stride = 0, tr_src_stride = 0;
1261 
1262     opmask_t k3333 = k1;
1263     opmask_t k5555 = k2;
1264     opmask_t kAAAA = k3;
1265     opmask_t kCCCC = k4;
1266     opmask_t k0F0F = k5;
1267     opmask_t kF0F0 = k6;
1268     opmask_t kTail = k7;
1269 
1270     reg64_t reg_src_base = rax;
1271     reg64_t reg_tr_src_base = rbx;
1272 
1273     reg64_t reg_src = r8;
1274     reg64_t reg_tr_src = r9;
1275     reg64_t reg_loop_N = r10;
1276     reg64_t reg_loop_K = r11;
1277     reg64_t reg_loop_batch = r12;
1278     reg64_t reg_tr_src_tmp = r13;
1279     reg32_t regw_tmp = r14d;
1280 
1281     void transpose_16x16(int nrows, int ncolumns = transpose_size);
1282     void generate() override;
1283 };
1284 
transpose_16x16(int nrows,int ncolumns)1285 void jit_brgemm_trans_wei_f32_t::transpose_16x16(int nrows, int ncolumns) {
1286     assert(nrows >= 0 && nrows <= transpose_size);
1287     static_assert(transpose_size == 16, "Unsupported transpose size");
1288     if (!nrows) return;
1289 
1290     auto src_zmm = [=](int i) {
1291         assert(i >= 0 && i < 16);
1292         return Zmm(i);
1293     };
1294 
1295     auto tmp_zmm = [=](int i) {
1296         assert(i >= 0 && i < 16);
1297         return Zmm(16 + i);
1298     };
1299 
1300     auto kmovw = [=](Opmask k, unsigned w) {
1301         mov(regw_tmp, w);
1302         jit_generator::kmovw(k, regw_tmp);
1303     };
1304 
1305     auto load = [=](int i) {
1306         auto src_load = src_zmm(i);
1307         if (ncolumns < transpose_size) {
1308             kmovw(kTail, (1 << ncolumns) - 1);
1309             src_load = src_zmm(i) | kTail | T_z;
1310         }
1311         vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1312     };
1313 
1314     auto store = [=](Zmm r, int i) {
1315         mov(reg_tr_src_tmp, reg_tr_src);
1316         if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
1317 
1318         // Xbyak does not allow k0 to be specified explicitly via the '|'
1319         // operator, so we have to do this via a method call (implicitly
1320         // EVEX encoding uses k0 to mean 'no mask')
1321         bool partial_store = nrows < transpose_size;
1322         auto k = partial_store ? kTail : k0;
1323         auto base = reg_tr_src_tmp;
1324         base.setOpmaskIdx(k.getIdx(), true);
1325 
1326         auto addr = EVEX_compress_addr(base, i * tr_src_stride);
1327         vmovups(addr, r);
1328     };
1329 
1330     auto transpose16x8 = [=](int base_idx) {
1331         assert(base_idx == 0 || base_idx == 8);
1332 
1333         // swap 1
1334         for (int i = 0; i < 4; i++) {
1335             int src_idx0 = base_idx + i * 2;
1336             int src_idx1 = src_idx0 + 1;
1337 
1338             int next_src_idx0 = src_idx0 + 2;
1339             int next_src_idx1 = src_idx1 + 2;
1340             bool load_next = base_idx == 0 || i < 3;
1341 
1342             if (base_idx == 0 && i == 0) {
1343                 load(src_idx0);
1344                 if (src_idx1 < nrows)
1345                     load(src_idx1);
1346                 else
1347                     vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
1348                             src_zmm(src_idx1));
1349             }
1350 
1351             auto tmp0 = tmp_zmm(src_idx0);
1352             auto tmp1 = tmp_zmm(src_idx1);
1353             auto src0 = src_zmm(src_idx0);
1354             auto src1 = src_zmm(src_idx1);
1355 
1356             if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
1357             valignd(tmp0, src0, src0, 0x1);
1358 
1359             if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
1360             valignd(tmp1, src1, src1, 0xf);
1361 
1362             vmovaps(src0 | kAAAA, tmp1);
1363             vmovaps(src1 | k5555, tmp0);
1364         }
1365         // swap 2
1366         for (int i = 0; i < 4; i++) {
1367             int select_half = (i < 2) ? 0 : 2;
1368             int src_idx0 = base_idx + i + select_half + 0;
1369             int src_idx2 = src_idx0 + 2;
1370 
1371             auto tmp0 = tmp_zmm(src_idx0);
1372             auto tmp1 = tmp_zmm(src_idx2);
1373             auto src0 = src_zmm(src_idx0);
1374             auto src2 = src_zmm(src_idx2);
1375 
1376             valignd(tmp0, src0, src0, 0x2);
1377             valignd(tmp1, src2, src2, 0xe);
1378             vmovaps(src2 | k3333, tmp0);
1379             vmovaps(src0 | kCCCC, tmp1);
1380         }
1381 
1382         // swap 4
1383         for (int i = 0; i < 4; i++) {
1384             int src_idx0 = base_idx + i;
1385             int src_idx4 = src_idx0 + 4;
1386 
1387             auto tmp0 = tmp_zmm(src_idx0);
1388             auto src0 = src_zmm(src_idx0);
1389             auto src4 = src_zmm(src_idx4);
1390 
1391             vmovaps(tmp0, src0);
1392             vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
1393             vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
1394         }
1395     };
1396 
1397     auto fixup16x16 = [=]() {
1398         // swap 8
1399         for (int i = 0; i < 8; i++) {
1400             auto tmp = tmp_zmm(i);
1401             auto src0 = src_zmm(i);
1402             auto src8 = src_zmm(8 + i);
1403             vshuff64x2(tmp, src0, src8, 0x44);
1404             store(tmp, i);
1405         }
1406 
1407         for (int i = 0; i < 8; i++) {
1408             auto tmp = tmp_zmm(8 + i);
1409             auto src0 = src_zmm(i);
1410             auto src8 = src_zmm(8 + i);
1411             vshuff64x2(tmp, src0, src8, 0xee);
1412             store(tmp, 8 + i);
1413         }
1414     };
1415 
1416     transpose16x8(0);
1417     transpose16x8(8);
1418     fixup16x16();
1419 }
1420 
generate()1421 void jit_brgemm_trans_wei_f32_t::generate() {
1422     preamble();
1423     assert(conf_->oc_block % transpose_size == 0);
1424     int fwd_ic_block = conf_->simd_w;
1425     int fwd_oc_block = 0;
1426     switch (conf_->wei_tag) {
1427         case OI16i64o:
1428         case OIw16i64o:
1429         case OIhw16i64o:
1430         case OIdhw16i64o:
1431         case OI8i64o2i:
1432         case OIw8i64o2i:
1433         case OIhw8i64o2i:
1434         case OIdhw8i64o2i:
1435         case OI16i64o2i:
1436         case OIw16i64o2i:
1437         case OIhw16i64o2i:
1438         case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
1439         case OI16i32o:
1440         case OIw16i32o:
1441         case OIhw16i32o:
1442         case OIdhw16i32o:
1443         case OI8i32o2i:
1444         case OIw8i32o2i:
1445         case OIhw8i32o2i:
1446         case OIdhw8i32o2i:
1447         case OI16i32o2i:
1448         case OIw16i32o2i:
1449         case OIhw16i32o2i:
1450         case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
1451         default: fwd_oc_block = conf_->simd_w;
1452     };
1453 
1454     int oc_tail = conf_->K_tail % transpose_size;
1455     int ic_block = conf_->ic_block;
1456     int ic_tail = conf_->N_tail % transpose_size;
1457     src_stride = fwd_oc_block * typesize;
1458     tr_src_stride = ic_block * typesize;
1459     dim_t N_src_shift = conf_->kd * conf_->kh * conf_->kw * fwd_ic_block
1460             * fwd_oc_block * typesize;
1461     dim_t N_tr_src_shift = conf_->simd_w * typesize;
1462     dim_t K_src_shift = conf_->simd_w * typesize;
1463     dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize;
1464 
1465     mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
1466     mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
1467     mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1468     mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
1469 
1470     auto kmovw = [=](Opmask k, unsigned w) {
1471         mov(regw_tmp, w);
1472         jit_generator::kmovw(k, regw_tmp);
1473     };
1474 
1475     kmovw(k3333, 0x3333); // 0011001100110011
1476     kmovw(k5555, 0x5555); // 0101010101010101
1477     kmovw(kAAAA, 0xaaaa); // 1010101010101010
1478     kmovw(kCCCC, 0xcccc); // 1100110011001100
1479     kmovw(k0F0F, 0x0f0f); // 0000111100001111
1480     kmovw(kF0F0, 0xf0f0); // 1111000011110000
1481 
1482     auto compute_N = [=](bool is_oc_tail) {
1483         mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
1484         mov(reg_src, reg_src_base);
1485         mov(reg_tr_src, reg_tr_src_base);
1486         Label N_loop, N_loop_tail;
1487 
1488         cmp(reg_loop_N, transpose_size);
1489         jl(N_loop_tail, T_NEAR);
1490 
1491         L(N_loop);
1492 
1493         transpose_16x16(transpose_size, is_oc_tail ? oc_tail : transpose_size);
1494         add(reg_src, N_src_shift);
1495         add(reg_tr_src, N_tr_src_shift);
1496 
1497         sub(reg_loop_N, transpose_size);
1498         cmp(reg_loop_N, transpose_size);
1499         jge(N_loop, T_NEAR);
1500 
1501         L(N_loop_tail);
1502         if (ic_tail > 0) {
1503             Label N_loop_done;
1504             cmp(reg_loop_N, 0);
1505             jle(N_loop_done, T_NEAR);
1506             transpose_16x16(ic_tail, is_oc_tail ? oc_tail : transpose_size);
1507             L(N_loop_done);
1508         }
1509     };
1510 
1511     Label K_loop, K_tail;
1512     if (oc_tail > 0) {
1513         cmp(reg_loop_K, transpose_size);
1514         jl(K_tail, T_NEAR);
1515     }
1516 
1517     L(K_loop);
1518     compute_N(false);
1519     add(reg_src_base, K_src_shift);
1520     add(reg_tr_src_base, K_tr_src_shift);
1521 
1522     sub(reg_loop_K, transpose_size);
1523     cmp(reg_loop_K, transpose_size);
1524     jge(K_loop, T_NEAR);
1525 
1526     L(K_tail);
1527     if (oc_tail > 0) {
1528         Label K_loop_done;
1529         cmp(reg_loop_K, 0);
1530         jle(K_loop_done, T_NEAR);
1531 
1532         compute_N(true);
1533         L(K_loop_done);
1534     }
1535 
1536     postamble();
1537 }
1538 
1539 struct jit_brgemm_trans_wei_bf16_t : public jit_brgemm_trans_wei_t,
1540                                      public jit_generator {
1541     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_bf16_t)
1542 
jit_brgemm_trans_wei_bf16_tdnnl::impl::cpu::x64::jit_brgemm_trans_wei_bf16_t1543     jit_brgemm_trans_wei_bf16_t(const jit_brgemm_primitive_conf_t *conf)
1544         : jit_brgemm_trans_wei_t(conf) {}
1545 
operator ()dnnl::impl::cpu::x64::jit_brgemm_trans_wei_bf16_t1546     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_brgemm_trans_wei_bf16_t1547     status_t create_kernel() override { return jit_generator::create_kernel(); }
1548 
1549 private:
1550     using reg64_t = const Xbyak::Reg64;
1551     using reg32_t = const Xbyak::Reg32;
1552     using opmask_t = const Xbyak::Opmask;
1553     using zmm = const Xbyak::Zmm;
1554 
1555     enum { typesize = sizeof(int16_t), transpose_size = 16 };
1556     dim_t src_stride = 0, tr_src_stride = 0;
1557 
1558     opmask_t kTail = k7;
1559 
1560     reg64_t reg_src_base = rax;
1561     reg64_t reg_tr_src_base = rbx;
1562 
1563     reg64_t reg_src = r8;
1564     reg64_t reg_tr_src = r9;
1565     reg64_t reg_loop_N = r10;
1566     reg64_t reg_loop_K = r11;
1567     reg64_t reg_loop_batch = r12;
1568     reg64_t reg_tr_src_tmp = r13;
1569     reg32_t regw_tmp = r14d;
1570     reg64_t imm_addr64 = r15;
1571 
1572     zmm v_abcdefgh_to_abefcdgh = zmm31;
1573 
1574     void transpose_16x16_vnni(int nrows, int ncolumns = transpose_size);
1575     void generate() override;
1576 };
1577 
transpose_16x16_vnni(int nrows,int ncolumns)1578 void jit_brgemm_trans_wei_bf16_t::transpose_16x16_vnni(
1579         int nrows, int ncolumns) {
1580     assert(nrows >= 0 && nrows <= transpose_size);
1581     static_assert(transpose_size == 16, "Unsupported transpose size");
1582     if (!nrows) return;
1583 
1584     auto src_zmm = [=](int i) {
1585         assert(i >= 0 && i < 8);
1586         return Zmm(i);
1587     };
1588 
1589     auto tmp_zmm = [=](int i) {
1590         assert(i >= 0 && i < 8);
1591         return Zmm(8 + i);
1592     };
1593 
1594     auto kmovw = [=](Opmask k, unsigned w) {
1595         mov(regw_tmp, w);
1596         jit_generator::kmovw(k, regw_tmp);
1597     };
1598 
1599     auto load = [=](int i) {
1600         auto src_load = src_zmm(i);
1601         if (ncolumns < transpose_size) {
1602             kmovw(kTail, (1 << ncolumns) - 1);
1603             src_load = src_zmm(i) | kTail | T_z;
1604         }
1605         vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1606     };
1607 
1608     auto store = [=](Zmm r, int i) {
1609         mov(reg_tr_src_tmp, reg_tr_src);
1610         if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
1611 
1612         // Xbyak does not allow k0 to be specified explicitly via the '|'
1613         // operator, so we have to do this via a method call (implicitly
1614         // EVEX encoding uses k0 to mean 'no mask')
1615         bool partial_store = nrows < transpose_size;
1616         auto k = partial_store ? kTail : k0;
1617         auto base = reg_tr_src_tmp;
1618         base.setOpmaskIdx(k.getIdx(), true);
1619 
1620         auto addr = EVEX_compress_addr(base, i * tr_src_stride);
1621         vmovups(addr, r);
1622     };
1623 
1624     for (int i = 0; i < 8; i++)
1625         load(i);
1626 
1627     for (int i = 0; i < 8; i++)
1628         vpshufb(src_zmm(i), src_zmm(i), v_abcdefgh_to_abefcdgh);
1629 
1630     for (int i = 0; i < 2; i++) {
1631         vpunpcklqdq(tmp_zmm(2 * i + 0), src_zmm(2 * i), src_zmm(2 * i + 1));
1632         vpunpckhqdq(tmp_zmm(2 * i + 1), src_zmm(2 * i), src_zmm(2 * i + 1));
1633     }
1634 
1635     for (int i = 0; i < 2; i++) {
1636         vpunpcklqdq(
1637                 src_zmm(2 * i + 0), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1));
1638         vpunpckhqdq(
1639                 src_zmm(2 * i + 1), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1));
1640     }
1641 
1642     for (int i = 0; i < 2; i++) {
1643         vshufi32x4(src_zmm(4 + 0 + i), tmp_zmm(i), tmp_zmm(2 + i), 0x88);
1644         vshufi32x4(src_zmm(4 + 2 + i), tmp_zmm(i), tmp_zmm(2 + i), 0xdd);
1645     }
1646 
1647     for (int i = 0; i < 2; i++) {
1648         vshufi32x4(tmp_zmm(0 + i), src_zmm(i), src_zmm(2 + i), 0x88);
1649         vshufi32x4(tmp_zmm(2 + i), src_zmm(i), src_zmm(2 + i), 0xdd);
1650     }
1651 
1652     for (int i = 0; i < 4; i++)
1653         vshufi32x4(src_zmm(i), src_zmm(4 + i), tmp_zmm(i), 0x88);
1654 
1655     for (int i = 0; i < 4; i++)
1656         vshufi32x4(src_zmm(4 + i), src_zmm(4 + i), tmp_zmm(i), 0xdd);
1657 
1658     for (int i = 0; i < 8; i++)
1659         store(src_zmm(i), i);
1660 }
1661 
generate()1662 void jit_brgemm_trans_wei_bf16_t::generate() {
1663     preamble();
1664     int fwd_oc_block = 0;
1665     switch (conf_->wei_tag) {
1666         case OI16i64o:
1667         case OIw16i64o:
1668         case OIhw16i64o:
1669         case OIdhw16i64o:
1670         case OI8i64o2i:
1671         case OIw8i64o2i:
1672         case OIhw8i64o2i:
1673         case OIdhw8i64o2i:
1674         case OI16i64o2i:
1675         case OIw16i64o2i:
1676         case OIhw16i64o2i:
1677         case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
1678         case OI16i32o:
1679         case OIw16i32o:
1680         case OIhw16i32o:
1681         case OIdhw16i32o:
1682         case OI8i32o2i:
1683         case OIw8i32o2i:
1684         case OIhw8i32o2i:
1685         case OIdhw8i32o2i:
1686         case OI16i32o2i:
1687         case OIw16i32o2i:
1688         case OIhw16i32o2i:
1689         case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
1690         default: fwd_oc_block = conf_->simd_w;
1691     };
1692 
1693     int oc_tail = conf_->K_tail % transpose_size;
1694     int ic_block = conf_->ic_block;
1695     int ic_tail = conf_->N_tail % transpose_size;
1696     src_stride = 2 * fwd_oc_block * typesize;
1697     tr_src_stride = 2 * ic_block * typesize;
1698     dim_t N_src_shift = conf_->simd_w * fwd_oc_block * typesize;
1699     dim_t N_tr_src_shift = 2 * conf_->simd_w * typesize;
1700     dim_t K_src_shift = 2 * conf_->simd_w * typesize;
1701     dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize;
1702 
1703     mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
1704     mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
1705     mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1706     mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
1707 
1708     alignas(64) static constexpr const int32_t abcdefgh_to_abefcdgh[16]
1709             = {0x05040100, 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100,
1710                     0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302,
1711                     0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302, 0x0d0c0908,
1712                     0x0f0e0b0a};
1713 
1714     auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
1715         mov(imm_addr64, reinterpret_cast<size_t>(addr));
1716         jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1717     };
1718 
1719     vmovdqa64(v_abcdefgh_to_abefcdgh, (const int64_t *)abcdefgh_to_abefcdgh);
1720     auto compute_N = [=](bool is_oc_tail) {
1721         mov(reg_src, reg_src_base);
1722         mov(reg_tr_src, reg_tr_src_base);
1723         mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
1724 
1725         Label N_loop, N_loop_tail;
1726         cmp(reg_loop_N, transpose_size);
1727         jl(N_loop_tail, T_NEAR);
1728 
1729         L(N_loop);
1730 
1731         transpose_16x16_vnni(
1732                 transpose_size, is_oc_tail ? oc_tail : transpose_size);
1733         add(reg_src, N_src_shift);
1734         add(reg_tr_src, N_tr_src_shift);
1735 
1736         sub(reg_loop_N, transpose_size);
1737         cmp(reg_loop_N, transpose_size);
1738         jge(N_loop, T_NEAR);
1739 
1740         L(N_loop_tail);
1741         if (ic_tail > 0) {
1742             Label N_loop_done;
1743             cmp(reg_loop_N, 0);
1744             jle(N_loop_done, T_NEAR);
1745             transpose_16x16_vnni(
1746                     ic_tail, is_oc_tail ? oc_tail : transpose_size);
1747             L(N_loop_done);
1748         }
1749     };
1750 
1751     Label K_loop, K_tail;
1752     if (oc_tail > 0) {
1753         cmp(reg_loop_K, transpose_size);
1754         jl(K_tail, T_NEAR);
1755     }
1756 
1757     L(K_loop);
1758     compute_N(false);
1759     add(reg_src_base, K_src_shift);
1760     add(reg_tr_src_base, K_tr_src_shift);
1761 
1762     sub(reg_loop_K, transpose_size);
1763     cmp(reg_loop_K, transpose_size);
1764     jge(K_loop, T_NEAR);
1765 
1766     L(K_tail);
1767     if (oc_tail > 0) {
1768         Label K_loop_done;
1769         cmp(reg_loop_K, 0);
1770         jle(K_loop_done, T_NEAR);
1771         compute_N(true);
1772         L(K_loop_done);
1773     }
1774 
1775     postamble();
1776 }
1777 
1778 struct jit_amx_ip_trans_diff_wei_to_vnni_t : public jit_amx_ip_trans_diff_wei,
1779                                              public jit_generator {
1780     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_amx_ip_trans_diff_wei_to_vnni)
1781 
jit_amx_ip_trans_diff_wei_to_vnni_tdnnl::impl::cpu::x64::jit_amx_ip_trans_diff_wei_to_vnni_t1782     jit_amx_ip_trans_diff_wei_to_vnni_t(const jit_brgemm_primitive_conf_t *jbgp,
1783             const int ext_ic_block, const int ext_oc_block)
1784         : jit_amx_ip_trans_diff_wei(jbgp, ext_ic_block, ext_oc_block) {}
1785 
operator ()dnnl::impl::cpu::x64::jit_amx_ip_trans_diff_wei_to_vnni_t1786     void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
create_kerneldnnl::impl::cpu::x64::jit_amx_ip_trans_diff_wei_to_vnni_t1787     status_t create_kernel() override { return jit_generator::create_kernel(); }
1788 
1789 private:
1790     void generate() override;
1791 };
1792 
generate()1793 void jit_amx_ip_trans_diff_wei_to_vnni_t::generate() {
1794     const int typesize_out = 2;
1795     const int typesize_acc = 4;
1796     const int simd_w = 16;
1797 
1798     using reg64_t = const Xbyak::Reg64;
1799     using reg32_t = const Xbyak::Reg32;
1800 
1801     const reg64_t &reg_output = r15;
1802     const reg64_t &reg_input = r14;
1803     const reg64_t &reg_prm_table = r13;
1804     const reg64_t &reg_last_ic_block = r12;
1805     const reg64_t &reg_last_oc_block = r11;
1806     const reg32_t &regw_tmp = r10d;
1807 
1808     const Xbyak::Zmm &zmm_idx = Xbyak::Zmm(31);
1809     auto get_zmm_src = [&](int ic) { return Xbyak::Zmm(ic % 8); };
1810 
1811     Xbyak::Label prm_table;
1812     Xbyak::Label skip_oc_tail, to_exit;
1813 
1814     Xbyak::Opmask load_mask = k4;
1815 
1816     int tail_mask = (jbgp_->N_tail % simd_w)
1817             ? (1 << (jbgp_->N_tail % simd_w)) - 1
1818             : 0xffff;
1819     auto kmovw = [=](Xbyak::Opmask k, unsigned w) {
1820         mov(regw_tmp, w);
1821         jit_generator::kmovw(k, regw_tmp);
1822     };
1823 
1824     auto reorder_oc_block = [&](int icb, int ic_block, bool is_oc_tail) {
1825         // INP:      [64i][No]         : FP32
1826         // OUT: [OCB][ICB][16i][No][2i]: BF16
1827         if (ic_block <= 0) return;
1828 
1829         dim_t inp_icb_offset = typesize_acc
1830                 * (icb * ext_ic_block_ * jbgp_->oc_block); // Internal
1831         dim_t out_icb_offset = typesize_out
1832                 * (icb * div_up(ext_ic_block_, 2) * ext_oc_block_
1833                         * 2); // External
1834 
1835         const int oc_padded = rnd_up(jbgp_->oc, jbgp_->oc_block);
1836         const int oc_padded_ext = rnd_up(jbgp_->oc, ext_oc_block_);
1837 
1838         bool tailing_done = false;
1839         for (int oc = 0; oc < jbgp_->oc_block; oc += simd_w) {
1840             int ext_oc = oc % ext_oc_block_;
1841             int ext_ocb = oc / ext_oc_block_;
1842             dim_t ext_ocb_offset = typesize_out
1843                     * (ext_ocb * div_up(jbgp_->ic, ext_ic_block_)
1844                             * div_up(ext_ic_block_, 2) * ext_oc_block_ * 2);
1845             if (is_oc_tail && oc_padded != oc_padded_ext
1846                     && oc + simd_w > ext_oc_block_)
1847                 break;
1848             dim_t inp_offset = inp_icb_offset + typesize_acc * (oc); // Internal
1849             dim_t out_offset = out_icb_offset + typesize_out * (ext_oc * 2)
1850                     + ext_ocb_offset; // External
1851             kmovw(load_mask, 0xffff);
1852             if (is_oc_tail) {
1853                 if (jbgp_->N_tail && (oc + simd_w) >= jbgp_->N_tail) {
1854                     if (tailing_done == false) {
1855                         kmovw(load_mask, tail_mask);
1856                         tailing_done = true;
1857                     } else {
1858                         auto zmm_src_0 = get_zmm_src(0);
1859                         vpxord(zmm_src_0, zmm_src_0, zmm_src_0);
1860                         for (int ic = 0; ic < ext_ic_block_ / 2; ic++) {
1861                             vmovups(ptr[reg_output + out_offset
1862                                             + typesize_out
1863                                                     * (ic * ext_oc_block_ * 2)],
1864                                     zmm_src_0);
1865                         }
1866                         continue;
1867                     }
1868                 }
1869             }
1870 
1871             int ic = 0;
1872             for (; ic < ic_block / 2; ic++) {
1873                 int ic1 = 2 * ic;
1874                 int ic2 = 2 * ic + 1;
1875 
1876                 auto zmm_src_0 = get_zmm_src(ic1);
1877                 auto zmm_src_1 = get_zmm_src(ic2);
1878 
1879                 vmovups(zmm_src_0 | load_mask | T_z,
1880                         ptr[reg_input + inp_offset
1881                                 + typesize_acc * (ic1 * jbgp_->oc_block)]);
1882                 vmovups(zmm_src_1 | load_mask | T_z,
1883                         ptr[reg_input + inp_offset
1884                                 + typesize_acc * (ic2 * jbgp_->oc_block)]);
1885 
1886                 vcvtne2ps2bf16(zmm_src_0, zmm_src_1, zmm_src_0);
1887                 vpermw(zmm_src_0, zmm_idx, zmm_src_0);
1888 
1889                 vmovups(ptr[reg_output + out_offset
1890                                 + typesize_out * (ic * ext_oc_block_ * 2)],
1891                         zmm_src_0);
1892             }
1893             if (ic_block % 2) {
1894                 int ic1 = 2 * ic;
1895                 int ic2 = 2 * ic + 1;
1896 
1897                 auto zmm_src_0 = get_zmm_src(ic1);
1898                 auto zmm_src_1 = get_zmm_src(ic2);
1899 
1900                 vmovups(zmm_src_0 | load_mask | T_z,
1901                         ptr[reg_input + inp_offset
1902                                 + typesize_acc * (ic1 * jbgp_->oc_block)]);
1903                 vpxord(zmm_src_1, zmm_src_1, zmm_src_1);
1904 
1905                 vcvtne2ps2bf16(zmm_src_0, zmm_src_1, zmm_src_0);
1906                 vpermw(zmm_src_0, zmm_idx, zmm_src_0);
1907 
1908                 vmovups(ptr[reg_output + out_offset
1909                                 + typesize_out * (ic * ext_oc_block_ * 2)],
1910                         zmm_src_0);
1911                 ic++;
1912             }
1913             if (ic < ext_ic_block_ / 2) {
1914                 auto zmm_src_0 = get_zmm_src(0);
1915                 vpxord(zmm_src_0, zmm_src_0, zmm_src_0);
1916                 for (; ic < ext_ic_block_ / 2; ic++) {
1917                     vmovups(ptr[reg_output + out_offset
1918                                     + typesize_out * (ic * ext_oc_block_ * 2)],
1919                             zmm_src_0);
1920                 }
1921             }
1922         }
1923     };
1924 
1925     auto reorder_ic_block = [&](bool is_oc_tail, bool is_ic_tail) {
1926         int nb_ic = div_up(jbgp_->ic_block, ext_ic_block_);
1927         for (int icb = 0; icb < nb_ic; icb++) {
1928             int ic_0 = icb * ext_ic_block_;
1929             int ic_1 = (icb + 1) * ext_ic_block_;
1930             if (is_ic_tail) {
1931                 int ext_ic_tail = (jbgp_->ic % ext_ic_block_)
1932                         ? (jbgp_->ic % ext_ic_block_)
1933                         : ext_ic_block_;
1934                 if (jbgp_->M_tail && ic_0 >= jbgp_->M_tail) break;
1935                 if (jbgp_->M_tail && ic_0 <= jbgp_->M_tail
1936                         && jbgp_->M_tail <= ic_1) {
1937                     reorder_oc_block(icb, ext_ic_tail, is_oc_tail);
1938                 } else {
1939                     reorder_oc_block(icb, ext_ic_block_, is_oc_tail);
1940                 }
1941             } else {
1942                 reorder_oc_block(icb, ext_ic_block_, is_oc_tail);
1943             }
1944         }
1945     };
1946 
1947     auto reorder = [&](bool is_oc_tail) {
1948         Xbyak::Label skip_ic_tail, to_exit_1;
1949 
1950         cmp(reg_last_ic_block, 0);
1951         je(skip_ic_tail, T_NEAR);
1952 
1953         reorder_ic_block(is_oc_tail, true);
1954         jmp(to_exit, T_NEAR);
1955 
1956         L(skip_ic_tail);
1957         reorder_ic_block(is_oc_tail, false);
1958 
1959         L(to_exit_1);
1960     };
1961 
1962     preamble();
1963 
1964     mov(reg_input, ptr[abi_param1 + GET_OFF(src)]);
1965     mov(reg_output, ptr[abi_param1 + GET_OFF(dst)]);
1966     mov(reg_last_ic_block, ptr[abi_param1 + GET_OFF(last_ic_block)]);
1967     mov(reg_last_oc_block, ptr[abi_param1 + GET_OFF(last_oc_block)]);
1968 
1969     mov(reg_prm_table, prm_table);
1970     vmovups(zmm_idx, ptr[reg_prm_table]);
1971 
1972     cmp(reg_last_oc_block, 0);
1973     je(skip_oc_tail, T_NEAR);
1974 
1975     reorder(true);
1976     jmp(to_exit, T_NEAR);
1977 
1978     L(skip_oc_tail);
1979     reorder(false);
1980 
1981     L(to_exit);
1982     postamble();
1983 
1984     align(64);
1985     L(prm_table);
1986     const uint16_t prm_array[32]
1987             = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
1988                     25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
1989     for (size_t i = 0; i < 32; ++i)
1990         dw(prm_array[i]);
1991 }
1992 
1993 #undef GET_OFF
1994 
create_brgemm_trans_src(std::unique_ptr<jit_brgemm_trans_src_t> & trans_ker,const jit_brgemm_primitive_conf_t * conf)1995 status_t create_brgemm_trans_src(
1996         std::unique_ptr<jit_brgemm_trans_src_t> &trans_ker,
1997         const jit_brgemm_primitive_conf_t *conf) {
1998     if (conf->prop_kind == dnnl_backward_weights
1999             && conf->src_dt == data_type::f32)
2000         CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_m_k_f32_t(conf)));
2001     else if (conf->prop_kind == dnnl_backward_weights
2002             && conf->src_dt == data_type::bf16)
2003         CHECK(safe_ptr_assign(
2004                 trans_ker, new jit_brgemm_trans_m_k_bf16_t(conf)));
2005     else
2006         return status::invalid_arguments;
2007 
2008     return trans_ker->create_kernel();
2009 }
2010 
create_brgemm_copy_to_coarse(std::unique_ptr<jit_brgemm_copy_to_coarse_t> & copy_ker,const jit_brgemm_primitive_conf_t * conf)2011 status_t create_brgemm_copy_to_coarse(
2012         std::unique_ptr<jit_brgemm_copy_to_coarse_t> &copy_ker,
2013         const jit_brgemm_primitive_conf_t *conf) {
2014     if (conf->isa == avx512_core_bf16_amx_int8
2015             || conf->isa == avx512_core_bf16_amx_bf16)
2016         CHECK(safe_ptr_assign(copy_ker, new jit_brgemm_copy_to_coarse_t(conf)));
2017     else
2018         return status::invalid_arguments;
2019 
2020     return copy_ker->create_kernel();
2021 }
2022 
create_brgemm_trans_to_vnni(std::unique_ptr<jit_brgemm_trans_to_vnni_t> & trans_ker,const jit_brgemm_primitive_conf_t * conf,jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform)2023 status_t create_brgemm_trans_to_vnni(
2024         std::unique_ptr<jit_brgemm_trans_to_vnni_t> &trans_ker,
2025         const jit_brgemm_primitive_conf_t *conf,
2026         jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform) {
2027     if (conf->prop_kind == dnnl_backward_weights
2028             && conf->dst_dt == data_type::bf16)
2029         CHECK(safe_ptr_assign(
2030                 trans_ker, new jit_trans_to_vnni_t(conf, matrix_to_transform)));
2031     else if (conf->prop_kind == dnnl_backward_weights
2032             && conf->dst_dt == data_type::f32)
2033         CHECK(safe_ptr_assign(
2034                 trans_ker, new jit_copy_f32_t(conf, matrix_to_transform)));
2035     else
2036         return status::invalid_arguments;
2037 
2038     return trans_ker->create_kernel();
2039 }
2040 
create_brgemm_trans_wei(std::unique_ptr<jit_brgemm_trans_wei_t> & trans_ker,const jit_brgemm_primitive_conf_t * conf)2041 status_t create_brgemm_trans_wei(
2042         std::unique_ptr<jit_brgemm_trans_wei_t> &trans_ker,
2043         const jit_brgemm_primitive_conf_t *conf) {
2044     if (conf->prop_kind == dnnl_backward_data && conf->wei_dt == data_type::f32)
2045         CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_wei_f32_t(conf)));
2046     else if (conf->prop_kind == dnnl_backward_data
2047             && conf->wei_dt == data_type::bf16)
2048         CHECK(safe_ptr_assign(
2049                 trans_ker, new jit_brgemm_trans_wei_bf16_t(conf)));
2050     else
2051         return status::invalid_arguments;
2052 
2053     return trans_ker->create_kernel();
2054 }
2055 
create_brgemm_amx_ip_trans_wei(std::unique_ptr<jit_amx_ip_trans_diff_wei> & trans_ker,const jit_brgemm_primitive_conf_t * conf,const int ext_ic_block,const int ext_oc_block)2056 status_t create_brgemm_amx_ip_trans_wei(
2057         std::unique_ptr<jit_amx_ip_trans_diff_wei> &trans_ker,
2058         const jit_brgemm_primitive_conf_t *conf, const int ext_ic_block,
2059         const int ext_oc_block) {
2060     if (conf->prop_kind == dnnl_backward_weights
2061             && conf->wei_dt == data_type::bf16) {
2062         CHECK(safe_ptr_assign(trans_ker,
2063                 new jit_amx_ip_trans_diff_wei_to_vnni_t(
2064                         conf, ext_ic_block, ext_oc_block)));
2065     } else
2066         return status::invalid_arguments;
2067 
2068     return trans_ker->create_kernel();
2069 }
2070 
2071 } // namespace x64
2072 } // namespace cpu
2073 } // namespace impl
2074 } // namespace dnnl
2075