1 /*
2
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
6
7 Copyright (C) 2014, The University of Texas at Austin
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name(s) of the copyright holder(s) nor the names of its
18 contributors may be used to endorse or promote products derived
19 from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33 */
34
35 #include "blis.h"
36
37 #define BLIS_ASM_SYNTAX_ATT
38 #include "bli_x86_asm_macros.h"
39
40 #define SGEMM_INPUT_GS_BETA_NZ \
41 vmovlps(mem(rcx), xmm0, xmm0) \
42 vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) \
43 vmovlps(mem(rcx, rsi, 2), xmm1, xmm1) \
44 vmovhps(mem(rcx, r13, 1), xmm1, xmm1) \
45 vshufps(imm(0x88), xmm1, xmm0, xmm0) \
46 vmovlps(mem(rcx, rsi, 4), xmm2, xmm2) \
47 vmovhps(mem(rcx, r15, 1), xmm2, xmm2) \
48 /* We can't use vmovhps for loading the last element becauase that
49 might result in reading beyond valid memory. (vmov[lh]psd load
50 pairs of adjacent floats at a time.) So we need to use vmovss
51 instead. But since we're limited to using ymm0 through ymm2
52 (ymm3 contains beta and ymm4 through ymm15 contain the microtile)
53 and due to the way vmovss zeros out all bits above 31, we have to
54 load element 7 before element 6. */ \
55 vmovss(mem(rcx, r10, 1), xmm1) \
56 vpermilps(imm(0xcf), xmm1, xmm1) \
57 vmovlps(mem(rcx, r13, 2), xmm1, xmm1) \
58 /*vmovhps(mem(rcx, r10, 1), xmm1, xmm1)*/ \
59 vshufps(imm(0x88), xmm1, xmm2, xmm2) \
60 vperm2f128(imm(0x20), ymm2, ymm0, ymm0)
61
62 #define SGEMM_OUTPUT_GS_BETA_NZ \
63 vextractf128(imm(1), ymm0, xmm2) \
64 vmovss(xmm0, mem(rcx)) \
65 vpermilps(imm(0x39), xmm0, xmm1) \
66 vmovss(xmm1, mem(rcx, rsi, 1)) \
67 vpermilps(imm(0x39), xmm1, xmm0) \
68 vmovss(xmm0, mem(rcx, rsi, 2)) \
69 vpermilps(imm(0x39), xmm0, xmm1) \
70 vmovss(xmm1, mem(rcx, r13, 1)) \
71 vmovss(xmm2, mem(rcx, rsi, 4)) \
72 vpermilps(imm(0x39), xmm2, xmm1) \
73 vmovss(xmm1, mem(rcx, r15, 1)) \
74 vpermilps(imm(0x39), xmm1, xmm2) \
75 vmovss(xmm2, mem(rcx, r13, 2)) \
76 vpermilps(imm(0x39), xmm2, xmm1) \
77 vmovss(xmm1, mem(rcx, r10, 1))
78
bli_sgemm_haswell_asm_16x6(dim_t k0,float * restrict alpha,float * restrict a,float * restrict b,float * restrict beta,float * restrict c,inc_t rs_c0,inc_t cs_c0,auxinfo_t * restrict data,cntx_t * restrict cntx)79 void bli_sgemm_haswell_asm_16x6
80 (
81 dim_t k0,
82 float* restrict alpha,
83 float* restrict a,
84 float* restrict b,
85 float* restrict beta,
86 float* restrict c, inc_t rs_c0, inc_t cs_c0,
87 auxinfo_t* restrict data,
88 cntx_t* restrict cntx
89 )
90 {
91 //void* a_next = bli_auxinfo_next_a( data );
92 //void* b_next = bli_auxinfo_next_b( data );
93
94 // Typecast local copies of integers in case dim_t and inc_t are a
95 // different size than is expected by load instructions.
96 uint64_t k_iter = k0 / 4;
97 uint64_t k_left = k0 % 4;
98 uint64_t rs_c = rs_c0;
99 uint64_t cs_c = cs_c0;
100
101 begin_asm()
102
103 vzeroall() // zero all xmm/ymm registers.
104
105
106 mov(var(a), rax) // load address of a.
107 mov(var(b), rbx) // load address of b.
108 //mov(%9, r15) // load address of b_next.
109
110 add(imm(32*4), rax)
111 // initialize loop by pre-loading
112 vmovaps(mem(rax, -4*32), ymm0)
113 vmovaps(mem(rax, -3*32), ymm1)
114
115 mov(var(c), rcx) // load address of c
116 mov(var(cs_c), rdi) // load cs_c
117 lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float)
118
119 lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c;
120 lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c;
121 prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c
122 prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c
123 prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*cs_c
124 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c
125 prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c
126 prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c
127
128
129
130
131 mov(var(k_iter), rsi) // i = k_iter;
132 test(rsi, rsi) // check i via logical AND.
133 je(.SCONSIDKLEFT) // if i == 0, jump to code that
134 // contains the k_left loop.
135
136
137 label(.SLOOPKITER) // MAIN LOOP
138
139
140 // iteration 0
141 prefetch(0, mem(rax, 128*4))
142
143 vbroadcastss(mem(rbx, 0*4), ymm2)
144 vbroadcastss(mem(rbx, 1*4), ymm3)
145 vfmadd231ps(ymm0, ymm2, ymm4)
146 vfmadd231ps(ymm1, ymm2, ymm5)
147 vfmadd231ps(ymm0, ymm3, ymm6)
148 vfmadd231ps(ymm1, ymm3, ymm7)
149
150 vbroadcastss(mem(rbx, 2*4), ymm2)
151 vbroadcastss(mem(rbx, 3*4), ymm3)
152 vfmadd231ps(ymm0, ymm2, ymm8)
153 vfmadd231ps(ymm1, ymm2, ymm9)
154 vfmadd231ps(ymm0, ymm3, ymm10)
155 vfmadd231ps(ymm1, ymm3, ymm11)
156
157 vbroadcastss(mem(rbx, 4*4), ymm2)
158 vbroadcastss(mem(rbx, 5*4), ymm3)
159 vfmadd231ps(ymm0, ymm2, ymm12)
160 vfmadd231ps(ymm1, ymm2, ymm13)
161 vfmadd231ps(ymm0, ymm3, ymm14)
162 vfmadd231ps(ymm1, ymm3, ymm15)
163
164 vmovaps(mem(rax, -2*32), ymm0)
165 vmovaps(mem(rax, -1*32), ymm1)
166
167 // iteration 1
168 vbroadcastss(mem(rbx, 6*4), ymm2)
169 vbroadcastss(mem(rbx, 7*4), ymm3)
170 vfmadd231ps(ymm0, ymm2, ymm4)
171 vfmadd231ps(ymm1, ymm2, ymm5)
172 vfmadd231ps(ymm0, ymm3, ymm6)
173 vfmadd231ps(ymm1, ymm3, ymm7)
174
175 vbroadcastss(mem(rbx, 8*4), ymm2)
176 vbroadcastss(mem(rbx, 9*4), ymm3)
177 vfmadd231ps(ymm0, ymm2, ymm8)
178 vfmadd231ps(ymm1, ymm2, ymm9)
179 vfmadd231ps(ymm0, ymm3, ymm10)
180 vfmadd231ps(ymm1, ymm3, ymm11)
181
182 vbroadcastss(mem(rbx, 10*4), ymm2)
183 vbroadcastss(mem(rbx, 11*4), ymm3)
184 vfmadd231ps(ymm0, ymm2, ymm12)
185 vfmadd231ps(ymm1, ymm2, ymm13)
186 vfmadd231ps(ymm0, ymm3, ymm14)
187 vfmadd231ps(ymm1, ymm3, ymm15)
188
189 vmovaps(mem(rax, 0*32), ymm0)
190 vmovaps(mem(rax, 1*32), ymm1)
191
192 // iteration 2
193 prefetch(0, mem(rax, 152*4))
194
195 vbroadcastss(mem(rbx, 12*4), ymm2)
196 vbroadcastss(mem(rbx, 13*4), ymm3)
197 vfmadd231ps(ymm0, ymm2, ymm4)
198 vfmadd231ps(ymm1, ymm2, ymm5)
199 vfmadd231ps(ymm0, ymm3, ymm6)
200 vfmadd231ps(ymm1, ymm3, ymm7)
201
202 vbroadcastss(mem(rbx, 14*4), ymm2)
203 vbroadcastss(mem(rbx, 15*4), ymm3)
204 vfmadd231ps(ymm0, ymm2, ymm8)
205 vfmadd231ps(ymm1, ymm2, ymm9)
206 vfmadd231ps(ymm0, ymm3, ymm10)
207 vfmadd231ps(ymm1, ymm3, ymm11)
208
209 vbroadcastss(mem(rbx, 16*4), ymm2)
210 vbroadcastss(mem(rbx, 17*4), ymm3)
211 vfmadd231ps(ymm0, ymm2, ymm12)
212 vfmadd231ps(ymm1, ymm2, ymm13)
213 vfmadd231ps(ymm0, ymm3, ymm14)
214 vfmadd231ps(ymm1, ymm3, ymm15)
215
216 vmovaps(mem(rax, 2*32), ymm0)
217 vmovaps(mem(rax, 3*32), ymm1)
218
219 // iteration 3
220 vbroadcastss(mem(rbx, 18*4), ymm2)
221 vbroadcastss(mem(rbx, 19*4), ymm3)
222 vfmadd231ps(ymm0, ymm2, ymm4)
223 vfmadd231ps(ymm1, ymm2, ymm5)
224 vfmadd231ps(ymm0, ymm3, ymm6)
225 vfmadd231ps(ymm1, ymm3, ymm7)
226
227 vbroadcastss(mem(rbx, 20*4), ymm2)
228 vbroadcastss(mem(rbx, 21*4), ymm3)
229 vfmadd231ps(ymm0, ymm2, ymm8)
230 vfmadd231ps(ymm1, ymm2, ymm9)
231 vfmadd231ps(ymm0, ymm3, ymm10)
232 vfmadd231ps(ymm1, ymm3, ymm11)
233
234 vbroadcastss(mem(rbx, 22*4), ymm2)
235 vbroadcastss(mem(rbx, 23*4), ymm3)
236 vfmadd231ps(ymm0, ymm2, ymm12)
237 vfmadd231ps(ymm1, ymm2, ymm13)
238 vfmadd231ps(ymm0, ymm3, ymm14)
239 vfmadd231ps(ymm1, ymm3, ymm15)
240
241 add(imm(4*16*4), rax) // a += 4*16 (unroll x mr)
242 add(imm(4*6*4), rbx) // b += 4*6 (unroll x nr)
243
244 vmovaps(mem(rax, -4*32), ymm0)
245 vmovaps(mem(rax, -3*32), ymm1)
246
247
248 dec(rsi) // i -= 1;
249 jne(.SLOOPKITER) // iterate again if i != 0.
250
251
252
253
254
255
256 label(.SCONSIDKLEFT)
257
258 mov(var(k_left), rsi) // i = k_left;
259 test(rsi, rsi) // check i via logical AND.
260 je(.SPOSTACCUM) // if i == 0, we're done; jump to end.
261 // else, we prepare to enter k_left loop.
262
263
264 label(.SLOOPKLEFT) // EDGE LOOP
265
266 prefetch(0, mem(rax, 128*4))
267
268 vbroadcastss(mem(rbx, 0*4), ymm2)
269 vbroadcastss(mem(rbx, 1*4), ymm3)
270 vfmadd231ps(ymm0, ymm2, ymm4)
271 vfmadd231ps(ymm1, ymm2, ymm5)
272 vfmadd231ps(ymm0, ymm3, ymm6)
273 vfmadd231ps(ymm1, ymm3, ymm7)
274
275 vbroadcastss(mem(rbx, 2*4), ymm2)
276 vbroadcastss(mem(rbx, 3*4), ymm3)
277 vfmadd231ps(ymm0, ymm2, ymm8)
278 vfmadd231ps(ymm1, ymm2, ymm9)
279 vfmadd231ps(ymm0, ymm3, ymm10)
280 vfmadd231ps(ymm1, ymm3, ymm11)
281
282 vbroadcastss(mem(rbx, 4*4), ymm2)
283 vbroadcastss(mem(rbx, 5*4), ymm3)
284 vfmadd231ps(ymm0, ymm2, ymm12)
285 vfmadd231ps(ymm1, ymm2, ymm13)
286 vfmadd231ps(ymm0, ymm3, ymm14)
287 vfmadd231ps(ymm1, ymm3, ymm15)
288
289 add(imm(1*16*4), rax) // a += 1*16 (unroll x mr)
290 add(imm(1*6*4), rbx) // b += 1*6 (unroll x nr)
291
292 vmovaps(mem(rax, -4*32), ymm0)
293 vmovaps(mem(rax, -3*32), ymm1)
294
295
296 dec(rsi) // i -= 1;
297 jne(.SLOOPKLEFT) // iterate again if i != 0.
298
299
300
301 label(.SPOSTACCUM)
302
303
304
305
306 mov(var(alpha), rax) // load address of alpha
307 mov(var(beta), rbx) // load address of beta
308 vbroadcastss(mem(rax), ymm0) // load alpha and duplicate
309 vbroadcastss(mem(rbx), ymm3) // load beta and duplicate
310
311 vmulps(ymm0, ymm4, ymm4) // scale by alpha
312 vmulps(ymm0, ymm5, ymm5)
313 vmulps(ymm0, ymm6, ymm6)
314 vmulps(ymm0, ymm7, ymm7)
315 vmulps(ymm0, ymm8, ymm8)
316 vmulps(ymm0, ymm9, ymm9)
317 vmulps(ymm0, ymm10, ymm10)
318 vmulps(ymm0, ymm11, ymm11)
319 vmulps(ymm0, ymm12, ymm12)
320 vmulps(ymm0, ymm13, ymm13)
321 vmulps(ymm0, ymm14, ymm14)
322 vmulps(ymm0, ymm15, ymm15)
323
324
325
326
327
328
329 mov(var(rs_c), rsi) // load rs_c
330 lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float)
331
332 lea(mem(rcx, rsi, 8), rdx) // load address of c + 8*rs_c;
333
334 lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c;
335 lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c;
336 lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c;
337
338
339 // now avoid loading C if beta == 0
340
341 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero.
342 vucomiss(xmm0, xmm3) // set ZF if beta == 0.
343 je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case
344
345
346 cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4.
347 jz(.SCOLSTORED) // jump to column storage case
348
349
350
351 label(.SGENSTORED)
352
353
354 SGEMM_INPUT_GS_BETA_NZ
355 vfmadd213ps(ymm4, ymm3, ymm0)
356 SGEMM_OUTPUT_GS_BETA_NZ
357 add(rdi, rcx) // c += cs_c;
358
359
360 SGEMM_INPUT_GS_BETA_NZ
361 vfmadd213ps(ymm6, ymm3, ymm0)
362 SGEMM_OUTPUT_GS_BETA_NZ
363 add(rdi, rcx) // c += cs_c;
364
365
366 SGEMM_INPUT_GS_BETA_NZ
367 vfmadd213ps(ymm8, ymm3, ymm0)
368 SGEMM_OUTPUT_GS_BETA_NZ
369 add(rdi, rcx) // c += cs_c;
370
371
372 SGEMM_INPUT_GS_BETA_NZ
373 vfmadd213ps(ymm10, ymm3, ymm0)
374 SGEMM_OUTPUT_GS_BETA_NZ
375 add(rdi, rcx) // c += cs_c;
376
377
378 SGEMM_INPUT_GS_BETA_NZ
379 vfmadd213ps(ymm12, ymm3, ymm0)
380 SGEMM_OUTPUT_GS_BETA_NZ
381 add(rdi, rcx) // c += cs_c;
382
383
384 SGEMM_INPUT_GS_BETA_NZ
385 vfmadd213ps(ymm14, ymm3, ymm0)
386 SGEMM_OUTPUT_GS_BETA_NZ
387 //add(rdi, rcx) // c += cs_c;
388
389
390 mov(rdx, rcx) // rcx = c + 8*rs_c
391
392
393 SGEMM_INPUT_GS_BETA_NZ
394 vfmadd213ps(ymm5, ymm3, ymm0)
395 SGEMM_OUTPUT_GS_BETA_NZ
396 add(rdi, rcx) // c += cs_c;
397
398
399 SGEMM_INPUT_GS_BETA_NZ
400 vfmadd213ps(ymm7, ymm3, ymm0)
401 SGEMM_OUTPUT_GS_BETA_NZ
402 add(rdi, rcx) // c += cs_c;
403
404
405 SGEMM_INPUT_GS_BETA_NZ
406 vfmadd213ps(ymm9, ymm3, ymm0)
407 SGEMM_OUTPUT_GS_BETA_NZ
408 add(rdi, rcx) // c += cs_c;
409
410
411 SGEMM_INPUT_GS_BETA_NZ
412 vfmadd213ps(ymm11, ymm3, ymm0)
413 SGEMM_OUTPUT_GS_BETA_NZ
414 add(rdi, rcx) // c += cs_c;
415
416
417 SGEMM_INPUT_GS_BETA_NZ
418 vfmadd213ps(ymm13, ymm3, ymm0)
419 SGEMM_OUTPUT_GS_BETA_NZ
420 add(rdi, rcx) // c += cs_c;
421
422
423 SGEMM_INPUT_GS_BETA_NZ
424 vfmadd213ps(ymm15, ymm3, ymm0)
425 SGEMM_OUTPUT_GS_BETA_NZ
426 //add(rdi, rcx) // c += cs_c;
427
428
429
430 jmp(.SDONE) // jump to end.
431
432
433
434 label(.SCOLSTORED)
435
436
437 vfmadd231ps(mem(rcx), ymm3, ymm4)
438 vmovups(ymm4, mem(rcx))
439 add(rdi, rcx)
440 vfmadd231ps(mem(rdx), ymm3, ymm5)
441 vmovups(ymm5, mem(rdx))
442 add(rdi, rdx)
443
444
445 vfmadd231ps(mem(rcx), ymm3, ymm6)
446 vmovups(ymm6, mem(rcx))
447 add(rdi, rcx)
448 vfmadd231ps(mem(rdx), ymm3, ymm7)
449 vmovups(ymm7, mem(rdx))
450 add(rdi, rdx)
451
452
453 vfmadd231ps(mem(rcx), ymm3, ymm8)
454 vmovups(ymm8, mem(rcx))
455 add(rdi, rcx)
456 vfmadd231ps(mem(rdx), ymm3, ymm9)
457 vmovups(ymm9, mem(rdx))
458 add(rdi, rdx)
459
460
461 vfmadd231ps(mem(rcx), ymm3, ymm10)
462 vmovups(ymm10, mem(rcx))
463 add(rdi, rcx)
464 vfmadd231ps(mem(rdx), ymm3, ymm11)
465 vmovups(ymm11, mem(rdx))
466 add(rdi, rdx)
467
468
469 vfmadd231ps(mem(rcx), ymm3, ymm12)
470 vmovups(ymm12, mem(rcx))
471 add(rdi, rcx)
472 vfmadd231ps(mem(rdx), ymm3, ymm13)
473 vmovups(ymm13, mem(rdx))
474 add(rdi, rdx)
475
476
477 vfmadd231ps(mem(rcx), ymm3, ymm14)
478 vmovups(ymm14, mem(rcx))
479 //add(rdi, rcx)
480 vfmadd231ps(mem(rdx), ymm3, ymm15)
481 vmovups(ymm15, mem(rdx))
482 //add(rdi, rdx)
483
484
485
486
487 jmp(.SDONE) // jump to end.
488
489
490
491 label(.SBETAZERO)
492
493 cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4.
494 jz(.SCOLSTORBZ) // jump to column storage case
495
496
497
498 label(.SGENSTORBZ)
499
500
501 vmovaps(ymm4, ymm0)
502 SGEMM_OUTPUT_GS_BETA_NZ
503 add(rdi, rcx) // c += cs_c;
504
505
506 vmovaps(ymm6, ymm0)
507 SGEMM_OUTPUT_GS_BETA_NZ
508 add(rdi, rcx) // c += cs_c;
509
510
511 vmovaps(ymm8, ymm0)
512 SGEMM_OUTPUT_GS_BETA_NZ
513 add(rdi, rcx) // c += cs_c;
514
515
516 vmovaps(ymm10, ymm0)
517 SGEMM_OUTPUT_GS_BETA_NZ
518 add(rdi, rcx) // c += cs_c;
519
520
521 vmovaps(ymm12, ymm0)
522 SGEMM_OUTPUT_GS_BETA_NZ
523 add(rdi, rcx) // c += cs_c;
524
525
526 vmovaps(ymm14, ymm0)
527 SGEMM_OUTPUT_GS_BETA_NZ
528 //add(rdi, rcx) // c += cs_c;
529
530
531 mov(rdx, rcx) // rcx = c + 8*rs_c
532
533
534 vmovaps(ymm5, ymm0)
535 SGEMM_OUTPUT_GS_BETA_NZ
536 add(rdi, rcx) // c += cs_c;
537
538
539 vmovaps(ymm7, ymm0)
540 SGEMM_OUTPUT_GS_BETA_NZ
541 add(rdi, rcx) // c += cs_c;
542
543
544 vmovaps(ymm9, ymm0)
545 SGEMM_OUTPUT_GS_BETA_NZ
546 add(rdi, rcx) // c += cs_c;
547
548
549 vmovaps(ymm11, ymm0)
550 SGEMM_OUTPUT_GS_BETA_NZ
551 add(rdi, rcx) // c += cs_c;
552
553
554 vmovaps(ymm13, ymm0)
555 SGEMM_OUTPUT_GS_BETA_NZ
556 add(rdi, rcx) // c += cs_c;
557
558
559 vmovaps(ymm15, ymm0)
560 SGEMM_OUTPUT_GS_BETA_NZ
561 //add(rdi, rcx) // c += cs_c;
562
563
564
565 jmp(.SDONE) // jump to end.
566
567
568
569 label(.SCOLSTORBZ)
570
571
572 vmovups(ymm4, mem(rcx))
573 add(rdi, rcx)
574 vmovups(ymm5, mem(rdx))
575 add(rdi, rdx)
576
577 vmovups(ymm6, mem(rcx))
578 add(rdi, rcx)
579 vmovups(ymm7, mem(rdx))
580 add(rdi, rdx)
581
582
583 vmovups(ymm8, mem(rcx))
584 add(rdi, rcx)
585 vmovups(ymm9, mem(rdx))
586 add(rdi, rdx)
587
588
589 vmovups(ymm10, mem(rcx))
590 add(rdi, rcx)
591 vmovups(ymm11, mem(rdx))
592 add(rdi, rdx)
593
594
595 vmovups(ymm12, mem(rcx))
596 add(rdi, rcx)
597 vmovups(ymm13, mem(rdx))
598 add(rdi, rdx)
599
600
601 vmovups(ymm14, mem(rcx))
602 //add(rdi, rcx)
603 vmovups(ymm15, mem(rdx))
604 //add(rdi, rdx)
605
606
607
608
609
610
611
612 label(.SDONE)
613
614
615
616 end_asm(
617 : // output operands (none)
618 : // input operands
619 [k_iter] "m" (k_iter), // 0
620 [k_left] "m" (k_left), // 1
621 [a] "m" (a), // 2
622 [b] "m" (b), // 3
623 [alpha] "m" (alpha), // 4
624 [beta] "m" (beta), // 5
625 [c] "m" (c), // 6
626 [rs_c] "m" (rs_c), // 7
627 [cs_c] "m" (cs_c)/*, // 8
628 [b_next] "m" (b_next), // 9
629 [a_next] "m" (a_next)*/ // 10
630 : // register clobber list
631 "rax", "rbx", "rcx", "rdx", "rsi", "rdi",
632 "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
633 "xmm0", "xmm1", "xmm2", "xmm3",
634 "xmm4", "xmm5", "xmm6", "xmm7",
635 "xmm8", "xmm9", "xmm10", "xmm11",
636 "xmm12", "xmm13", "xmm14", "xmm15",
637 "memory"
638 )
639 }
640
641 #define DGEMM_INPUT_GS_BETA_NZ \
642 vmovlpd(mem(rcx), xmm0, xmm0) \
643 vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \
644 vmovlpd(mem(rcx, rsi, 2), xmm1, xmm1) \
645 vmovhpd(mem(rcx, r13, 1), xmm1, xmm1) \
646 vperm2f128(imm(0x20), ymm1, ymm0, ymm0) /*\
647 vmovlpd(mem(rcx, rsi, 4), xmm2, xmm2) \
648 vmovhpd(mem(rcx, r15, 1), xmm2, xmm2) \
649 vmovlpd(mem(rcx, r13, 2), xmm1, xmm1) \
650 vmovhpd(mem(rcx, r10, 1), xmm1, xmm1) \
651 vperm2f128(imm(0x20), ymm1, ymm2, ymm2)*/
652
653 #define DGEMM_OUTPUT_GS_BETA_NZ \
654 vextractf128(imm(1), ymm0, xmm1) \
655 vmovlpd(xmm0, mem(rcx)) \
656 vmovhpd(xmm0, mem(rcx, rsi, 1)) \
657 vmovlpd(xmm1, mem(rcx, rsi, 2)) \
658 vmovhpd(xmm1, mem(rcx, r13, 1)) /*\
659 vextractf128(imm(1), ymm2, xmm1) \
660 vmovlpd(xmm2, mem(rcx, rsi, 4)) \
661 vmovhpd(xmm2, mem(rcx, r15, 1)) \
662 vmovlpd(xmm1, mem(rcx, r13, 2)) \
663 vmovhpd(xmm1, mem(rcx, r10, 1))*/
664
bli_dgemm_haswell_asm_8x6(dim_t k0,double * restrict alpha,double * restrict a,double * restrict b,double * restrict beta,double * restrict c,inc_t rs_c0,inc_t cs_c0,auxinfo_t * restrict data,cntx_t * restrict cntx)665 void bli_dgemm_haswell_asm_8x6
666 (
667 dim_t k0,
668 double* restrict alpha,
669 double* restrict a,
670 double* restrict b,
671 double* restrict beta,
672 double* restrict c, inc_t rs_c0, inc_t cs_c0,
673 auxinfo_t* restrict data,
674 cntx_t* restrict cntx
675 )
676 {
677 //void* a_next = bli_auxinfo_next_a( data );
678 //void* b_next = bli_auxinfo_next_b( data );
679
680 // Typecast local copies of integers in case dim_t and inc_t are a
681 // different size than is expected by load instructions.
682 uint64_t k_iter = k0 / 4;
683 uint64_t k_left = k0 % 4;
684 uint64_t rs_c = rs_c0;
685 uint64_t cs_c = cs_c0;
686
687 begin_asm()
688
689 vzeroall() // zero all xmm/ymm registers.
690
691
692 mov(var(a), rax) // load address of a.
693 mov(var(b), rbx) // load address of b.
694 //mov(%9, r15) // load address of b_next.
695
696 add(imm(32*4), rax)
697 // initialize loop by pre-loading
698 vmovapd(mem(rax, -4*32), ymm0)
699 vmovapd(mem(rax, -3*32), ymm1)
700
701 mov(var(c), rcx) // load address of c
702 mov(var(cs_c), rdi) // load cs_c
703 lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double)
704
705 lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c;
706 lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c;
707 prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c
708 prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c
709 prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*cs_c
710 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c
711 prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c
712 prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c
713
714
715
716
717 mov(var(k_iter), rsi) // i = k_iter;
718 test(rsi, rsi) // check i via logical AND.
719 je(.DCONSIDKLEFT) // if i == 0, jump to code that
720 // contains the k_left loop.
721
722
723 label(.DLOOPKITER) // MAIN LOOP
724
725
726 // iteration 0
727 prefetch(0, mem(rax, 64*8))
728
729 vbroadcastsd(mem(rbx, 0*8), ymm2)
730 vbroadcastsd(mem(rbx, 1*8), ymm3)
731 vfmadd231pd(ymm0, ymm2, ymm4)
732 vfmadd231pd(ymm1, ymm2, ymm5)
733 vfmadd231pd(ymm0, ymm3, ymm6)
734 vfmadd231pd(ymm1, ymm3, ymm7)
735
736 vbroadcastsd(mem(rbx, 2*8), ymm2)
737 vbroadcastsd(mem(rbx, 3*8), ymm3)
738 vfmadd231pd(ymm0, ymm2, ymm8)
739 vfmadd231pd(ymm1, ymm2, ymm9)
740 vfmadd231pd(ymm0, ymm3, ymm10)
741 vfmadd231pd(ymm1, ymm3, ymm11)
742
743 vbroadcastsd(mem(rbx, 4*8), ymm2)
744 vbroadcastsd(mem(rbx, 5*8), ymm3)
745 vfmadd231pd(ymm0, ymm2, ymm12)
746 vfmadd231pd(ymm1, ymm2, ymm13)
747 vfmadd231pd(ymm0, ymm3, ymm14)
748 vfmadd231pd(ymm1, ymm3, ymm15)
749
750 vmovapd(mem(rax, -2*32), ymm0)
751 vmovapd(mem(rax, -1*32), ymm1)
752
753 // iteration 1
754 vbroadcastsd(mem(rbx, 6*8), ymm2)
755 vbroadcastsd(mem(rbx, 7*8), ymm3)
756 vfmadd231pd(ymm0, ymm2, ymm4)
757 vfmadd231pd(ymm1, ymm2, ymm5)
758 vfmadd231pd(ymm0, ymm3, ymm6)
759 vfmadd231pd(ymm1, ymm3, ymm7)
760
761 vbroadcastsd(mem(rbx, 8*8), ymm2)
762 vbroadcastsd(mem(rbx, 9*8), ymm3)
763 vfmadd231pd(ymm0, ymm2, ymm8)
764 vfmadd231pd(ymm1, ymm2, ymm9)
765 vfmadd231pd(ymm0, ymm3, ymm10)
766 vfmadd231pd(ymm1, ymm3, ymm11)
767
768 vbroadcastsd(mem(rbx, 10*8), ymm2)
769 vbroadcastsd(mem(rbx, 11*8), ymm3)
770 vfmadd231pd(ymm0, ymm2, ymm12)
771 vfmadd231pd(ymm1, ymm2, ymm13)
772 vfmadd231pd(ymm0, ymm3, ymm14)
773 vfmadd231pd(ymm1, ymm3, ymm15)
774
775 vmovapd(mem(rax, 0*32), ymm0)
776 vmovapd(mem(rax, 1*32), ymm1)
777
778 // iteration 2
779 prefetch(0, mem(rax, 76*8))
780
781 vbroadcastsd(mem(rbx, 12*8), ymm2)
782 vbroadcastsd(mem(rbx, 13*8), ymm3)
783 vfmadd231pd(ymm0, ymm2, ymm4)
784 vfmadd231pd(ymm1, ymm2, ymm5)
785 vfmadd231pd(ymm0, ymm3, ymm6)
786 vfmadd231pd(ymm1, ymm3, ymm7)
787
788 vbroadcastsd(mem(rbx, 14*8), ymm2)
789 vbroadcastsd(mem(rbx, 15*8), ymm3)
790 vfmadd231pd(ymm0, ymm2, ymm8)
791 vfmadd231pd(ymm1, ymm2, ymm9)
792 vfmadd231pd(ymm0, ymm3, ymm10)
793 vfmadd231pd(ymm1, ymm3, ymm11)
794
795 vbroadcastsd(mem(rbx, 16*8), ymm2)
796 vbroadcastsd(mem(rbx, 17*8), ymm3)
797 vfmadd231pd(ymm0, ymm2, ymm12)
798 vfmadd231pd(ymm1, ymm2, ymm13)
799 vfmadd231pd(ymm0, ymm3, ymm14)
800 vfmadd231pd(ymm1, ymm3, ymm15)
801
802 vmovapd(mem(rax, 2*32), ymm0)
803 vmovapd(mem(rax, 3*32), ymm1)
804
805 // iteration 3
806 vbroadcastsd(mem(rbx, 18*8), ymm2)
807 vbroadcastsd(mem(rbx, 19*8), ymm3)
808 vfmadd231pd(ymm0, ymm2, ymm4)
809 vfmadd231pd(ymm1, ymm2, ymm5)
810 vfmadd231pd(ymm0, ymm3, ymm6)
811 vfmadd231pd(ymm1, ymm3, ymm7)
812
813 vbroadcastsd(mem(rbx, 20*8), ymm2)
814 vbroadcastsd(mem(rbx, 21*8), ymm3)
815 vfmadd231pd(ymm0, ymm2, ymm8)
816 vfmadd231pd(ymm1, ymm2, ymm9)
817 vfmadd231pd(ymm0, ymm3, ymm10)
818 vfmadd231pd(ymm1, ymm3, ymm11)
819
820 vbroadcastsd(mem(rbx, 22*8), ymm2)
821 vbroadcastsd(mem(rbx, 23*8), ymm3)
822 vfmadd231pd(ymm0, ymm2, ymm12)
823 vfmadd231pd(ymm1, ymm2, ymm13)
824 vfmadd231pd(ymm0, ymm3, ymm14)
825 vfmadd231pd(ymm1, ymm3, ymm15)
826
827 add(imm(4*8*8), rax) // a += 4*8 (unroll x mr)
828 add(imm(4*6*8), rbx) // b += 4*6 (unroll x nr)
829
830 vmovapd(mem(rax, -4*32), ymm0)
831 vmovapd(mem(rax, -3*32), ymm1)
832
833
834 dec(rsi) // i -= 1;
835 jne(.DLOOPKITER) // iterate again if i != 0.
836
837
838
839
840
841
842 label(.DCONSIDKLEFT)
843
844 mov(var(k_left), rsi) // i = k_left;
845 test(rsi, rsi) // check i via logical AND.
846 je(.DPOSTACCUM) // if i == 0, we're done; jump to end.
847 // else, we prepare to enter k_left loop.
848
849
850 label(.DLOOPKLEFT) // EDGE LOOP
851
852 prefetch(0, mem(rax, 64*8))
853
854 vbroadcastsd(mem(rbx, 0*8), ymm2)
855 vbroadcastsd(mem(rbx, 1*8), ymm3)
856 vfmadd231pd(ymm0, ymm2, ymm4)
857 vfmadd231pd(ymm1, ymm2, ymm5)
858 vfmadd231pd(ymm0, ymm3, ymm6)
859 vfmadd231pd(ymm1, ymm3, ymm7)
860
861 vbroadcastsd(mem(rbx, 2*8), ymm2)
862 vbroadcastsd(mem(rbx, 3*8), ymm3)
863 vfmadd231pd(ymm0, ymm2, ymm8)
864 vfmadd231pd(ymm1, ymm2, ymm9)
865 vfmadd231pd(ymm0, ymm3, ymm10)
866 vfmadd231pd(ymm1, ymm3, ymm11)
867
868 vbroadcastsd(mem(rbx, 4*8), ymm2)
869 vbroadcastsd(mem(rbx, 5*8), ymm3)
870 vfmadd231pd(ymm0, ymm2, ymm12)
871 vfmadd231pd(ymm1, ymm2, ymm13)
872 vfmadd231pd(ymm0, ymm3, ymm14)
873 vfmadd231pd(ymm1, ymm3, ymm15)
874
875 add(imm(1*8*8), rax) // a += 1*8 (unroll x mr)
876 add(imm(1*6*8), rbx) // b += 1*6 (unroll x nr)
877
878 vmovapd(mem(rax, -4*32), ymm0)
879 vmovapd(mem(rax, -3*32), ymm1)
880
881
882 dec(rsi) // i -= 1;
883 jne(.DLOOPKLEFT) // iterate again if i != 0.
884
885
886
887 label(.DPOSTACCUM)
888
889
890
891
892 mov(var(alpha), rax) // load address of alpha
893 mov(var(beta), rbx) // load address of beta
894 vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate
895 vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate
896
897 vmulpd(ymm0, ymm4, ymm4) // scale by alpha
898 vmulpd(ymm0, ymm5, ymm5)
899 vmulpd(ymm0, ymm6, ymm6)
900 vmulpd(ymm0, ymm7, ymm7)
901 vmulpd(ymm0, ymm8, ymm8)
902 vmulpd(ymm0, ymm9, ymm9)
903 vmulpd(ymm0, ymm10, ymm10)
904 vmulpd(ymm0, ymm11, ymm11)
905 vmulpd(ymm0, ymm12, ymm12)
906 vmulpd(ymm0, ymm13, ymm13)
907 vmulpd(ymm0, ymm14, ymm14)
908 vmulpd(ymm0, ymm15, ymm15)
909
910
911
912
913
914
915 mov(var(rs_c), rsi) // load rs_c
916 lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double)
917
918 lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c;
919
920 lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c;
921 //lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c;
922 //lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c;
923
924
925 // now avoid loading C if beta == 0
926
927 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
928 vucomisd(xmm0, xmm3) // set ZF if beta == 0.
929 je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case
930
931
932 cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8.
933 jz(.DCOLSTORED) // jump to column storage case
934
935
936
937 label(.DGENSTORED)
938
939
940 DGEMM_INPUT_GS_BETA_NZ
941 vfmadd213pd(ymm4, ymm3, ymm0)
942 DGEMM_OUTPUT_GS_BETA_NZ
943 add(rdi, rcx) // c += cs_c;
944
945
946 DGEMM_INPUT_GS_BETA_NZ
947 vfmadd213pd(ymm6, ymm3, ymm0)
948 DGEMM_OUTPUT_GS_BETA_NZ
949 add(rdi, rcx) // c += cs_c;
950
951
952 DGEMM_INPUT_GS_BETA_NZ
953 vfmadd213pd(ymm8, ymm3, ymm0)
954 DGEMM_OUTPUT_GS_BETA_NZ
955 add(rdi, rcx) // c += cs_c;
956
957
958 DGEMM_INPUT_GS_BETA_NZ
959 vfmadd213pd(ymm10, ymm3, ymm0)
960 DGEMM_OUTPUT_GS_BETA_NZ
961 add(rdi, rcx) // c += cs_c;
962
963
964 DGEMM_INPUT_GS_BETA_NZ
965 vfmadd213pd(ymm12, ymm3, ymm0)
966 DGEMM_OUTPUT_GS_BETA_NZ
967 add(rdi, rcx) // c += cs_c;
968
969
970 DGEMM_INPUT_GS_BETA_NZ
971 vfmadd213pd(ymm14, ymm3, ymm0)
972 DGEMM_OUTPUT_GS_BETA_NZ
973 //add(rdi, rcx) // c += cs_c;
974
975
976 mov(rdx, rcx) // rcx = c + 4*rs_c
977
978
979 DGEMM_INPUT_GS_BETA_NZ
980 vfmadd213pd(ymm5, ymm3, ymm0)
981 DGEMM_OUTPUT_GS_BETA_NZ
982 add(rdi, rcx) // c += cs_c;
983
984
985 DGEMM_INPUT_GS_BETA_NZ
986 vfmadd213pd(ymm7, ymm3, ymm0)
987 DGEMM_OUTPUT_GS_BETA_NZ
988 add(rdi, rcx) // c += cs_c;
989
990
991 DGEMM_INPUT_GS_BETA_NZ
992 vfmadd213pd(ymm9, ymm3, ymm0)
993 DGEMM_OUTPUT_GS_BETA_NZ
994 add(rdi, rcx) // c += cs_c;
995
996
997 DGEMM_INPUT_GS_BETA_NZ
998 vfmadd213pd(ymm11, ymm3, ymm0)
999 DGEMM_OUTPUT_GS_BETA_NZ
1000 add(rdi, rcx) // c += cs_c;
1001
1002
1003 DGEMM_INPUT_GS_BETA_NZ
1004 vfmadd213pd(ymm13, ymm3, ymm0)
1005 DGEMM_OUTPUT_GS_BETA_NZ
1006 add(rdi, rcx) // c += cs_c;
1007
1008
1009 DGEMM_INPUT_GS_BETA_NZ
1010 vfmadd213pd(ymm15, ymm3, ymm0)
1011 DGEMM_OUTPUT_GS_BETA_NZ
1012 //add(rdi, rcx) // c += cs_c;
1013
1014
1015
1016 jmp(.DDONE) // jump to end.
1017
1018
1019
1020 label(.DCOLSTORED)
1021
1022
1023 vfmadd231pd(mem(rcx), ymm3, ymm4)
1024 vmovupd(ymm4, mem(rcx))
1025 add(rdi, rcx)
1026 vfmadd231pd(mem(rdx), ymm3, ymm5)
1027 vmovupd(ymm5, mem(rdx))
1028 add(rdi, rdx)
1029
1030
1031 vfmadd231pd(mem(rcx), ymm3, ymm6)
1032 vmovupd(ymm6, mem(rcx))
1033 add(rdi, rcx)
1034 vfmadd231pd(mem(rdx), ymm3, ymm7)
1035 vmovupd(ymm7, mem(rdx))
1036 add(rdi, rdx)
1037
1038
1039 vfmadd231pd(mem(rcx), ymm3, ymm8)
1040 vmovupd(ymm8, mem(rcx))
1041 add(rdi, rcx)
1042 vfmadd231pd(mem(rdx), ymm3, ymm9)
1043 vmovupd(ymm9, mem(rdx))
1044 add(rdi, rdx)
1045
1046
1047 vfmadd231pd(mem(rcx), ymm3, ymm10)
1048 vmovupd(ymm10, mem(rcx))
1049 add(rdi, rcx)
1050 vfmadd231pd(mem(rdx), ymm3, ymm11)
1051 vmovupd(ymm11, mem(rdx))
1052 add(rdi, rdx)
1053
1054
1055 vfmadd231pd(mem(rcx), ymm3, ymm12)
1056 vmovupd(ymm12, mem(rcx))
1057 add(rdi, rcx)
1058 vfmadd231pd(mem(rdx), ymm3, ymm13)
1059 vmovupd(ymm13, mem(rdx))
1060 add(rdi, rdx)
1061
1062
1063 vfmadd231pd(mem(rcx), ymm3, ymm14)
1064 vmovupd(ymm14, mem(rcx))
1065 //add(rdi, rcx)
1066 vfmadd231pd(mem(rdx), ymm3, ymm15)
1067 vmovupd(ymm15, mem(rdx))
1068 //add(rdi, rdx)
1069
1070
1071
1072 jmp(.DDONE) // jump to end.
1073
1074
1075
1076 label(.DBETAZERO)
1077
1078 cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8.
1079 jz(.DCOLSTORBZ) // jump to column storage case
1080
1081
1082
1083 label(.DGENSTORBZ)
1084
1085
1086 vmovapd(ymm4, ymm0)
1087 DGEMM_OUTPUT_GS_BETA_NZ
1088 add(rdi, rcx) // c += cs_c;
1089
1090
1091 vmovapd(ymm6, ymm0)
1092 DGEMM_OUTPUT_GS_BETA_NZ
1093 add(rdi, rcx) // c += cs_c;
1094
1095
1096 vmovapd(ymm8, ymm0)
1097 DGEMM_OUTPUT_GS_BETA_NZ
1098 add(rdi, rcx) // c += cs_c;
1099
1100
1101 vmovapd(ymm10, ymm0)
1102 DGEMM_OUTPUT_GS_BETA_NZ
1103 add(rdi, rcx) // c += cs_c;
1104
1105
1106 vmovapd(ymm12, ymm0)
1107 DGEMM_OUTPUT_GS_BETA_NZ
1108 add(rdi, rcx) // c += cs_c;
1109
1110
1111 vmovapd(ymm14, ymm0)
1112 DGEMM_OUTPUT_GS_BETA_NZ
1113 //add(rdi, rcx) // c += cs_c;
1114
1115
1116 mov(rdx, rcx) // rcx = c + 4*rs_c
1117
1118
1119 vmovapd(ymm5, ymm0)
1120 DGEMM_OUTPUT_GS_BETA_NZ
1121 add(rdi, rcx) // c += cs_c;
1122
1123
1124 vmovapd(ymm7, ymm0)
1125 DGEMM_OUTPUT_GS_BETA_NZ
1126 add(rdi, rcx) // c += cs_c;
1127
1128
1129 vmovapd(ymm9, ymm0)
1130 DGEMM_OUTPUT_GS_BETA_NZ
1131 add(rdi, rcx) // c += cs_c;
1132
1133
1134 vmovapd(ymm11, ymm0)
1135 DGEMM_OUTPUT_GS_BETA_NZ
1136 add(rdi, rcx) // c += cs_c;
1137
1138
1139 vmovapd(ymm13, ymm0)
1140 DGEMM_OUTPUT_GS_BETA_NZ
1141 add(rdi, rcx) // c += cs_c;
1142
1143
1144 vmovapd(ymm15, ymm0)
1145 DGEMM_OUTPUT_GS_BETA_NZ
1146 //add(rdi, rcx) // c += cs_c;
1147
1148
1149
1150 jmp(.DDONE) // jump to end.
1151
1152
1153
1154 label(.DCOLSTORBZ)
1155
1156
1157 vmovupd(ymm4, mem(rcx))
1158 add(rdi, rcx)
1159 vmovupd(ymm5, mem(rdx))
1160 add(rdi, rdx)
1161
1162 vmovupd(ymm6, mem(rcx))
1163 add(rdi, rcx)
1164 vmovupd(ymm7, mem(rdx))
1165 add(rdi, rdx)
1166
1167
1168 vmovupd(ymm8, mem(rcx))
1169 add(rdi, rcx)
1170 vmovupd(ymm9, mem(rdx))
1171 add(rdi, rdx)
1172
1173
1174 vmovupd(ymm10, mem(rcx))
1175 add(rdi, rcx)
1176 vmovupd(ymm11, mem(rdx))
1177 add(rdi, rdx)
1178
1179
1180 vmovupd(ymm12, mem(rcx))
1181 add(rdi, rcx)
1182 vmovupd(ymm13, mem(rdx))
1183 add(rdi, rdx)
1184
1185
1186 vmovupd(ymm14, mem(rcx))
1187 //add(rdi, rcx)
1188 vmovupd(ymm15, mem(rdx))
1189 //add(rdi, rdx)
1190
1191
1192
1193
1194
1195
1196
1197 label(.DDONE)
1198
1199
1200
1201 end_asm(
1202 : // output operands (none)
1203 : // input operands
1204 [k_iter] "m" (k_iter), // 0
1205 [k_left] "m" (k_left), // 1
1206 [a] "m" (a), // 2
1207 [b] "m" (b), // 3
1208 [alpha] "m" (alpha), // 4
1209 [beta] "m" (beta), // 5
1210 [c] "m" (c), // 6
1211 [rs_c] "m" (rs_c), // 7
1212 [cs_c] "m" (cs_c)/*, // 8
1213 [b_next] "m" (b_next), // 9
1214 [a_next] "m" (a_next)*/ // 10
1215 : // register clobber list
1216 "rax", "rbx", "rcx", "rdx", "rsi", "rdi",
1217 "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
1218 "xmm0", "xmm1", "xmm2", "xmm3",
1219 "xmm4", "xmm5", "xmm6", "xmm7",
1220 "xmm8", "xmm9", "xmm10", "xmm11",
1221 "xmm12", "xmm13", "xmm14", "xmm15",
1222 "memory"
1223 )
1224 }
1225
1226
1227
1228
1229 // assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
1230 // outputs to ymm0
1231 #define CGEMM_INPUT_SCALE_GS_BETA_NZ \
1232 vmovlpd(mem(rcx), xmm0, xmm0) \
1233 vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \
1234 vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \
1235 vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \
1236 vinsertf128(imm(1), xmm3, ymm0, ymm0) \
1237 vpermilps(imm(0xb1), ymm0, ymm3) \
1238 vmulps(ymm1, ymm0, ymm0) \
1239 vmulps(ymm2, ymm3, ymm3) \
1240 vaddsubps(ymm3, ymm0, ymm0)
1241
1242 // assumes values to output are in ymm0
1243 #define CGEMM_OUTPUT_GS \
1244 vextractf128(imm(1), ymm0, xmm3) \
1245 vmovlpd(xmm0, mem(rcx)) \
1246 vmovhpd(xmm0, mem(rcx, rsi, 1)) \
1247 vmovlpd(xmm3, mem(rcx, rsi, 2)) \
1248 vmovhpd(xmm3, mem(rcx, r13, 1))
1249
1250 #define CGEMM_INPUT_SCALE_CS_BETA_NZ \
1251 vmovups(mem(rcx), ymm0) \
1252 vpermilps(imm(0xb1), ymm0, ymm3) \
1253 vmulps(ymm1, ymm0, ymm0) \
1254 vmulps(ymm2, ymm3, ymm3) \
1255 vaddsubps(ymm3, ymm0, ymm0)
1256
1257 #define CGEMM_OUTPUT_CS \
1258 vmovups(ymm0, mem(rcx)) \
1259
bli_cgemm_haswell_asm_8x3(dim_t k0,scomplex * restrict alpha,scomplex * restrict a,scomplex * restrict b,scomplex * restrict beta,scomplex * restrict c,inc_t rs_c0,inc_t cs_c0,auxinfo_t * restrict data,cntx_t * restrict cntx)1260 void bli_cgemm_haswell_asm_8x3
1261 (
1262 dim_t k0,
1263 scomplex* restrict alpha,
1264 scomplex* restrict a,
1265 scomplex* restrict b,
1266 scomplex* restrict beta,
1267 scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
1268 auxinfo_t* restrict data,
1269 cntx_t* restrict cntx
1270 )
1271 {
1272 //void* a_next = bli_auxinfo_next_a( data );
1273 //void* b_next = bli_auxinfo_next_b( data );
1274
1275 // Typecast local copies of integers in case dim_t and inc_t are a
1276 // different size than is expected by load instructions.
1277 uint64_t k_iter = k0 / 4;
1278 uint64_t k_left = k0 % 4;
1279 uint64_t rs_c = rs_c0;
1280 uint64_t cs_c = cs_c0;
1281
1282 begin_asm()
1283
1284 vzeroall() // zero all xmm/ymm registers.
1285
1286
1287 mov(var(a), rax) // load address of a.
1288 mov(var(b), rbx) // load address of b.
1289 //mov(%9, r15) // load address of b_next.
1290
1291 add(imm(32*4), rax)
1292 // initialize loop by pre-loading
1293 vmovaps(mem(rax, -4*32), ymm0)
1294 vmovaps(mem(rax, -3*32), ymm1)
1295
1296 mov(var(c), rcx) // load address of c
1297 mov(var(cs_c), rdi) // load cs_c
1298 lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex)
1299
1300 lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c;
1301 lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c;
1302
1303 prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c
1304 prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c
1305 prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c
1306
1307
1308
1309
1310 mov(var(k_iter), rsi) // i = k_iter;
1311 test(rsi, rsi) // check i via logical AND.
1312 je(.CCONSIDKLEFT) // if i == 0, jump to code that
1313 // contains the k_left loop.
1314
1315
1316 label(.CLOOPKITER) // MAIN LOOP
1317
1318
1319 // iteration 0
1320 prefetch(0, mem(rax, 32*8))
1321
1322 vbroadcastss(mem(rbx, 0*4), ymm2)
1323 vbroadcastss(mem(rbx, 1*4), ymm3)
1324 vfmadd231ps(ymm0, ymm2, ymm4)
1325 vfmadd231ps(ymm1, ymm2, ymm5)
1326 vfmadd231ps(ymm0, ymm3, ymm6)
1327 vfmadd231ps(ymm1, ymm3, ymm7)
1328
1329 vbroadcastss(mem(rbx, 2*4), ymm2)
1330 vbroadcastss(mem(rbx, 3*4), ymm3)
1331 vfmadd231ps(ymm0, ymm2, ymm8)
1332 vfmadd231ps(ymm1, ymm2, ymm9)
1333 vfmadd231ps(ymm0, ymm3, ymm10)
1334 vfmadd231ps(ymm1, ymm3, ymm11)
1335
1336 vbroadcastss(mem(rbx, 4*4), ymm2)
1337 vbroadcastss(mem(rbx, 5*4), ymm3)
1338 vfmadd231ps(ymm0, ymm2, ymm12)
1339 vfmadd231ps(ymm1, ymm2, ymm13)
1340 vfmadd231ps(ymm0, ymm3, ymm14)
1341 vfmadd231ps(ymm1, ymm3, ymm15)
1342
1343 vmovaps(mem(rax, -2*32), ymm0)
1344 vmovaps(mem(rax, -1*32), ymm1)
1345
1346 // iteration 1
1347 vbroadcastss(mem(rbx, 6*4), ymm2)
1348 vbroadcastss(mem(rbx, 7*4), ymm3)
1349 vfmadd231ps(ymm0, ymm2, ymm4)
1350 vfmadd231ps(ymm1, ymm2, ymm5)
1351 vfmadd231ps(ymm0, ymm3, ymm6)
1352 vfmadd231ps(ymm1, ymm3, ymm7)
1353
1354 vbroadcastss(mem(rbx, 8*4), ymm2)
1355 vbroadcastss(mem(rbx, 9*4), ymm3)
1356 vfmadd231ps(ymm0, ymm2, ymm8)
1357 vfmadd231ps(ymm1, ymm2, ymm9)
1358 vfmadd231ps(ymm0, ymm3, ymm10)
1359 vfmadd231ps(ymm1, ymm3, ymm11)
1360
1361 vbroadcastss(mem(rbx, 10*4), ymm2)
1362 vbroadcastss(mem(rbx, 11*4), ymm3)
1363 vfmadd231ps(ymm0, ymm2, ymm12)
1364 vfmadd231ps(ymm1, ymm2, ymm13)
1365 vfmadd231ps(ymm0, ymm3, ymm14)
1366 vfmadd231ps(ymm1, ymm3, ymm15)
1367
1368 vmovaps(mem(rax, 0*32), ymm0)
1369 vmovaps(mem(rax, 1*32), ymm1)
1370
1371 // iteration 2
1372 prefetch(0, mem(rax, 38*8))
1373
1374 vbroadcastss(mem(rbx, 12*4), ymm2)
1375 vbroadcastss(mem(rbx, 13*4), ymm3)
1376 vfmadd231ps(ymm0, ymm2, ymm4)
1377 vfmadd231ps(ymm1, ymm2, ymm5)
1378 vfmadd231ps(ymm0, ymm3, ymm6)
1379 vfmadd231ps(ymm1, ymm3, ymm7)
1380
1381 vbroadcastss(mem(rbx, 14*4), ymm2)
1382 vbroadcastss(mem(rbx, 15*4), ymm3)
1383 vfmadd231ps(ymm0, ymm2, ymm8)
1384 vfmadd231ps(ymm1, ymm2, ymm9)
1385 vfmadd231ps(ymm0, ymm3, ymm10)
1386 vfmadd231ps(ymm1, ymm3, ymm11)
1387
1388 vbroadcastss(mem(rbx, 16*4), ymm2)
1389 vbroadcastss(mem(rbx, 17*4), ymm3)
1390 vfmadd231ps(ymm0, ymm2, ymm12)
1391 vfmadd231ps(ymm1, ymm2, ymm13)
1392 vfmadd231ps(ymm0, ymm3, ymm14)
1393 vfmadd231ps(ymm1, ymm3, ymm15)
1394
1395 vmovaps(mem(rax, 2*32), ymm0)
1396 vmovaps(mem(rax, 3*32), ymm1)
1397
1398 // iteration 3
1399 vbroadcastss(mem(rbx, 18*4), ymm2)
1400 vbroadcastss(mem(rbx, 19*4), ymm3)
1401 vfmadd231ps(ymm0, ymm2, ymm4)
1402 vfmadd231ps(ymm1, ymm2, ymm5)
1403 vfmadd231ps(ymm0, ymm3, ymm6)
1404 vfmadd231ps(ymm1, ymm3, ymm7)
1405
1406 vbroadcastss(mem(rbx, 20*4), ymm2)
1407 vbroadcastss(mem(rbx, 21*4), ymm3)
1408 vfmadd231ps(ymm0, ymm2, ymm8)
1409 vfmadd231ps(ymm1, ymm2, ymm9)
1410 vfmadd231ps(ymm0, ymm3, ymm10)
1411 vfmadd231ps(ymm1, ymm3, ymm11)
1412
1413 vbroadcastss(mem(rbx, 22*4), ymm2)
1414 vbroadcastss(mem(rbx, 23*4), ymm3)
1415 vfmadd231ps(ymm0, ymm2, ymm12)
1416 vfmadd231ps(ymm1, ymm2, ymm13)
1417 vfmadd231ps(ymm0, ymm3, ymm14)
1418 vfmadd231ps(ymm1, ymm3, ymm15)
1419
1420 add(imm(4*8*8), rax) // a += 4*8 (unroll x mr)
1421 add(imm(4*3*8), rbx) // b += 4*3 (unroll x nr)
1422
1423 vmovaps(mem(rax, -4*32), ymm0)
1424 vmovaps(mem(rax, -3*32), ymm1)
1425
1426
1427 dec(rsi) // i -= 1;
1428 jne(.CLOOPKITER) // iterate again if i != 0.
1429
1430
1431
1432
1433
1434
1435 label(.CCONSIDKLEFT)
1436
1437 mov(var(k_left), rsi) // i = k_left;
1438 test(rsi, rsi) // check i via logical AND.
1439 je(.CPOSTACCUM) // if i == 0, we're done; jump to end.
1440 // else, we prepare to enter k_left loop.
1441
1442
1443 label(.CLOOPKLEFT) // EDGE LOOP
1444
1445 prefetch(0, mem(rax, 32*8))
1446
1447 vbroadcastss(mem(rbx, 0*4), ymm2)
1448 vbroadcastss(mem(rbx, 1*4), ymm3)
1449 vfmadd231ps(ymm0, ymm2, ymm4)
1450 vfmadd231ps(ymm1, ymm2, ymm5)
1451 vfmadd231ps(ymm0, ymm3, ymm6)
1452 vfmadd231ps(ymm1, ymm3, ymm7)
1453
1454 vbroadcastss(mem(rbx, 2*4), ymm2)
1455 vbroadcastss(mem(rbx, 3*4), ymm3)
1456 vfmadd231ps(ymm0, ymm2, ymm8)
1457 vfmadd231ps(ymm1, ymm2, ymm9)
1458 vfmadd231ps(ymm0, ymm3, ymm10)
1459 vfmadd231ps(ymm1, ymm3, ymm11)
1460
1461 vbroadcastss(mem(rbx, 4*4), ymm2)
1462 vbroadcastss(mem(rbx, 5*4), ymm3)
1463 vfmadd231ps(ymm0, ymm2, ymm12)
1464 vfmadd231ps(ymm1, ymm2, ymm13)
1465 vfmadd231ps(ymm0, ymm3, ymm14)
1466 vfmadd231ps(ymm1, ymm3, ymm15)
1467
1468 add(imm(1*8*8), rax) // a += 1*8 (unroll x mr)
1469 add(imm(1*3*8), rbx) // b += 1*3 (unroll x nr)
1470
1471 vmovaps(mem(rax, -4*32), ymm0)
1472 vmovaps(mem(rax, -3*32), ymm1)
1473
1474
1475 dec(rsi) // i -= 1;
1476 jne(.CLOOPKLEFT) // iterate again if i != 0.
1477
1478
1479
1480 label(.CPOSTACCUM)
1481
1482
1483 // permute even and odd elements
1484 // of ymm6/7, ymm10/11, ymm/14/15
1485 vpermilps(imm(0xb1), ymm6, ymm6)
1486 vpermilps(imm(0xb1), ymm7, ymm7)
1487 vpermilps(imm(0xb1), ymm10, ymm10)
1488 vpermilps(imm(0xb1), ymm11, ymm11)
1489 vpermilps(imm(0xb1), ymm14, ymm14)
1490 vpermilps(imm(0xb1), ymm15, ymm15)
1491
1492
1493 // subtract/add even/odd elements
1494 vaddsubps(ymm6, ymm4, ymm4)
1495 vaddsubps(ymm7, ymm5, ymm5)
1496
1497 vaddsubps(ymm10, ymm8, ymm8)
1498 vaddsubps(ymm11, ymm9, ymm9)
1499
1500 vaddsubps(ymm14, ymm12, ymm12)
1501 vaddsubps(ymm15, ymm13, ymm13)
1502
1503
1504
1505
1506 mov(var(alpha), rax) // load address of alpha
1507 vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate
1508 vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate
1509
1510
1511 vpermilps(imm(0xb1), ymm4, ymm3)
1512 vmulps(ymm0, ymm4, ymm4)
1513 vmulps(ymm1, ymm3, ymm3)
1514 vaddsubps(ymm3, ymm4, ymm4)
1515
1516 vpermilps(imm(0xb1), ymm5, ymm3)
1517 vmulps(ymm0, ymm5, ymm5)
1518 vmulps(ymm1, ymm3, ymm3)
1519 vaddsubps(ymm3, ymm5, ymm5)
1520
1521
1522 vpermilps(imm(0xb1), ymm8, ymm3)
1523 vmulps(ymm0, ymm8, ymm8)
1524 vmulps(ymm1, ymm3, ymm3)
1525 vaddsubps(ymm3, ymm8, ymm8)
1526
1527 vpermilps(imm(0xb1), ymm9, ymm3)
1528 vmulps(ymm0, ymm9, ymm9)
1529 vmulps(ymm1, ymm3, ymm3)
1530 vaddsubps(ymm3, ymm9, ymm9)
1531
1532
1533 vpermilps(imm(0xb1), ymm12, ymm3)
1534 vmulps(ymm0, ymm12, ymm12)
1535 vmulps(ymm1, ymm3, ymm3)
1536 vaddsubps(ymm3, ymm12, ymm12)
1537
1538 vpermilps(imm(0xb1), ymm13, ymm3)
1539 vmulps(ymm0, ymm13, ymm13)
1540 vmulps(ymm1, ymm3, ymm3)
1541 vaddsubps(ymm3, ymm13, ymm13)
1542
1543
1544
1545
1546
1547 mov(var(beta), rbx) // load address of beta
1548 vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate
1549 vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate
1550
1551
1552
1553
1554 mov(var(rs_c), rsi) // load rs_c
1555 lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex)
1556 lea(mem(, rsi, 4), rdx) // rdx = 4*rs_c;
1557 lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c;
1558
1559
1560
1561 // now avoid loading C if beta == 0
1562 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero.
1563 vucomiss(xmm0, xmm1) // set ZF if beta_r == 0.
1564 sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 );
1565 vucomiss(xmm0, xmm2) // set ZF if beta_i == 0.
1566 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 );
1567 and(r8b, r9b) // set ZF if r8b & r9b == 1.
1568 jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case
1569
1570
1571 cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8.
1572 jz(.CCOLSTORED) // jump to row storage case
1573
1574
1575
1576 label(.CGENSTORED)
1577
1578
1579 CGEMM_INPUT_SCALE_GS_BETA_NZ
1580 vaddps(ymm4, ymm0, ymm0)
1581 CGEMM_OUTPUT_GS
1582 add(rdx, rcx) // c += 4*rs_c;
1583
1584
1585 CGEMM_INPUT_SCALE_GS_BETA_NZ
1586 vaddps(ymm5, ymm0, ymm0)
1587 CGEMM_OUTPUT_GS
1588 mov(r11, rcx) // rcx = c + 1*cs_c
1589
1590
1591
1592 CGEMM_INPUT_SCALE_GS_BETA_NZ
1593 vaddps(ymm8, ymm0, ymm0)
1594 CGEMM_OUTPUT_GS
1595 add(rdx, rcx) // c += 4*rs_c;
1596
1597
1598 CGEMM_INPUT_SCALE_GS_BETA_NZ
1599 vaddps(ymm9, ymm0, ymm0)
1600 CGEMM_OUTPUT_GS
1601 mov(r12, rcx) // rcx = c + 2*cs_c
1602
1603
1604
1605 CGEMM_INPUT_SCALE_GS_BETA_NZ
1606 vaddps(ymm12, ymm0, ymm0)
1607 CGEMM_OUTPUT_GS
1608 add(rdx, rcx) // c += 4*rs_c;
1609
1610
1611 CGEMM_INPUT_SCALE_GS_BETA_NZ
1612 vaddps(ymm13, ymm0, ymm0)
1613 CGEMM_OUTPUT_GS
1614
1615
1616
1617 jmp(.CDONE) // jump to end.
1618
1619
1620
1621 label(.CCOLSTORED)
1622
1623
1624 CGEMM_INPUT_SCALE_CS_BETA_NZ
1625 vaddps(ymm4, ymm0, ymm0)
1626 CGEMM_OUTPUT_CS
1627 add(rdx, rcx) // c += 4*rs_c;
1628
1629
1630 CGEMM_INPUT_SCALE_CS_BETA_NZ
1631 vaddps(ymm5, ymm0, ymm0)
1632 CGEMM_OUTPUT_CS
1633 mov(r11, rcx) // rcx = c + 1*cs_c
1634
1635
1636
1637 CGEMM_INPUT_SCALE_CS_BETA_NZ
1638 vaddps(ymm8, ymm0, ymm0)
1639 CGEMM_OUTPUT_CS
1640 add(rdx, rcx) // c += 4*rs_c;
1641
1642
1643 CGEMM_INPUT_SCALE_CS_BETA_NZ
1644 vaddps(ymm9, ymm0, ymm0)
1645 CGEMM_OUTPUT_CS
1646 mov(r12, rcx) // rcx = c + 2*cs_c
1647
1648
1649
1650 CGEMM_INPUT_SCALE_CS_BETA_NZ
1651 vaddps(ymm12, ymm0, ymm0)
1652 CGEMM_OUTPUT_CS
1653 add(rdx, rcx) // c += 4*rs_c;
1654
1655
1656 CGEMM_INPUT_SCALE_CS_BETA_NZ
1657 vaddps(ymm13, ymm0, ymm0)
1658 CGEMM_OUTPUT_CS
1659
1660
1661
1662 jmp(.CDONE) // jump to end.
1663
1664
1665
1666 label(.CBETAZERO)
1667
1668 cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8.
1669 jz(.CCOLSTORBZ) // jump to row storage case
1670
1671
1672
1673 label(.CGENSTORBZ)
1674
1675
1676 vmovaps(ymm4, ymm0)
1677 CGEMM_OUTPUT_GS
1678 add(rdx, rcx) // c += 4*rs_c;
1679
1680
1681 vmovaps(ymm5, ymm0)
1682 CGEMM_OUTPUT_GS
1683 mov(r11, rcx) // rcx = c + 1*cs_c
1684
1685
1686
1687 vmovaps(ymm8, ymm0)
1688 CGEMM_OUTPUT_GS
1689 add(rdx, rcx) // c += 4*rs_c;
1690
1691
1692 vmovaps(ymm9, ymm0)
1693 CGEMM_OUTPUT_GS
1694 mov(r12, rcx) // rcx = c + 2*cs_c
1695
1696
1697
1698 vmovaps(ymm12, ymm0)
1699 CGEMM_OUTPUT_GS
1700 add(rdx, rcx) // c += 4*rs_c;
1701
1702
1703 vmovaps(ymm13, ymm0)
1704 CGEMM_OUTPUT_GS
1705
1706
1707
1708 jmp(.CDONE) // jump to end.
1709
1710
1711
1712 label(.CCOLSTORBZ)
1713
1714
1715 vmovups(ymm4, mem(rcx))
1716 vmovups(ymm5, mem(rcx, rdx, 1))
1717
1718 vmovups(ymm8, mem(r11))
1719 vmovups(ymm9, mem(r11, rdx, 1))
1720
1721 vmovups(ymm12, mem(r12))
1722 vmovups(ymm13, mem(r12, rdx, 1))
1723
1724
1725
1726
1727
1728
1729 label(.CDONE)
1730
1731
1732
1733 end_asm(
1734 : // output operands (none)
1735 : // input operands
1736 [k_iter] "m" (k_iter), // 0
1737 [k_left] "m" (k_left), // 1
1738 [a] "m" (a), // 2
1739 [b] "m" (b), // 3
1740 [alpha] "m" (alpha), // 4
1741 [beta] "m" (beta), // 5
1742 [c] "m" (c), // 6
1743 [rs_c] "m" (rs_c), // 7
1744 [cs_c] "m" (cs_c)/*, // 8
1745 [b_next] "m" (b_next), // 9
1746 [a_next] "m" (a_next)*/ // 10
1747 : // register clobber list
1748 "rax", "rbx", "rcx", "rdx", "rsi", "rdi",
1749 "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
1750 "xmm0", "xmm1", "xmm2", "xmm3",
1751 "xmm4", "xmm5", "xmm6", "xmm7",
1752 "xmm8", "xmm9", "xmm10", "xmm11",
1753 "xmm12", "xmm13", "xmm14", "xmm15",
1754 "memory"
1755 )
1756 }
1757
1758
1759
1760
1761 // assumes beta.r, beta.i have been broadcast into ymm1, ymm2.
1762 // outputs to ymm0
1763 #define ZGEMM_INPUT_SCALE_GS_BETA_NZ \
1764 vmovupd(mem(rcx), xmm0) \
1765 vmovupd(mem(rcx, rsi, 1), xmm3) \
1766 vinsertf128(imm(1), xmm3, ymm0, ymm0) \
1767 vpermilpd(imm(0x5), ymm0, ymm3) \
1768 vmulpd(ymm1, ymm0, ymm0) \
1769 vmulpd(ymm2, ymm3, ymm3) \
1770 vaddsubpd(ymm3, ymm0, ymm0)
1771
1772 // assumes values to output are in ymm0
1773 #define ZGEMM_OUTPUT_GS \
1774 vextractf128(imm(1), ymm0, xmm3) \
1775 vmovupd(xmm0, mem(rcx)) \
1776 vmovupd(xmm3, mem(rcx, rsi, 1)) \
1777
1778 #define ZGEMM_INPUT_SCALE_CS_BETA_NZ \
1779 vmovups(mem(rcx), ymm0) \
1780 vpermilpd(imm(0x5), ymm0, ymm3) \
1781 vmulpd(ymm1, ymm0, ymm0) \
1782 vmulpd(ymm2, ymm3, ymm3) \
1783 vaddsubpd(ymm3, ymm0, ymm0)
1784
1785 #define ZGEMM_OUTPUT_CS \
1786 vmovupd(ymm0, mem(rcx)) \
1787
bli_zgemm_haswell_asm_4x3(dim_t k0,dcomplex * restrict alpha,dcomplex * restrict a,dcomplex * restrict b,dcomplex * restrict beta,dcomplex * restrict c,inc_t rs_c0,inc_t cs_c0,auxinfo_t * restrict data,cntx_t * restrict cntx)1788 void bli_zgemm_haswell_asm_4x3
1789 (
1790 dim_t k0,
1791 dcomplex* restrict alpha,
1792 dcomplex* restrict a,
1793 dcomplex* restrict b,
1794 dcomplex* restrict beta,
1795 dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
1796 auxinfo_t* restrict data,
1797 cntx_t* restrict cntx
1798 )
1799 {
1800 //void* a_next = bli_auxinfo_next_a( data );
1801 //void* b_next = bli_auxinfo_next_b( data );
1802
1803 // Typecast local copies of integers in case dim_t and inc_t are a
1804 // different size than is expected by load instructions.
1805 uint64_t k_iter = k0 / 4;
1806 uint64_t k_left = k0 % 4;
1807 uint64_t rs_c = rs_c0;
1808 uint64_t cs_c = cs_c0;
1809
1810 begin_asm()
1811
1812 vzeroall() // zero all xmm/ymm registers.
1813
1814
1815 mov(var(a), rax) // load address of a.
1816 mov(var(b), rbx) // load address of b.
1817 //mov(%9, r15) // load address of b_next.
1818
1819 add(imm(32*4), rax)
1820 // initialize loop by pre-loading
1821 vmovapd(mem(rax, -4*32), ymm0)
1822 vmovapd(mem(rax, -3*32), ymm1)
1823
1824 mov(var(c), rcx) // load address of c
1825 mov(var(cs_c), rdi) // load cs_c
1826 lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex)
1827 lea(mem(, rdi, 2), rdi)
1828
1829 lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c;
1830 lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c;
1831
1832 prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c
1833 prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c
1834 prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c
1835
1836
1837
1838
1839 mov(var(k_iter), rsi) // i = k_iter;
1840 test(rsi, rsi) // check i via logical AND.
1841 je(.ZCONSIDKLEFT) // if i == 0, jump to code that
1842 // contains the k_left loop.
1843
1844
1845 label(.ZLOOPKITER) // MAIN LOOP
1846
1847
1848 // iteration 0
1849 prefetch(0, mem(rax, 32*16))
1850
1851 vbroadcastsd(mem(rbx, 0*8), ymm2)
1852 vbroadcastsd(mem(rbx, 1*8), ymm3)
1853 vfmadd231pd(ymm0, ymm2, ymm4)
1854 vfmadd231pd(ymm1, ymm2, ymm5)
1855 vfmadd231pd(ymm0, ymm3, ymm6)
1856 vfmadd231pd(ymm1, ymm3, ymm7)
1857
1858 vbroadcastsd(mem(rbx, 2*8), ymm2)
1859 vbroadcastsd(mem(rbx, 3*8), ymm3)
1860 vfmadd231pd(ymm0, ymm2, ymm8)
1861 vfmadd231pd(ymm1, ymm2, ymm9)
1862 vfmadd231pd(ymm0, ymm3, ymm10)
1863 vfmadd231pd(ymm1, ymm3, ymm11)
1864
1865 vbroadcastsd(mem(rbx, 4*8), ymm2)
1866 vbroadcastsd(mem(rbx, 5*8), ymm3)
1867 vfmadd231pd(ymm0, ymm2, ymm12)
1868 vfmadd231pd(ymm1, ymm2, ymm13)
1869 vfmadd231pd(ymm0, ymm3, ymm14)
1870 vfmadd231pd(ymm1, ymm3, ymm15)
1871
1872 vmovapd(mem(rax, -2*32), ymm0)
1873 vmovapd(mem(rax, -1*32), ymm1)
1874
1875 // iteration 1
1876 vbroadcastsd(mem(rbx, 6*8), ymm2)
1877 vbroadcastsd(mem(rbx, 7*8), ymm3)
1878 vfmadd231pd(ymm0, ymm2, ymm4)
1879 vfmadd231pd(ymm1, ymm2, ymm5)
1880 vfmadd231pd(ymm0, ymm3, ymm6)
1881 vfmadd231pd(ymm1, ymm3, ymm7)
1882
1883 vbroadcastsd(mem(rbx, 8*8), ymm2)
1884 vbroadcastsd(mem(rbx, 9*8), ymm3)
1885 vfmadd231pd(ymm0, ymm2, ymm8)
1886 vfmadd231pd(ymm1, ymm2, ymm9)
1887 vfmadd231pd(ymm0, ymm3, ymm10)
1888 vfmadd231pd(ymm1, ymm3, ymm11)
1889
1890 vbroadcastsd(mem(rbx, 10*8), ymm2)
1891 vbroadcastsd(mem(rbx, 11*8), ymm3)
1892 vfmadd231pd(ymm0, ymm2, ymm12)
1893 vfmadd231pd(ymm1, ymm2, ymm13)
1894 vfmadd231pd(ymm0, ymm3, ymm14)
1895 vfmadd231pd(ymm1, ymm3, ymm15)
1896
1897 vmovapd(mem(rax, 0*32), ymm0)
1898 vmovapd(mem(rax, 1*32), ymm1)
1899
1900 // iteration 2
1901 prefetch(0, mem(rax, 38*16))
1902
1903 vbroadcastsd(mem(rbx, 12*8), ymm2)
1904 vbroadcastsd(mem(rbx, 13*8), ymm3)
1905 vfmadd231pd(ymm0, ymm2, ymm4)
1906 vfmadd231pd(ymm1, ymm2, ymm5)
1907 vfmadd231pd(ymm0, ymm3, ymm6)
1908 vfmadd231pd(ymm1, ymm3, ymm7)
1909
1910 vbroadcastsd(mem(rbx, 14*8), ymm2)
1911 vbroadcastsd(mem(rbx, 15*8), ymm3)
1912 vfmadd231pd(ymm0, ymm2, ymm8)
1913 vfmadd231pd(ymm1, ymm2, ymm9)
1914 vfmadd231pd(ymm0, ymm3, ymm10)
1915 vfmadd231pd(ymm1, ymm3, ymm11)
1916
1917 vbroadcastsd(mem(rbx, 16*8), ymm2)
1918 vbroadcastsd(mem(rbx, 17*8), ymm3)
1919 vfmadd231pd(ymm0, ymm2, ymm12)
1920 vfmadd231pd(ymm1, ymm2, ymm13)
1921 vfmadd231pd(ymm0, ymm3, ymm14)
1922 vfmadd231pd(ymm1, ymm3, ymm15)
1923
1924 vmovapd(mem(rax, 2*32), ymm0)
1925 vmovapd(mem(rax, 3*32), ymm1)
1926
1927 // iteration 3
1928 vbroadcastsd(mem(rbx, 18*8), ymm2)
1929 vbroadcastsd(mem(rbx, 19*8), ymm3)
1930 vfmadd231pd(ymm0, ymm2, ymm4)
1931 vfmadd231pd(ymm1, ymm2, ymm5)
1932 vfmadd231pd(ymm0, ymm3, ymm6)
1933 vfmadd231pd(ymm1, ymm3, ymm7)
1934
1935 vbroadcastsd(mem(rbx, 20*8), ymm2)
1936 vbroadcastsd(mem(rbx, 21*8), ymm3)
1937 vfmadd231pd(ymm0, ymm2, ymm8)
1938 vfmadd231pd(ymm1, ymm2, ymm9)
1939 vfmadd231pd(ymm0, ymm3, ymm10)
1940 vfmadd231pd(ymm1, ymm3, ymm11)
1941
1942 vbroadcastsd(mem(rbx, 22*8), ymm2)
1943 vbroadcastsd(mem(rbx, 23*8), ymm3)
1944 vfmadd231pd(ymm0, ymm2, ymm12)
1945 vfmadd231pd(ymm1, ymm2, ymm13)
1946 vfmadd231pd(ymm0, ymm3, ymm14)
1947 vfmadd231pd(ymm1, ymm3, ymm15)
1948
1949 add(imm(4*4*16), rax) // a += 4*4 (unroll x mr)
1950 add(imm(4*3*16), rbx) // b += 4*3 (unroll x nr)
1951
1952 vmovapd(mem(rax, -4*32), ymm0)
1953 vmovapd(mem(rax, -3*32), ymm1)
1954
1955
1956 dec(rsi) // i -= 1;
1957 jne(.ZLOOPKITER) // iterate again if i != 0.
1958
1959
1960
1961
1962
1963
1964 label(.ZCONSIDKLEFT)
1965
1966 mov(var(k_left), rsi) // i = k_left;
1967 test(rsi, rsi) // check i via logical AND.
1968 je(.ZPOSTACCUM) // if i == 0, we're done; jump to end.
1969 // else, we prepare to enter k_left loop.
1970
1971
1972 label(.ZLOOPKLEFT) // EDGE LOOP
1973
1974 prefetch(0, mem(rax, 32*16))
1975
1976 vbroadcastsd(mem(rbx, 0*8), ymm2)
1977 vbroadcastsd(mem(rbx, 1*8), ymm3)
1978 vfmadd231pd(ymm0, ymm2, ymm4)
1979 vfmadd231pd(ymm1, ymm2, ymm5)
1980 vfmadd231pd(ymm0, ymm3, ymm6)
1981 vfmadd231pd(ymm1, ymm3, ymm7)
1982
1983 vbroadcastsd(mem(rbx, 2*8), ymm2)
1984 vbroadcastsd(mem(rbx, 3*8), ymm3)
1985 vfmadd231pd(ymm0, ymm2, ymm8)
1986 vfmadd231pd(ymm1, ymm2, ymm9)
1987 vfmadd231pd(ymm0, ymm3, ymm10)
1988 vfmadd231pd(ymm1, ymm3, ymm11)
1989
1990 vbroadcastsd(mem(rbx, 4*8), ymm2)
1991 vbroadcastsd(mem(rbx, 5*8), ymm3)
1992 vfmadd231pd(ymm0, ymm2, ymm12)
1993 vfmadd231pd(ymm1, ymm2, ymm13)
1994 vfmadd231pd(ymm0, ymm3, ymm14)
1995 vfmadd231pd(ymm1, ymm3, ymm15)
1996
1997 add(imm(1*4*16), rax) // a += 1*4 (unroll x mr)
1998 add(imm(1*3*16), rbx) // b += 1*3 (unroll x nr)
1999
2000 vmovapd(mem(rax, -4*32), ymm0)
2001 vmovapd(mem(rax, -3*32), ymm1)
2002
2003
2004 dec(rsi) // i -= 1;
2005 jne(.ZLOOPKLEFT) // iterate again if i != 0.
2006
2007
2008
2009 label(.ZPOSTACCUM)
2010
2011 // permute even and odd elements
2012 // of ymm6/7, ymm10/11, ymm/14/15
2013 vpermilpd(imm(0x5), ymm6, ymm6)
2014 vpermilpd(imm(0x5), ymm7, ymm7)
2015 vpermilpd(imm(0x5), ymm10, ymm10)
2016 vpermilpd(imm(0x5), ymm11, ymm11)
2017 vpermilpd(imm(0x5), ymm14, ymm14)
2018 vpermilpd(imm(0x5), ymm15, ymm15)
2019
2020
2021 // subtract/add even/odd elements
2022 vaddsubpd(ymm6, ymm4, ymm4)
2023 vaddsubpd(ymm7, ymm5, ymm5)
2024
2025 vaddsubpd(ymm10, ymm8, ymm8)
2026 vaddsubpd(ymm11, ymm9, ymm9)
2027
2028 vaddsubpd(ymm14, ymm12, ymm12)
2029 vaddsubpd(ymm15, ymm13, ymm13)
2030
2031
2032
2033
2034 mov(var(alpha), rax) // load address of alpha
2035 vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate
2036 vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate
2037
2038
2039 vpermilpd(imm(0x5), ymm4, ymm3)
2040 vmulpd(ymm0, ymm4, ymm4)
2041 vmulpd(ymm1, ymm3, ymm3)
2042 vaddsubpd(ymm3, ymm4, ymm4)
2043
2044 vpermilpd(imm(0x5), ymm5, ymm3)
2045 vmulpd(ymm0, ymm5, ymm5)
2046 vmulpd(ymm1, ymm3, ymm3)
2047 vaddsubpd(ymm3, ymm5, ymm5)
2048
2049
2050 vpermilpd(imm(0x5), ymm8, ymm3)
2051 vmulpd(ymm0, ymm8, ymm8)
2052 vmulpd(ymm1, ymm3, ymm3)
2053 vaddsubpd(ymm3, ymm8, ymm8)
2054
2055 vpermilpd(imm(0x5), ymm9, ymm3)
2056 vmulpd(ymm0, ymm9, ymm9)
2057 vmulpd(ymm1, ymm3, ymm3)
2058 vaddsubpd(ymm3, ymm9, ymm9)
2059
2060
2061 vpermilpd(imm(0x5), ymm12, ymm3)
2062 vmulpd(ymm0, ymm12, ymm12)
2063 vmulpd(ymm1, ymm3, ymm3)
2064 vaddsubpd(ymm3, ymm12, ymm12)
2065
2066 vpermilpd(imm(0x5), ymm13, ymm3)
2067 vmulpd(ymm0, ymm13, ymm13)
2068 vmulpd(ymm1, ymm3, ymm3)
2069 vaddsubpd(ymm3, ymm13, ymm13)
2070
2071
2072
2073
2074
2075 mov(var(beta), rbx) // load address of beta
2076 vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate
2077 vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate
2078
2079
2080
2081
2082 mov(var(rs_c), rsi) // load rs_c
2083 lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex)
2084 lea(mem(, rsi, 2), rsi)
2085 lea(mem(, rsi, 2), rdx) // rdx = 2*rs_c;
2086
2087
2088
2089 // now avoid loading C if beta == 0
2090 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero.
2091 vucomisd(xmm0, xmm1) // set ZF if beta_r == 0.
2092 sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 );
2093 vucomisd(xmm0, xmm2) // set ZF if beta_i == 0.
2094 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 );
2095 and(r8b, r9b) // set ZF if r8b & r9b == 1.
2096 jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case
2097
2098
2099 cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16.
2100 jz(.ZCOLSTORED) // jump to row storage case
2101
2102
2103
2104 label(.ZGENSTORED)
2105
2106
2107 ZGEMM_INPUT_SCALE_GS_BETA_NZ
2108 vaddpd(ymm4, ymm0, ymm0)
2109 ZGEMM_OUTPUT_GS
2110 add(rdx, rcx) // c += 2*rs_c;
2111
2112
2113 ZGEMM_INPUT_SCALE_GS_BETA_NZ
2114 vaddpd(ymm5, ymm0, ymm0)
2115 ZGEMM_OUTPUT_GS
2116 mov(r11, rcx) // rcx = c + 1*cs_c
2117
2118
2119
2120 ZGEMM_INPUT_SCALE_GS_BETA_NZ
2121 vaddpd(ymm8, ymm0, ymm0)
2122 ZGEMM_OUTPUT_GS
2123 add(rdx, rcx) // c += 2*rs_c;
2124
2125
2126 ZGEMM_INPUT_SCALE_GS_BETA_NZ
2127 vaddpd(ymm9, ymm0, ymm0)
2128 ZGEMM_OUTPUT_GS
2129 mov(r12, rcx) // rcx = c + 2*cs_c
2130
2131
2132
2133 ZGEMM_INPUT_SCALE_GS_BETA_NZ
2134 vaddpd(ymm12, ymm0, ymm0)
2135 ZGEMM_OUTPUT_GS
2136 add(rdx, rcx) // c += 2*rs_c;
2137
2138
2139 ZGEMM_INPUT_SCALE_GS_BETA_NZ
2140 vaddpd(ymm13, ymm0, ymm0)
2141 ZGEMM_OUTPUT_GS
2142
2143
2144
2145 jmp(.ZDONE) // jump to end.
2146
2147
2148
2149 label(.ZCOLSTORED)
2150
2151
2152 ZGEMM_INPUT_SCALE_CS_BETA_NZ
2153 vaddpd(ymm4, ymm0, ymm0)
2154 ZGEMM_OUTPUT_CS
2155 add(rdx, rcx) // c += 2*rs_c;
2156
2157
2158 ZGEMM_INPUT_SCALE_CS_BETA_NZ
2159 vaddpd(ymm5, ymm0, ymm0)
2160 ZGEMM_OUTPUT_CS
2161 mov(r11, rcx) // rcx = c + 1*cs_c
2162
2163
2164
2165 ZGEMM_INPUT_SCALE_CS_BETA_NZ
2166 vaddpd(ymm8, ymm0, ymm0)
2167 ZGEMM_OUTPUT_CS
2168 add(rdx, rcx) // c += 2*rs_c;
2169
2170
2171 ZGEMM_INPUT_SCALE_CS_BETA_NZ
2172 vaddpd(ymm9, ymm0, ymm0)
2173 ZGEMM_OUTPUT_CS
2174 mov(r12, rcx) // rcx = c + 2*cs_c
2175
2176
2177
2178 ZGEMM_INPUT_SCALE_CS_BETA_NZ
2179 vaddpd(ymm12, ymm0, ymm0)
2180 ZGEMM_OUTPUT_CS
2181 add(rdx, rcx) // c += 2*rs_c;
2182
2183
2184 ZGEMM_INPUT_SCALE_CS_BETA_NZ
2185 vaddpd(ymm13, ymm0, ymm0)
2186 ZGEMM_OUTPUT_CS
2187
2188
2189
2190 jmp(.ZDONE) // jump to end.
2191
2192
2193
2194 label(.ZBETAZERO)
2195
2196 cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16.
2197 jz(.ZCOLSTORBZ) // jump to row storage case
2198
2199
2200
2201 label(.ZGENSTORBZ)
2202
2203
2204 vmovapd(ymm4, ymm0)
2205 ZGEMM_OUTPUT_GS
2206 add(rdx, rcx) // c += 2*rs_c;
2207
2208
2209 vmovapd(ymm5, ymm0)
2210 ZGEMM_OUTPUT_GS
2211 mov(r11, rcx) // rcx = c + 1*cs_c
2212
2213
2214
2215 vmovapd(ymm8, ymm0)
2216 ZGEMM_OUTPUT_GS
2217 add(rdx, rcx) // c += 2*rs_c;
2218
2219
2220 vmovapd(ymm9, ymm0)
2221 ZGEMM_OUTPUT_GS
2222 mov(r12, rcx) // rcx = c + 2*cs_c
2223
2224
2225
2226 vmovapd(ymm12, ymm0)
2227 ZGEMM_OUTPUT_GS
2228 add(rdx, rcx) // c += 2*rs_c;
2229
2230
2231 vmovapd(ymm13, ymm0)
2232 ZGEMM_OUTPUT_GS
2233
2234
2235
2236 jmp(.ZDONE) // jump to end.
2237
2238
2239
2240 label(.ZCOLSTORBZ)
2241
2242
2243 vmovupd(ymm4, mem(rcx))
2244 vmovupd(ymm5, mem(rcx, rdx, 1))
2245
2246 vmovupd(ymm8, mem(r11))
2247 vmovupd(ymm9, mem(r11, rdx, 1))
2248
2249 vmovupd(ymm12, mem(r12))
2250 vmovupd(ymm13, mem(r12, rdx, 1))
2251
2252
2253
2254
2255
2256
2257 label(.ZDONE)
2258
2259
2260
2261 end_asm(
2262 : // output operands (none)
2263 : // input operands
2264 [k_iter] "m" (k_iter), // 0
2265 [k_left] "m" (k_left), // 1
2266 [a] "m" (a), // 2
2267 [b] "m" (b), // 3
2268 [alpha] "m" (alpha), // 4
2269 [beta] "m" (beta), // 5
2270 [c] "m" (c), // 6
2271 [rs_c] "m" (rs_c), // 7
2272 [cs_c] "m" (cs_c)/*, // 8
2273 [b_next] "m" (b_next), // 9
2274 [a_next] "m" (a_next)*/ // 10
2275 : // register clobber list
2276 "rax", "rbx", "rcx", "rdx", "rsi", "rdi",
2277 "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15",
2278 "xmm0", "xmm1", "xmm2", "xmm3",
2279 "xmm4", "xmm5", "xmm6", "xmm7",
2280 "xmm8", "xmm9", "xmm10", "xmm11",
2281 "xmm12", "xmm13", "xmm14", "xmm15",
2282 "memory"
2283 )
2284 }
2285
2286
2287