1 /*
2 
3    BLIS
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
6 
7    Copyright (C) 2014, The University of Texas at Austin
8 
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions are
11    met:
12     - Redistributions of source code must retain the above copyright
13       notice, this list of conditions and the following disclaimer.
14     - Redistributions in binary form must reproduce the above copyright
15       notice, this list of conditions and the following disclaimer in the
16       documentation and/or other materials provided with the distribution.
17     - Neither the name of The University of Texas at Austin nor the names
18       of its contributors may be used to endorse or promote products
19       derived derived from this software without specific prior written permission.
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
25    OF TEXAS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
26    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
27    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
28    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
29    OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 
33 */
34 
35 #include "blis.h"
36 #include <assert.h>
37 
38 #include "../knl/bli_avx512_macros.h"
39 
40 #define A_L1_PREFETCH_DIST 4 //should be multiple of 2
41 
42 #define LOOP_ALIGN ALIGN16
43 
44 #define UPDATE_C(R1,R2,R3,R4) \
45 \
46     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
47     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
48     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
49     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
50     VFMADD231PD(ZMM(R1), ZMM(1), MEM(RCX,0*64)) \
51     VFMADD231PD(ZMM(R2), ZMM(1), MEM(RCX,1*64)) \
52     VFMADD231PD(ZMM(R3), ZMM(1), MEM(RCX,RAX,1,0*64)) \
53     VFMADD231PD(ZMM(R4), ZMM(1), MEM(RCX,RAX,1,1*64)) \
54     VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
55     VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
56     VMOVUPD(MEM(RCX,RAX,1,0*64), ZMM(R3)) \
57     VMOVUPD(MEM(RCX,RAX,1,1*64), ZMM(R4)) \
58     LEA(RCX, MEM(RCX,RAX,2))
59 
60 #define UPDATE_C_BZ(R1,R2,R3,R4) \
61 \
62     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
63     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
64     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
65     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
66     VMOVUPD(MEM(RCX,0*64), ZMM(R1)) \
67     VMOVUPD(MEM(RCX,1*64), ZMM(R2)) \
68     VMOVUPD(MEM(RCX,RAX,1,0*64), ZMM(R3)) \
69     VMOVUPD(MEM(RCX,RAX,1,1*64), ZMM(R4)) \
70     LEA(RCX, MEM(RCX,RAX,2))
71 
72 #define UPDATE_C_ROW_SCATTERED(R1,R2,R3,R4) \
73 \
74     KXNORW(K(1), K(0), K(0)) \
75     KXNORW(K(2), K(0), K(0)) \
76     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
77     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(2),8)) \
78     VFMADD231PD(ZMM(R1), ZMM(6), ZMM(1)) \
79     VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(2), ZMM(R1)) \
80 \
81     KXNORW(K(1), K(0), K(0)) \
82     KXNORW(K(2), K(0), K(0)) \
83     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
84     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(3),8)) \
85     VFMADD231PD(ZMM(R2), ZMM(6), ZMM(1)) \
86     VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(2), ZMM(R2)) \
87 \
88     LEA(RCX, MEM(RCX,RAX,1)) \
89 \
90     KXNORW(K(1), K(0), K(0)) \
91     KXNORW(K(2), K(0), K(0)) \
92     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
93     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(2),8)) \
94     VFMADD231PD(ZMM(R3), ZMM(6), ZMM(1)) \
95     VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(2), ZMM(R3)) \
96 \
97     KXNORW(K(1), K(0), K(0)) \
98     KXNORW(K(2), K(0), K(0)) \
99     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
100     VGATHERQPD(ZMM(6) MASK_K(1), MEM(RCX,ZMM(3),8)) \
101     VFMADD231PD(ZMM(R4), ZMM(6), ZMM(1)) \
102     VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(2), ZMM(R4)) \
103 \
104     LEA(RCX, MEM(RCX,RAX,1))
105 
106 #define UPDATE_C_BZ_ROW_SCATTERED(R1,R2,R3,R4) \
107 \
108     KXNORW(K(1), K(0), K(0)) \
109     VMULPD(ZMM(R1), ZMM(R1), ZMM(0)) \
110     VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(1), ZMM(R1)) \
111 \
112     KXNORW(K(1), K(0), K(0)) \
113     VMULPD(ZMM(R2), ZMM(R2), ZMM(0)) \
114     VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(1), ZMM(R2)) \
115 \
116     LEA(RCX, MEM(RCX,RAX,1)) \
117 \
118     KXNORW(K(1), K(0), K(0)) \
119     VMULPD(ZMM(R3), ZMM(R3), ZMM(0)) \
120     VSCATTERQPD(MEM(RCX,ZMM(2),8) MASK_K(1), ZMM(R3)) \
121 \
122     KXNORW(K(1), K(0), K(0)) \
123     VMULPD(ZMM(R4), ZMM(R4), ZMM(0)) \
124     VSCATTERQPD(MEM(RCX,ZMM(3),8) MASK_K(1), ZMM(R4)) \
125 \
126     LEA(RCX, MEM(RCX,RAX,1))
127 
128 #define PREFETCH_C_L1 \
129 \
130     PREFETCHW0(MEM(RCX,      0*64)) \
131     PREFETCHW0(MEM(RCX,      1*64)) \
132     PREFETCHW0(MEM(RCX,R12,1,0*64)) \
133     PREFETCHW0(MEM(RCX,R12,1,1*64)) \
134     PREFETCHW0(MEM(RCX,R12,2,0*64)) \
135     PREFETCHW0(MEM(RCX,R12,2,1*64)) \
136     PREFETCHW0(MEM(RCX,R13,1,0*64)) \
137     PREFETCHW0(MEM(RCX,R13,1,1*64)) \
138     PREFETCHW0(MEM(RCX,R12,4,0*64)) \
139     PREFETCHW0(MEM(RCX,R12,4,1*64)) \
140     PREFETCHW0(MEM(RCX,R14,1,0*64)) \
141     PREFETCHW0(MEM(RCX,R14,1,1*64)) \
142     PREFETCHW0(MEM(RCX,R13,2,0*64)) \
143     PREFETCHW0(MEM(RCX,R13,2,1*64)) \
144     PREFETCHW0(MEM(RCX,R15,1,0*64)) \
145     PREFETCHW0(MEM(RCX,R15,1,1*64)) \
146     PREFETCHW0(MEM(RDX,      0*64)) \
147     PREFETCHW0(MEM(RDX,      1*64)) \
148     PREFETCHW0(MEM(RDX,R12,1,0*64)) \
149     PREFETCHW0(MEM(RDX,R12,1,1*64)) \
150     PREFETCHW0(MEM(RDX,R12,2,0*64)) \
151     PREFETCHW0(MEM(RDX,R12,2,1*64)) \
152     PREFETCHW0(MEM(RDX,R13,1,0*64)) \
153     PREFETCHW0(MEM(RDX,R13,1,1*64))
154 
155 //
156 // n: index in unrolled loop
157 //
158 // a: ZMM register to load into
159 // b: ZMM register to read from
160 //
161 // ...: addressing for A, except for offset
162 //
163 #define SUBITER(n) \
164 \
165     VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 0)*8)) \
166     VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 1)*8)) \
167     VFMADD231PD(ZMM( 8), ZMM(0), ZMM(3)) \
168     VFMADD231PD(ZMM( 9), ZMM(1), ZMM(3)) \
169     VFMADD231PD(ZMM(10), ZMM(0), ZMM(4)) \
170     VFMADD231PD(ZMM(11), ZMM(1), ZMM(4)) \
171     \
172     VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 2)*8)) \
173     VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 3)*8)) \
174     VFMADD231PD(ZMM(12), ZMM(0), ZMM(3)) \
175     VFMADD231PD(ZMM(13), ZMM(1), ZMM(3)) \
176     VFMADD231PD(ZMM(14), ZMM(0), ZMM(4)) \
177     VFMADD231PD(ZMM(15), ZMM(1), ZMM(4)) \
178     \
179     VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 4)*8)) \
180     VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 5)*8)) \
181     VFMADD231PD(ZMM(16), ZMM(0), ZMM(3)) \
182     VFMADD231PD(ZMM(17), ZMM(1), ZMM(3)) \
183     VFMADD231PD(ZMM(18), ZMM(0), ZMM(4)) \
184     VFMADD231PD(ZMM(19), ZMM(1), ZMM(4)) \
185     \
186     VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 6)*8)) \
187     VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 7)*8)) \
188     VFMADD231PD(ZMM(20), ZMM(0), ZMM(3)) \
189     VFMADD231PD(ZMM(21), ZMM(1), ZMM(3)) \
190     VFMADD231PD(ZMM(22), ZMM(0), ZMM(4)) \
191     VFMADD231PD(ZMM(23), ZMM(1), ZMM(4)) \
192     \
193     VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+ 8)*8)) \
194     VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+ 9)*8)) \
195     VFMADD231PD(ZMM(24), ZMM(0), ZMM(3)) \
196     VFMADD231PD(ZMM(25), ZMM(1), ZMM(3)) \
197     VFMADD231PD(ZMM(26), ZMM(0), ZMM(4)) \
198     VFMADD231PD(ZMM(27), ZMM(1), ZMM(4)) \
199     \
200     VBROADCASTSD(ZMM(3), MEM(RAX,(12*n+10)*8)) \
201     VBROADCASTSD(ZMM(4), MEM(RAX,(12*n+11)*8)) \
202     VFMADD231PD(ZMM(28), ZMM(0), ZMM(3)) \
203     VFMADD231PD(ZMM(29), ZMM(1), ZMM(3)) \
204     VFMADD231PD(ZMM(30), ZMM(0), ZMM(4)) \
205     VFMADD231PD(ZMM(31), ZMM(1), ZMM(4)) \
206     \
207     VMOVAPD(ZMM(0), MEM(RBX,(16*n+0)*8)) \
208     VMOVAPD(ZMM(1), MEM(RBX,(16*n+8)*8))
209 
210 //This is an array used for the scatter/gather instructions.
211 static int64_t offsets[16] __attribute__((aligned(64))) =
212     { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
213 
bli_dgemm_opt_12x16_l1(dim_t k_,double * restrict alpha,double * restrict a,double * restrict b,double * restrict beta,double * restrict c,inc_t rs_c_,inc_t cs_c_,auxinfo_t * data,cntx_t * restrict cntx)214 void bli_dgemm_opt_12x16_l1(
215                              dim_t            k_,
216                              double* restrict alpha,
217                              double* restrict a,
218                              double* restrict b,
219                              double* restrict beta,
220                              double* restrict c, inc_t rs_c_, inc_t cs_c_,
221                              auxinfo_t*       data,
222                              cntx_t* restrict cntx
223                            )
224 {
225     (void)data;
226     (void)cntx;
227 
228     const int64_t* offsetPtr = &offsets[0];
229     const int64_t k = k_;
230     const int64_t rs_c = rs_c_;
231     const int64_t cs_c = cs_c_;
232 
233     __asm__ volatile
234     (
235 
236     VXORPD(YMM(8), YMM(8), YMM(8)) //clear out registers
237     VMOVAPD(YMM( 7), YMM(8))
238     VMOVAPD(YMM( 9), YMM(8))
239     VMOVAPD(YMM(10), YMM(8))   MOV(RSI, VAR(k)) //loop index
240     VMOVAPD(YMM(11), YMM(8))   MOV(RAX, VAR(a)) //load address of a
241     VMOVAPD(YMM(12), YMM(8))   MOV(RBX, VAR(b)) //load address of b
242     VMOVAPD(YMM(13), YMM(8))   MOV(RCX, VAR(c)) //load address of c
243     VMOVAPD(YMM(14), YMM(8))
244     VMOVAPD(YMM(15), YMM(8))   VMOVAPD(ZMM(0), MEM(RBX, 0*8)) //pre-load b
245     VMOVAPD(YMM(16), YMM(8))   VMOVAPD(ZMM(1), MEM(RBX, 8*8)) //pre-load b
246     VMOVAPD(YMM(17), YMM(8))
247     VMOVAPD(YMM(18), YMM(8))
248     VMOVAPD(YMM(19), YMM(8))   MOV(R12, VAR(rs_c))      //rs_c
249     VMOVAPD(YMM(20), YMM(8))   LEA(R13, MEM(R12,R12,2)) //*3
250     VMOVAPD(YMM(21), YMM(8))   LEA(R14, MEM(R12,R12,4)) //*5
251     VMOVAPD(YMM(22), YMM(8))   LEA(R15, MEM(R14,R12,2)) //*7
252     VMOVAPD(YMM(23), YMM(8))   LEA(RDX, MEM(RCX,R12,8)) //c + 8*rs_c
253     VMOVAPD(YMM(24), YMM(8))
254     VMOVAPD(YMM(25), YMM(8))   MOV(R8, IMM(12*8)) //mr*sizeof(double)
255     VMOVAPD(YMM(26), YMM(8))   MOV(R9, IMM(16*8)) //nr*sizeof(double)
256     VMOVAPD(YMM(27), YMM(8))
257     VMOVAPD(YMM(28), YMM(8))   LEA(RBX, MEM(RBX,R9,1)) //adjust b for pre-load
258     VMOVAPD(YMM(29), YMM(8))
259     VMOVAPD(YMM(30), YMM(8))
260     VMOVAPD(YMM(31), YMM(8))
261 
262     TEST(RSI, RSI)
263     JZ(POSTACCUM)
264 
265     PREFETCH_C_L1
266 
267     MOV(RDI, RSI)
268     AND(RSI, IMM(3))
269     SAR(RDI, IMM(2))
270     JZ(TAIL_LOOP)
271 
272     LOOP_ALIGN
273     LABEL(MAIN_LOOP)
274 
275         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8))
276         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+64))
277         SUBITER(0)
278         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+128))
279         SUBITER(1)
280         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+192))
281         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+256))
282         SUBITER(2)
283         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+320))
284         SUBITER(3)
285 
286         LEA(RAX, MEM(RAX,R8,4))
287         LEA(RBX, MEM(RBX,R9,4))
288 
289         DEC(RDI)
290 
291     JNZ(MAIN_LOOP)
292 
293     TEST(RSI, RSI)
294     JZ(POSTACCUM)
295 
296     LOOP_ALIGN
297     LABEL(TAIL_LOOP)
298 
299         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8))
300         PREFETCH(0, MEM(RAX,A_L1_PREFETCH_DIST*12*8+64))
301         SUBITER(0)
302 
303         ADD(RAX, R8)
304         ADD(RBX, R9)
305 
306         DEC(RSI)
307 
308     JNZ(TAIL_LOOP)
309 
310     LABEL(POSTACCUM)
311 
312     MOV(RAX, VAR(alpha))
313     MOV(RBX, VAR(beta))
314     VBROADCASTSD(ZMM(0), MEM(RAX))
315     VBROADCASTSD(ZMM(1), MEM(RBX))
316 
317     MOV(RAX, VAR(rs_c))
318     LEA(RAX, MEM(,RAX,8))
319     MOV(RBX, VAR(cs_c))
320 
321     // Check if C is row stride. If not, jump to the slow scattered update
322     CMP(RBX, IMM(1))
323     JNE(SCATTEREDUPDATE)
324 
325         VCOMISD(XMM(1), XMM(7))
326         JE(COLSTORBZ)
327 
328             UPDATE_C( 8, 9,10,11)
329             UPDATE_C(12,13,14,15)
330             UPDATE_C(16,17,18,19)
331             UPDATE_C(20,21,22,23)
332             UPDATE_C(24,25,26,27)
333             UPDATE_C(28,29,30,31)
334 
335         JMP(END)
336         LABEL(COLSTORBZ)
337 
338             UPDATE_C_BZ( 8, 9,10,11)
339             UPDATE_C_BZ(12,13,14,15)
340             UPDATE_C_BZ(16,17,18,19)
341             UPDATE_C_BZ(20,21,22,23)
342             UPDATE_C_BZ(24,25,26,27)
343             UPDATE_C_BZ(28,29,30,31)
344 
345     JMP(END)
346     LABEL(SCATTEREDUPDATE)
347 
348         MOV(RDI, VAR(offsetPtr))
349         VMOVDQA64(ZMM(2), MEM(RDI,0*64))
350         VMOVDQA64(ZMM(3), MEM(RDI,1*64))
351         VPBROADCASTQ(ZMM(6), RBX)
352         VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
353         VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
354 
355         VCOMISD(XMM(1), XMM(7))
356         JE(SCATTERBZ)
357 
358             UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
359             UPDATE_C_ROW_SCATTERED(12,13,14,15)
360             UPDATE_C_ROW_SCATTERED(16,17,18,19)
361             UPDATE_C_ROW_SCATTERED(20,21,22,23)
362             UPDATE_C_ROW_SCATTERED(24,25,26,27)
363             UPDATE_C_ROW_SCATTERED(28,29,30,31)
364 
365         JMP(END)
366         LABEL(SCATTERBZ)
367 
368             UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
369             UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
370             UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
371             UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
372             UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
373             UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
374 
375     LABEL(END)
376 
377     VZEROUPPER()
378 
379     : // output operands
380     : // input operands
381       [k]         "m" (k),
382       [a]         "m" (a),
383       [b]         "m" (b),
384       [alpha]     "m" (alpha),
385       [beta]      "m" (beta),
386       [c]         "m" (c),
387       [rs_c]      "m" (rs_c),
388       [cs_c]      "m" (cs_c),
389       [offsetPtr] "m" (offsetPtr)
390     : // register clobber list
391       "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
392       "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
393       "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13",
394       "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21",
395       "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
396       "zmm30", "zmm31", "memory"
397     );
398 }
399