1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "cpu/x64/jit_generator.hpp"
18 
19 #include "cpu/x64/gemm/bf16/common_s16.hpp"
20 
21 namespace dnnl {
22 namespace impl {
23 namespace cpu {
24 namespace x64 {
25 
jit_avx512_core_s16_24x8_copy_bn_kern()26 jit_avx512_core_s16_24x8_copy_bn_kern::jit_avx512_core_s16_24x8_copy_bn_kern()
27     : jit_generator(nullptr, S16_COPY_KERNEL_CODE_SIZE) {}
28 
generate()29 void jit_avx512_core_s16_24x8_copy_bn_kern::generate() {
30 
31 #ifndef _WIN32
32 #define M rdi
33 #define N rsi
34 #define A rdx
35 #define LDA rcx
36 #define ALPHA r8
37 #define B r9
38 
39 #define I rax
40 #define A1 r10
41 #define A2 r8
42 #define LDA3 r11
43 
44 #else
45 
46 #define M rcx
47 #define N rdx
48 #define A r8
49 #define LDA r9
50 #define ALPHA rax
51 #define B rdi
52 
53 #define I rax
54 #define A1 rsi
55 #define A2 r10
56 #define LDA3 r11
57 
58 #define ARG_ALPHA 40 + stacksize + rsp
59 #define ARG_B 48 + stacksize + rsp
60 
61 #endif
62 
63     inLocalLabel();
64     {
65 
66         Xbyak::Label l158;
67         Xbyak::Label l1c8;
68         Xbyak::Label l23c;
69         Xbyak::Label l24;
70         Xbyak::Label l24c;
71         Xbyak::Label l258;
72         Xbyak::Label l270;
73         Xbyak::Label l2e0;
74         Xbyak::Label l32c;
75         Xbyak::Label l370;
76         Xbyak::Label l3b4;
77         Xbyak::Label l3c4;
78         Xbyak::Label l3d0;
79         Xbyak::Label l3e8;
80         Xbyak::Label l40;
81         Xbyak::Label l41c;
82         Xbyak::Label l448;
83         Xbyak::Label l474;
84         Xbyak::Label l49c;
85         Xbyak::Label l4aa;
86         Xbyak::Label l4b4;
87         Xbyak::Label l4c4;
88         Xbyak::Label l4e0;
89         Xbyak::Label l500;
90         Xbyak::Label l520;
91         Xbyak::Label l540;
92         Xbyak::Label l558;
93         Xbyak::Label l568;
94         Xbyak::Label ldc;
95 
96         preamble();
97 #ifdef _WIN32
98         auto stacksize = get_size_of_abi_save_regs();
99         mov(ALPHA, ptr[ARG_ALPHA]);
100         mov(B, ptr[ARG_B]);
101 #endif
102 
103         mov(N, qword[N]);
104         mov(M, qword[M]);
105         mov(LDA, qword[LDA]);
106         shl(LDA, 1);
107         lea(LDA3, ptr[LDA + LDA * 2]);
108         sub(A, -128);
109         sub(B, -128);
110         cmp(N, 0x8);
111         jl(l24c, T_NEAR);
112         align(4);
113 
114         L(l24);
115         mov(A1, A);
116         lea(A2, ptr[A1 + LDA * 4]);
117         lea(I, ptr[A1 + LDA * 8]);
118         mov(A, I);
119         mov(I, M);
120         sar(I, 0x3);
121         jle(ldc, T_NEAR);
122         align(4);
123 
124         L(l40);
125         vmovdqu(xmm4, xword[A1 - 0x80]);
126         vmovdqu(xmm5, xword[A1 + LDA * 1 - 0x80]);
127         vmovdqu(xmm0, xword[A1 + LDA * 2 - 0x80]);
128         vmovdqu(xmm1, xword[A1 + LDA3 * 1 - 0x80]);
129         sub(A1, -16);
130         vmovdqu(xmm2, xword[A2 - 0x80]);
131         vperm2f128(ymm4, ymm4, ymm2, 0x20);
132         vmovdqu(xmm3, xword[A2 + LDA * 1 - 0x80]);
133         vperm2f128(ymm5, ymm5, ymm3, 0x20);
134         vmovdqu(xmm2, xword[A2 + LDA * 2 - 0x80]);
135         vperm2f128(ymm0, ymm0, ymm2, 0x20);
136         vmovdqu(xmm3, xword[A2 + LDA3 * 1 - 0x80]);
137         vperm2f128(ymm1, ymm1, ymm3, 0x20);
138         sub(A2, -16);
139         vunpcklps(ymm2, ymm4, ymm0);
140         vunpckhps(ymm3, ymm4, ymm0);
141         vunpcklps(ymm4, ymm5, ymm1);
142         vunpckhps(ymm5, ymm5, ymm1);
143         vunpcklps(ymm0, ymm2, ymm4);
144         vunpckhps(ymm1, ymm2, ymm4);
145         vunpcklps(ymm2, ymm3, ymm5);
146         vunpckhps(ymm3, ymm3, ymm5);
147         vmovdqu(yword[B - 0x80], ymm0);
148         vmovdqu(yword[B - 0x60], ymm1);
149         vmovdqu(yword[B - 0x40], ymm2);
150         vmovdqu(yword[B - 0x20], ymm3);
151         sub(B, -128);
152         dec(I);
153         jg(l40, T_NEAR);
154         align(4);
155 
156         L(ldc);
157         test(M, 0x4);
158         jle(l158, T_NEAR);
159         vmovq(xmm0, qword[A1 - 0x80]);
160         vmovq(xmm1, qword[A1 + LDA * 1 - 0x80]);
161         vmovq(xmm2, qword[A1 + LDA * 2 - 0x80]);
162         vmovq(xmm3, qword[A1 + LDA3 * 1 - 0x80]);
163         sub(A1, -8);
164         vunpcklps(xmm0, xmm0, xmm2);
165         vunpcklps(xmm1, xmm1, xmm3);
166         vmovq(xmm2, qword[A2 - 0x80]);
167         vmovq(xmm3, qword[A2 + LDA * 1 - 0x80]);
168         vmovq(xmm4, qword[A2 + LDA * 2 - 0x80]);
169         vmovq(xmm5, qword[A2 + LDA3 * 1 - 0x80]);
170         sub(A2, -8);
171         vunpcklps(xmm2, xmm2, xmm4);
172         vunpcklps(xmm3, xmm3, xmm5);
173         vperm2f128(ymm0, ymm0, ymm2, 0x20);
174         vperm2f128(ymm1, ymm1, ymm3, 0x20);
175         vunpcklps(ymm2, ymm0, ymm1);
176         vunpckhps(ymm3, ymm0, ymm1);
177         vmovdqu(yword[B - 0x80], ymm2);
178         vmovdqu(yword[B - 0x60], ymm3);
179         sub(B, -64);
180         align(4);
181 
182         L(l158);
183         test(M, 0x2);
184         jle(l1c8, T_NEAR);
185         vmovd(xmm0, dword[A1 - 0x80]);
186         vmovd(xmm1, dword[A1 + LDA * 1 - 0x80]);
187         vmovd(xmm2, dword[A1 + LDA * 2 - 0x80]);
188         vmovd(xmm3, dword[A1 + LDA3 * 1 - 0x80]);
189         sub(A1, -4);
190         vunpcklps(xmm0, xmm0, xmm1);
191         vunpcklps(xmm2, xmm2, xmm3);
192         vpunpcklqdq(xmm0, xmm0, xmm2);
193         vmovd(xmm1, dword[A2 - 0x80]);
194         vmovd(xmm2, dword[A2 + LDA * 1 - 0x80]);
195         vmovd(xmm3, dword[A2 + LDA * 2 - 0x80]);
196         vmovd(xmm4, dword[A2 + LDA3 * 1 - 0x80]);
197         sub(A2, -4);
198         vunpcklps(xmm1, xmm1, xmm2);
199         vunpcklps(xmm3, xmm3, xmm4);
200         vpunpcklqdq(xmm1, xmm1, xmm3);
201         vinsertf128(ymm0, ymm0, xmm1, 0x1);
202         vmovdqu(yword[B - 0x80], ymm0);
203         sub(B, -32);
204         align(4);
205 
206         L(l1c8);
207         test(M, 0x1);
208         jle(l23c, T_NEAR);
209         mov(ax, word[A1 - 0x80]);
210         vpinsrw(xmm0, xmm0, eax, 0x0);
211         mov(ax, word[A1 + LDA * 1 - 0x80]);
212         vpinsrw(xmm0, xmm0, eax, 0x1);
213         mov(ax, word[A1 + LDA * 2 - 0x80]);
214         vpinsrw(xmm0, xmm0, eax, 0x2);
215         mov(ax, word[A1 + LDA3 * 1 - 0x80]);
216         vpinsrw(xmm0, xmm0, eax, 0x3);
217         lea(A2, ptr[A1 + LDA * 4]);
218         mov(ax, word[A2 - 0x80]);
219         vpinsrw(xmm0, xmm0, eax, 0x4);
220         mov(ax, word[A2 + LDA * 1 - 0x80]);
221         vpinsrw(xmm0, xmm0, eax, 0x5);
222         mov(ax, word[A2 + LDA * 2 - 0x80]);
223         vpinsrw(xmm0, xmm0, eax, 0x6);
224         mov(ax, word[A2 + LDA3 * 1 - 0x80]);
225         vpinsrw(xmm0, xmm0, eax, 0x7);
226         lea(A2, ptr[A2 + LDA * 4]);
227         vmovdqu(xword[B - 0x80], xmm0);
228         sub(B, -16);
229         align(4);
230 
231         L(l23c);
232         sub(N, 0x8);
233         cmp(N, 0x8);
234         jge(l24, T_NEAR);
235         align(4);
236 
237         L(l24c);
238         cmp(N, 0x4);
239         jl(l3c4, T_NEAR);
240         align(4);
241 
242         L(l258);
243         mov(A1, A);
244         lea(A2, ptr[A1 + LDA * 2]);
245         lea(I, ptr[A1 + LDA * 4]);
246         mov(A, I);
247         mov(I, M);
248         sar(I, 0x3);
249         jle(l2e0, T_NEAR);
250         align(4);
251 
252         L(l270);
253         vmovdqu(xmm0, xword[A1 - 0x80]);
254         vmovdqu(xmm1, xword[A1 + LDA * 1 - 0x80]);
255         sub(A1, -16);
256         vmovdqu(xmm2, xword[A2 - 0x80]);
257         vmovdqu(xmm3, xword[A2 + LDA * 1 - 0x80]);
258         sub(A2, -16);
259         vperm2f128(ymm0, ymm0, ymm2, 0x20);
260         vperm2f128(ymm1, ymm1, ymm3, 0x20);
261         vunpcklps(ymm2, ymm0, ymm1);
262         vunpckhps(ymm3, ymm0, ymm1);
263         vperm2f128(ymm0, ymm2, ymm2, 0x1);
264         vperm2f128(ymm1, ymm3, ymm3, 0x1);
265         vshufpd(ymm0, ymm2, ymm0, 0xc);
266         vshufpd(ymm1, ymm3, ymm1, 0xc);
267         vpermilpd(ymm0, ymm0, 0x6);
268         vpermilpd(ymm1, ymm1, 0x6);
269         vmovdqu(yword[B - 0x80], ymm0);
270         vmovdqu(yword[B - 0x60], ymm1);
271         sub(B, -64);
272         dec(I);
273         jg(l270, T_NEAR);
274         align(4);
275 
276         L(l2e0);
277         test(M, 0x4);
278         jle(l32c, T_NEAR);
279         vmovq(xmm0, qword[A1 - 0x80]);
280         vmovq(xmm1, qword[A1 + LDA * 1 - 0x80]);
281         sub(A1, -8);
282         vmovq(xmm2, qword[A2 - 0x80]);
283         vmovq(xmm3, qword[A2 + LDA * 1 - 0x80]);
284         sub(A2, -8);
285         vunpcklps(xmm0, xmm0, xmm2);
286         vunpcklps(xmm1, xmm1, xmm3);
287         vunpcklps(xmm2, xmm0, xmm1);
288         vunpckhps(xmm3, xmm0, xmm1);
289         vmovdqu(xword[B - 0x80], xmm2);
290         vmovdqu(xword[B - 0x70], xmm3);
291         sub(B, -32);
292         align(4);
293 
294         L(l32c);
295         test(M, 0x2);
296         jle(l370, T_NEAR);
297         vmovd(xmm0, dword[A1 - 0x80]);
298         vmovd(xmm1, dword[A1 + LDA * 1 - 0x80]);
299         sub(A1, -4);
300         vmovd(xmm2, dword[A2 - 0x80]);
301         vmovd(xmm3, dword[A2 + LDA * 1 - 0x80]);
302         sub(A2, -4);
303         vunpcklps(xmm0, xmm0, xmm1);
304         vunpcklps(xmm2, xmm2, xmm3);
305         vpunpcklqdq(xmm0, xmm0, xmm2);
306         vmovdqu(xword[B - 0x80], xmm0);
307         sub(B, -16);
308         align(4);
309 
310         L(l370);
311         test(M, 0x1);
312         jle(l3b4, T_NEAR);
313         mov(ax, word[A1 - 0x80]);
314         vpinsrw(xmm0, xmm0, eax, 0x0);
315         mov(ax, word[A1 + LDA * 1 - 0x80]);
316         vpinsrw(xmm0, xmm0, eax, 0x1);
317         mov(ax, word[A1 + LDA * 2 - 0x80]);
318         vpinsrw(xmm0, xmm0, eax, 0x2);
319         mov(ax, word[A1 + LDA3 * 1 - 0x80]);
320         vpinsrw(xmm0, xmm0, eax, 0x3);
321         lea(A2, ptr[A1 + LDA * 4]);
322         vmovq(qword[B - 0x80], xmm0);
323         sub(B, -8);
324         align(4);
325 
326         L(l3b4);
327         sub(N, 0x4);
328         cmp(N, 0x4);
329         jge(l258, T_NEAR);
330         align(4);
331 
332         L(l3c4);
333         cmp(N, 0x2);
334         jl(l4aa, T_NEAR);
335         align(4);
336 
337         L(l3d0);
338         mov(A1, A);
339         lea(A2, ptr[A1 + LDA * 1]);
340         lea(I, ptr[A1 + LDA * 2]);
341         mov(A, I);
342         mov(I, M);
343         sar(I, 0x3);
344         jle(l41c, T_NEAR);
345         align(4);
346 
347         L(l3e8);
348         vmovdqu(xmm0, xword[A1 - 0x80]);
349         sub(A1, -16);
350         vmovdqu(xmm1, xword[A2 - 0x80]);
351         sub(A2, -16);
352         vunpcklps(xmm2, xmm0, xmm1);
353         vunpckhps(xmm3, xmm0, xmm1);
354         vmovdqu(xword[B - 0x80], xmm2);
355         vmovdqu(xword[B - 0x70], xmm3);
356         sub(B, -32);
357         dec(I);
358         jg(l3e8, T_NEAR);
359         align(4);
360 
361         L(l41c);
362         test(M, 0x4);
363         jle(l448, T_NEAR);
364         vmovq(xmm0, qword[A1 - 0x80]);
365         sub(A1, -8);
366         vmovq(xmm1, qword[A2 - 0x80]);
367         sub(A2, -8);
368         vunpcklps(xmm0, xmm0, xmm1);
369         vmovdqu(xword[B - 0x80], xmm0);
370         sub(B, -16);
371         align(4);
372 
373         L(l448);
374         test(M, 0x2);
375         jle(l474, T_NEAR);
376         vmovd(xmm0, dword[A1 - 0x80]);
377         sub(A1, -4);
378         vmovd(xmm1, dword[A2 - 0x80]);
379         sub(A2, -4);
380         vunpcklps(xmm0, xmm0, xmm1);
381         vmovq(qword[B - 0x80], xmm0);
382         sub(B, -8);
383         align(4);
384 
385         L(l474);
386         test(M, 0x1);
387         jle(l49c, T_NEAR);
388         mov(ax, word[A1 - 0x80]);
389         vpinsrw(xmm0, xmm0, eax, 0x0);
390         mov(ax, word[A1 + LDA * 1 - 0x80]);
391         vpinsrw(xmm0, xmm0, eax, 0x1);
392         vmovd(dword[B - 0x80], xmm0);
393         sub(B, -4);
394         align(4);
395 
396         L(l49c);
397         sub(N, 0x2);
398         cmp(N, 0x2);
399         jge(l3d0, T_NEAR);
400         align(4);
401 
402         L(l4aa);
403         cmp(N, 0x1);
404         jl(l568, T_NEAR);
405         align(4);
406 
407         L(l4b4);
408         mov(A1, A);
409         add(A, LDA);
410         mov(I, M);
411         sar(I, 0x4);
412         jle(l4e0, T_NEAR);
413         align(4);
414 
415         L(l4c4);
416         vmovdqu(ymm0, yword[A1 - 0x80]);
417         sub(A1, -32);
418         vmovdqu(yword[B - 0x80], ymm0);
419         sub(B, -32);
420         dec(I);
421         jg(l4c4, T_NEAR);
422         align(4);
423 
424         L(l4e0);
425         test(M, 0x8);
426         jle(l500, T_NEAR);
427         vmovdqu(xmm0, xword[A1 - 0x80]);
428         sub(A1, -16);
429         vmovdqu(xword[B - 0x80], xmm0);
430         sub(B, -16);
431         align(4);
432 
433         L(l500);
434         test(M, 0x4);
435         jle(l520, T_NEAR);
436         vmovq(xmm0, qword[A1 - 0x80]);
437         sub(A1, -8);
438         vmovq(qword[B - 0x80], xmm0);
439         sub(B, -8);
440         align(4);
441 
442         L(l520);
443         test(M, 0x2);
444         jle(l540, T_NEAR);
445         vmovd(xmm0, dword[A1 - 0x80]);
446         sub(A1, -4);
447         vmovd(dword[B - 0x80], xmm0);
448         sub(B, -4);
449         align(4);
450 
451         L(l540);
452         test(M, 0x1);
453         jle(l558, T_NEAR);
454         mov(ax, word[A1 - 0x80]);
455         mov(word[B - 0x80], ax);
456         sub(B, -2);
457         align(4);
458 
459         L(l558);
460         sub(N, 0x1);
461         cmp(N, 0x1);
462         jge(l4b4, T_NEAR);
463         align(4);
464 
465         L(l568);
466         vzeroupper();
467         postamble();
468     }
469     outLocalLabel();
470 
471 #undef M
472 #undef N
473 #undef A
474 #undef LDA
475 #undef ALPHA
476 #undef B
477 #undef I
478 #undef A1
479 #undef A2
480 #undef LDA3
481 #ifdef _WIN32
482 #undef ARG_ALPHA
483 #undef ARG_B
484 #endif
485 }
486 
487 } // namespace x64
488 } // namespace cpu
489 } // namespace impl
490 } // namespace dnnl
491