1 /*
2 
3    BLIS
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
6 
7    Copyright (C) 2014, The University of Texas at Austin
8 
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions are
11    met:
12     - Redistributions of source code must retain the above copyright
13       notice, this list of conditions and the following disclaimer.
14     - Redistributions in binary form must reproduce the above copyright
15       notice, this list of conditions and the following disclaimer in the
16       documentation and/or other materials provided with the distribution.
17     - Neither the name of The University of Texas at Austin nor the names
18       of its contributors may be used to endorse or promote products
19       derived derived from this software without specific prior written permission.
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
25    OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
29    OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 
33 */
34 
35 #include "blis.h"
36 #include <assert.h>
37 
38 #include "../knl/bli_avx512_macros.h"
39 
40 #define A_L1_PREFETCH_DIST 4 //should be multiple of 4
41 
42 #define LOOP_ALIGN ALIGN16
43 
44 #define UPDATE_C(R1,R2,R3,R4) \
45 \
46     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
47     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
48     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
49     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
50     VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX,0*64)) \
51     VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX,1*64)) \
52     VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX,2*64)) \
53     VFMADD231PD(ZMM(R4), ZMM(1), MEM(RCX,3*64)) \
54     VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
55     VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
56     VMOVUPD(MEM(RCX,2*64), ZMM(R3)) \
57     VMOVUPD(MEM(RCX,3*64), ZMM(R4)) \
58     LEA(RCX, MEM(RCX,RAX,1))
59 
60 #define UPDATE_C_BZ(R1,R2,R3,R4) \
61 \
62     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
63     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
64     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
65     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
66     VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
67     VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
68     VMOVUPD(MEM(RCX,2*64), ZMM(R3)) \
69     VMOVUPD(MEM(RCX,3*64), ZMM(R4)) \
70     LEA(RCX, MEM(RCX,RAX,1))
71 
72 #define UPDATE_C_ROW_SCATTERED(R1,R2,R3,R4) \
73 \
74     KXNORW(K(1), K(0), K(0)) \
75     KXNORW(K(2), K(0), K(0)) \
76     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
77     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(2),8)) \
78     VFMADD231PD(ZMM(R1), ZMM(6), ZMM(1)) \
79     VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(2), ZMM(R1)) \
80 \
81     KXNORW(K(1), K(0), K(0)) \
82     KXNORW(K(2), K(0), K(0)) \
83     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
84     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(3),8)) \
85     VFMADD231PD(ZMM(R2), ZMM(6), ZMM(1)) \
86     VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(2), ZMM(R2)) \
87 \
88     KXNORW(K(1), K(0), K(0)) \
89     KXNORW(K(2), K(0), K(0)) \
90     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
91     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(4),8)) \
92     VFMADD231PD(ZMM(R3), ZMM(6), ZMM(1)) \
93     VSCATTERQPD(MEM(RCX,ZMM(4),8) MASK_K(2), ZMM(R3)) \
94 \
95     KXNORW(K(1), K(0), K(0)) \
96     KXNORW(K(2), K(0), K(0)) \
97     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
98     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(5),8)) \
99     VFMADD231PD(ZMM(R4), ZMM(6), ZMM(1)) \
100     VSCATTERQPD(MEM(RCX,ZMM(5),8) MASK_K(2), ZMM(R4)) \
101 \
102     LEA(RCX, MEM(RCX,RAX,1))
103 
104 #define UPDATE_C_BZ_ROW_SCATTERED(R1,R2,R3,R4) \
105 \
106     KXNORW(K(1), K(0), K(0)) \
107     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
108     VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(1), ZMM(R1)) \
109 \
110     KXNORW(K(1), K(0), K(0)) \
111     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
112     VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(1), ZMM(R2)) \
113 \
114     KXNORW(K(1), K(0), K(0)) \
115     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
116     VSCATTERQPD(MEM(RCX,ZMM(4),8) MASK_K(1), ZMM(R3)) \
117 \
118     KXNORW(K(1), K(0), K(0)) \
119     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
120     VSCATTERQPD(MEM(RCX,ZMM(5),8) MASK_K(1), ZMM(R4)) \
121 \
122     LEA(RCX, MEM(RCX,RAX,1))
123 
124 #define PREFETCH_C_L1 \
125 \
126     PREFETCHW0(MEM(RCX,      0*64)) \
127     PREFETCHW0(MEM(RCX,      1*64)) \
128     PREFETCHW0(MEM(RCX,      2*64)) \
129     PREFETCHW0(MEM(RCX,      3*64)) \
130     PREFETCHW0(MEM(RCX,R12,1,0*64)) \
131     PREFETCHW0(MEM(RCX,R12,1,1*64)) \
132     PREFETCHW0(MEM(RCX,R12,1,2*64)) \
133     PREFETCHW0(MEM(RCX,R12,1,3*64)) \
134     PREFETCHW0(MEM(RCX,R12,2,0*64)) \
135     PREFETCHW0(MEM(RCX,R12,2,1*64)) \
136     PREFETCHW0(MEM(RCX,R12,2,2*64)) \
137     PREFETCHW0(MEM(RCX,R12,2,3*64)) \
138     PREFETCHW0(MEM(RCX,R13,1,0*64)) \
139     PREFETCHW0(MEM(RCX,R13,1,1*64)) \
140     PREFETCHW0(MEM(RCX,R13,1,2*64)) \
141     PREFETCHW0(MEM(RCX,R13,1,3*64)) \
142     PREFETCHW0(MEM(RCX,R12,4,0*64)) \
143     PREFETCHW0(MEM(RCX,R12,4,1*64)) \
144     PREFETCHW0(MEM(RCX,R12,4,2*64)) \
145     PREFETCHW0(MEM(RCX,R12,4,3*64)) \
146     PREFETCHW0(MEM(RCX,R14,1,0*64)) \
147     PREFETCHW0(MEM(RCX,R14,1,1*64)) \
148     PREFETCHW0(MEM(RCX,R14,1,2*64)) \
149     PREFETCHW0(MEM(RCX,R14,1,3*64)) \
150 
151 //
152 // n: index in unrolled loop
153 //
154 // a: ZMM register to load into
155 // b: ZMM register to read from
156 //
157 // ...: addressing for A, except for offset
158 //
159 #define SUBITER(n) \
160 \
161     VBROADCASTSD(ZMM(4), MEM(RAX,(6*n+0)*8)) \
162     VBROADCASTSD(ZMM(5), MEM(RAX,(6*n+1)*8)) \
163     VFMADD231PD(ZMM( 8), ZMM(0), ZMM(4))  VFMADD231PD(ZMM(12), ZMM(0), ZMM(5)) \
164     VFMADD231PD(ZMM( 9), ZMM(1), ZMM(4))  VFMADD231PD(ZMM(13), ZMM(1), ZMM(5)) \
165     VFMADD231PD(ZMM(10), ZMM(2), ZMM(4))  VFMADD231PD(ZMM(14), ZMM(2), ZMM(5)) \
166     VFMADD231PD(ZMM(11), ZMM(3), ZMM(4))  VFMADD231PD(ZMM(15), ZMM(3), ZMM(5)) \
167     \
168     VBROADCASTSD(ZMM(4), MEM(RAX,(6*n+2)*8)) \
169     VBROADCASTSD(ZMM(5), MEM(RAX,(6*n+3)*8)) \
170     VFMADD231PD(ZMM(16), ZMM(0), ZMM(4))  VFMADD231PD(ZMM(20), ZMM(0), ZMM(5)) \
171     VFMADD231PD(ZMM(17), ZMM(1), ZMM(4))  VFMADD231PD(ZMM(21), ZMM(1), ZMM(5)) \
172     VFMADD231PD(ZMM(18), ZMM(2), ZMM(4))  VFMADD231PD(ZMM(22), ZMM(2), ZMM(5)) \
173     VFMADD231PD(ZMM(19), ZMM(3), ZMM(4))  VFMADD231PD(ZMM(23), ZMM(3), ZMM(5)) \
174     \
175     VBROADCASTSD(ZMM(4), MEM(RAX,(6*n+4)*8)) \
176     VBROADCASTSD(ZMM(5), MEM(RAX,(6*n+5)*8)) \
177     VFMADD231PD(ZMM(24), ZMM(0), ZMM(4))  VFMADD231PD(ZMM(28), ZMM(0), ZMM(5)) \
178     VFMADD231PD(ZMM(25), ZMM(1), ZMM(4))  VFMADD231PD(ZMM(29), ZMM(1), ZMM(5)) \
179     VFMADD231PD(ZMM(26), ZMM(2), ZMM(4))  VFMADD231PD(ZMM(30), ZMM(2), ZMM(5)) \
180     VFMADD231PD(ZMM(27), ZMM(3), ZMM(4))  VFMADD231PD(ZMM(31), ZMM(3), ZMM(5)) \
181     \
182     VMOVAPD(ZMM(0), MEM(RBX,(32*n+ 0)*8)) \
183     VMOVAPD(ZMM(1), MEM(RBX,(32*n+ 8)*8)) \
184     VMOVAPD(ZMM(2), MEM(RBX,(32*n+16)*8)) \
185     VMOVAPD(ZMM(3), MEM(RBX,(32*n+24)*8))
186 
187 //This is an array used for the scatter/gather instructions.
188 static int64_t offsets[32] __attribute__((aligned(64))) =
189     { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15,
190      16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
191 
bli_dgemm_opt_6x32_l1(dim_t k_,double * restrict alpha,double * restrict a,double * restrict b,double * restrict beta,double * restrict c,inc_t rs_c_,inc_t cs_c_,auxinfo_t * data,cntx_t * restrict cntx)192 void bli_dgemm_opt_6x32_l1(
193                              dim_t            k_,
194                              double* restrict alpha,
195                              double* restrict a,
196                              double* restrict b,
197                              double* restrict beta,
198                              double* restrict c, inc_t rs_c_, inc_t cs_c_,
199                              auxinfo_t*       data,
200                              cntx_t* restrict cntx
201                            )
202 {
203     (void)data;
204     (void)cntx;
205 
206     const int64_t* offsetPtr = &offsets[0];
207     const int64_t k = k_;
208     const int64_t rs_c = rs_c_;
209     const int64_t cs_c = cs_c_;
210 
211     __asm__ volatile
212     (
213 
214     VXORPD(YMM(8), YMM(8), YMM(8)) //clear out registers
215     VMOVAPD(YMM( 7), YMM(8))
216     VMOVAPD(YMM( 9), YMM(8))
217     VMOVAPD(YMM(10), YMM(8))   MOV(RSI, VAR(k)) //loop index
218     VMOVAPD(YMM(11), YMM(8))   MOV(RAX, VAR(a)) //load address of a
219     VMOVAPD(YMM(12), YMM(8))   MOV(RBX, VAR(b)) //load address of b
220     VMOVAPD(YMM(13), YMM(8))   MOV(RCX, VAR(c)) //load address of c
221     VMOVAPD(YMM(14), YMM(8))
222     VMOVAPD(YMM(15), YMM(8))   VMOVAPD(ZMM(0), MEM(RBX, 0*8)) //pre-load b
223     VMOVAPD(YMM(16), YMM(8))   VMOVAPD(ZMM(1), MEM(RBX, 8*8)) //pre-load b
224     VMOVAPD(YMM(17), YMM(8))   VMOVAPD(ZMM(2), MEM(RBX,16*8)) //pre-load b
225     VMOVAPD(YMM(18), YMM(8))   VMOVAPD(ZMM(3), MEM(RBX,24*8)) //pre-load b
226     VMOVAPD(YMM(19), YMM(8))
227     VMOVAPD(YMM(20), YMM(8))
228     VMOVAPD(YMM(21), YMM(8))   MOV(R12, VAR(rs_c))      //rs_c
229     VMOVAPD(YMM(22), YMM(8))   LEA(R13, MEM(R12,R12,2)) //*3
230     VMOVAPD(YMM(23), YMM(8))   LEA(R14, MEM(R12,R12,4)) //*5
231     VMOVAPD(YMM(24), YMM(8))
232     VMOVAPD(YMM(25), YMM(8))   MOV(R8, IMM( 6*8)) //mr*sizeof(double)
233     VMOVAPD(YMM(26), YMM(8))   MOV(R9, IMM(32*8)) //nr*sizeof(double)
234     VMOVAPD(YMM(27), YMM(8))
235     VMOVAPD(YMM(28), YMM(8))   LEA(RBX, MEM(RBX,R9,1)) //adjust b for pre-load
236     VMOVAPD(YMM(29), YMM(8))
237     VMOVAPD(YMM(30), YMM(8))
238     VMOVAPD(YMM(31), YMM(8))
239 
240     TEST(RSI, RSI)
241     JZ(POSTACCUM)
242 
243     PREFETCH_C_L1
244 
245     MOV(RDI, RSI)
246     AND(RSI, IMM(3))
247     SAR(RDI, IMM(2))
248     JZ(TAIL_LOOP)
249 
250     LOOP_ALIGN
251     LABEL(MAIN_LOOP)
252 
253         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8))
254         SUBITER(0)
255         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8+64))
256         SUBITER(1)
257         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8+128))
258         SUBITER(2)
259         SUBITER(3)
260 
261         LEA(RAX, MEM(RAX,R8,4))
262         LEA(RBX, MEM(RBX,R9,4))
263 
264         DEC(RDI)
265 
266     JNZ(MAIN_LOOP)
267 
268     TEST(RSI, RSI)
269     JZ(POSTACCUM)
270 
271     LOOP_ALIGN
272     LABEL(TAIL_LOOP)
273 
274         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*6*8))
275         SUBITER(0)
276 
277         ADD(RAX, R8)
278         ADD(RBX, R9)
279 
280         DEC(RSI)
281 
282     JNZ(TAIL_LOOP)
283 
284     LABEL(POSTACCUM)
285 
286     MOV(RAX, VAR(alpha))
287     MOV(RBX, VAR(beta))
288     VBROADCASTSD(ZMM(0), MEM(RAX))
289     VBROADCASTSD(ZMM(1), MEM(RBX))
290 
291     MOV(RAX, VAR(rs_c))
292     LEA(RAX, MEM(,RAX,8))
293     MOV(RBX, VAR(cs_c))
294 
295     // Check if C is row stride. If not, jump to the slow scattered update
296     CMP(RBX, IMM(1))
297     JNE(SCATTEREDUPDATE)
298 
299         VCOMISD(XMM(1), XMM(7))
300         JE(COLSTORBZ)
301 
302             UPDATE_C( 8, 9,10,11)
303             UPDATE_C(12,13,14,15)
304             UPDATE_C(16,17,18,19)
305             UPDATE_C(20,21,22,23)
306             UPDATE_C(24,25,26,27)
307             UPDATE_C(28,29,30,31)
308 
309         JMP(END)
310         LABEL(COLSTORBZ)
311 
312             UPDATE_C_BZ( 8, 9,10,11)
313             UPDATE_C_BZ(12,13,14,15)
314             UPDATE_C_BZ(16,17,18,19)
315             UPDATE_C_BZ(20,21,22,23)
316             UPDATE_C_BZ(24,25,26,27)
317             UPDATE_C_BZ(28,29,30,31)
318 
319     JMP(END)
320     LABEL(SCATTEREDUPDATE)
321 
322         MOV(RDI, VAR(offsetPtr))
323         VMOVDQA64(ZMM(2), MEM(RDI,0*64))
324         VMOVDQA64(ZMM(3), MEM(RDI,1*64))
325         VMOVDQA64(ZMM(4), MEM(RDI,2*64))
326         VMOVDQA64(ZMM(5), MEM(RDI,3*64))
327         VPBROADCASTQ(ZMM(6), RBX)
328         VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
329         VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
330         VPMULLQ(ZMM(4), ZMM(6), ZMM(4))
331         VPMULLQ(ZMM(5), ZMM(6), ZMM(5))
332 
333         VCOMISD(XMM(1), XMM(7))
334         JE(SCATTERBZ)
335 
336             UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
337             UPDATE_C_ROW_SCATTERED(12,13,14,15)
338             UPDATE_C_ROW_SCATTERED(16,17,18,19)
339             UPDATE_C_ROW_SCATTERED(20,21,22,23)
340             UPDATE_C_ROW_SCATTERED(24,25,26,27)
341             UPDATE_C_ROW_SCATTERED(28,29,30,31)
342 
343         JMP(END)
344         LABEL(SCATTERBZ)
345 
346             UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
347             UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
348             UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
349             UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
350             UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
351             UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
352 
353     LABEL(END)
354 
355     VZEROUPPER()
356 
357     : // output operands
358     : // input operands
359       [k]         "m" (k),
360       [a]         "m" (a),
361       [b]         "m" (b),
362       [alpha]     "m" (alpha),
363       [beta]      "m" (beta),
364       [c]         "m" (c),
365       [rs_c]      "m" (rs_c),
366       [cs_c]      "m" (cs_c),
367       [offsetPtr] "m" (offsetPtr)
368     : // register clobber list
369       "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
370       "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
371       "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13",
372       "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21",
373       "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
374       "zmm30", "zmm31", "memory"
375     );
376 }
377