1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "cpu/x64/jit_generator.hpp"
18 
19 #include "cpu/x64/gemm/bf16/common_s16.hpp"
20 
21 namespace dnnl {
22 namespace impl {
23 namespace cpu {
24 namespace x64 {
25 
jit_avx512_core_s16_24x8_copy_bt_kern()26 jit_avx512_core_s16_24x8_copy_bt_kern::jit_avx512_core_s16_24x8_copy_bt_kern()
27     : jit_generator(nullptr, S16_COPY_KERNEL_CODE_SIZE) {}
28 
generate()29 void jit_avx512_core_s16_24x8_copy_bt_kern::generate() {
30 
31 #ifndef _WIN32
32 #define M rdi
33 #define N rsi
34 #define A rdx
35 #define LDA rcx
36 #define ALPHA r8
37 #define B r9
38 
39 #define I rax
40 #define A1 r10
41 #define A2 r8
42 #define LDA3 r11
43 
44 #else
45 
46 #define M rcx
47 #define N rdx
48 #define A r8
49 #define LDA r9
50 #define ALPHA rax
51 #define B rdi
52 
53 #define I rax
54 #define A1 rsi
55 #define A2 r10
56 #define LDA3 r11
57 
58 #define ARG_ALPHA 40 + stacksize + rsp
59 #define ARG_B 48 + stacksize + rsp
60 
61 #endif
62 
63     inLocalLabel();
64     {
65 
66         Xbyak::Label l13c;
67         Xbyak::Label l170;
68         Xbyak::Label l18c;
69         Xbyak::Label l19c;
70         Xbyak::Label l1a8;
71         Xbyak::Label l1b8;
72         Xbyak::Label l234;
73         Xbyak::Label l24;
74         Xbyak::Label l27c;
75         Xbyak::Label l2a8;
76         Xbyak::Label l2c4;
77         Xbyak::Label l2d4;
78         Xbyak::Label l2e0;
79         Xbyak::Label l2f0;
80         Xbyak::Label l368;
81         Xbyak::Label l38;
82         Xbyak::Label l3ac;
83         Xbyak::Label l3d8;
84         Xbyak::Label l3f4;
85         Xbyak::Label l402;
86         Xbyak::Label l40c;
87         Xbyak::Label l41c;
88         Xbyak::Label l494;
89         Xbyak::Label l4dc;
90         Xbyak::Label l50c;
91         Xbyak::Label l524;
92         Xbyak::Label l534;
93         Xbyak::Label le0;
94 
95         preamble();
96 #ifdef _WIN32
97         auto stacksize = get_size_of_abi_save_regs();
98         mov(ALPHA, ptr[ARG_ALPHA]);
99         mov(B, ptr[ARG_B]);
100 #endif
101 
102         mov(M, qword[M]);
103         mov(N, qword[N]);
104         mov(LDA, qword[LDA]);
105         shl(LDA, 1);
106         lea(LDA3, ptr[LDA + LDA * 2]);
107         sub(A, -128);
108         sub(B, -128);
109         cmp(N, 0x8);
110         jl(l19c, T_NEAR);
111         align(4);
112 
113         L(l24);
114         mov(A1, A);
115         add(A, 0x10);
116         mov(I, M);
117         sar(I, 0x3);
118         jle(le0, T_NEAR);
119         align(4);
120 
121         L(l38);
122         vmovdqu(xmm0, xword[A1 - 0x80]);
123         add(A1, LDA);
124         vmovdqu(xmm1, xword[A1 - 0x80]);
125         add(A1, LDA);
126         vmovdqu(xmm2, xword[A1 - 0x80]);
127         add(A1, LDA);
128         vmovdqu(xmm3, xword[A1 - 0x80]);
129         add(A1, LDA);
130         vpunpcklwd(xmm4, xmm0, xmm1);
131         vpunpckhwd(xmm5, xmm0, xmm1);
132         vperm2f128(ymm0, ymm4, ymm5, 0x20);
133         vpunpcklwd(xmm4, xmm2, xmm3);
134         vpunpckhwd(xmm5, xmm2, xmm3);
135         vperm2f128(ymm2, ymm4, ymm5, 0x20);
136         vmovdqu(yword[B - 0x80], ymm0);
137         vmovdqu(yword[B - 0x60], ymm2);
138         vmovdqu(xmm0, xword[A1 - 0x80]);
139         add(A1, LDA);
140         vmovdqu(xmm1, xword[A1 - 0x80]);
141         add(A1, LDA);
142         vmovdqu(xmm2, xword[A1 - 0x80]);
143         add(A1, LDA);
144         vmovdqu(xmm3, xword[A1 - 0x80]);
145         add(A1, LDA);
146         vpunpcklwd(xmm4, xmm0, xmm1);
147         vpunpckhwd(xmm5, xmm0, xmm1);
148         vperm2f128(ymm0, ymm4, ymm5, 0x20);
149         vpunpcklwd(xmm4, xmm2, xmm3);
150         vpunpckhwd(xmm5, xmm2, xmm3);
151         vperm2f128(ymm2, ymm4, ymm5, 0x20);
152         vmovdqu(yword[B - 0x40], ymm0);
153         vmovdqu(yword[B - 0x20], ymm2);
154         sub(B, -128);
155         dec(I);
156         jg(l38, T_NEAR);
157         align(4);
158 
159         L(le0);
160         test(M, 0x4);
161         jle(l13c, T_NEAR);
162         vmovdqu(xmm0, xword[A1 - 0x80]);
163         add(A1, LDA);
164         vmovdqu(xmm1, xword[A1 - 0x80]);
165         add(A1, LDA);
166         vmovdqu(xmm2, xword[A1 - 0x80]);
167         add(A1, LDA);
168         vmovdqu(xmm3, xword[A1 - 0x80]);
169         add(A1, LDA);
170         vpunpcklwd(xmm4, xmm0, xmm1);
171         vpunpckhwd(xmm5, xmm0, xmm1);
172         vperm2f128(ymm0, ymm4, ymm5, 0x20);
173         vpunpcklwd(xmm4, xmm2, xmm3);
174         vpunpckhwd(xmm5, xmm2, xmm3);
175         vperm2f128(ymm2, ymm4, ymm5, 0x20);
176         vmovdqu(yword[B - 0x80], ymm0);
177         vmovdqu(yword[B - 0x60], ymm2);
178         sub(B, -64);
179         align(4);
180 
181         L(l13c);
182         test(M, 0x2);
183         jle(l170, T_NEAR);
184         vmovdqu(xmm0, xword[A1 - 0x80]);
185         add(A1, LDA);
186         vmovdqu(xmm1, xword[A1 - 0x80]);
187         add(A1, LDA);
188         vpunpcklwd(xmm2, xmm0, xmm1);
189         vpunpckhwd(xmm3, xmm0, xmm1);
190         vperm2f128(ymm0, ymm2, ymm3, 0x20);
191         vmovdqu(yword[B - 0x80], ymm0);
192         sub(B, -32);
193         align(4);
194 
195         L(l170);
196         test(M, 0x1);
197         jle(l18c, T_NEAR);
198         vmovdqu(xmm0, xword[A1 - 0x80]);
199         vmovdqu(xword[B - 0x80], xmm0);
200         sub(B, -16);
201         align(4);
202 
203         L(l18c);
204         sub(N, 0x8);
205         cmp(N, 0x8);
206         jge(l24, T_NEAR);
207         align(4);
208 
209         L(l19c);
210         cmp(N, 0x4);
211         jl(l2d4, T_NEAR);
212         align(4);
213 
214         L(l1a8);
215         mov(A1, A);
216         add(A, 0x8);
217         mov(I, M);
218         sar(I, 0x3);
219         jle(l234, T_NEAR);
220         align(4);
221 
222         L(l1b8);
223         vmovq(xmm0, qword[A1 - 0x80]);
224         add(A1, LDA);
225         vmovq(xmm1, qword[A1 - 0x80]);
226         add(A1, LDA);
227         vmovq(xmm2, qword[A1 - 0x80]);
228         add(A1, LDA);
229         vmovq(xmm3, qword[A1 - 0x80]);
230         add(A1, LDA);
231         vpunpcklwd(xmm0, xmm0, xmm1);
232         vpunpcklwd(xmm2, xmm2, xmm3);
233         vperm2f128(ymm0, ymm0, ymm2, 0x20);
234         vmovdqu(yword[B - 0x80], ymm0);
235         vmovq(xmm0, qword[A1 - 0x80]);
236         add(A1, LDA);
237         vmovq(xmm1, qword[A1 - 0x80]);
238         add(A1, LDA);
239         vmovq(xmm2, qword[A1 - 0x80]);
240         add(A1, LDA);
241         vmovq(xmm3, qword[A1 - 0x80]);
242         add(A1, LDA);
243         vpunpcklwd(xmm0, xmm0, xmm1);
244         vpunpcklwd(xmm2, xmm2, xmm3);
245         vperm2f128(ymm0, ymm0, ymm2, 0x20);
246         vmovdqu(yword[B - 0x60], ymm0);
247         sub(B, -64);
248         dec(I);
249         jg(l1b8, T_NEAR);
250         align(4);
251 
252         L(l234);
253         test(M, 0x4);
254         jle(l27c, T_NEAR);
255         vmovq(xmm0, qword[A1 - 0x80]);
256         add(A1, LDA);
257         vmovq(xmm1, qword[A1 - 0x80]);
258         add(A1, LDA);
259         vmovq(xmm2, qword[A1 - 0x80]);
260         add(A1, LDA);
261         vmovq(xmm3, qword[A1 - 0x80]);
262         add(A1, LDA);
263         vpunpcklwd(xmm0, xmm0, xmm1);
264         vpunpcklwd(xmm2, xmm2, xmm3);
265         vmovdqu(xword[B - 0x80], xmm0);
266         vmovdqu(xword[B - 0x70], xmm2);
267         sub(B, -32);
268         align(4);
269 
270         L(l27c);
271         test(M, 0x2);
272         jle(l2a8, T_NEAR);
273         vmovq(xmm0, qword[A1 - 0x80]);
274         add(A1, LDA);
275         vmovq(xmm1, qword[A1 - 0x80]);
276         add(A1, LDA);
277         vpunpcklwd(xmm0, xmm0, xmm1);
278         vmovdqu(xword[B - 0x80], xmm0);
279         sub(B, -16);
280         align(4);
281 
282         L(l2a8);
283         test(M, 0x1);
284         jle(l2c4, T_NEAR);
285         vmovq(xmm0, qword[A1 - 0x80]);
286         vmovq(qword[B - 0x80], xmm0);
287         sub(B, -8);
288         align(4);
289 
290         L(l2c4);
291         sub(N, 0x4);
292         cmp(N, 0x4);
293         jge(l1a8, T_NEAR);
294         align(4);
295 
296         L(l2d4);
297         cmp(N, 0x2);
298         jl(l402, T_NEAR);
299         align(4);
300 
301         L(l2e0);
302         mov(A1, A);
303         add(A, 0x4);
304         mov(I, M);
305         sar(I, 0x3);
306         jle(l368, T_NEAR);
307         align(4);
308 
309         L(l2f0);
310         vmovd(xmm0, dword[A1 - 0x80]);
311         add(A1, LDA);
312         vmovd(xmm1, dword[A1 - 0x80]);
313         add(A1, LDA);
314         vmovd(xmm2, dword[A1 - 0x80]);
315         add(A1, LDA);
316         vmovd(xmm3, dword[A1 - 0x80]);
317         add(A1, LDA);
318         vpunpcklwd(xmm0, xmm0, xmm1);
319         vpunpcklwd(xmm2, xmm2, xmm3);
320         vpunpcklqdq(xmm0, xmm0, xmm2);
321         vmovdqu(xword[B - 0x80], xmm0);
322         vmovd(xmm0, dword[A1 - 0x80]);
323         add(A1, LDA);
324         vmovd(xmm1, dword[A1 - 0x80]);
325         add(A1, LDA);
326         vmovd(xmm2, dword[A1 - 0x80]);
327         add(A1, LDA);
328         vmovd(xmm3, dword[A1 - 0x80]);
329         add(A1, LDA);
330         vpunpcklwd(xmm0, xmm0, xmm1);
331         vpunpcklwd(xmm2, xmm2, xmm3);
332         vpunpcklqdq(xmm0, xmm0, xmm2);
333         vmovdqu(xword[B - 0x70], xmm0);
334         sub(B, -32);
335         dec(I);
336         jg(l2f0, T_NEAR);
337         align(4);
338 
339         L(l368);
340         test(M, 0x4);
341         jle(l3ac, T_NEAR);
342         vmovd(xmm0, dword[A1 - 0x80]);
343         add(A1, LDA);
344         vmovd(xmm1, dword[A1 - 0x80]);
345         add(A1, LDA);
346         vmovd(xmm2, dword[A1 - 0x80]);
347         add(A1, LDA);
348         vmovd(xmm3, dword[A1 - 0x80]);
349         add(A1, LDA);
350         vpunpcklwd(xmm0, xmm0, xmm1);
351         vpunpcklwd(xmm2, xmm2, xmm3);
352         vpunpcklqdq(xmm0, xmm0, xmm2);
353         vmovdqu(xword[B - 0x80], xmm0);
354         sub(B, -16);
355         align(4);
356 
357         L(l3ac);
358         test(M, 0x2);
359         jle(l3d8, T_NEAR);
360         vmovd(xmm0, dword[A1 - 0x80]);
361         add(A1, LDA);
362         vmovd(xmm1, dword[A1 - 0x80]);
363         add(A1, LDA);
364         vpunpcklwd(xmm0, xmm0, xmm1);
365         vmovq(qword[B - 0x80], xmm0);
366         sub(B, -8);
367         align(4);
368 
369         L(l3d8);
370         test(M, 0x1);
371         jle(l3f4, T_NEAR);
372         vmovd(xmm0, dword[A1 - 0x80]);
373         vmovd(dword[B - 0x80], xmm0);
374         sub(B, -4);
375         align(4);
376 
377         L(l3f4);
378         sub(N, 0x2);
379         cmp(N, 0x2);
380         jge(l2e0, T_NEAR);
381         align(4);
382 
383         L(l402);
384         cmp(N, 0x1);
385         jl(l534, T_NEAR);
386         align(4);
387 
388         L(l40c);
389         mov(A1, A);
390         add(A, 0x2);
391         mov(LDA3, M);
392         sar(LDA3, 0x3);
393         jle(l494, T_NEAR);
394         align(4);
395 
396         L(l41c);
397         mov(ax, word[A1 - 0x80]);
398         add(A1, LDA);
399         vpinsrw(xmm0, xmm0, eax, 0x0);
400         mov(ax, word[A1 - 0x80]);
401         add(A1, LDA);
402         vpinsrw(xmm0, xmm0, eax, 0x1);
403         mov(ax, word[A1 - 0x80]);
404         add(A1, LDA);
405         vpinsrw(xmm0, xmm0, eax, 0x2);
406         mov(ax, word[A1 - 0x80]);
407         add(A1, LDA);
408         vpinsrw(xmm0, xmm0, eax, 0x3);
409         mov(ax, word[A1 - 0x80]);
410         add(A1, LDA);
411         vpinsrw(xmm0, xmm0, eax, 0x4);
412         mov(ax, word[A1 - 0x80]);
413         add(A1, LDA);
414         vpinsrw(xmm0, xmm0, eax, 0x5);
415         mov(ax, word[A1 - 0x80]);
416         add(A1, LDA);
417         vpinsrw(xmm0, xmm0, eax, 0x6);
418         mov(ax, word[A1 - 0x80]);
419         add(A1, LDA);
420         vpinsrw(xmm0, xmm0, eax, 0x7);
421         vmovdqu(xword[B - 0x80], xmm0);
422         sub(B, -16);
423         dec(LDA3);
424         jg(l41c, T_NEAR);
425         align(4);
426 
427         L(l494);
428         test(M, 0x4);
429         jle(l4dc, T_NEAR);
430         mov(ax, word[A1 - 0x80]);
431         add(A1, LDA);
432         vpinsrw(xmm0, xmm0, eax, 0x0);
433         mov(ax, word[A1 - 0x80]);
434         add(A1, LDA);
435         vpinsrw(xmm0, xmm0, eax, 0x1);
436         mov(ax, word[A1 - 0x80]);
437         add(A1, LDA);
438         vpinsrw(xmm0, xmm0, eax, 0x2);
439         mov(ax, word[A1 - 0x80]);
440         add(A1, LDA);
441         vpinsrw(xmm0, xmm0, eax, 0x3);
442         vmovq(qword[B - 0x80], xmm0);
443         sub(B, -8);
444         align(4);
445 
446         L(l4dc);
447         test(M, 0x2);
448         jle(l50c, T_NEAR);
449         mov(ax, word[A1 - 0x80]);
450         add(A1, LDA);
451         vpinsrw(xmm0, xmm0, eax, 0x0);
452         mov(ax, word[A1 - 0x80]);
453         add(A1, LDA);
454         vpinsrw(xmm0, xmm0, eax, 0x1);
455         vmovd(dword[B - 0x80], xmm0);
456         sub(B, -4);
457         align(4);
458 
459         L(l50c);
460         test(M, 0x1);
461         jle(l524, T_NEAR);
462         mov(ax, word[A1 - 0x80]);
463         mov(word[B - 0x80], ax);
464         sub(B, -2);
465         align(4);
466 
467         L(l524);
468         sub(N, 0x1);
469         cmp(N, 0x1);
470         jge(l40c, T_NEAR);
471         align(4);
472 
473         L(l534);
474         vzeroupper();
475         postamble();
476     }
477     outLocalLabel();
478 
479 #undef M
480 #undef N
481 #undef A
482 #undef LDA
483 #undef ALPHA
484 #undef B
485 #undef I
486 #undef A1
487 #undef A2
488 #undef LDA3
489 #ifdef _WIN32
490 #undef ARG_ALPHA
491 #undef ARG_B
492 #endif
493 }
494 
495 } // namespace x64
496 } // namespace cpu
497 } // namespace impl
498 } // namespace dnnl
499