1 /*******************************************************************************
2 * Copyright 2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "cpu/x64/jit_generator.hpp"
18
19 #include "cpu/x64/gemm/bf16/common_s16.hpp"
20
21 namespace dnnl {
22 namespace impl {
23 namespace cpu {
24 namespace x64 {
25
jit_avx512_core_s16_24x8_copy_bt_kern()26 jit_avx512_core_s16_24x8_copy_bt_kern::jit_avx512_core_s16_24x8_copy_bt_kern()
27 : jit_generator(nullptr, S16_COPY_KERNEL_CODE_SIZE) {}
28
generate()29 void jit_avx512_core_s16_24x8_copy_bt_kern::generate() {
30
31 #ifndef _WIN32
32 #define M rdi
33 #define N rsi
34 #define A rdx
35 #define LDA rcx
36 #define ALPHA r8
37 #define B r9
38
39 #define I rax
40 #define A1 r10
41 #define A2 r8
42 #define LDA3 r11
43
44 #else
45
46 #define M rcx
47 #define N rdx
48 #define A r8
49 #define LDA r9
50 #define ALPHA rax
51 #define B rdi
52
53 #define I rax
54 #define A1 rsi
55 #define A2 r10
56 #define LDA3 r11
57
58 #define ARG_ALPHA 40 + stacksize + rsp
59 #define ARG_B 48 + stacksize + rsp
60
61 #endif
62
63 inLocalLabel();
64 {
65
66 Xbyak::Label l13c;
67 Xbyak::Label l170;
68 Xbyak::Label l18c;
69 Xbyak::Label l19c;
70 Xbyak::Label l1a8;
71 Xbyak::Label l1b8;
72 Xbyak::Label l234;
73 Xbyak::Label l24;
74 Xbyak::Label l27c;
75 Xbyak::Label l2a8;
76 Xbyak::Label l2c4;
77 Xbyak::Label l2d4;
78 Xbyak::Label l2e0;
79 Xbyak::Label l2f0;
80 Xbyak::Label l368;
81 Xbyak::Label l38;
82 Xbyak::Label l3ac;
83 Xbyak::Label l3d8;
84 Xbyak::Label l3f4;
85 Xbyak::Label l402;
86 Xbyak::Label l40c;
87 Xbyak::Label l41c;
88 Xbyak::Label l494;
89 Xbyak::Label l4dc;
90 Xbyak::Label l50c;
91 Xbyak::Label l524;
92 Xbyak::Label l534;
93 Xbyak::Label le0;
94
95 preamble();
96 #ifdef _WIN32
97 auto stacksize = get_size_of_abi_save_regs();
98 mov(ALPHA, ptr[ARG_ALPHA]);
99 mov(B, ptr[ARG_B]);
100 #endif
101
102 mov(M, qword[M]);
103 mov(N, qword[N]);
104 mov(LDA, qword[LDA]);
105 shl(LDA, 1);
106 lea(LDA3, ptr[LDA + LDA * 2]);
107 sub(A, -128);
108 sub(B, -128);
109 cmp(N, 0x8);
110 jl(l19c, T_NEAR);
111 align(4);
112
113 L(l24);
114 mov(A1, A);
115 add(A, 0x10);
116 mov(I, M);
117 sar(I, 0x3);
118 jle(le0, T_NEAR);
119 align(4);
120
121 L(l38);
122 vmovdqu(xmm0, xword[A1 - 0x80]);
123 add(A1, LDA);
124 vmovdqu(xmm1, xword[A1 - 0x80]);
125 add(A1, LDA);
126 vmovdqu(xmm2, xword[A1 - 0x80]);
127 add(A1, LDA);
128 vmovdqu(xmm3, xword[A1 - 0x80]);
129 add(A1, LDA);
130 vpunpcklwd(xmm4, xmm0, xmm1);
131 vpunpckhwd(xmm5, xmm0, xmm1);
132 vperm2f128(ymm0, ymm4, ymm5, 0x20);
133 vpunpcklwd(xmm4, xmm2, xmm3);
134 vpunpckhwd(xmm5, xmm2, xmm3);
135 vperm2f128(ymm2, ymm4, ymm5, 0x20);
136 vmovdqu(yword[B - 0x80], ymm0);
137 vmovdqu(yword[B - 0x60], ymm2);
138 vmovdqu(xmm0, xword[A1 - 0x80]);
139 add(A1, LDA);
140 vmovdqu(xmm1, xword[A1 - 0x80]);
141 add(A1, LDA);
142 vmovdqu(xmm2, xword[A1 - 0x80]);
143 add(A1, LDA);
144 vmovdqu(xmm3, xword[A1 - 0x80]);
145 add(A1, LDA);
146 vpunpcklwd(xmm4, xmm0, xmm1);
147 vpunpckhwd(xmm5, xmm0, xmm1);
148 vperm2f128(ymm0, ymm4, ymm5, 0x20);
149 vpunpcklwd(xmm4, xmm2, xmm3);
150 vpunpckhwd(xmm5, xmm2, xmm3);
151 vperm2f128(ymm2, ymm4, ymm5, 0x20);
152 vmovdqu(yword[B - 0x40], ymm0);
153 vmovdqu(yword[B - 0x20], ymm2);
154 sub(B, -128);
155 dec(I);
156 jg(l38, T_NEAR);
157 align(4);
158
159 L(le0);
160 test(M, 0x4);
161 jle(l13c, T_NEAR);
162 vmovdqu(xmm0, xword[A1 - 0x80]);
163 add(A1, LDA);
164 vmovdqu(xmm1, xword[A1 - 0x80]);
165 add(A1, LDA);
166 vmovdqu(xmm2, xword[A1 - 0x80]);
167 add(A1, LDA);
168 vmovdqu(xmm3, xword[A1 - 0x80]);
169 add(A1, LDA);
170 vpunpcklwd(xmm4, xmm0, xmm1);
171 vpunpckhwd(xmm5, xmm0, xmm1);
172 vperm2f128(ymm0, ymm4, ymm5, 0x20);
173 vpunpcklwd(xmm4, xmm2, xmm3);
174 vpunpckhwd(xmm5, xmm2, xmm3);
175 vperm2f128(ymm2, ymm4, ymm5, 0x20);
176 vmovdqu(yword[B - 0x80], ymm0);
177 vmovdqu(yword[B - 0x60], ymm2);
178 sub(B, -64);
179 align(4);
180
181 L(l13c);
182 test(M, 0x2);
183 jle(l170, T_NEAR);
184 vmovdqu(xmm0, xword[A1 - 0x80]);
185 add(A1, LDA);
186 vmovdqu(xmm1, xword[A1 - 0x80]);
187 add(A1, LDA);
188 vpunpcklwd(xmm2, xmm0, xmm1);
189 vpunpckhwd(xmm3, xmm0, xmm1);
190 vperm2f128(ymm0, ymm2, ymm3, 0x20);
191 vmovdqu(yword[B - 0x80], ymm0);
192 sub(B, -32);
193 align(4);
194
195 L(l170);
196 test(M, 0x1);
197 jle(l18c, T_NEAR);
198 vmovdqu(xmm0, xword[A1 - 0x80]);
199 vmovdqu(xword[B - 0x80], xmm0);
200 sub(B, -16);
201 align(4);
202
203 L(l18c);
204 sub(N, 0x8);
205 cmp(N, 0x8);
206 jge(l24, T_NEAR);
207 align(4);
208
209 L(l19c);
210 cmp(N, 0x4);
211 jl(l2d4, T_NEAR);
212 align(4);
213
214 L(l1a8);
215 mov(A1, A);
216 add(A, 0x8);
217 mov(I, M);
218 sar(I, 0x3);
219 jle(l234, T_NEAR);
220 align(4);
221
222 L(l1b8);
223 vmovq(xmm0, qword[A1 - 0x80]);
224 add(A1, LDA);
225 vmovq(xmm1, qword[A1 - 0x80]);
226 add(A1, LDA);
227 vmovq(xmm2, qword[A1 - 0x80]);
228 add(A1, LDA);
229 vmovq(xmm3, qword[A1 - 0x80]);
230 add(A1, LDA);
231 vpunpcklwd(xmm0, xmm0, xmm1);
232 vpunpcklwd(xmm2, xmm2, xmm3);
233 vperm2f128(ymm0, ymm0, ymm2, 0x20);
234 vmovdqu(yword[B - 0x80], ymm0);
235 vmovq(xmm0, qword[A1 - 0x80]);
236 add(A1, LDA);
237 vmovq(xmm1, qword[A1 - 0x80]);
238 add(A1, LDA);
239 vmovq(xmm2, qword[A1 - 0x80]);
240 add(A1, LDA);
241 vmovq(xmm3, qword[A1 - 0x80]);
242 add(A1, LDA);
243 vpunpcklwd(xmm0, xmm0, xmm1);
244 vpunpcklwd(xmm2, xmm2, xmm3);
245 vperm2f128(ymm0, ymm0, ymm2, 0x20);
246 vmovdqu(yword[B - 0x60], ymm0);
247 sub(B, -64);
248 dec(I);
249 jg(l1b8, T_NEAR);
250 align(4);
251
252 L(l234);
253 test(M, 0x4);
254 jle(l27c, T_NEAR);
255 vmovq(xmm0, qword[A1 - 0x80]);
256 add(A1, LDA);
257 vmovq(xmm1, qword[A1 - 0x80]);
258 add(A1, LDA);
259 vmovq(xmm2, qword[A1 - 0x80]);
260 add(A1, LDA);
261 vmovq(xmm3, qword[A1 - 0x80]);
262 add(A1, LDA);
263 vpunpcklwd(xmm0, xmm0, xmm1);
264 vpunpcklwd(xmm2, xmm2, xmm3);
265 vmovdqu(xword[B - 0x80], xmm0);
266 vmovdqu(xword[B - 0x70], xmm2);
267 sub(B, -32);
268 align(4);
269
270 L(l27c);
271 test(M, 0x2);
272 jle(l2a8, T_NEAR);
273 vmovq(xmm0, qword[A1 - 0x80]);
274 add(A1, LDA);
275 vmovq(xmm1, qword[A1 - 0x80]);
276 add(A1, LDA);
277 vpunpcklwd(xmm0, xmm0, xmm1);
278 vmovdqu(xword[B - 0x80], xmm0);
279 sub(B, -16);
280 align(4);
281
282 L(l2a8);
283 test(M, 0x1);
284 jle(l2c4, T_NEAR);
285 vmovq(xmm0, qword[A1 - 0x80]);
286 vmovq(qword[B - 0x80], xmm0);
287 sub(B, -8);
288 align(4);
289
290 L(l2c4);
291 sub(N, 0x4);
292 cmp(N, 0x4);
293 jge(l1a8, T_NEAR);
294 align(4);
295
296 L(l2d4);
297 cmp(N, 0x2);
298 jl(l402, T_NEAR);
299 align(4);
300
301 L(l2e0);
302 mov(A1, A);
303 add(A, 0x4);
304 mov(I, M);
305 sar(I, 0x3);
306 jle(l368, T_NEAR);
307 align(4);
308
309 L(l2f0);
310 vmovd(xmm0, dword[A1 - 0x80]);
311 add(A1, LDA);
312 vmovd(xmm1, dword[A1 - 0x80]);
313 add(A1, LDA);
314 vmovd(xmm2, dword[A1 - 0x80]);
315 add(A1, LDA);
316 vmovd(xmm3, dword[A1 - 0x80]);
317 add(A1, LDA);
318 vpunpcklwd(xmm0, xmm0, xmm1);
319 vpunpcklwd(xmm2, xmm2, xmm3);
320 vpunpcklqdq(xmm0, xmm0, xmm2);
321 vmovdqu(xword[B - 0x80], xmm0);
322 vmovd(xmm0, dword[A1 - 0x80]);
323 add(A1, LDA);
324 vmovd(xmm1, dword[A1 - 0x80]);
325 add(A1, LDA);
326 vmovd(xmm2, dword[A1 - 0x80]);
327 add(A1, LDA);
328 vmovd(xmm3, dword[A1 - 0x80]);
329 add(A1, LDA);
330 vpunpcklwd(xmm0, xmm0, xmm1);
331 vpunpcklwd(xmm2, xmm2, xmm3);
332 vpunpcklqdq(xmm0, xmm0, xmm2);
333 vmovdqu(xword[B - 0x70], xmm0);
334 sub(B, -32);
335 dec(I);
336 jg(l2f0, T_NEAR);
337 align(4);
338
339 L(l368);
340 test(M, 0x4);
341 jle(l3ac, T_NEAR);
342 vmovd(xmm0, dword[A1 - 0x80]);
343 add(A1, LDA);
344 vmovd(xmm1, dword[A1 - 0x80]);
345 add(A1, LDA);
346 vmovd(xmm2, dword[A1 - 0x80]);
347 add(A1, LDA);
348 vmovd(xmm3, dword[A1 - 0x80]);
349 add(A1, LDA);
350 vpunpcklwd(xmm0, xmm0, xmm1);
351 vpunpcklwd(xmm2, xmm2, xmm3);
352 vpunpcklqdq(xmm0, xmm0, xmm2);
353 vmovdqu(xword[B - 0x80], xmm0);
354 sub(B, -16);
355 align(4);
356
357 L(l3ac);
358 test(M, 0x2);
359 jle(l3d8, T_NEAR);
360 vmovd(xmm0, dword[A1 - 0x80]);
361 add(A1, LDA);
362 vmovd(xmm1, dword[A1 - 0x80]);
363 add(A1, LDA);
364 vpunpcklwd(xmm0, xmm0, xmm1);
365 vmovq(qword[B - 0x80], xmm0);
366 sub(B, -8);
367 align(4);
368
369 L(l3d8);
370 test(M, 0x1);
371 jle(l3f4, T_NEAR);
372 vmovd(xmm0, dword[A1 - 0x80]);
373 vmovd(dword[B - 0x80], xmm0);
374 sub(B, -4);
375 align(4);
376
377 L(l3f4);
378 sub(N, 0x2);
379 cmp(N, 0x2);
380 jge(l2e0, T_NEAR);
381 align(4);
382
383 L(l402);
384 cmp(N, 0x1);
385 jl(l534, T_NEAR);
386 align(4);
387
388 L(l40c);
389 mov(A1, A);
390 add(A, 0x2);
391 mov(LDA3, M);
392 sar(LDA3, 0x3);
393 jle(l494, T_NEAR);
394 align(4);
395
396 L(l41c);
397 mov(ax, word[A1 - 0x80]);
398 add(A1, LDA);
399 vpinsrw(xmm0, xmm0, eax, 0x0);
400 mov(ax, word[A1 - 0x80]);
401 add(A1, LDA);
402 vpinsrw(xmm0, xmm0, eax, 0x1);
403 mov(ax, word[A1 - 0x80]);
404 add(A1, LDA);
405 vpinsrw(xmm0, xmm0, eax, 0x2);
406 mov(ax, word[A1 - 0x80]);
407 add(A1, LDA);
408 vpinsrw(xmm0, xmm0, eax, 0x3);
409 mov(ax, word[A1 - 0x80]);
410 add(A1, LDA);
411 vpinsrw(xmm0, xmm0, eax, 0x4);
412 mov(ax, word[A1 - 0x80]);
413 add(A1, LDA);
414 vpinsrw(xmm0, xmm0, eax, 0x5);
415 mov(ax, word[A1 - 0x80]);
416 add(A1, LDA);
417 vpinsrw(xmm0, xmm0, eax, 0x6);
418 mov(ax, word[A1 - 0x80]);
419 add(A1, LDA);
420 vpinsrw(xmm0, xmm0, eax, 0x7);
421 vmovdqu(xword[B - 0x80], xmm0);
422 sub(B, -16);
423 dec(LDA3);
424 jg(l41c, T_NEAR);
425 align(4);
426
427 L(l494);
428 test(M, 0x4);
429 jle(l4dc, T_NEAR);
430 mov(ax, word[A1 - 0x80]);
431 add(A1, LDA);
432 vpinsrw(xmm0, xmm0, eax, 0x0);
433 mov(ax, word[A1 - 0x80]);
434 add(A1, LDA);
435 vpinsrw(xmm0, xmm0, eax, 0x1);
436 mov(ax, word[A1 - 0x80]);
437 add(A1, LDA);
438 vpinsrw(xmm0, xmm0, eax, 0x2);
439 mov(ax, word[A1 - 0x80]);
440 add(A1, LDA);
441 vpinsrw(xmm0, xmm0, eax, 0x3);
442 vmovq(qword[B - 0x80], xmm0);
443 sub(B, -8);
444 align(4);
445
446 L(l4dc);
447 test(M, 0x2);
448 jle(l50c, T_NEAR);
449 mov(ax, word[A1 - 0x80]);
450 add(A1, LDA);
451 vpinsrw(xmm0, xmm0, eax, 0x0);
452 mov(ax, word[A1 - 0x80]);
453 add(A1, LDA);
454 vpinsrw(xmm0, xmm0, eax, 0x1);
455 vmovd(dword[B - 0x80], xmm0);
456 sub(B, -4);
457 align(4);
458
459 L(l50c);
460 test(M, 0x1);
461 jle(l524, T_NEAR);
462 mov(ax, word[A1 - 0x80]);
463 mov(word[B - 0x80], ax);
464 sub(B, -2);
465 align(4);
466
467 L(l524);
468 sub(N, 0x1);
469 cmp(N, 0x1);
470 jge(l40c, T_NEAR);
471 align(4);
472
473 L(l534);
474 vzeroupper();
475 postamble();
476 }
477 outLocalLabel();
478
479 #undef M
480 #undef N
481 #undef A
482 #undef LDA
483 #undef ALPHA
484 #undef B
485 #undef I
486 #undef A1
487 #undef A2
488 #undef LDA3
489 #ifdef _WIN32
490 #undef ARG_ALPHA
491 #undef ARG_B
492 #endif
493 }
494
495 } // namespace x64
496 } // namespace cpu
497 } // namespace impl
498 } // namespace dnnl
499