1 /*********************************************************************************
2 Copyright (c) 2013, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 **********************************************************************************/
27 
28 
29 /* comment below left for history, data does not represent the implementation in this file */
30 
31 /*********************************************************************
32 * 2014/07/28 Saar
33 *        BLASTEST               : OK
34 *        CTEST                  : OK
35 *        TEST                   : OK
36 *
37 * 2013/10/28 Saar
38 * Parameter:
39 *	SGEMM_DEFAULT_UNROLL_N	4
40 *	SGEMM_DEFAULT_UNROLL_M	16
41 *	SGEMM_DEFAULT_P		768
42 *	SGEMM_DEFAULT_Q		384
43 *	A_PR1			512
44 *	B_PR1			512
45 *
46 *
47 * 2014/07/28 Saar
48 * Performance at 9216x9216x9216:
49 *       1 thread:      102 GFLOPS       (SANDYBRIDGE:  59)      (MKL:   83)
50 *       2 threads:     195 GFLOPS       (SANDYBRIDGE: 116)      (MKL:  155)
51 *       3 threads:     281 GFLOPS       (SANDYBRIDGE: 165)      (MKL:  230)
52 *       4 threads:     366 GFLOPS       (SANDYBRIDGE: 223)      (MKL:  267)
53 *
54 *********************************************************************/
55 
56 #include "common.h"
57 #include <immintrin.h>
58 
59 
60 
61 /*******************************************************************************************
62 * 8 lines of N
63 *******************************************************************************************/
64 
65 
66 
67 
68 
69 
70 /*******************************************************************************************
71 * 4 lines of N
72 *******************************************************************************************/
73 
74 #define INIT64x4()	\
75 	row0 = _mm512_setzero_ps();					\
76 	row1 = _mm512_setzero_ps();					\
77 	row2 = _mm512_setzero_ps();					\
78 	row3 = _mm512_setzero_ps();					\
79 	row0b = _mm512_setzero_ps();					\
80 	row1b = _mm512_setzero_ps();					\
81 	row2b = _mm512_setzero_ps();					\
82 	row3b = _mm512_setzero_ps();					\
83 	row0c = _mm512_setzero_ps();					\
84 	row1c = _mm512_setzero_ps();					\
85 	row2c = _mm512_setzero_ps();					\
86 	row3c = _mm512_setzero_ps();					\
87 	row0d = _mm512_setzero_ps();					\
88 	row1d = _mm512_setzero_ps();					\
89 	row2d = _mm512_setzero_ps();					\
90 	row3d = _mm512_setzero_ps();					\
91 
92 #define KERNEL64x4_SUB() 						\
93 	zmm0   = _mm512_loadu_ps(AO);					\
94 	zmm1   = _mm512_loadu_ps(A1);					\
95 	zmm5   = _mm512_loadu_ps(A2);					\
96 	zmm7   = _mm512_loadu_ps(A3);					\
97 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO));		\
98 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+1));		\
99 	row0  += zmm0 * zmm2;						\
100 	row1  += zmm0 * zmm3;						\
101 	row0b += zmm1 * zmm2;						\
102 	row1b += zmm1 * zmm3;						\
103 	row0c += zmm5 * zmm2;						\
104 	row1c += zmm5 * zmm3;						\
105 	row0d += zmm7 * zmm2;						\
106 	row1d += zmm7 * zmm3;						\
107 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO+2));		\
108 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+3));		\
109 	row2  += zmm0 * zmm2;						\
110 	row3 += zmm0 * zmm3;						\
111 	row2b += zmm1 * zmm2;						\
112 	row3b += zmm1 * zmm3;						\
113 	row2c += zmm5 * zmm2;						\
114 	row3c += zmm5 * zmm3;						\
115 	row2d += zmm7 * zmm2;						\
116 	row3d += zmm7 * zmm3;						\
117 	BO += 4;							\
118 	AO += 16;							\
119 	A1 += 16;							\
120 	A2 += 16;							\
121 	A3 += 16;							\
122 
123 
124 #define SAVE64x4(ALPHA)							\
125 	zmm0   = _mm512_set1_ps(ALPHA);					\
126 	row0  *= zmm0;							\
127 	row1  *= zmm0;							\
128 	row2  *= zmm0;							\
129 	row3 *= zmm0;							\
130 	row0b *= zmm0;							\
131 	row1b *= zmm0;							\
132 	row2b *= zmm0;							\
133 	row3b *= zmm0;							\
134 	row0c *= zmm0;							\
135 	row1c *= zmm0;							\
136 	row2c *= zmm0;							\
137 	row3c *= zmm0;							\
138 	row0d *= zmm0;							\
139 	row1d *= zmm0;							\
140 	row2d *= zmm0;							\
141 	row3d *= zmm0;							\
142 	row0  += _mm512_loadu_ps(CO1 + 0*ldc);				\
143 	row1  += _mm512_loadu_ps(CO1 + 1*ldc);				\
144 	row2  += _mm512_loadu_ps(CO1 + 2*ldc);				\
145 	row3 += _mm512_loadu_ps(CO1 + 3*ldc);				\
146 	_mm512_storeu_ps(CO1 + 0*ldc, row0);				\
147 	_mm512_storeu_ps(CO1 + 1*ldc, row1);				\
148 	_mm512_storeu_ps(CO1 + 2*ldc, row2);				\
149 	_mm512_storeu_ps(CO1 + 3*ldc, row3);				\
150 	row0b  += _mm512_loadu_ps(CO1 + 0*ldc + 16);			\
151 	row1b  += _mm512_loadu_ps(CO1 + 1*ldc + 16);			\
152 	row2b  += _mm512_loadu_ps(CO1 + 2*ldc + 16);			\
153 	row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16);			\
154 	_mm512_storeu_ps(CO1 + 0*ldc + 16, row0b);			\
155 	_mm512_storeu_ps(CO1 + 1*ldc + 16, row1b);			\
156 	_mm512_storeu_ps(CO1 + 2*ldc + 16, row2b);			\
157 	_mm512_storeu_ps(CO1 + 3*ldc + 16, row3b);			\
158 	row0c  += _mm512_loadu_ps(CO1 + 0*ldc + 32);			\
159 	row1c  += _mm512_loadu_ps(CO1 + 1*ldc + 32);			\
160 	row2c  += _mm512_loadu_ps(CO1 + 2*ldc + 32);			\
161 	row3c  += _mm512_loadu_ps(CO1 + 3*ldc + 32);			\
162 	_mm512_storeu_ps(CO1 + 0*ldc + 32, row0c);			\
163 	_mm512_storeu_ps(CO1 + 1*ldc + 32, row1c);			\
164 	_mm512_storeu_ps(CO1 + 2*ldc + 32, row2c);			\
165 	_mm512_storeu_ps(CO1 + 3*ldc + 32, row3c);			\
166 	row0d  += _mm512_loadu_ps(CO1 + 0*ldc + 48);			\
167 	row1d  += _mm512_loadu_ps(CO1 + 1*ldc + 48);			\
168 	row2d  += _mm512_loadu_ps(CO1 + 2*ldc + 48);			\
169 	row3d  += _mm512_loadu_ps(CO1 + 3*ldc + 48);			\
170 	_mm512_storeu_ps(CO1 + 0*ldc + 48, row0d);			\
171 	_mm512_storeu_ps(CO1 + 1*ldc + 48, row1d);			\
172 	_mm512_storeu_ps(CO1 + 2*ldc + 48, row2d);			\
173 	_mm512_storeu_ps(CO1 + 3*ldc + 48, row3d);
174 
175 
176 #define INIT48x4()	\
177 	row0 = _mm512_setzero_ps();					\
178 	row1 = _mm512_setzero_ps();					\
179 	row2 = _mm512_setzero_ps();					\
180 	row3 = _mm512_setzero_ps();					\
181 	row0b = _mm512_setzero_ps();					\
182 	row1b = _mm512_setzero_ps();					\
183 	row2b = _mm512_setzero_ps();					\
184 	row3b = _mm512_setzero_ps();					\
185 	row0c = _mm512_setzero_ps();					\
186 	row1c = _mm512_setzero_ps();					\
187 	row2c = _mm512_setzero_ps();					\
188 	row3c = _mm512_setzero_ps();					\
189 
190 #define KERNEL48x4_SUB() 						\
191 	zmm0   = _mm512_loadu_ps(AO);					\
192 	zmm1   = _mm512_loadu_ps(A1);					\
193 	zmm5   = _mm512_loadu_ps(A2);					\
194 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO));		\
195 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+1));		\
196 	row0  += zmm0 * zmm2;						\
197 	row1  += zmm0 * zmm3;						\
198 	row0b += zmm1 * zmm2;						\
199 	row1b += zmm1 * zmm3;						\
200 	row0c += zmm5 * zmm2;						\
201 	row1c += zmm5 * zmm3;						\
202 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO+2));		\
203 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+3));		\
204 	row2  += zmm0 * zmm2;						\
205 	row3 += zmm0 * zmm3;						\
206 	row2b += zmm1 * zmm2;						\
207 	row3b += zmm1 * zmm3;						\
208 	row2c += zmm5 * zmm2;						\
209 	row3c += zmm5 * zmm3;						\
210 	BO += 4;							\
211 	AO += 16;							\
212 	A1 += 16;							\
213 	A2 += 16;
214 
215 
216 #define SAVE48x4(ALPHA)							\
217 	zmm0   = _mm512_set1_ps(ALPHA);					\
218 	row0  *= zmm0;							\
219 	row1  *= zmm0;							\
220 	row2  *= zmm0;							\
221 	row3 *= zmm0;							\
222 	row0b *= zmm0;							\
223 	row1b *= zmm0;							\
224 	row2b *= zmm0;							\
225 	row3b *= zmm0;							\
226 	row0c *= zmm0;							\
227 	row1c *= zmm0;							\
228 	row2c *= zmm0;							\
229 	row3c *= zmm0;							\
230 	row0  += _mm512_loadu_ps(CO1 + 0*ldc);				\
231 	row1  += _mm512_loadu_ps(CO1 + 1*ldc);				\
232 	row2  += _mm512_loadu_ps(CO1 + 2*ldc);				\
233 	row3 += _mm512_loadu_ps(CO1 + 3*ldc);				\
234 	_mm512_storeu_ps(CO1 + 0*ldc, row0);				\
235 	_mm512_storeu_ps(CO1 + 1*ldc, row1);				\
236 	_mm512_storeu_ps(CO1 + 2*ldc, row2);				\
237 	_mm512_storeu_ps(CO1 + 3*ldc, row3);				\
238 	row0b  += _mm512_loadu_ps(CO1 + 0*ldc + 16);			\
239 	row1b  += _mm512_loadu_ps(CO1 + 1*ldc + 16);			\
240 	row2b  += _mm512_loadu_ps(CO1 + 2*ldc + 16);			\
241 	row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16);			\
242 	_mm512_storeu_ps(CO1 + 0*ldc + 16, row0b);			\
243 	_mm512_storeu_ps(CO1 + 1*ldc + 16, row1b);			\
244 	_mm512_storeu_ps(CO1 + 2*ldc + 16, row2b);			\
245 	_mm512_storeu_ps(CO1 + 3*ldc + 16, row3b);			\
246 	row0c  += _mm512_loadu_ps(CO1 + 0*ldc + 32);			\
247 	row1c  += _mm512_loadu_ps(CO1 + 1*ldc + 32);			\
248 	row2c  += _mm512_loadu_ps(CO1 + 2*ldc + 32);			\
249 	row3c  += _mm512_loadu_ps(CO1 + 3*ldc + 32);			\
250 	_mm512_storeu_ps(CO1 + 0*ldc + 32, row0c);			\
251 	_mm512_storeu_ps(CO1 + 1*ldc + 32, row1c);			\
252 	_mm512_storeu_ps(CO1 + 2*ldc + 32, row2c);			\
253 	_mm512_storeu_ps(CO1 + 3*ldc + 32, row3c);
254 
255 
256 #define INIT32x4()	\
257 	row0 = _mm512_setzero_ps();					\
258 	row1 = _mm512_setzero_ps();					\
259 	row2 = _mm512_setzero_ps();					\
260 	row3 = _mm512_setzero_ps();					\
261 	row0b = _mm512_setzero_ps();					\
262 	row1b = _mm512_setzero_ps();					\
263 	row2b = _mm512_setzero_ps();					\
264 	row3b = _mm512_setzero_ps();					\
265 
266 #define KERNEL32x4_SUB() 						\
267 	zmm0   = _mm512_loadu_ps(AO);					\
268 	zmm1   = _mm512_loadu_ps(A1);					\
269 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO));		\
270 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+1));		\
271 	row0  += zmm0 * zmm2;						\
272 	row1  += zmm0 * zmm3;						\
273 	row0b += zmm1 * zmm2;						\
274 	row1b += zmm1 * zmm3;						\
275 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO+2));		\
276 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+3));		\
277 	row2  += zmm0 * zmm2;						\
278 	row3  += zmm0 * zmm3;						\
279 	row2b += zmm1 * zmm2;						\
280 	row3b += zmm1 * zmm3;						\
281 	BO += 4;							\
282 	AO += 16;							\
283 	A1 += 16;
284 
285 
286 #define SAVE32x4(ALPHA)							\
287 	zmm0   = _mm512_set1_ps(ALPHA);					\
288 	row0  *= zmm0;							\
289 	row1  *= zmm0;							\
290 	row2  *= zmm0;							\
291 	row3 *= zmm0;							\
292 	row0b *= zmm0;							\
293 	row1b *= zmm0;							\
294 	row2b *= zmm0;							\
295 	row3b *= zmm0;							\
296 	row0  += _mm512_loadu_ps(CO1 + 0*ldc);				\
297 	row1  += _mm512_loadu_ps(CO1 + 1*ldc);				\
298 	row2  += _mm512_loadu_ps(CO1 + 2*ldc);				\
299 	row3 += _mm512_loadu_ps(CO1 + 3*ldc);				\
300 	_mm512_storeu_ps(CO1 + 0*ldc, row0);				\
301 	_mm512_storeu_ps(CO1 + 1*ldc, row1);				\
302 	_mm512_storeu_ps(CO1 + 2*ldc, row2);				\
303 	_mm512_storeu_ps(CO1 + 3*ldc, row3);				\
304 	row0b  += _mm512_loadu_ps(CO1 + 0*ldc + 16);			\
305 	row1b  += _mm512_loadu_ps(CO1 + 1*ldc + 16);			\
306 	row2b  += _mm512_loadu_ps(CO1 + 2*ldc + 16);			\
307 	row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16);			\
308 	_mm512_storeu_ps(CO1 + 0*ldc + 16, row0b);			\
309 	_mm512_storeu_ps(CO1 + 1*ldc + 16, row1b);			\
310 	_mm512_storeu_ps(CO1 + 2*ldc + 16, row2b);			\
311 	_mm512_storeu_ps(CO1 + 3*ldc + 16, row3b);
312 
313 
314 
315 #define INIT16x4()	\
316 	row0 = _mm512_setzero_ps();					\
317 	row1 = _mm512_setzero_ps();					\
318 	row2 = _mm512_setzero_ps();					\
319 	row3 = _mm512_setzero_ps();					\
320 
321 #define KERNEL16x4_SUB() 						\
322 	zmm0   = _mm512_loadu_ps(AO);					\
323 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO));		\
324 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+1));		\
325 	row0  += zmm0 * zmm2;						\
326 	row1  += zmm0 * zmm3;						\
327 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO+2));		\
328 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO+3));		\
329 	row2  += zmm0 * zmm2;						\
330 	row3 += zmm0 * zmm3;						\
331 	BO += 4;							\
332 	AO += 16;
333 
334 
335 #define SAVE16x4(ALPHA)							\
336 	zmm0   = _mm512_set1_ps(ALPHA);					\
337 	row0  *= zmm0;							\
338 	row1  *= zmm0;							\
339 	row2  *= zmm0;							\
340 	row3  *= zmm0;							\
341 	row0  += _mm512_loadu_ps(CO1 + 0 * ldc);			\
342 	row1  += _mm512_loadu_ps(CO1 + 1 * ldc);			\
343 	row2  += _mm512_loadu_ps(CO1 + 2 * ldc);			\
344 	row3  += _mm512_loadu_ps(CO1 + 3 * ldc);			\
345 	_mm512_storeu_ps(CO1 + 0 * ldc, row0);				\
346 	_mm512_storeu_ps(CO1 + 1 * ldc, row1);				\
347 	_mm512_storeu_ps(CO1 + 2 * ldc, row2);				\
348 	_mm512_storeu_ps(CO1 + 3 * ldc, row3);
349 
350 
351 
352 /*******************************************************************************************/
353 
354 #define INIT8x4()							\
355 	ymm4 = _mm256_setzero_ps();					\
356 	ymm6 = _mm256_setzero_ps();					\
357 	ymm8 = _mm256_setzero_ps();					\
358 	ymm10 = _mm256_setzero_ps();					\
359 
360 #define KERNEL8x4_SUB() 						\
361 	ymm0   = _mm256_loadu_ps(AO);					\
362 	ymm2   =  _mm256_broadcastss_ps(_mm_load_ss(BO + 0));		\
363 	ymm3   =  _mm256_broadcastss_ps(_mm_load_ss(BO + 1));		\
364 	ymm4  += ymm0 * ymm2;						\
365 	ymm6  += ymm0 * ymm3;						\
366 	ymm2   =  _mm256_broadcastss_ps(_mm_load_ss(BO + 2));		\
367 	ymm3   =  _mm256_broadcastss_ps(_mm_load_ss(BO + 3));		\
368 	ymm8  += ymm0 * ymm2;						\
369 	ymm10 += ymm0 * ymm3;						\
370 	BO  += 4;							\
371 	AO  += 8;
372 
373 
374 #define SAVE8x4(ALPHA)							\
375 	ymm0   = _mm256_set1_ps(ALPHA);					\
376 	ymm4  *= ymm0;							\
377 	ymm6  *= ymm0;							\
378 	ymm8  *= ymm0;							\
379 	ymm10 *= ymm0;							\
380 	ymm4  += _mm256_loadu_ps(CO1 + 0 * ldc);			\
381 	ymm6  += _mm256_loadu_ps(CO1 + 1 * ldc);			\
382 	ymm8  += _mm256_loadu_ps(CO1 + 2 * ldc);			\
383 	ymm10 += _mm256_loadu_ps(CO1 + 3 * ldc);			\
384 	_mm256_storeu_ps(CO1 + 0 * ldc, ymm4);				\
385 	_mm256_storeu_ps(CO1 + 1 * ldc, ymm6);				\
386 	_mm256_storeu_ps(CO1 + 2 * ldc, ymm8);				\
387 	_mm256_storeu_ps(CO1 + 3 * ldc, ymm10);				\
388 
389 
390 
391 /*******************************************************************************************/
392 
393 #define INIT4x4()							\
394 	row0 = _mm_setzero_ps();					\
395 	row1 = _mm_setzero_ps();					\
396 	row2 = _mm_setzero_ps();					\
397 	row3 = _mm_setzero_ps();					\
398 
399 
400 #define KERNEL4x4_SUB() 						\
401 	xmm0   = _mm_loadu_ps(AO);					\
402 	xmm2   =  _mm_broadcastss_ps(_mm_load_ss(BO + 0));		\
403 	xmm3   =  _mm_broadcastss_ps(_mm_load_ss(BO + 1));		\
404 	row0  += xmm0 * xmm2;						\
405 	row1  += xmm0 * xmm3;						\
406 	xmm2   =  _mm_broadcastss_ps(_mm_load_ss(BO + 2));		\
407 	xmm3   =  _mm_broadcastss_ps(_mm_load_ss(BO + 3));		\
408 	row2  += xmm0 * xmm2;						\
409 	row3  += xmm0 * xmm3;						\
410 	BO  += 4;							\
411 	AO  += 4;
412 
413 
414 #define SAVE4x4(ALPHA)							\
415 	xmm0   = _mm_set1_ps(ALPHA);					\
416 	row0  *= xmm0;							\
417 	row1  *= xmm0;							\
418 	row2  *= xmm0;							\
419 	row3  *= xmm0;							\
420 	row0  += _mm_loadu_ps(CO1 + 0 * ldc);				\
421 	row1  += _mm_loadu_ps(CO1 + 1 * ldc);				\
422 	row2  += _mm_loadu_ps(CO1 + 2 * ldc);				\
423 	row3  += _mm_loadu_ps(CO1 + 3 * ldc);				\
424 	_mm_storeu_ps(CO1 + 0 * ldc, row0);				\
425 	_mm_storeu_ps(CO1 + 1 * ldc, row1);				\
426 	_mm_storeu_ps(CO1 + 2 * ldc, row2);				\
427 	_mm_storeu_ps(CO1 + 3 * ldc, row3);				\
428 
429 
430 /*******************************************************************************************/
431 
432 #define INIT2x4() 	\
433 	row0 = 0; row0b = 0; row1 = 0; row1b = 0; 			\
434 	row2 = 0; row2b = 0; row3 = 0; row3b = 0;
435 
436 #define KERNEL2x4_SUB()							\
437 	xmm0  = *(AO);							\
438 	xmm1  = *(AO + 1);						\
439 	xmm2  = *(BO + 0);						\
440 	xmm3  = *(BO + 1);						\
441 	row0 += xmm0 * xmm2;						\
442 	row0b += xmm1 * xmm2;						\
443 	row1 += xmm0 * xmm3;						\
444 	row1b += xmm1 * xmm3;						\
445 	xmm2 = *(BO + 2);						\
446 	xmm3 = *(BO + 3);						\
447 	row2 += xmm0 * xmm2;						\
448 	row2b += xmm1 * xmm2;						\
449 	row3 += xmm0 * xmm3;						\
450 	row3b += xmm1 * xmm3;						\
451 	BO += 4;							\
452 	AO += 2;
453 
454 
455 #define SAVE2x4(ALPHA)							\
456 	xmm0   = ALPHA;							\
457 	row0  *= xmm0;							\
458 	row0b *= xmm0;							\
459 	row1  *= xmm0;							\
460 	row1b *= xmm0;							\
461 	row2  *= xmm0;							\
462 	row2b *= xmm0;							\
463 	row3  *= xmm0;							\
464 	row3b *= xmm0;							\
465 	*(CO1 + 0 * ldc + 0) += row0;					\
466 	*(CO1 + 0 * ldc + 1) += row0b;					\
467 	*(CO1 + 1 * ldc + 0) += row1;					\
468 	*(CO1 + 1 * ldc + 1) += row1b;					\
469 	*(CO1 + 2 * ldc + 0) += row2;					\
470 	*(CO1 + 2 * ldc + 1) += row2b;					\
471 	*(CO1 + 3 * ldc + 0) += row3;					\
472 	*(CO1 + 3 * ldc + 1) += row3b;					\
473 
474 
475 
476 /*******************************************************************************************/
477 
478 #define INIT1x4() \
479 	row0 = 0; row1 = 0; row2 = 0; row3 = 0;
480 #define KERNEL1x4_SUB()							\
481 	xmm0  = *(AO );							\
482 	xmm2  = *(BO + 0);						\
483 	xmm3  = *(BO + 1);						\
484 	row0 += xmm0 * xmm2;						\
485 	row1 += xmm0 * xmm3;						\
486 	xmm2   = *(BO + 2);						\
487 	xmm3   = *(BO + 3);						\
488 	row2  += xmm0 * xmm2;						\
489 	row3 += xmm0 * xmm3;						\
490 	BO += 4;							\
491 	AO += 1;
492 
493 
494 #define SAVE1x4(ALPHA)							\
495 	xmm0   = ALPHA;							\
496 	row0  *= xmm0;							\
497 	row1  *= xmm0;							\
498 	row2  *= xmm0;							\
499 	row3  *= xmm0;							\
500 	*(CO1 + 0 * ldc) += row0;					\
501 	*(CO1 + 1 * ldc) += row1;					\
502 	*(CO1 + 2 * ldc) += row2;					\
503 	*(CO1 + 3 * ldc) += row3;					\
504 
505 
506 
507 /*******************************************************************************************/
508 
509 /*******************************************************************************************
510 * 2 lines of N
511 *******************************************************************************************/
512 
513 #define INIT16x2()							\
514 	row0 = _mm512_setzero_ps();					\
515 	row1 = _mm512_setzero_ps();					\
516 
517 
518 #define KERNEL16x2_SUB() 						\
519 	zmm0   = _mm512_loadu_ps(AO);					\
520 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO));		\
521 	zmm3   =  _mm512_broadcastss_ps(_mm_load_ss(BO + 1));		\
522 	row0  += zmm0 * zmm2;						\
523 	row1  += zmm0 * zmm3;						\
524 	BO += 2;							\
525 	AO += 16;
526 
527 
528 #define SAVE16x2(ALPHA)							\
529 	zmm0   = _mm512_set1_ps(ALPHA);					\
530 	row0  *= zmm0;							\
531 	row1  *= zmm0;							\
532 	row0  += _mm512_loadu_ps(CO1);					\
533 	row1  += _mm512_loadu_ps(CO1 + ldc);				\
534 	_mm512_storeu_ps(CO1      , row0);				\
535 	_mm512_storeu_ps(CO1 + ldc, row1);				\
536 
537 
538 
539 
540 /*******************************************************************************************/
541 
542 #define INIT8x2()	\
543 	ymm4 = _mm256_setzero_ps();					\
544 	ymm6 = _mm256_setzero_ps();					\
545 
546 #define KERNEL8x2_SUB() 						\
547 	ymm0   = _mm256_loadu_ps(AO);					\
548 	ymm2   =  _mm256_broadcastss_ps(_mm_load_ss(BO));		\
549 	ymm3   =  _mm256_broadcastss_ps(_mm_load_ss(BO + 1));		\
550 	ymm4  += ymm0 * ymm2;						\
551 	ymm6  += ymm0 * ymm3;						\
552 	BO  += 2;							\
553 	AO  += 8;
554 
555 
556 #define SAVE8x2(ALPHA)							\
557 	ymm0   = _mm256_set1_ps(ALPHA);					\
558 	ymm4  *= ymm0;							\
559 	ymm6  *= ymm0;							\
560 	ymm4  += _mm256_loadu_ps(CO1);					\
561 	ymm6  += _mm256_loadu_ps(CO1 + ldc);				\
562 	_mm256_storeu_ps(CO1      , ymm4);				\
563 	_mm256_storeu_ps(CO1 + ldc, ymm6);				\
564 
565 
566 
567 /*******************************************************************************************/
568 
569 #define INIT4x2()	\
570 	row0 = _mm_setzero_ps(); 					\
571 	row1 = _mm_setzero_ps(); 					\
572 
573 #define KERNEL4x2_SUB() 						\
574 	xmm0   = _mm_loadu_ps(AO);					\
575 	xmm2   =  _mm_broadcastss_ps(_mm_load_ss(BO));			\
576 	xmm3   =  _mm_broadcastss_ps(_mm_load_ss(BO + 1));		\
577 	row0  += xmm0 * xmm2;						\
578 	row1  += xmm0 * xmm3;						\
579 	BO  += 2;							\
580 	AO  += 4;
581 
582 
583 #define SAVE4x2(ALPHA)							\
584 	xmm0   = _mm_set1_ps(ALPHA);					\
585 	row0  *= xmm0;							\
586 	row1  *= xmm0;							\
587 	row0  += _mm_loadu_ps(CO1);					\
588 	row1  += _mm_loadu_ps(CO1 + ldc);				\
589 	_mm_storeu_ps(CO1      , row0);					\
590 	_mm_storeu_ps(CO1 + ldc, row1);					\
591 
592 
593 
594 /*******************************************************************************************/
595 
596 
597 #define INIT2x2() 	\
598 	row0 = 0; row0b = 0; row1 = 0; row1b = 0; 			\
599 
600 #define KERNEL2x2_SUB()							\
601 	xmm0  = *(AO + 0);						\
602 	xmm1  = *(AO + 1);						\
603 	xmm2  = *(BO + 0);						\
604 	xmm3  = *(BO + 1);						\
605 	row0 += xmm0 * xmm2;						\
606 	row0b += xmm1 * xmm2;						\
607 	row1 += xmm0 * xmm3;						\
608 	row1b += xmm1 * xmm3;						\
609 	BO += 2;							\
610 	AO += 2;							\
611 
612 
613 #define SAVE2x2(ALPHA)							\
614 	xmm0   = ALPHA;							\
615 	row0  *= xmm0;							\
616 	row0b  *= xmm0;							\
617 	row1  *= xmm0;							\
618 	row1b  *= xmm0;							\
619 	*(CO1         ) += row0;					\
620 	*(CO1 +1      ) += row0b;					\
621 	*(CO1 + ldc   ) += row1;					\
622 	*(CO1 + ldc +1) += row1b;					\
623 
624 
625 /*******************************************************************************************/
626 
627 #define INIT1x2()	\
628 	row0 = 0; row1 = 0;
629 
630 #define KERNEL1x2_SUB()							\
631 	xmm0  = *(AO);							\
632 	xmm2  = *(BO + 0);						\
633 	xmm3  = *(BO + 1);						\
634 	row0 += xmm0 * xmm2;						\
635 	row1 += xmm0 * xmm3;						\
636 	BO += 2;							\
637 	AO += 1;
638 
639 
640 #define SAVE1x2(ALPHA)							\
641 	xmm0   = ALPHA;							\
642 	row0  *= xmm0;							\
643 	row1  *= xmm0;							\
644 	*(CO1         ) += row0;					\
645 	*(CO1 + ldc   ) += row1;					\
646 
647 
648 /*******************************************************************************************/
649 
650 /*******************************************************************************************
651 * 1 line of N
652 *******************************************************************************************/
653 
654 #define INIT16x1() \
655 	row0 = _mm512_setzero_ps();				\
656 
657 #define KERNEL16x1_SUB() 						\
658 	zmm0   = _mm512_loadu_ps(AO);			\
659 	zmm2   =  _mm512_broadcastss_ps(_mm_load_ss(BO));		\
660 	row0  += zmm0 * zmm2;						\
661 	BO += 1;							\
662 	AO += 16;
663 
664 
665 #define SAVE16x1(ALPHA)							\
666 	zmm0   = _mm512_set1_ps(ALPHA);					\
667 	row0  *= zmm0;							\
668 	row0  += _mm512_loadu_ps(CO1);					\
669 	_mm512_storeu_ps(CO1      , row0);				\
670 
671 
672 /*******************************************************************************************/
673 
674 #define INIT8x1()							\
675 	ymm4 = _mm256_setzero_ps();
676 
677 #define KERNEL8x1_SUB() 						\
678 	ymm0   = _mm256_loadu_ps(AO);					\
679 	ymm2   =  _mm256_broadcastss_ps(_mm_load_ss(BO));		\
680 	ymm4  += ymm0 * ymm2;						\
681 	BO  += 1;							\
682 	AO  += 8;
683 
684 
685 #define SAVE8x1(ALPHA)							\
686 	ymm0   = _mm256_set1_ps(ALPHA);					\
687 	ymm4  *= ymm0;							\
688 	ymm4  += _mm256_loadu_ps(CO1);					\
689 	_mm256_storeu_ps(CO1      , ymm4);				\
690 
691 
692 /*******************************************************************************************/
693 
694 #define INIT4x1()							\
695 	row0 = _mm_setzero_ps();					\
696 
697 #define KERNEL4x1_SUB() 						\
698 	xmm0   = _mm_loadu_ps(AO);					\
699 	xmm2   =  _mm_broadcastss_ps(_mm_load_ss(BO));			\
700 	row0  += xmm0 * xmm2;						\
701 	BO    += 1;							\
702 	AO    += 4;
703 
704 
705 #define SAVE4x1(ALPHA)							\
706 	xmm0   = _mm_set1_ps(ALPHA);					\
707 	row0  *= xmm0;							\
708 	row0  += _mm_loadu_ps(CO1);					\
709 	_mm_storeu_ps(CO1      , row0);					\
710 
711 
712 
713 /*******************************************************************************************/
714 
715 #define INIT2x1()							\
716 	row0 = 0; row0b = 0;
717 
718 #define KERNEL2x1_SUB()							\
719 	xmm0  = *(AO + 0);						\
720 	xmm1  = *(AO + 1);						\
721 	xmm2  = *(BO);							\
722 	row0 += xmm0 * xmm2;						\
723 	row0b += xmm1 * xmm2;						\
724 	BO += 1;							\
725 	AO += 2;
726 
727 
728 #define SAVE2x1(ALPHA)							\
729 	xmm0   = ALPHA;							\
730 	row0  *= xmm0;							\
731 	row0b  *= xmm0;							\
732 	*(CO1         ) += row0;					\
733 	*(CO1 +1      ) += row0b;					\
734 
735 
736 /*******************************************************************************************/
737 
738 #define INIT1x1()							\
739 	row0 = 0;
740 
741 #define KERNEL1x1_SUB()							\
742 	xmm0  = *(AO);							\
743 	xmm2  = *(BO);							\
744 	row0 += xmm0 * xmm2;						\
745 	BO += 1;							\
746 	AO += 1;
747 
748 
749 #define SAVE1x1(ALPHA)							\
750 	xmm0   = ALPHA;							\
751 	row0  *= xmm0;							\
752 	*(CO1         ) += row0;					\
753 
754 
755 /*******************************************************************************************/
756 
757 
758 /*************************************************************************************
759 * GEMM Kernel
760 *************************************************************************************/
761 
762 int __attribute__ ((noinline))
CNAME(BLASLONG m,BLASLONG n,BLASLONG k,float alpha,float * __restrict A,float * __restrict B,float * __restrict C,BLASLONG ldc)763 CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict A, float * __restrict B, float * __restrict C, BLASLONG ldc)
764 {
765 	unsigned long long M = m, N = n, K = k;
766 	if (M == 0)
767 		return 0;
768 	if (N == 0)
769 		return 0;
770 	if (K == 0)
771 		return 0;
772 
773 
774 	while (N >= 4) {
775 		float *CO1;
776 		float *AO;
777 		int i;
778 		// L8_10
779 		CO1 = C;
780 		C += 4 * ldc;
781 
782 		AO = A;
783 
784 		i = m;
785 		while (i >= 64) {
786 			float *BO;
787 			float *A1, *A2, *A3;
788 			// L8_11
789 			__m512 zmm0, zmm1, zmm2, zmm3, row0, zmm5, row1, zmm7, row2, row3, row0b, row1b, row2b, row3b, row0c, row1c, row2c, row3c, row0d, row1d, row2d, row3d;
790 			BO = B;
791 			int kloop = K;
792 
793 			A1 = AO + 16 * K;
794 			A2 = A1 + 16 * K;
795 			A3 = A2 + 16 * K;
796 
797 			INIT64x4()
798 
799 			while (kloop > 0) {
800 				// L12_17
801 				KERNEL64x4_SUB()
802 				kloop--;
803 			}
804 			// L8_19
805 			SAVE64x4(alpha)
806 			CO1 += 64;
807 			AO += 48 * K;
808 
809 			i -= 64;
810 		}
811 		while (i >= 32) {
812 			float *BO;
813 			float *A1;
814 			// L8_11
815 			__m512 zmm0, zmm1, zmm2, zmm3, row0, row1, row2, row3, row0b, row1b, row2b, row3b;
816 			BO = B;
817 			int kloop = K;
818 
819 			A1 = AO + 16 * K;
820 
821 			INIT32x4()
822 
823 			while (kloop > 0) {
824 				// L12_17
825 				KERNEL32x4_SUB()
826 				kloop--;
827 			}
828 			// L8_19
829 			SAVE32x4(alpha)
830 			CO1 += 32;
831 			AO += 16 * K;
832 
833 			i -= 32;
834 		}
835 		while (i >= 16) {
836 			float *BO;
837 			// L8_11
838 			__m512 zmm0, zmm2, zmm3, row0, row1, row2, row3;
839 			BO = B;
840 			int kloop = K;
841 
842 			INIT16x4()
843 
844 			while (kloop > 0) {
845 				// L12_17
846 				KERNEL16x4_SUB()
847 				kloop--;
848 			}
849 			// L8_19
850 			SAVE16x4(alpha)
851 			CO1 += 16;
852 
853 			i -= 16;
854 		}
855 		while (i >= 8) {
856 			float *BO;
857 			// L8_11
858 			__m256 ymm0, ymm2, ymm3, ymm4, ymm6,ymm8,ymm10;
859 			BO = B;
860 			int kloop = K;
861 
862 			INIT8x4()
863 
864 			while (kloop > 0) {
865 				// L12_17
866 				KERNEL8x4_SUB()
867 				kloop--;
868 			}
869 			// L8_19
870 			SAVE8x4(alpha)
871 			CO1 += 8;
872 
873 			i -= 8;
874 		}
875 		while (i >= 4) {
876 			// L8_11
877 			float *BO;
878 			__m128 xmm0, xmm2, xmm3, row0, row1, row2, row3;
879 			BO = B;
880 			int kloop = K;
881 
882 			INIT4x4()
883 			// L8_16
884 			while (kloop > 0) {
885 				// L12_17
886 				KERNEL4x4_SUB()
887 				kloop--;
888 			}
889 			// L8_19
890 			SAVE4x4(alpha)
891 			CO1 += 4;
892 
893 			i -= 4;
894 		}
895 
896 /**************************************************************************
897 * Rest of M
898 ***************************************************************************/
899 
900 		while (i >= 2) {
901 			float *BO;
902 			float xmm0, xmm1, xmm2, xmm3, row0, row0b, row1, row1b, row2, row2b, row3, row3b;
903 			BO = B;
904 
905 			INIT2x4()
906 			int kloop = K;
907 
908 			while (kloop > 0) {
909 				KERNEL2x4_SUB()
910 				kloop--;
911 			}
912 			SAVE2x4(alpha)
913 			CO1 += 2;
914 			i -= 2;
915 		}
916 			// L13_40
917 		while (i >= 1) {
918 			float *BO;
919 			float xmm0, xmm2, xmm3, row0, row1, row2, row3;
920 			int kloop = K;
921 			BO = B;
922 			INIT1x4()
923 
924 			while (kloop > 0) {
925 				KERNEL1x4_SUB()
926 				kloop--;
927 			}
928 			SAVE1x4(alpha)
929 			CO1 += 1;
930 			i -= 1;
931 		}
932 
933 		B += K * 4;
934 		N -= 4;
935 	}
936 
937 /**************************************************************************************************/
938 
939 		// L8_0
940 	while (N >= 2) {
941 		float *CO1;
942 		float *AO;
943 		int i;
944 		// L8_10
945 		CO1 = C;
946 		C += 2 * ldc;
947 
948 		AO = A;
949 
950 		i = m;
951 		while (i >= 16) {
952 			float *BO;
953 
954 			// L8_11
955 			__m512 zmm0, zmm2, zmm3, row0, row1;
956 			BO = B;
957 			int kloop = K;
958 
959 			INIT16x2()
960 
961 			while (kloop > 0) {
962 				// L12_17
963 				KERNEL16x2_SUB()
964 				kloop--;
965 			}
966 			// L8_19
967 			SAVE16x2(alpha)
968 			CO1 += 16;
969 
970 			i -= 16;
971 		}
972 		while (i >= 8) {
973 			float *BO;
974 			__m256 ymm0, ymm2, ymm3, ymm4, ymm6;
975 			// L8_11
976 			BO = B;
977 			int kloop = K;
978 
979 			INIT8x2()
980 
981 			// L8_16
982 			while (kloop > 0) {
983 				// L12_17
984 				KERNEL8x2_SUB()
985 				kloop--;
986 			}
987 			// L8_19
988 			SAVE8x2(alpha)
989 			CO1 += 8;
990 
991 			i-=8;
992 		}
993 
994 		while (i >= 4) {
995 			float *BO;
996 			__m128 xmm0, xmm2, xmm3, row0, row1;
997 			// L8_11
998 			BO = B;
999 			int kloop = K;
1000 
1001 			INIT4x2()
1002 
1003 			// L8_16
1004 			while (kloop > 0) {
1005 				// L12_17
1006 				KERNEL4x2_SUB()
1007 				kloop--;
1008 			}
1009 			// L8_19
1010 			SAVE4x2(alpha)
1011 			CO1 += 4;
1012 
1013 			i-=4;
1014 		}
1015 
1016 /**************************************************************************
1017 * Rest of M
1018 ***************************************************************************/
1019 
1020 		while (i >= 2) {
1021 			float *BO;
1022 			float xmm0, xmm1, xmm2, xmm3, row0, row0b, row1, row1b;
1023 			int kloop = K;
1024 			BO = B;
1025 
1026 			INIT2x2()
1027 
1028 			while (kloop > 0) {
1029 				KERNEL2x2_SUB()
1030 				kloop--;
1031 			}
1032 			SAVE2x2(alpha)
1033 			CO1 += 2;
1034 			i -= 2;
1035 		}
1036 			// L13_40
1037 		while (i >= 1) {
1038 			float *BO;
1039 			float xmm0, xmm2, xmm3, row0, row1;
1040 			int kloop = K;
1041 			BO = B;
1042 
1043 			INIT1x2()
1044 
1045 			while (kloop > 0) {
1046 				KERNEL1x2_SUB()
1047 				kloop--;
1048 			}
1049 			SAVE1x2(alpha)
1050 			CO1 += 1;
1051 			i -= 1;
1052 		}
1053 
1054 		B += K * 2;
1055 		N -= 2;
1056 	}
1057 
1058 		// L8_0
1059 	while (N >= 1) {
1060 		// L8_10
1061 		float *CO1;
1062 		float *AO;
1063 		int i;
1064 
1065 		CO1 = C;
1066 		C += ldc;
1067 
1068 		AO = A;
1069 
1070 		i = m;
1071 		while (i >= 16) {
1072 			float *BO;
1073 			__m512 zmm0, zmm2, row0;
1074 			// L8_11
1075 			BO = B;
1076 			int kloop = K;
1077 
1078 			INIT16x1()
1079 			// L8_16
1080 			while (kloop > 0) {
1081 				// L12_17
1082 				KERNEL16x1_SUB()
1083 				kloop--;
1084 			}
1085 			// L8_19
1086 			SAVE16x1(alpha)
1087 			CO1 += 16;
1088 
1089 			i-= 16;
1090 		}
1091 		while (i >= 8) {
1092 			float *BO;
1093 			__m256 ymm0, ymm2, ymm4;
1094 			// L8_11
1095 			BO = B;
1096 			int kloop = K;
1097 
1098 			INIT8x1()
1099 			// L8_16
1100 			while (kloop > 0) {
1101 				// L12_17
1102 				KERNEL8x1_SUB()
1103 				kloop--;
1104 			}
1105 			// L8_19
1106 			SAVE8x1(alpha)
1107 			CO1 += 8;
1108 
1109 			i-= 8;
1110 		}
1111 		while (i >= 4) {
1112 			float *BO;
1113 			__m128 xmm0, xmm2, row0;
1114 			// L8_11
1115 			BO = B;
1116 			int kloop = K;
1117 
1118 			INIT4x1()
1119 			// L8_16
1120 			while (kloop > 0) {
1121 				// L12_17
1122 				KERNEL4x1_SUB()
1123 				kloop--;
1124 			}
1125 			// L8_19
1126 			SAVE4x1(alpha)
1127 			CO1 += 4;
1128 
1129 			i-= 4;
1130 		}
1131 
1132 /**************************************************************************
1133 * Rest of M
1134 ***************************************************************************/
1135 
1136 		while (i >= 2) {
1137 			float *BO;
1138 			float xmm0, xmm1, xmm2, row0, row0b;
1139 			int kloop = K;
1140 			BO = B;
1141 
1142 			INIT2x1()
1143 
1144 			while (kloop > 0) {
1145 				KERNEL2x1_SUB()
1146 				kloop--;
1147 			}
1148 			SAVE2x1(alpha)
1149 			CO1 += 2;
1150 			i -= 2;
1151 		}
1152 				// L13_40
1153 		while (i >= 1) {
1154 			float *BO;
1155 			float xmm0, xmm2, row0;
1156 			int kloop = K;
1157 
1158 			BO = B;
1159 			INIT1x1()
1160 
1161 
1162 			while (kloop > 0) {
1163 				KERNEL1x1_SUB()
1164 				kloop--;
1165 			}
1166 			SAVE1x1(alpha)
1167 			CO1 += 1;
1168 			i -= 1;
1169 		}
1170 
1171 		B += K * 1;
1172 		N -= 1;
1173 	}
1174 
1175 
1176 	return 0;
1177 }
1178 
1179 #include "sgemm_direct_skylakex.c"
1180