1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "cpu/x64/jit_generator.hpp"
18
19 #include "cpu/x64/gemm/bf16/common_s16.hpp"
20
21 namespace dnnl {
22 namespace impl {
23 namespace cpu {
24 namespace x64 {
25
jit_avx512_core_s16_24x8_copy_bn_kern()26 jit_avx512_core_s16_24x8_copy_bn_kern::jit_avx512_core_s16_24x8_copy_bn_kern()
27 : jit_generator(nullptr, S16_COPY_KERNEL_CODE_SIZE) {}
28
generate()29 void jit_avx512_core_s16_24x8_copy_bn_kern::generate() {
30
31 #ifndef _WIN32
32 #define M rdi
33 #define N rsi
34 #define A rdx
35 #define LDA rcx
36 #define ALPHA r8
37 #define B r9
38
39 #define I rax
40 #define A1 r10
41 #define A2 r8
42 #define LDA3 r11
43
44 #else
45
46 #define M rcx
47 #define N rdx
48 #define A r8
49 #define LDA r9
50 #define ALPHA rax
51 #define B rdi
52
53 #define I rax
54 #define A1 rsi
55 #define A2 r10
56 #define LDA3 r11
57
58 #define ARG_ALPHA 40 + stacksize + rsp
59 #define ARG_B 48 + stacksize + rsp
60
61 #endif
62
63 inLocalLabel();
64 {
65
66 Xbyak::Label l158;
67 Xbyak::Label l1c8;
68 Xbyak::Label l23c;
69 Xbyak::Label l24;
70 Xbyak::Label l24c;
71 Xbyak::Label l258;
72 Xbyak::Label l270;
73 Xbyak::Label l2e0;
74 Xbyak::Label l32c;
75 Xbyak::Label l370;
76 Xbyak::Label l3b4;
77 Xbyak::Label l3c4;
78 Xbyak::Label l3d0;
79 Xbyak::Label l3e8;
80 Xbyak::Label l40;
81 Xbyak::Label l41c;
82 Xbyak::Label l448;
83 Xbyak::Label l474;
84 Xbyak::Label l49c;
85 Xbyak::Label l4aa;
86 Xbyak::Label l4b4;
87 Xbyak::Label l4c4;
88 Xbyak::Label l4e0;
89 Xbyak::Label l500;
90 Xbyak::Label l520;
91 Xbyak::Label l540;
92 Xbyak::Label l558;
93 Xbyak::Label l568;
94 Xbyak::Label ldc;
95
96 preamble();
97 #ifdef _WIN32
98 auto stacksize = get_size_of_abi_save_regs();
99 mov(ALPHA, ptr[ARG_ALPHA]);
100 mov(B, ptr[ARG_B]);
101 #endif
102
103 mov(N, qword[N]);
104 mov(M, qword[M]);
105 mov(LDA, qword[LDA]);
106 shl(LDA, 1);
107 lea(LDA3, ptr[LDA + LDA * 2]);
108 sub(A, -128);
109 sub(B, -128);
110 cmp(N, 0x8);
111 jl(l24c, T_NEAR);
112 align(4);
113
114 L(l24);
115 mov(A1, A);
116 lea(A2, ptr[A1 + LDA * 4]);
117 lea(I, ptr[A1 + LDA * 8]);
118 mov(A, I);
119 mov(I, M);
120 sar(I, 0x3);
121 jle(ldc, T_NEAR);
122 align(4);
123
124 L(l40);
125 vmovdqu(xmm4, xword[A1 - 0x80]);
126 vmovdqu(xmm5, xword[A1 + LDA * 1 - 0x80]);
127 vmovdqu(xmm0, xword[A1 + LDA * 2 - 0x80]);
128 vmovdqu(xmm1, xword[A1 + LDA3 * 1 - 0x80]);
129 sub(A1, -16);
130 vmovdqu(xmm2, xword[A2 - 0x80]);
131 vperm2f128(ymm4, ymm4, ymm2, 0x20);
132 vmovdqu(xmm3, xword[A2 + LDA * 1 - 0x80]);
133 vperm2f128(ymm5, ymm5, ymm3, 0x20);
134 vmovdqu(xmm2, xword[A2 + LDA * 2 - 0x80]);
135 vperm2f128(ymm0, ymm0, ymm2, 0x20);
136 vmovdqu(xmm3, xword[A2 + LDA3 * 1 - 0x80]);
137 vperm2f128(ymm1, ymm1, ymm3, 0x20);
138 sub(A2, -16);
139 vunpcklps(ymm2, ymm4, ymm0);
140 vunpckhps(ymm3, ymm4, ymm0);
141 vunpcklps(ymm4, ymm5, ymm1);
142 vunpckhps(ymm5, ymm5, ymm1);
143 vunpcklps(ymm0, ymm2, ymm4);
144 vunpckhps(ymm1, ymm2, ymm4);
145 vunpcklps(ymm2, ymm3, ymm5);
146 vunpckhps(ymm3, ymm3, ymm5);
147 vmovdqu(yword[B - 0x80], ymm0);
148 vmovdqu(yword[B - 0x60], ymm1);
149 vmovdqu(yword[B - 0x40], ymm2);
150 vmovdqu(yword[B - 0x20], ymm3);
151 sub(B, -128);
152 dec(I);
153 jg(l40, T_NEAR);
154 align(4);
155
156 L(ldc);
157 test(M, 0x4);
158 jle(l158, T_NEAR);
159 vmovq(xmm0, qword[A1 - 0x80]);
160 vmovq(xmm1, qword[A1 + LDA * 1 - 0x80]);
161 vmovq(xmm2, qword[A1 + LDA * 2 - 0x80]);
162 vmovq(xmm3, qword[A1 + LDA3 * 1 - 0x80]);
163 sub(A1, -8);
164 vunpcklps(xmm0, xmm0, xmm2);
165 vunpcklps(xmm1, xmm1, xmm3);
166 vmovq(xmm2, qword[A2 - 0x80]);
167 vmovq(xmm3, qword[A2 + LDA * 1 - 0x80]);
168 vmovq(xmm4, qword[A2 + LDA * 2 - 0x80]);
169 vmovq(xmm5, qword[A2 + LDA3 * 1 - 0x80]);
170 sub(A2, -8);
171 vunpcklps(xmm2, xmm2, xmm4);
172 vunpcklps(xmm3, xmm3, xmm5);
173 vperm2f128(ymm0, ymm0, ymm2, 0x20);
174 vperm2f128(ymm1, ymm1, ymm3, 0x20);
175 vunpcklps(ymm2, ymm0, ymm1);
176 vunpckhps(ymm3, ymm0, ymm1);
177 vmovdqu(yword[B - 0x80], ymm2);
178 vmovdqu(yword[B - 0x60], ymm3);
179 sub(B, -64);
180 align(4);
181
182 L(l158);
183 test(M, 0x2);
184 jle(l1c8, T_NEAR);
185 vmovd(xmm0, dword[A1 - 0x80]);
186 vmovd(xmm1, dword[A1 + LDA * 1 - 0x80]);
187 vmovd(xmm2, dword[A1 + LDA * 2 - 0x80]);
188 vmovd(xmm3, dword[A1 + LDA3 * 1 - 0x80]);
189 sub(A1, -4);
190 vunpcklps(xmm0, xmm0, xmm1);
191 vunpcklps(xmm2, xmm2, xmm3);
192 vpunpcklqdq(xmm0, xmm0, xmm2);
193 vmovd(xmm1, dword[A2 - 0x80]);
194 vmovd(xmm2, dword[A2 + LDA * 1 - 0x80]);
195 vmovd(xmm3, dword[A2 + LDA * 2 - 0x80]);
196 vmovd(xmm4, dword[A2 + LDA3 * 1 - 0x80]);
197 sub(A2, -4);
198 vunpcklps(xmm1, xmm1, xmm2);
199 vunpcklps(xmm3, xmm3, xmm4);
200 vpunpcklqdq(xmm1, xmm1, xmm3);
201 vinsertf128(ymm0, ymm0, xmm1, 0x1);
202 vmovdqu(yword[B - 0x80], ymm0);
203 sub(B, -32);
204 align(4);
205
206 L(l1c8);
207 test(M, 0x1);
208 jle(l23c, T_NEAR);
209 mov(ax, word[A1 - 0x80]);
210 vpinsrw(xmm0, xmm0, eax, 0x0);
211 mov(ax, word[A1 + LDA * 1 - 0x80]);
212 vpinsrw(xmm0, xmm0, eax, 0x1);
213 mov(ax, word[A1 + LDA * 2 - 0x80]);
214 vpinsrw(xmm0, xmm0, eax, 0x2);
215 mov(ax, word[A1 + LDA3 * 1 - 0x80]);
216 vpinsrw(xmm0, xmm0, eax, 0x3);
217 lea(A2, ptr[A1 + LDA * 4]);
218 mov(ax, word[A2 - 0x80]);
219 vpinsrw(xmm0, xmm0, eax, 0x4);
220 mov(ax, word[A2 + LDA * 1 - 0x80]);
221 vpinsrw(xmm0, xmm0, eax, 0x5);
222 mov(ax, word[A2 + LDA * 2 - 0x80]);
223 vpinsrw(xmm0, xmm0, eax, 0x6);
224 mov(ax, word[A2 + LDA3 * 1 - 0x80]);
225 vpinsrw(xmm0, xmm0, eax, 0x7);
226 lea(A2, ptr[A2 + LDA * 4]);
227 vmovdqu(xword[B - 0x80], xmm0);
228 sub(B, -16);
229 align(4);
230
231 L(l23c);
232 sub(N, 0x8);
233 cmp(N, 0x8);
234 jge(l24, T_NEAR);
235 align(4);
236
237 L(l24c);
238 cmp(N, 0x4);
239 jl(l3c4, T_NEAR);
240 align(4);
241
242 L(l258);
243 mov(A1, A);
244 lea(A2, ptr[A1 + LDA * 2]);
245 lea(I, ptr[A1 + LDA * 4]);
246 mov(A, I);
247 mov(I, M);
248 sar(I, 0x3);
249 jle(l2e0, T_NEAR);
250 align(4);
251
252 L(l270);
253 vmovdqu(xmm0, xword[A1 - 0x80]);
254 vmovdqu(xmm1, xword[A1 + LDA * 1 - 0x80]);
255 sub(A1, -16);
256 vmovdqu(xmm2, xword[A2 - 0x80]);
257 vmovdqu(xmm3, xword[A2 + LDA * 1 - 0x80]);
258 sub(A2, -16);
259 vperm2f128(ymm0, ymm0, ymm2, 0x20);
260 vperm2f128(ymm1, ymm1, ymm3, 0x20);
261 vunpcklps(ymm2, ymm0, ymm1);
262 vunpckhps(ymm3, ymm0, ymm1);
263 vperm2f128(ymm0, ymm2, ymm2, 0x1);
264 vperm2f128(ymm1, ymm3, ymm3, 0x1);
265 vshufpd(ymm0, ymm2, ymm0, 0xc);
266 vshufpd(ymm1, ymm3, ymm1, 0xc);
267 vpermilpd(ymm0, ymm0, 0x6);
268 vpermilpd(ymm1, ymm1, 0x6);
269 vmovdqu(yword[B - 0x80], ymm0);
270 vmovdqu(yword[B - 0x60], ymm1);
271 sub(B, -64);
272 dec(I);
273 jg(l270, T_NEAR);
274 align(4);
275
276 L(l2e0);
277 test(M, 0x4);
278 jle(l32c, T_NEAR);
279 vmovq(xmm0, qword[A1 - 0x80]);
280 vmovq(xmm1, qword[A1 + LDA * 1 - 0x80]);
281 sub(A1, -8);
282 vmovq(xmm2, qword[A2 - 0x80]);
283 vmovq(xmm3, qword[A2 + LDA * 1 - 0x80]);
284 sub(A2, -8);
285 vunpcklps(xmm0, xmm0, xmm2);
286 vunpcklps(xmm1, xmm1, xmm3);
287 vunpcklps(xmm2, xmm0, xmm1);
288 vunpckhps(xmm3, xmm0, xmm1);
289 vmovdqu(xword[B - 0x80], xmm2);
290 vmovdqu(xword[B - 0x70], xmm3);
291 sub(B, -32);
292 align(4);
293
294 L(l32c);
295 test(M, 0x2);
296 jle(l370, T_NEAR);
297 vmovd(xmm0, dword[A1 - 0x80]);
298 vmovd(xmm1, dword[A1 + LDA * 1 - 0x80]);
299 sub(A1, -4);
300 vmovd(xmm2, dword[A2 - 0x80]);
301 vmovd(xmm3, dword[A2 + LDA * 1 - 0x80]);
302 sub(A2, -4);
303 vunpcklps(xmm0, xmm0, xmm1);
304 vunpcklps(xmm2, xmm2, xmm3);
305 vpunpcklqdq(xmm0, xmm0, xmm2);
306 vmovdqu(xword[B - 0x80], xmm0);
307 sub(B, -16);
308 align(4);
309
310 L(l370);
311 test(M, 0x1);
312 jle(l3b4, T_NEAR);
313 mov(ax, word[A1 - 0x80]);
314 vpinsrw(xmm0, xmm0, eax, 0x0);
315 mov(ax, word[A1 + LDA * 1 - 0x80]);
316 vpinsrw(xmm0, xmm0, eax, 0x1);
317 mov(ax, word[A1 + LDA * 2 - 0x80]);
318 vpinsrw(xmm0, xmm0, eax, 0x2);
319 mov(ax, word[A1 + LDA3 * 1 - 0x80]);
320 vpinsrw(xmm0, xmm0, eax, 0x3);
321 lea(A2, ptr[A1 + LDA * 4]);
322 vmovq(qword[B - 0x80], xmm0);
323 sub(B, -8);
324 align(4);
325
326 L(l3b4);
327 sub(N, 0x4);
328 cmp(N, 0x4);
329 jge(l258, T_NEAR);
330 align(4);
331
332 L(l3c4);
333 cmp(N, 0x2);
334 jl(l4aa, T_NEAR);
335 align(4);
336
337 L(l3d0);
338 mov(A1, A);
339 lea(A2, ptr[A1 + LDA * 1]);
340 lea(I, ptr[A1 + LDA * 2]);
341 mov(A, I);
342 mov(I, M);
343 sar(I, 0x3);
344 jle(l41c, T_NEAR);
345 align(4);
346
347 L(l3e8);
348 vmovdqu(xmm0, xword[A1 - 0x80]);
349 sub(A1, -16);
350 vmovdqu(xmm1, xword[A2 - 0x80]);
351 sub(A2, -16);
352 vunpcklps(xmm2, xmm0, xmm1);
353 vunpckhps(xmm3, xmm0, xmm1);
354 vmovdqu(xword[B - 0x80], xmm2);
355 vmovdqu(xword[B - 0x70], xmm3);
356 sub(B, -32);
357 dec(I);
358 jg(l3e8, T_NEAR);
359 align(4);
360
361 L(l41c);
362 test(M, 0x4);
363 jle(l448, T_NEAR);
364 vmovq(xmm0, qword[A1 - 0x80]);
365 sub(A1, -8);
366 vmovq(xmm1, qword[A2 - 0x80]);
367 sub(A2, -8);
368 vunpcklps(xmm0, xmm0, xmm1);
369 vmovdqu(xword[B - 0x80], xmm0);
370 sub(B, -16);
371 align(4);
372
373 L(l448);
374 test(M, 0x2);
375 jle(l474, T_NEAR);
376 vmovd(xmm0, dword[A1 - 0x80]);
377 sub(A1, -4);
378 vmovd(xmm1, dword[A2 - 0x80]);
379 sub(A2, -4);
380 vunpcklps(xmm0, xmm0, xmm1);
381 vmovq(qword[B - 0x80], xmm0);
382 sub(B, -8);
383 align(4);
384
385 L(l474);
386 test(M, 0x1);
387 jle(l49c, T_NEAR);
388 mov(ax, word[A1 - 0x80]);
389 vpinsrw(xmm0, xmm0, eax, 0x0);
390 mov(ax, word[A1 + LDA * 1 - 0x80]);
391 vpinsrw(xmm0, xmm0, eax, 0x1);
392 vmovd(dword[B - 0x80], xmm0);
393 sub(B, -4);
394 align(4);
395
396 L(l49c);
397 sub(N, 0x2);
398 cmp(N, 0x2);
399 jge(l3d0, T_NEAR);
400 align(4);
401
402 L(l4aa);
403 cmp(N, 0x1);
404 jl(l568, T_NEAR);
405 align(4);
406
407 L(l4b4);
408 mov(A1, A);
409 add(A, LDA);
410 mov(I, M);
411 sar(I, 0x4);
412 jle(l4e0, T_NEAR);
413 align(4);
414
415 L(l4c4);
416 vmovdqu(ymm0, yword[A1 - 0x80]);
417 sub(A1, -32);
418 vmovdqu(yword[B - 0x80], ymm0);
419 sub(B, -32);
420 dec(I);
421 jg(l4c4, T_NEAR);
422 align(4);
423
424 L(l4e0);
425 test(M, 0x8);
426 jle(l500, T_NEAR);
427 vmovdqu(xmm0, xword[A1 - 0x80]);
428 sub(A1, -16);
429 vmovdqu(xword[B - 0x80], xmm0);
430 sub(B, -16);
431 align(4);
432
433 L(l500);
434 test(M, 0x4);
435 jle(l520, T_NEAR);
436 vmovq(xmm0, qword[A1 - 0x80]);
437 sub(A1, -8);
438 vmovq(qword[B - 0x80], xmm0);
439 sub(B, -8);
440 align(4);
441
442 L(l520);
443 test(M, 0x2);
444 jle(l540, T_NEAR);
445 vmovd(xmm0, dword[A1 - 0x80]);
446 sub(A1, -4);
447 vmovd(dword[B - 0x80], xmm0);
448 sub(B, -4);
449 align(4);
450
451 L(l540);
452 test(M, 0x1);
453 jle(l558, T_NEAR);
454 mov(ax, word[A1 - 0x80]);
455 mov(word[B - 0x80], ax);
456 sub(B, -2);
457 align(4);
458
459 L(l558);
460 sub(N, 0x1);
461 cmp(N, 0x1);
462 jge(l4b4, T_NEAR);
463 align(4);
464
465 L(l568);
466 vzeroupper();
467 postamble();
468 }
469 outLocalLabel();
470
471 #undef M
472 #undef N
473 #undef A
474 #undef LDA
475 #undef ALPHA
476 #undef B
477 #undef I
478 #undef A1
479 #undef A2
480 #undef LDA3
481 #ifdef _WIN32
482 #undef ARG_ALPHA
483 #undef ARG_B
484 #endif
485 }
486
487 } // namespace x64
488 } // namespace cpu
489 } // namespace impl
490 } // namespace dnnl
491