1 /*
2
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
6
7 Copyright (C) 2014, The University of Texas at Austin
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived derived from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
25 OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
29 OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33 */
34
35 #include "blis.h"
36 #include <assert.h>
37
38 #include "../knl/bli_avx512_macros.h"
39
40 #define A_L1_PREFETCH_DIST 4 //should be multiple of 4
41
42 #define LOOP_ALIGN ALIGN16
43
44 #define UPDATE_C(R1,R2,R3,R4) \
45 \
46 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
47 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
48 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
49 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
50 VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX,0*64)) \
51 VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX,1*64)) \
52 VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX,2*64)) \
53 VFMADD231PD(ZMM(R4), ZMM(1), MEM(RCX,3*64)) \
54 VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
55 VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
56 VMOVUPD(MEM(RCX,2*64), ZMM(R3)) \
57 VMOVUPD(MEM(RCX,3*64), ZMM(R4)) \
58 LEA(RCX, MEM(RCX,RAX,1))
59
60 #define UPDATE_C_BZ(R1,R2,R3,R4) \
61 \
62 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
63 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
64 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
65 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
66 VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
67 VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
68 VMOVUPD(MEM(RCX,2*64), ZMM(R3)) \
69 VMOVUPD(MEM(RCX,3*64), ZMM(R4)) \
70 LEA(RCX, MEM(RCX,RAX,1))
71
72 #define UPDATE_C_ROW_SCATTERED(R1,R2,R3,R4) \
73 \
74 KXNORW(K(1), K(0), K(0)) \
75 KXNORW(K(2), K(0), K(0)) \
76 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
77 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(2),8)) \
78 VFMADD231PD(ZMM(R1), ZMM(6), ZMM(1)) \
79 VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(2), ZMM(R1)) \
80 \
81 KXNORW(K(1), K(0), K(0)) \
82 KXNORW(K(2), K(0), K(0)) \
83 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
84 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(3),8)) \
85 VFMADD231PD(ZMM(R2), ZMM(6), ZMM(1)) \
86 VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(2), ZMM(R2)) \
87 \
88 KXNORW(K(1), K(0), K(0)) \
89 KXNORW(K(2), K(0), K(0)) \
90 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
91 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(4),8)) \
92 VFMADD231PD(ZMM(R3), ZMM(6), ZMM(1)) \
93 VSCATTERQPD(MEM(RCX,ZMM(4),8) MASK_K(2), ZMM(R3)) \
94 \
95 KXNORW(K(1), K(0), K(0)) \
96 KXNORW(K(2), K(0), K(0)) \
97 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
98 VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(5),8)) \
99 VFMADD231PD(ZMM(R4), ZMM(6), ZMM(1)) \
100 VSCATTERQPD(MEM(RCX,ZMM(5),8) MASK_K(2), ZMM(R4)) \
101 \
102 LEA(RCX, MEM(RCX,RAX,1))
103
104 #define UPDATE_C_BZ_ROW_SCATTERED(R1,R2,R3,R4) \
105 \
106 KXNORW(K(1), K(0), K(0)) \
107 VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
108 VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(1), ZMM(R1)) \
109 \
110 KXNORW(K(1), K(0), K(0)) \
111 VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
112 VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(1), ZMM(R2)) \
113 \
114 KXNORW(K(1), K(0), K(0)) \
115 VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
116 VSCATTERQPD(MEM(RCX,ZMM(4),8) MASK_K(1), ZMM(R3)) \
117 \
118 KXNORW(K(1), K(0), K(0)) \
119 VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
120 VSCATTERQPD(MEM(RCX,ZMM(5),8) MASK_K(1), ZMM(R4)) \
121 \
122 LEA(RCX, MEM(RCX,RAX,1))
123
124 #define PREFETCH_C_L1 \
125 \
126 PREFETCHW0(MEM(RCX, 0*64)) \
127 PREFETCHW0(MEM(RCX, 1*64)) \
128 PREFETCHW0(MEM(RCX, 2*64)) \
129 PREFETCHW0(MEM(RCX, 3*64)) \
130 PREFETCHW0(MEM(RCX,R12,1,0*64)) \
131 PREFETCHW0(MEM(RCX,R12,1,1*64)) \
132 PREFETCHW0(MEM(RCX,R12,1,2*64)) \
133 PREFETCHW0(MEM(RCX,R12,1,3*64)) \
134 PREFETCHW0(MEM(RCX,R12,2,0*64)) \
135 PREFETCHW0(MEM(RCX,R12,2,1*64)) \
136 PREFETCHW0(MEM(RCX,R12,2,2*64)) \
137 PREFETCHW0(MEM(RCX,R12,2,3*64)) \
138 PREFETCHW0(MEM(RCX,R13,1,0*64)) \
139 PREFETCHW0(MEM(RCX,R13,1,1*64)) \
140 PREFETCHW0(MEM(RCX,R13,1,2*64)) \
141 PREFETCHW0(MEM(RCX,R13,1,3*64)) \
142 PREFETCHW0(MEM(RCX,R12,4,0*64)) \
143 PREFETCHW0(MEM(RCX,R12,4,1*64)) \
144 PREFETCHW0(MEM(RCX,R12,4,2*64)) \
145 PREFETCHW0(MEM(RCX,R12,4,3*64)) \
146 PREFETCHW0(MEM(RCX,R14,1,0*64)) \
147 PREFETCHW0(MEM(RCX,R14,1,1*64)) \
148 PREFETCHW0(MEM(RCX,R14,1,2*64)) \
149 PREFETCHW0(MEM(RCX,R14,1,3*64)) \
150
151 //
152 // n: index in unrolled loop
153 //
154 // a: ZMM register to load into
155 // b: ZMM register to read from
156 //
157 // ...: addressing for A, except for offset
158 //
159 #define SUBITER(n) \
160 \
161 VBROADCASTSD(ZMM(4), MEM(RAX,(6*n+0)*8)) \
162 VBROADCASTSD(ZMM(5), MEM(RAX,(6*n+1)*8)) \
163 VFMADD231PD(ZMM( 8), ZMM(0), ZMM(4)) VFMADD231PD(ZMM(12), ZMM(0), ZMM(5)) \
164 VFMADD231PD(ZMM( 9), ZMM(1), ZMM(4)) VFMADD231PD(ZMM(13), ZMM(1), ZMM(5)) \
165 VFMADD231PD(ZMM(10), ZMM(2), ZMM(4)) VFMADD231PD(ZMM(14), ZMM(2), ZMM(5)) \
166 VFMADD231PD(ZMM(11), ZMM(3), ZMM(4)) VFMADD231PD(ZMM(15), ZMM(3), ZMM(5)) \
167 \
168 VBROADCASTSD(ZMM(4), MEM(RAX,(6*n+2)*8)) \
169 VBROADCASTSD(ZMM(5), MEM(RAX,(6*n+3)*8)) \
170 VFMADD231PD(ZMM(16), ZMM(0), ZMM(4)) VFMADD231PD(ZMM(20), ZMM(0), ZMM(5)) \
171 VFMADD231PD(ZMM(17), ZMM(1), ZMM(4)) VFMADD231PD(ZMM(21), ZMM(1), ZMM(5)) \
172 VFMADD231PD(ZMM(18), ZMM(2), ZMM(4)) VFMADD231PD(ZMM(22), ZMM(2), ZMM(5)) \
173 VFMADD231PD(ZMM(19), ZMM(3), ZMM(4)) VFMADD231PD(ZMM(23), ZMM(3), ZMM(5)) \
174 \
175 VBROADCASTSD(ZMM(4), MEM(RAX,(6*n+4)*8)) \
176 VBROADCASTSD(ZMM(5), MEM(RAX,(6*n+5)*8)) \
177 VFMADD231PD(ZMM(24), ZMM(0), ZMM(4)) VFMADD231PD(ZMM(28), ZMM(0), ZMM(5)) \
178 VFMADD231PD(ZMM(25), ZMM(1), ZMM(4)) VFMADD231PD(ZMM(29), ZMM(1), ZMM(5)) \
179 VFMADD231PD(ZMM(26), ZMM(2), ZMM(4)) VFMADD231PD(ZMM(30), ZMM(2), ZMM(5)) \
180 VFMADD231PD(ZMM(27), ZMM(3), ZMM(4)) VFMADD231PD(ZMM(31), ZMM(3), ZMM(5)) \
181 \
182 VMOVAPD(ZMM(0), MEM(RBX,(32*n+ 0)*8)) \
183 VMOVAPD(ZMM(1), MEM(RBX,(32*n+ 8)*8)) \
184 VMOVAPD(ZMM(2), MEM(RBX,(32*n+16)*8)) \
185 VMOVAPD(ZMM(3), MEM(RBX,(32*n+24)*8))
186
187 //This is an array used for the scatter/gather instructions.
188 static int64_t offsets[32] __attribute__((aligned(64))) =
189 { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,
190 16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
191
bli_dgemm_opt_6x32_l1(dim_t k_,double * restrict alpha,double * restrict a,double * restrict b,double * restrict beta,double * restrict c,inc_t rs_c_,inc_t cs_c_,auxinfo_t * data,cntx_t * restrict cntx)192 void bli_dgemm_opt_6x32_l1(
193 dim_t k_,
194 double* restrict alpha,
195 double* restrict a,
196 double* restrict b,
197 double* restrict beta,
198 double* restrict c, inc_t rs_c_, inc_t cs_c_,
199 auxinfo_t* data,
200 cntx_t* restrict cntx
201 )
202 {
203 (void)data;
204 (void)cntx;
205
206 const int64_t* offsetPtr = &offsets[0];
207 const int64_t k = k_;
208 const int64_t rs_c = rs_c_;
209 const int64_t cs_c = cs_c_;
210
211 __asm__ volatile
212 (
213
214 VXORPD(YMM(8), YMM(8), YMM(8)) //clear out registers
215 VMOVAPD(YMM( 7), YMM(8))
216 VMOVAPD(YMM( 9), YMM(8))
217 VMOVAPD(YMM(10), YMM(8)) MOV(RSI, VAR(k)) //loop index
218 VMOVAPD(YMM(11), YMM(8)) MOV(RAX, VAR(a)) //load address of a
219 VMOVAPD(YMM(12), YMM(8)) MOV(RBX, VAR(b)) //load address of b
220 VMOVAPD(YMM(13), YMM(8)) MOV(RCX, VAR(c)) //load address of c
221 VMOVAPD(YMM(14), YMM(8))
222 VMOVAPD(YMM(15), YMM(8)) VMOVAPD(ZMM(0), MEM(RBX, 0*8)) //pre-load b
223 VMOVAPD(YMM(16), YMM(8)) VMOVAPD(ZMM(1), MEM(RBX, 8*8)) //pre-load b
224 VMOVAPD(YMM(17), YMM(8)) VMOVAPD(ZMM(2), MEM(RBX,16*8)) //pre-load b
225 VMOVAPD(YMM(18), YMM(8)) VMOVAPD(ZMM(3), MEM(RBX,24*8)) //pre-load b
226 VMOVAPD(YMM(19), YMM(8))
227 VMOVAPD(YMM(20), YMM(8))
228 VMOVAPD(YMM(21), YMM(8)) MOV(R12, VAR(rs_c)) //rs_c
229 VMOVAPD(YMM(22), YMM(8)) LEA(R13, MEM(R12,R12,2)) //*3
230 VMOVAPD(YMM(23), YMM(8)) LEA(R14, MEM(R12,R12,4)) //*5
231 VMOVAPD(YMM(24), YMM(8))
232 VMOVAPD(YMM(25), YMM(8)) MOV(R8, IMM( 6*8)) //mr*sizeof(double)
233 VMOVAPD(YMM(26), YMM(8)) MOV(R9, IMM(32*8)) //nr*sizeof(double)
234 VMOVAPD(YMM(27), YMM(8))
235 VMOVAPD(YMM(28), YMM(8)) LEA(RBX, MEM(RBX,R9,1)) //adjust b for pre-load
236 VMOVAPD(YMM(29), YMM(8))
237 VMOVAPD(YMM(30), YMM(8))
238 VMOVAPD(YMM(31), YMM(8))
239
240 TEST(RSI, RSI)
241 JZ(POSTACCUM)
242
243 PREFETCH_C_L1
244
245 MOV(RDI, RSI)
246 AND(RSI, IMM(3))
247 SAR(RDI, IMM(2))
248 JZ(TAIL_LOOP)
249
250 LOOP_ALIGN
251 LABEL(MAIN_LOOP)
252
253 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8))
254 SUBITER(0)
255 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8+64))
256 SUBITER(1)
257 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8+128))
258 SUBITER(2)
259 SUBITER(3)
260
261 LEA(RAX, MEM(RAX,R8,4))
262 LEA(RBX, MEM(RBX,R9,4))
263
264 DEC(RDI)
265
266 JNZ(MAIN_LOOP)
267
268 TEST(RSI, RSI)
269 JZ(POSTACCUM)
270
271 LOOP_ALIGN
272 LABEL(TAIL_LOOP)
273
274 PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8))
275 SUBITER(0)
276
277 ADD(RAX, R8)
278 ADD(RBX, R9)
279
280 DEC(RSI)
281
282 JNZ(TAIL_LOOP)
283
284 LABEL(POSTACCUM)
285
286 MOV(RAX, VAR(alpha))
287 MOV(RBX, VAR(beta))
288 VBROADCASTSD(ZMM(0), MEM(RAX))
289 VBROADCASTSD(ZMM(1), MEM(RBX))
290
291 MOV(RAX, VAR(rs_c))
292 LEA(RAX, MEM(,RAX,8))
293 MOV(RBX, VAR(cs_c))
294
295 // Check if C is row stride. If not, jump to the slow scattered update
296 CMP(RBX, IMM(1))
297 JNE(SCATTEREDUPDATE)
298
299 VCOMISD(XMM(1), XMM(7))
300 JE(COLSTORBZ)
301
302 UPDATE_C( 8, 9,10,11)
303 UPDATE_C(12,13,14,15)
304 UPDATE_C(16,17,18,19)
305 UPDATE_C(20,21,22,23)
306 UPDATE_C(24,25,26,27)
307 UPDATE_C(28,29,30,31)
308
309 JMP(END)
310 LABEL(COLSTORBZ)
311
312 UPDATE_C_BZ( 8, 9,10,11)
313 UPDATE_C_BZ(12,13,14,15)
314 UPDATE_C_BZ(16,17,18,19)
315 UPDATE_C_BZ(20,21,22,23)
316 UPDATE_C_BZ(24,25,26,27)
317 UPDATE_C_BZ(28,29,30,31)
318
319 JMP(END)
320 LABEL(SCATTEREDUPDATE)
321
322 MOV(RDI, VAR(offsetPtr))
323 VMOVDQA64(ZMM(2), MEM(RDI,0*64))
324 VMOVDQA64(ZMM(3), MEM(RDI,1*64))
325 VMOVDQA64(ZMM(4), MEM(RDI,2*64))
326 VMOVDQA64(ZMM(5), MEM(RDI,3*64))
327 VPBROADCASTQ(ZMM(6), RBX)
328 VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
329 VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
330 VPMULLQ(ZMM(4), ZMM(6), ZMM(4))
331 VPMULLQ(ZMM(5), ZMM(6), ZMM(5))
332
333 VCOMISD(XMM(1), XMM(7))
334 JE(SCATTERBZ)
335
336 UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
337 UPDATE_C_ROW_SCATTERED(12,13,14,15)
338 UPDATE_C_ROW_SCATTERED(16,17,18,19)
339 UPDATE_C_ROW_SCATTERED(20,21,22,23)
340 UPDATE_C_ROW_SCATTERED(24,25,26,27)
341 UPDATE_C_ROW_SCATTERED(28,29,30,31)
342
343 JMP(END)
344 LABEL(SCATTERBZ)
345
346 UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
347 UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
348 UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
349 UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
350 UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
351 UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
352
353 LABEL(END)
354
355 VZEROUPPER()
356
357 : // output operands
358 : // input operands
359 [k] "m" (k),
360 [a] "m" (a),
361 [b] "m" (b),
362 [alpha] "m" (alpha),
363 [beta] "m" (beta),
364 [c] "m" (c),
365 [rs_c] "m" (rs_c),
366 [cs_c] "m" (cs_c),
367 [offsetPtr] "m" (offsetPtr)
368 : // register clobber list
369 "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
370 "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
371 "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13",
372 "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21",
373 "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
374 "zmm30", "zmm31", "memory"
375 );
376 }
377