1 /*
2
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
6
7 Copyright (C) 2014, The University of Texas at Austin
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived derived from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
25 OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
29 OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33 */
34
35 #include "blis.h"
36 #include <assert.h>
37
38 #include "../knl/bli_avx512_macros.h"
39
40 #define A_L1_PREFETCH_DIST 4 //should be multiple of 2
41
42 #define LOOP_ALIGN ALIGN16
43
44 #define UPDATE_C(R1,R2,R3,R4) \
45 \
46 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
47 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
48 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
49 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
50 VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX,0*64)) \
51 VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX,1*64)) \
52 VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX,RAX,1,0*64)) \
53 VFMADD231PD(ZMM(R4), ZMM(1), MEM(RCX,RAX,1,1*64)) \
54 VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
55 VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
56 VMOVUPD(MEM(RCX,RAX,1,0*64), ZMM(R3)) \
57 VMOVUPD(MEM(RCX,RAX,1,1*64), ZMM(R4)) \
58 LEA(RCX, MEM(RCX,RAX,2))
59
60 #define UPDATE_C_BZ(R1,R2,R3,R4) \
61 \
62 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
63 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
64 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
65 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
66 VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
67 VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
68 VMOVUPD(MEM(RCX,RAX,1,0*64), ZMM(R3)) \
69 VMOVUPD(MEM(RCX,RAX,1,1*64), ZMM(R4)) \
70 LEA(RCX, MEM(RCX,RAX,2))
71
72 #define UPDATE_C_ROW_SCATTERED(R1,R2,R3,R4) \
73 \
74 KXNORW(K(1), K(0), K(0)) \
75 KXNORW(K(2), K(0), K(0)) \
76 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
77 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(2),8)) \
78 VFMADD231PD(ZMM(R1), ZMM(6), ZMM(1)) \
79 VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(2), ZMM(R1)) \
80 \
81 KXNORW(K(1), K(0), K(0)) \
82 KXNORW(K(2), K(0), K(0)) \
83 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
84 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(3),8)) \
85 VFMADD231PD(ZMM(R2), ZMM(6), ZMM(1)) \
86 VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(2), ZMM(R2)) \
87 \
88 LEA(RCX, MEM(RCX,RAX,1)) \
89 \
90 KXNORW(K(1), K(0), K(0)) \
91 KXNORW(K(2), K(0), K(0)) \
92 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
93 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(2),8)) \
94 VFMADD231PD(ZMM(R3), ZMM(6), ZMM(1)) \
95 VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(2), ZMM(R3)) \
96 \
97 KXNORW(K(1), K(0), K(0)) \
98 KXNORW(K(2), K(0), K(0)) \
99 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
100 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(3),8)) \
101 VFMADD231PD(ZMM(R4), ZMM(6), ZMM(1)) \
102 VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(2), ZMM(R4)) \
103 \
104 LEA(RCX, MEM(RCX,RAX,1))
105
106 #define UPDATE_C_BZ_ROW_SCATTERED(R1,R2,R3,R4) \
107 \
108 KXNORW(K(1), K(0), K(0)) \
109 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
110 VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(1), ZMM(R1)) \
111 \
112 KXNORW(K(1), K(0), K(0)) \
113 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
114 VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(1), ZMM(R2)) \
115 \
116 LEA(RCX, MEM(RCX,RAX,1)) \
117 \
118 KXNORW(K(1), K(0), K(0)) \
119 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
120 VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(1), ZMM(R3)) \
121 \
122 KXNORW(K(1), K(0), K(0)) \
123 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
124 VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(1), ZMM(R4)) \
125 \
126 LEA(RCX, MEM(RCX,RAX,1))
127
128 #define PREFETCH_C_L1 \
129 \
130 PREFETCHW0(MEM(RCX, 0*64)) \
131 PREFETCHW0(MEM(RCX, 1*64)) \
132 PREFETCHW0(MEM(RCX,R12,1,0*64)) \
133 PREFETCHW0(MEM(RCX,R12,1,1*64)) \
134 PREFETCHW0(MEM(RCX,R12,2,0*64)) \
135 PREFETCHW0(MEM(RCX,R12,2,1*64)) \
136 PREFETCHW0(MEM(RCX,R13,1,0*64)) \
137 PREFETCHW0(MEM(RCX,R13,1,1*64)) \
138 PREFETCHW0(MEM(RCX,R12,4,0*64)) \
139 PREFETCHW0(MEM(RCX,R12,4,1*64)) \
140 PREFETCHW0(MEM(RCX,R14,1,0*64)) \
141 PREFETCHW0(MEM(RCX,R14,1,1*64)) \
142 PREFETCHW0(MEM(RCX,R13,2,0*64)) \
143 PREFETCHW0(MEM(RCX,R13,2,1*64)) \
144 PREFETCHW0(MEM(RCX,R15,1,0*64)) \
145 PREFETCHW0(MEM(RCX,R15,1,1*64)) \
146 PREFETCHW0(MEM(RDX, 0*64)) \
147 PREFETCHW0(MEM(RDX, 1*64)) \
148 PREFETCHW0(MEM(RDX,R12,1,0*64)) \
149 PREFETCHW0(MEM(RDX,R12,1,1*64)) \
150 PREFETCHW0(MEM(RDX,R12,2,0*64)) \
151 PREFETCHW0(MEM(RDX,R12,2,1*64)) \
152 PREFETCHW0(MEM(RDX,R13,1,0*64)) \
153 PREFETCHW0(MEM(RDX,R13,1,1*64))
154
155 //
156 // n: index in unrolled loop
157 //
158 // a: ZMM register to load into
159 // b: ZMM register to read from
160 //
161 // ...: addressing for A, except for offset
162 //
163 #define SUBITER(n) \
164 \
165 VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 0)*8)) \
166 VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 1)*8)) \
167 VFMADD231PD(ZMM( 8), ZMM(0), ZMM(3)) \
168 VFMADD231PD(ZMM( 9), ZMM(1), ZMM(3)) \
169 VFMADD231PD(ZMM(10), ZMM(0), ZMM(4)) \
170 VFMADD231PD(ZMM(11), ZMM(1), ZMM(4)) \
171 \
172 VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 2)*8)) \
173 VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 3)*8)) \
174 VFMADD231PD(ZMM(12), ZMM(0), ZMM(3)) \
175 VFMADD231PD(ZMM(13), ZMM(1), ZMM(3)) \
176 VFMADD231PD(ZMM(14), ZMM(0), ZMM(4)) \
177 VFMADD231PD(ZMM(15), ZMM(1), ZMM(4)) \
178 \
179 VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 4)*8)) \
180 VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 5)*8)) \
181 VFMADD231PD(ZMM(16), ZMM(0), ZMM(3)) \
182 VFMADD231PD(ZMM(17), ZMM(1), ZMM(3)) \
183 VFMADD231PD(ZMM(18), ZMM(0), ZMM(4)) \
184 VFMADD231PD(ZMM(19), ZMM(1), ZMM(4)) \
185 \
186 VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 6)*8)) \
187 VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 7)*8)) \
188 VFMADD231PD(ZMM(20), ZMM(0), ZMM(3)) \
189 VFMADD231PD(ZMM(21), ZMM(1), ZMM(3)) \
190 VFMADD231PD(ZMM(22), ZMM(0), ZMM(4)) \
191 VFMADD231PD(ZMM(23), ZMM(1), ZMM(4)) \
192 \
193 VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 8)*8)) \
194 VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 9)*8)) \
195 VFMADD231PD(ZMM(24), ZMM(0), ZMM(3)) \
196 VFMADD231PD(ZMM(25), ZMM(1), ZMM(3)) \
197 VFMADD231PD(ZMM(26), ZMM(0), ZMM(4)) \
198 VFMADD231PD(ZMM(27), ZMM(1), ZMM(4)) \
199 \
200 VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+10)*8)) \
201 VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+11)*8)) \
202 VFMADD231PD(ZMM(28), ZMM(0), ZMM(3)) \
203 VFMADD231PD(ZMM(29), ZMM(1), ZMM(3)) \
204 VFMADD231PD(ZMM(30), ZMM(0), ZMM(4)) \
205 VFMADD231PD(ZMM(31), ZMM(1), ZMM(4)) \
206 \
207 VMOVAPD(ZMM(0), MEM(RBX,(16*n+0)*8)) \
208 VMOVAPD(ZMM(1), MEM(RBX,(16*n+8)*8))
209
210 //This is an array used for the scatter/gather instructions.
211 static int64_t offsets[16] __attribute__((aligned(64))) =
212 { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
213
bli_dgemm_opt_12x16_l1(dim_t k_,double * restrict alpha,double * restrict a,double * restrict b,double * restrict beta,double * restrict c,inc_t rs_c_,inc_t cs_c_,auxinfo_t * data,cntx_t * restrict cntx)214 void bli_dgemm_opt_12x16_l1(
215 dim_t k_,
216 double* restrict alpha,
217 double* restrict a,
218 double* restrict b,
219 double* restrict beta,
220 double* restrict c, inc_t rs_c_, inc_t cs_c_,
221 auxinfo_t* data,
222 cntx_t* restrict cntx
223 )
224 {
225 (void)data;
226 (void)cntx;
227
228 const int64_t* offsetPtr = &offsets[0];
229 const int64_t k = k_;
230 const int64_t rs_c = rs_c_;
231 const int64_t cs_c = cs_c_;
232
233 __asm__ volatile
234 (
235
236 VXORPD(YMM(8), YMM(8), YMM(8)) //clear out registers
237 VMOVAPD(YMM( 7), YMM(8))
238 VMOVAPD(YMM( 9), YMM(8))
239 VMOVAPD(YMM(10), YMM(8)) MOV(RSI, VAR(k)) //loop index
240 VMOVAPD(YMM(11), YMM(8)) MOV(RAX, VAR(a)) //load address of a
241 VMOVAPD(YMM(12), YMM(8)) MOV(RBX, VAR(b)) //load address of b
242 VMOVAPD(YMM(13), YMM(8)) MOV(RCX, VAR(c)) //load address of c
243 VMOVAPD(YMM(14), YMM(8))
244 VMOVAPD(YMM(15), YMM(8)) VMOVAPD(ZMM(0), MEM(RBX, 0*8)) //pre-load b
245 VMOVAPD(YMM(16), YMM(8)) VMOVAPD(ZMM(1), MEM(RBX, 8*8)) //pre-load b
246 VMOVAPD(YMM(17), YMM(8))
247 VMOVAPD(YMM(18), YMM(8))
248 VMOVAPD(YMM(19), YMM(8)) MOV(R12, VAR(rs_c)) //rs_c
249 VMOVAPD(YMM(20), YMM(8)) LEA(R13, MEM(R12,R12,2)) //*3
250 VMOVAPD(YMM(21), YMM(8)) LEA(R14, MEM(R12,R12,4)) //*5
251 VMOVAPD(YMM(22), YMM(8)) LEA(R15, MEM(R14,R12,2)) //*7
252 VMOVAPD(YMM(23), YMM(8)) LEA(RDX, MEM(RCX,R12,8)) //c + 8*rs_c
253 VMOVAPD(YMM(24), YMM(8))
254 VMOVAPD(YMM(25), YMM(8)) MOV(R8, IMM(12*8)) //mr*sizeof(double)
255 VMOVAPD(YMM(26), YMM(8)) MOV(R9, IMM(16*8)) //nr*sizeof(double)
256 VMOVAPD(YMM(27), YMM(8))
257 VMOVAPD(YMM(28), YMM(8)) LEA(RBX, MEM(RBX,R9,1)) //adjust b for pre-load
258 VMOVAPD(YMM(29), YMM(8))
259 VMOVAPD(YMM(30), YMM(8))
260 VMOVAPD(YMM(31), YMM(8))
261
262 TEST(RSI, RSI)
263 JZ(POSTACCUM)
264
265 PREFETCH_C_L1
266
267 MOV(RDI, RSI)
268 AND(RSI, IMM(3))
269 SAR(RDI, IMM(2))
270 JZ(TAIL_LOOP)
271
272 LOOP_ALIGN
273 LABEL(MAIN_LOOP)
274
275 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8))
276 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+64))
277 SUBITER(0)
278 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+128))
279 SUBITER(1)
280 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+192))
281 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+256))
282 SUBITER(2)
283 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+320))
284 SUBITER(3)
285
286 LEA(RAX, MEM(RAX,R8,4))
287 LEA(RBX, MEM(RBX,R9,4))
288
289 DEC(RDI)
290
291 JNZ(MAIN_LOOP)
292
293 TEST(RSI, RSI)
294 JZ(POSTACCUM)
295
296 LOOP_ALIGN
297 LABEL(TAIL_LOOP)
298
299 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8))
300 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+64))
301 SUBITER(0)
302
303 ADD(RAX, R8)
304 ADD(RBX, R9)
305
306 DEC(RSI)
307
308 JNZ(TAIL_LOOP)
309
310 LABEL(POSTACCUM)
311
312 MOV(RAX, VAR(alpha))
313 MOV(RBX, VAR(beta))
314 VBROADCASTSD(ZMM(0), MEM(RAX))
315 VBROADCASTSD(ZMM(1), MEM(RBX))
316
317 MOV(RAX, VAR(rs_c))
318 LEA(RAX, MEM(,RAX,8))
319 MOV(RBX, VAR(cs_c))
320
321 // Check if C is row stride. If not, jump to the slow scattered update
322 CMP(RBX, IMM(1))
323 JNE(SCATTEREDUPDATE)
324
325 VCOMISD(XMM(1), XMM(7))
326 JE(COLSTORBZ)
327
328 UPDATE_C( 8, 9,10,11)
329 UPDATE_C(12,13,14,15)
330 UPDATE_C(16,17,18,19)
331 UPDATE_C(20,21,22,23)
332 UPDATE_C(24,25,26,27)
333 UPDATE_C(28,29,30,31)
334
335 JMP(END)
336 LABEL(COLSTORBZ)
337
338 UPDATE_C_BZ( 8, 9,10,11)
339 UPDATE_C_BZ(12,13,14,15)
340 UPDATE_C_BZ(16,17,18,19)
341 UPDATE_C_BZ(20,21,22,23)
342 UPDATE_C_BZ(24,25,26,27)
343 UPDATE_C_BZ(28,29,30,31)
344
345 JMP(END)
346 LABEL(SCATTEREDUPDATE)
347
348 MOV(RDI, VAR(offsetPtr))
349 VMOVDQA64(ZMM(2), MEM(RDI,0*64))
350 VMOVDQA64(ZMM(3), MEM(RDI,1*64))
351 VPBROADCASTQ(ZMM(6), RBX)
352 VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
353 VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
354
355 VCOMISD(XMM(1), XMM(7))
356 JE(SCATTERBZ)
357
358 UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
359 UPDATE_C_ROW_SCATTERED(12,13,14,15)
360 UPDATE_C_ROW_SCATTERED(16,17,18,19)
361 UPDATE_C_ROW_SCATTERED(20,21,22,23)
362 UPDATE_C_ROW_SCATTERED(24,25,26,27)
363 UPDATE_C_ROW_SCATTERED(28,29,30,31)
364
365 JMP(END)
366 LABEL(SCATTERBZ)
367
368 UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
369 UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
370 UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
371 UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
372 UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
373 UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
374
375 LABEL(END)
376
377 VZEROUPPER()
378
379 : // output operands
380 : // input operands
381 [k] "m" (k),
382 [a] "m" (a),
383 [b] "m" (b),
384 [alpha] "m" (alpha),
385 [beta] "m" (beta),
386 [c] "m" (c),
387 [rs_c] "m" (rs_c),
388 [cs_c] "m" (cs_c),
389 [offsetPtr] "m" (offsetPtr)
390 : // register clobber list
391 "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
392 "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
393 "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13",
394 "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21",
395 "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
396 "zmm30", "zmm31", "memory"
397 );
398 }
399