1 /*********************************************************************************
2 Copyright (c) 2013, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 **********************************************************************************/
27
28
29 /* comment below left for history, data does not represent the implementation in this file */
30
31 /*********************************************************************
32 * 2014/07/28 Saar
33 * BLASTEST : OK
34 * CTEST : OK
35 * TEST : OK
36 *
37 * 2013/10/28 Saar
38 * Parameter:
39 * SGEMM_DEFAULT_UNROLL_N 4
40 * SGEMM_DEFAULT_UNROLL_M 16
41 * SGEMM_DEFAULT_P 768
42 * SGEMM_DEFAULT_Q 384
43 * A_PR1 512
44 * B_PR1 512
45 *
46 *
47 * 2014/07/28 Saar
48 * Performance at 9216x9216x9216:
49 * 1 thread: 102 GFLOPS (SANDYBRIDGE: 59) (MKL: 83)
50 * 2 threads: 195 GFLOPS (SANDYBRIDGE: 116) (MKL: 155)
51 * 3 threads: 281 GFLOPS (SANDYBRIDGE: 165) (MKL: 230)
52 * 4 threads: 366 GFLOPS (SANDYBRIDGE: 223) (MKL: 267)
53 *
54 *********************************************************************/
55
56 #include "common.h"
57 #include <immintrin.h>
58
59
60
61 /*******************************************************************************************
62 * 8 lines of N
63 *******************************************************************************************/
64
65
66
67
68
69
70 /*******************************************************************************************
71 * 4 lines of N
72 *******************************************************************************************/
73
74 #define INIT64x4() \
75 row0 = _mm512_setzero_ps(); \
76 row1 = _mm512_setzero_ps(); \
77 row2 = _mm512_setzero_ps(); \
78 row3 = _mm512_setzero_ps(); \
79 row0b = _mm512_setzero_ps(); \
80 row1b = _mm512_setzero_ps(); \
81 row2b = _mm512_setzero_ps(); \
82 row3b = _mm512_setzero_ps(); \
83 row0c = _mm512_setzero_ps(); \
84 row1c = _mm512_setzero_ps(); \
85 row2c = _mm512_setzero_ps(); \
86 row3c = _mm512_setzero_ps(); \
87 row0d = _mm512_setzero_ps(); \
88 row1d = _mm512_setzero_ps(); \
89 row2d = _mm512_setzero_ps(); \
90 row3d = _mm512_setzero_ps(); \
91
92 #define KERNEL64x4_SUB() \
93 zmm0 = _mm512_loadu_ps(AO); \
94 zmm1 = _mm512_loadu_ps(A1); \
95 zmm5 = _mm512_loadu_ps(A2); \
96 zmm7 = _mm512_loadu_ps(A3); \
97 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
98 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
99 row0 += zmm0 * zmm2; \
100 row1 += zmm0 * zmm3; \
101 row0b += zmm1 * zmm2; \
102 row1b += zmm1 * zmm3; \
103 row0c += zmm5 * zmm2; \
104 row1c += zmm5 * zmm3; \
105 row0d += zmm7 * zmm2; \
106 row1d += zmm7 * zmm3; \
107 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
108 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
109 row2 += zmm0 * zmm2; \
110 row3 += zmm0 * zmm3; \
111 row2b += zmm1 * zmm2; \
112 row3b += zmm1 * zmm3; \
113 row2c += zmm5 * zmm2; \
114 row3c += zmm5 * zmm3; \
115 row2d += zmm7 * zmm2; \
116 row3d += zmm7 * zmm3; \
117 BO += 4; \
118 AO += 16; \
119 A1 += 16; \
120 A2 += 16; \
121 A3 += 16; \
122
123
124 #define SAVE64x4(ALPHA) \
125 zmm0 = _mm512_set1_ps(ALPHA); \
126 row0 *= zmm0; \
127 row1 *= zmm0; \
128 row2 *= zmm0; \
129 row3 *= zmm0; \
130 row0b *= zmm0; \
131 row1b *= zmm0; \
132 row2b *= zmm0; \
133 row3b *= zmm0; \
134 row0c *= zmm0; \
135 row1c *= zmm0; \
136 row2c *= zmm0; \
137 row3c *= zmm0; \
138 row0d *= zmm0; \
139 row1d *= zmm0; \
140 row2d *= zmm0; \
141 row3d *= zmm0; \
142 row0 += _mm512_loadu_ps(CO1 + 0*ldc); \
143 row1 += _mm512_loadu_ps(CO1 + 1*ldc); \
144 row2 += _mm512_loadu_ps(CO1 + 2*ldc); \
145 row3 += _mm512_loadu_ps(CO1 + 3*ldc); \
146 _mm512_storeu_ps(CO1 + 0*ldc, row0); \
147 _mm512_storeu_ps(CO1 + 1*ldc, row1); \
148 _mm512_storeu_ps(CO1 + 2*ldc, row2); \
149 _mm512_storeu_ps(CO1 + 3*ldc, row3); \
150 row0b += _mm512_loadu_ps(CO1 + 0*ldc + 16); \
151 row1b += _mm512_loadu_ps(CO1 + 1*ldc + 16); \
152 row2b += _mm512_loadu_ps(CO1 + 2*ldc + 16); \
153 row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16); \
154 _mm512_storeu_ps(CO1 + 0*ldc + 16, row0b); \
155 _mm512_storeu_ps(CO1 + 1*ldc + 16, row1b); \
156 _mm512_storeu_ps(CO1 + 2*ldc + 16, row2b); \
157 _mm512_storeu_ps(CO1 + 3*ldc + 16, row3b); \
158 row0c += _mm512_loadu_ps(CO1 + 0*ldc + 32); \
159 row1c += _mm512_loadu_ps(CO1 + 1*ldc + 32); \
160 row2c += _mm512_loadu_ps(CO1 + 2*ldc + 32); \
161 row3c += _mm512_loadu_ps(CO1 + 3*ldc + 32); \
162 _mm512_storeu_ps(CO1 + 0*ldc + 32, row0c); \
163 _mm512_storeu_ps(CO1 + 1*ldc + 32, row1c); \
164 _mm512_storeu_ps(CO1 + 2*ldc + 32, row2c); \
165 _mm512_storeu_ps(CO1 + 3*ldc + 32, row3c); \
166 row0d += _mm512_loadu_ps(CO1 + 0*ldc + 48); \
167 row1d += _mm512_loadu_ps(CO1 + 1*ldc + 48); \
168 row2d += _mm512_loadu_ps(CO1 + 2*ldc + 48); \
169 row3d += _mm512_loadu_ps(CO1 + 3*ldc + 48); \
170 _mm512_storeu_ps(CO1 + 0*ldc + 48, row0d); \
171 _mm512_storeu_ps(CO1 + 1*ldc + 48, row1d); \
172 _mm512_storeu_ps(CO1 + 2*ldc + 48, row2d); \
173 _mm512_storeu_ps(CO1 + 3*ldc + 48, row3d);
174
175
176 #define INIT48x4() \
177 row0 = _mm512_setzero_ps(); \
178 row1 = _mm512_setzero_ps(); \
179 row2 = _mm512_setzero_ps(); \
180 row3 = _mm512_setzero_ps(); \
181 row0b = _mm512_setzero_ps(); \
182 row1b = _mm512_setzero_ps(); \
183 row2b = _mm512_setzero_ps(); \
184 row3b = _mm512_setzero_ps(); \
185 row0c = _mm512_setzero_ps(); \
186 row1c = _mm512_setzero_ps(); \
187 row2c = _mm512_setzero_ps(); \
188 row3c = _mm512_setzero_ps(); \
189
190 #define KERNEL48x4_SUB() \
191 zmm0 = _mm512_loadu_ps(AO); \
192 zmm1 = _mm512_loadu_ps(A1); \
193 zmm5 = _mm512_loadu_ps(A2); \
194 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
195 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
196 row0 += zmm0 * zmm2; \
197 row1 += zmm0 * zmm3; \
198 row0b += zmm1 * zmm2; \
199 row1b += zmm1 * zmm3; \
200 row0c += zmm5 * zmm2; \
201 row1c += zmm5 * zmm3; \
202 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
203 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
204 row2 += zmm0 * zmm2; \
205 row3 += zmm0 * zmm3; \
206 row2b += zmm1 * zmm2; \
207 row3b += zmm1 * zmm3; \
208 row2c += zmm5 * zmm2; \
209 row3c += zmm5 * zmm3; \
210 BO += 4; \
211 AO += 16; \
212 A1 += 16; \
213 A2 += 16;
214
215
216 #define SAVE48x4(ALPHA) \
217 zmm0 = _mm512_set1_ps(ALPHA); \
218 row0 *= zmm0; \
219 row1 *= zmm0; \
220 row2 *= zmm0; \
221 row3 *= zmm0; \
222 row0b *= zmm0; \
223 row1b *= zmm0; \
224 row2b *= zmm0; \
225 row3b *= zmm0; \
226 row0c *= zmm0; \
227 row1c *= zmm0; \
228 row2c *= zmm0; \
229 row3c *= zmm0; \
230 row0 += _mm512_loadu_ps(CO1 + 0*ldc); \
231 row1 += _mm512_loadu_ps(CO1 + 1*ldc); \
232 row2 += _mm512_loadu_ps(CO1 + 2*ldc); \
233 row3 += _mm512_loadu_ps(CO1 + 3*ldc); \
234 _mm512_storeu_ps(CO1 + 0*ldc, row0); \
235 _mm512_storeu_ps(CO1 + 1*ldc, row1); \
236 _mm512_storeu_ps(CO1 + 2*ldc, row2); \
237 _mm512_storeu_ps(CO1 + 3*ldc, row3); \
238 row0b += _mm512_loadu_ps(CO1 + 0*ldc + 16); \
239 row1b += _mm512_loadu_ps(CO1 + 1*ldc + 16); \
240 row2b += _mm512_loadu_ps(CO1 + 2*ldc + 16); \
241 row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16); \
242 _mm512_storeu_ps(CO1 + 0*ldc + 16, row0b); \
243 _mm512_storeu_ps(CO1 + 1*ldc + 16, row1b); \
244 _mm512_storeu_ps(CO1 + 2*ldc + 16, row2b); \
245 _mm512_storeu_ps(CO1 + 3*ldc + 16, row3b); \
246 row0c += _mm512_loadu_ps(CO1 + 0*ldc + 32); \
247 row1c += _mm512_loadu_ps(CO1 + 1*ldc + 32); \
248 row2c += _mm512_loadu_ps(CO1 + 2*ldc + 32); \
249 row3c += _mm512_loadu_ps(CO1 + 3*ldc + 32); \
250 _mm512_storeu_ps(CO1 + 0*ldc + 32, row0c); \
251 _mm512_storeu_ps(CO1 + 1*ldc + 32, row1c); \
252 _mm512_storeu_ps(CO1 + 2*ldc + 32, row2c); \
253 _mm512_storeu_ps(CO1 + 3*ldc + 32, row3c);
254
255
256 #define INIT32x4() \
257 row0 = _mm512_setzero_ps(); \
258 row1 = _mm512_setzero_ps(); \
259 row2 = _mm512_setzero_ps(); \
260 row3 = _mm512_setzero_ps(); \
261 row0b = _mm512_setzero_ps(); \
262 row1b = _mm512_setzero_ps(); \
263 row2b = _mm512_setzero_ps(); \
264 row3b = _mm512_setzero_ps(); \
265
266 #define KERNEL32x4_SUB() \
267 zmm0 = _mm512_loadu_ps(AO); \
268 zmm1 = _mm512_loadu_ps(A1); \
269 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
270 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
271 row0 += zmm0 * zmm2; \
272 row1 += zmm0 * zmm3; \
273 row0b += zmm1 * zmm2; \
274 row1b += zmm1 * zmm3; \
275 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
276 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
277 row2 += zmm0 * zmm2; \
278 row3 += zmm0 * zmm3; \
279 row2b += zmm1 * zmm2; \
280 row3b += zmm1 * zmm3; \
281 BO += 4; \
282 AO += 16; \
283 A1 += 16;
284
285
286 #define SAVE32x4(ALPHA) \
287 zmm0 = _mm512_set1_ps(ALPHA); \
288 row0 *= zmm0; \
289 row1 *= zmm0; \
290 row2 *= zmm0; \
291 row3 *= zmm0; \
292 row0b *= zmm0; \
293 row1b *= zmm0; \
294 row2b *= zmm0; \
295 row3b *= zmm0; \
296 row0 += _mm512_loadu_ps(CO1 + 0*ldc); \
297 row1 += _mm512_loadu_ps(CO1 + 1*ldc); \
298 row2 += _mm512_loadu_ps(CO1 + 2*ldc); \
299 row3 += _mm512_loadu_ps(CO1 + 3*ldc); \
300 _mm512_storeu_ps(CO1 + 0*ldc, row0); \
301 _mm512_storeu_ps(CO1 + 1*ldc, row1); \
302 _mm512_storeu_ps(CO1 + 2*ldc, row2); \
303 _mm512_storeu_ps(CO1 + 3*ldc, row3); \
304 row0b += _mm512_loadu_ps(CO1 + 0*ldc + 16); \
305 row1b += _mm512_loadu_ps(CO1 + 1*ldc + 16); \
306 row2b += _mm512_loadu_ps(CO1 + 2*ldc + 16); \
307 row3b += _mm512_loadu_ps(CO1 + 3*ldc + 16); \
308 _mm512_storeu_ps(CO1 + 0*ldc + 16, row0b); \
309 _mm512_storeu_ps(CO1 + 1*ldc + 16, row1b); \
310 _mm512_storeu_ps(CO1 + 2*ldc + 16, row2b); \
311 _mm512_storeu_ps(CO1 + 3*ldc + 16, row3b);
312
313
314
315 #define INIT16x4() \
316 row0 = _mm512_setzero_ps(); \
317 row1 = _mm512_setzero_ps(); \
318 row2 = _mm512_setzero_ps(); \
319 row3 = _mm512_setzero_ps(); \
320
321 #define KERNEL16x4_SUB() \
322 zmm0 = _mm512_loadu_ps(AO); \
323 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
324 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+1)); \
325 row0 += zmm0 * zmm2; \
326 row1 += zmm0 * zmm3; \
327 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO+2)); \
328 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO+3)); \
329 row2 += zmm0 * zmm2; \
330 row3 += zmm0 * zmm3; \
331 BO += 4; \
332 AO += 16;
333
334
335 #define SAVE16x4(ALPHA) \
336 zmm0 = _mm512_set1_ps(ALPHA); \
337 row0 *= zmm0; \
338 row1 *= zmm0; \
339 row2 *= zmm0; \
340 row3 *= zmm0; \
341 row0 += _mm512_loadu_ps(CO1 + 0 * ldc); \
342 row1 += _mm512_loadu_ps(CO1 + 1 * ldc); \
343 row2 += _mm512_loadu_ps(CO1 + 2 * ldc); \
344 row3 += _mm512_loadu_ps(CO1 + 3 * ldc); \
345 _mm512_storeu_ps(CO1 + 0 * ldc, row0); \
346 _mm512_storeu_ps(CO1 + 1 * ldc, row1); \
347 _mm512_storeu_ps(CO1 + 2 * ldc, row2); \
348 _mm512_storeu_ps(CO1 + 3 * ldc, row3);
349
350
351
352 /*******************************************************************************************/
353
354 #define INIT8x4() \
355 ymm4 = _mm256_setzero_ps(); \
356 ymm6 = _mm256_setzero_ps(); \
357 ymm8 = _mm256_setzero_ps(); \
358 ymm10 = _mm256_setzero_ps(); \
359
360 #define KERNEL8x4_SUB() \
361 ymm0 = _mm256_loadu_ps(AO); \
362 ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO + 0)); \
363 ymm3 = _mm256_broadcastss_ps(_mm_load_ss(BO + 1)); \
364 ymm4 += ymm0 * ymm2; \
365 ymm6 += ymm0 * ymm3; \
366 ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO + 2)); \
367 ymm3 = _mm256_broadcastss_ps(_mm_load_ss(BO + 3)); \
368 ymm8 += ymm0 * ymm2; \
369 ymm10 += ymm0 * ymm3; \
370 BO += 4; \
371 AO += 8;
372
373
374 #define SAVE8x4(ALPHA) \
375 ymm0 = _mm256_set1_ps(ALPHA); \
376 ymm4 *= ymm0; \
377 ymm6 *= ymm0; \
378 ymm8 *= ymm0; \
379 ymm10 *= ymm0; \
380 ymm4 += _mm256_loadu_ps(CO1 + 0 * ldc); \
381 ymm6 += _mm256_loadu_ps(CO1 + 1 * ldc); \
382 ymm8 += _mm256_loadu_ps(CO1 + 2 * ldc); \
383 ymm10 += _mm256_loadu_ps(CO1 + 3 * ldc); \
384 _mm256_storeu_ps(CO1 + 0 * ldc, ymm4); \
385 _mm256_storeu_ps(CO1 + 1 * ldc, ymm6); \
386 _mm256_storeu_ps(CO1 + 2 * ldc, ymm8); \
387 _mm256_storeu_ps(CO1 + 3 * ldc, ymm10); \
388
389
390
391 /*******************************************************************************************/
392
393 #define INIT4x4() \
394 row0 = _mm_setzero_ps(); \
395 row1 = _mm_setzero_ps(); \
396 row2 = _mm_setzero_ps(); \
397 row3 = _mm_setzero_ps(); \
398
399
400 #define KERNEL4x4_SUB() \
401 xmm0 = _mm_loadu_ps(AO); \
402 xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO + 0)); \
403 xmm3 = _mm_broadcastss_ps(_mm_load_ss(BO + 1)); \
404 row0 += xmm0 * xmm2; \
405 row1 += xmm0 * xmm3; \
406 xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO + 2)); \
407 xmm3 = _mm_broadcastss_ps(_mm_load_ss(BO + 3)); \
408 row2 += xmm0 * xmm2; \
409 row3 += xmm0 * xmm3; \
410 BO += 4; \
411 AO += 4;
412
413
414 #define SAVE4x4(ALPHA) \
415 xmm0 = _mm_set1_ps(ALPHA); \
416 row0 *= xmm0; \
417 row1 *= xmm0; \
418 row2 *= xmm0; \
419 row3 *= xmm0; \
420 row0 += _mm_loadu_ps(CO1 + 0 * ldc); \
421 row1 += _mm_loadu_ps(CO1 + 1 * ldc); \
422 row2 += _mm_loadu_ps(CO1 + 2 * ldc); \
423 row3 += _mm_loadu_ps(CO1 + 3 * ldc); \
424 _mm_storeu_ps(CO1 + 0 * ldc, row0); \
425 _mm_storeu_ps(CO1 + 1 * ldc, row1); \
426 _mm_storeu_ps(CO1 + 2 * ldc, row2); \
427 _mm_storeu_ps(CO1 + 3 * ldc, row3); \
428
429
430 /*******************************************************************************************/
431
432 #define INIT2x4() \
433 row0 = 0; row0b = 0; row1 = 0; row1b = 0; \
434 row2 = 0; row2b = 0; row3 = 0; row3b = 0;
435
436 #define KERNEL2x4_SUB() \
437 xmm0 = *(AO); \
438 xmm1 = *(AO + 1); \
439 xmm2 = *(BO + 0); \
440 xmm3 = *(BO + 1); \
441 row0 += xmm0 * xmm2; \
442 row0b += xmm1 * xmm2; \
443 row1 += xmm0 * xmm3; \
444 row1b += xmm1 * xmm3; \
445 xmm2 = *(BO + 2); \
446 xmm3 = *(BO + 3); \
447 row2 += xmm0 * xmm2; \
448 row2b += xmm1 * xmm2; \
449 row3 += xmm0 * xmm3; \
450 row3b += xmm1 * xmm3; \
451 BO += 4; \
452 AO += 2;
453
454
455 #define SAVE2x4(ALPHA) \
456 xmm0 = ALPHA; \
457 row0 *= xmm0; \
458 row0b *= xmm0; \
459 row1 *= xmm0; \
460 row1b *= xmm0; \
461 row2 *= xmm0; \
462 row2b *= xmm0; \
463 row3 *= xmm0; \
464 row3b *= xmm0; \
465 *(CO1 + 0 * ldc + 0) += row0; \
466 *(CO1 + 0 * ldc + 1) += row0b; \
467 *(CO1 + 1 * ldc + 0) += row1; \
468 *(CO1 + 1 * ldc + 1) += row1b; \
469 *(CO1 + 2 * ldc + 0) += row2; \
470 *(CO1 + 2 * ldc + 1) += row2b; \
471 *(CO1 + 3 * ldc + 0) += row3; \
472 *(CO1 + 3 * ldc + 1) += row3b; \
473
474
475
476 /*******************************************************************************************/
477
478 #define INIT1x4() \
479 row0 = 0; row1 = 0; row2 = 0; row3 = 0;
480 #define KERNEL1x4_SUB() \
481 xmm0 = *(AO ); \
482 xmm2 = *(BO + 0); \
483 xmm3 = *(BO + 1); \
484 row0 += xmm0 * xmm2; \
485 row1 += xmm0 * xmm3; \
486 xmm2 = *(BO + 2); \
487 xmm3 = *(BO + 3); \
488 row2 += xmm0 * xmm2; \
489 row3 += xmm0 * xmm3; \
490 BO += 4; \
491 AO += 1;
492
493
494 #define SAVE1x4(ALPHA) \
495 xmm0 = ALPHA; \
496 row0 *= xmm0; \
497 row1 *= xmm0; \
498 row2 *= xmm0; \
499 row3 *= xmm0; \
500 *(CO1 + 0 * ldc) += row0; \
501 *(CO1 + 1 * ldc) += row1; \
502 *(CO1 + 2 * ldc) += row2; \
503 *(CO1 + 3 * ldc) += row3; \
504
505
506
507 /*******************************************************************************************/
508
509 /*******************************************************************************************
510 * 2 lines of N
511 *******************************************************************************************/
512
513 #define INIT16x2() \
514 row0 = _mm512_setzero_ps(); \
515 row1 = _mm512_setzero_ps(); \
516
517
518 #define KERNEL16x2_SUB() \
519 zmm0 = _mm512_loadu_ps(AO); \
520 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
521 zmm3 = _mm512_broadcastss_ps(_mm_load_ss(BO + 1)); \
522 row0 += zmm0 * zmm2; \
523 row1 += zmm0 * zmm3; \
524 BO += 2; \
525 AO += 16;
526
527
528 #define SAVE16x2(ALPHA) \
529 zmm0 = _mm512_set1_ps(ALPHA); \
530 row0 *= zmm0; \
531 row1 *= zmm0; \
532 row0 += _mm512_loadu_ps(CO1); \
533 row1 += _mm512_loadu_ps(CO1 + ldc); \
534 _mm512_storeu_ps(CO1 , row0); \
535 _mm512_storeu_ps(CO1 + ldc, row1); \
536
537
538
539
540 /*******************************************************************************************/
541
542 #define INIT8x2() \
543 ymm4 = _mm256_setzero_ps(); \
544 ymm6 = _mm256_setzero_ps(); \
545
546 #define KERNEL8x2_SUB() \
547 ymm0 = _mm256_loadu_ps(AO); \
548 ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO)); \
549 ymm3 = _mm256_broadcastss_ps(_mm_load_ss(BO + 1)); \
550 ymm4 += ymm0 * ymm2; \
551 ymm6 += ymm0 * ymm3; \
552 BO += 2; \
553 AO += 8;
554
555
556 #define SAVE8x2(ALPHA) \
557 ymm0 = _mm256_set1_ps(ALPHA); \
558 ymm4 *= ymm0; \
559 ymm6 *= ymm0; \
560 ymm4 += _mm256_loadu_ps(CO1); \
561 ymm6 += _mm256_loadu_ps(CO1 + ldc); \
562 _mm256_storeu_ps(CO1 , ymm4); \
563 _mm256_storeu_ps(CO1 + ldc, ymm6); \
564
565
566
567 /*******************************************************************************************/
568
569 #define INIT4x2() \
570 row0 = _mm_setzero_ps(); \
571 row1 = _mm_setzero_ps(); \
572
573 #define KERNEL4x2_SUB() \
574 xmm0 = _mm_loadu_ps(AO); \
575 xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO)); \
576 xmm3 = _mm_broadcastss_ps(_mm_load_ss(BO + 1)); \
577 row0 += xmm0 * xmm2; \
578 row1 += xmm0 * xmm3; \
579 BO += 2; \
580 AO += 4;
581
582
583 #define SAVE4x2(ALPHA) \
584 xmm0 = _mm_set1_ps(ALPHA); \
585 row0 *= xmm0; \
586 row1 *= xmm0; \
587 row0 += _mm_loadu_ps(CO1); \
588 row1 += _mm_loadu_ps(CO1 + ldc); \
589 _mm_storeu_ps(CO1 , row0); \
590 _mm_storeu_ps(CO1 + ldc, row1); \
591
592
593
594 /*******************************************************************************************/
595
596
597 #define INIT2x2() \
598 row0 = 0; row0b = 0; row1 = 0; row1b = 0; \
599
600 #define KERNEL2x2_SUB() \
601 xmm0 = *(AO + 0); \
602 xmm1 = *(AO + 1); \
603 xmm2 = *(BO + 0); \
604 xmm3 = *(BO + 1); \
605 row0 += xmm0 * xmm2; \
606 row0b += xmm1 * xmm2; \
607 row1 += xmm0 * xmm3; \
608 row1b += xmm1 * xmm3; \
609 BO += 2; \
610 AO += 2; \
611
612
613 #define SAVE2x2(ALPHA) \
614 xmm0 = ALPHA; \
615 row0 *= xmm0; \
616 row0b *= xmm0; \
617 row1 *= xmm0; \
618 row1b *= xmm0; \
619 *(CO1 ) += row0; \
620 *(CO1 +1 ) += row0b; \
621 *(CO1 + ldc ) += row1; \
622 *(CO1 + ldc +1) += row1b; \
623
624
625 /*******************************************************************************************/
626
627 #define INIT1x2() \
628 row0 = 0; row1 = 0;
629
630 #define KERNEL1x2_SUB() \
631 xmm0 = *(AO); \
632 xmm2 = *(BO + 0); \
633 xmm3 = *(BO + 1); \
634 row0 += xmm0 * xmm2; \
635 row1 += xmm0 * xmm3; \
636 BO += 2; \
637 AO += 1;
638
639
640 #define SAVE1x2(ALPHA) \
641 xmm0 = ALPHA; \
642 row0 *= xmm0; \
643 row1 *= xmm0; \
644 *(CO1 ) += row0; \
645 *(CO1 + ldc ) += row1; \
646
647
648 /*******************************************************************************************/
649
650 /*******************************************************************************************
651 * 1 line of N
652 *******************************************************************************************/
653
654 #define INIT16x1() \
655 row0 = _mm512_setzero_ps(); \
656
657 #define KERNEL16x1_SUB() \
658 zmm0 = _mm512_loadu_ps(AO); \
659 zmm2 = _mm512_broadcastss_ps(_mm_load_ss(BO)); \
660 row0 += zmm0 * zmm2; \
661 BO += 1; \
662 AO += 16;
663
664
665 #define SAVE16x1(ALPHA) \
666 zmm0 = _mm512_set1_ps(ALPHA); \
667 row0 *= zmm0; \
668 row0 += _mm512_loadu_ps(CO1); \
669 _mm512_storeu_ps(CO1 , row0); \
670
671
672 /*******************************************************************************************/
673
674 #define INIT8x1() \
675 ymm4 = _mm256_setzero_ps();
676
677 #define KERNEL8x1_SUB() \
678 ymm0 = _mm256_loadu_ps(AO); \
679 ymm2 = _mm256_broadcastss_ps(_mm_load_ss(BO)); \
680 ymm4 += ymm0 * ymm2; \
681 BO += 1; \
682 AO += 8;
683
684
685 #define SAVE8x1(ALPHA) \
686 ymm0 = _mm256_set1_ps(ALPHA); \
687 ymm4 *= ymm0; \
688 ymm4 += _mm256_loadu_ps(CO1); \
689 _mm256_storeu_ps(CO1 , ymm4); \
690
691
692 /*******************************************************************************************/
693
694 #define INIT4x1() \
695 row0 = _mm_setzero_ps(); \
696
697 #define KERNEL4x1_SUB() \
698 xmm0 = _mm_loadu_ps(AO); \
699 xmm2 = _mm_broadcastss_ps(_mm_load_ss(BO)); \
700 row0 += xmm0 * xmm2; \
701 BO += 1; \
702 AO += 4;
703
704
705 #define SAVE4x1(ALPHA) \
706 xmm0 = _mm_set1_ps(ALPHA); \
707 row0 *= xmm0; \
708 row0 += _mm_loadu_ps(CO1); \
709 _mm_storeu_ps(CO1 , row0); \
710
711
712
713 /*******************************************************************************************/
714
715 #define INIT2x1() \
716 row0 = 0; row0b = 0;
717
718 #define KERNEL2x1_SUB() \
719 xmm0 = *(AO + 0); \
720 xmm1 = *(AO + 1); \
721 xmm2 = *(BO); \
722 row0 += xmm0 * xmm2; \
723 row0b += xmm1 * xmm2; \
724 BO += 1; \
725 AO += 2;
726
727
728 #define SAVE2x1(ALPHA) \
729 xmm0 = ALPHA; \
730 row0 *= xmm0; \
731 row0b *= xmm0; \
732 *(CO1 ) += row0; \
733 *(CO1 +1 ) += row0b; \
734
735
736 /*******************************************************************************************/
737
738 #define INIT1x1() \
739 row0 = 0;
740
741 #define KERNEL1x1_SUB() \
742 xmm0 = *(AO); \
743 xmm2 = *(BO); \
744 row0 += xmm0 * xmm2; \
745 BO += 1; \
746 AO += 1;
747
748
749 #define SAVE1x1(ALPHA) \
750 xmm0 = ALPHA; \
751 row0 *= xmm0; \
752 *(CO1 ) += row0; \
753
754
755 /*******************************************************************************************/
756
757
758 /*************************************************************************************
759 * GEMM Kernel
760 *************************************************************************************/
761
762 int __attribute__ ((noinline))
CNAME(BLASLONG m,BLASLONG n,BLASLONG k,float alpha,float * __restrict A,float * __restrict B,float * __restrict C,BLASLONG ldc)763 CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict A, float * __restrict B, float * __restrict C, BLASLONG ldc)
764 {
765 unsigned long long M = m, N = n, K = k;
766 if (M == 0)
767 return 0;
768 if (N == 0)
769 return 0;
770 if (K == 0)
771 return 0;
772
773
774 while (N >= 4) {
775 float *CO1;
776 float *AO;
777 int i;
778 // L8_10
779 CO1 = C;
780 C += 4 * ldc;
781
782 AO = A;
783
784 i = m;
785 while (i >= 64) {
786 float *BO;
787 float *A1, *A2, *A3;
788 // L8_11
789 __m512 zmm0, zmm1, zmm2, zmm3, row0, zmm5, row1, zmm7, row2, row3, row0b, row1b, row2b, row3b, row0c, row1c, row2c, row3c, row0d, row1d, row2d, row3d;
790 BO = B;
791 int kloop = K;
792
793 A1 = AO + 16 * K;
794 A2 = A1 + 16 * K;
795 A3 = A2 + 16 * K;
796
797 INIT64x4()
798
799 while (kloop > 0) {
800 // L12_17
801 KERNEL64x4_SUB()
802 kloop--;
803 }
804 // L8_19
805 SAVE64x4(alpha)
806 CO1 += 64;
807 AO += 48 * K;
808
809 i -= 64;
810 }
811 while (i >= 32) {
812 float *BO;
813 float *A1;
814 // L8_11
815 __m512 zmm0, zmm1, zmm2, zmm3, row0, row1, row2, row3, row0b, row1b, row2b, row3b;
816 BO = B;
817 int kloop = K;
818
819 A1 = AO + 16 * K;
820
821 INIT32x4()
822
823 while (kloop > 0) {
824 // L12_17
825 KERNEL32x4_SUB()
826 kloop--;
827 }
828 // L8_19
829 SAVE32x4(alpha)
830 CO1 += 32;
831 AO += 16 * K;
832
833 i -= 32;
834 }
835 while (i >= 16) {
836 float *BO;
837 // L8_11
838 __m512 zmm0, zmm2, zmm3, row0, row1, row2, row3;
839 BO = B;
840 int kloop = K;
841
842 INIT16x4()
843
844 while (kloop > 0) {
845 // L12_17
846 KERNEL16x4_SUB()
847 kloop--;
848 }
849 // L8_19
850 SAVE16x4(alpha)
851 CO1 += 16;
852
853 i -= 16;
854 }
855 while (i >= 8) {
856 float *BO;
857 // L8_11
858 __m256 ymm0, ymm2, ymm3, ymm4, ymm6,ymm8,ymm10;
859 BO = B;
860 int kloop = K;
861
862 INIT8x4()
863
864 while (kloop > 0) {
865 // L12_17
866 KERNEL8x4_SUB()
867 kloop--;
868 }
869 // L8_19
870 SAVE8x4(alpha)
871 CO1 += 8;
872
873 i -= 8;
874 }
875 while (i >= 4) {
876 // L8_11
877 float *BO;
878 __m128 xmm0, xmm2, xmm3, row0, row1, row2, row3;
879 BO = B;
880 int kloop = K;
881
882 INIT4x4()
883 // L8_16
884 while (kloop > 0) {
885 // L12_17
886 KERNEL4x4_SUB()
887 kloop--;
888 }
889 // L8_19
890 SAVE4x4(alpha)
891 CO1 += 4;
892
893 i -= 4;
894 }
895
896 /**************************************************************************
897 * Rest of M
898 ***************************************************************************/
899
900 while (i >= 2) {
901 float *BO;
902 float xmm0, xmm1, xmm2, xmm3, row0, row0b, row1, row1b, row2, row2b, row3, row3b;
903 BO = B;
904
905 INIT2x4()
906 int kloop = K;
907
908 while (kloop > 0) {
909 KERNEL2x4_SUB()
910 kloop--;
911 }
912 SAVE2x4(alpha)
913 CO1 += 2;
914 i -= 2;
915 }
916 // L13_40
917 while (i >= 1) {
918 float *BO;
919 float xmm0, xmm2, xmm3, row0, row1, row2, row3;
920 int kloop = K;
921 BO = B;
922 INIT1x4()
923
924 while (kloop > 0) {
925 KERNEL1x4_SUB()
926 kloop--;
927 }
928 SAVE1x4(alpha)
929 CO1 += 1;
930 i -= 1;
931 }
932
933 B += K * 4;
934 N -= 4;
935 }
936
937 /**************************************************************************************************/
938
939 // L8_0
940 while (N >= 2) {
941 float *CO1;
942 float *AO;
943 int i;
944 // L8_10
945 CO1 = C;
946 C += 2 * ldc;
947
948 AO = A;
949
950 i = m;
951 while (i >= 16) {
952 float *BO;
953
954 // L8_11
955 __m512 zmm0, zmm2, zmm3, row0, row1;
956 BO = B;
957 int kloop = K;
958
959 INIT16x2()
960
961 while (kloop > 0) {
962 // L12_17
963 KERNEL16x2_SUB()
964 kloop--;
965 }
966 // L8_19
967 SAVE16x2(alpha)
968 CO1 += 16;
969
970 i -= 16;
971 }
972 while (i >= 8) {
973 float *BO;
974 __m256 ymm0, ymm2, ymm3, ymm4, ymm6;
975 // L8_11
976 BO = B;
977 int kloop = K;
978
979 INIT8x2()
980
981 // L8_16
982 while (kloop > 0) {
983 // L12_17
984 KERNEL8x2_SUB()
985 kloop--;
986 }
987 // L8_19
988 SAVE8x2(alpha)
989 CO1 += 8;
990
991 i-=8;
992 }
993
994 while (i >= 4) {
995 float *BO;
996 __m128 xmm0, xmm2, xmm3, row0, row1;
997 // L8_11
998 BO = B;
999 int kloop = K;
1000
1001 INIT4x2()
1002
1003 // L8_16
1004 while (kloop > 0) {
1005 // L12_17
1006 KERNEL4x2_SUB()
1007 kloop--;
1008 }
1009 // L8_19
1010 SAVE4x2(alpha)
1011 CO1 += 4;
1012
1013 i-=4;
1014 }
1015
1016 /**************************************************************************
1017 * Rest of M
1018 ***************************************************************************/
1019
1020 while (i >= 2) {
1021 float *BO;
1022 float xmm0, xmm1, xmm2, xmm3, row0, row0b, row1, row1b;
1023 int kloop = K;
1024 BO = B;
1025
1026 INIT2x2()
1027
1028 while (kloop > 0) {
1029 KERNEL2x2_SUB()
1030 kloop--;
1031 }
1032 SAVE2x2(alpha)
1033 CO1 += 2;
1034 i -= 2;
1035 }
1036 // L13_40
1037 while (i >= 1) {
1038 float *BO;
1039 float xmm0, xmm2, xmm3, row0, row1;
1040 int kloop = K;
1041 BO = B;
1042
1043 INIT1x2()
1044
1045 while (kloop > 0) {
1046 KERNEL1x2_SUB()
1047 kloop--;
1048 }
1049 SAVE1x2(alpha)
1050 CO1 += 1;
1051 i -= 1;
1052 }
1053
1054 B += K * 2;
1055 N -= 2;
1056 }
1057
1058 // L8_0
1059 while (N >= 1) {
1060 // L8_10
1061 float *CO1;
1062 float *AO;
1063 int i;
1064
1065 CO1 = C;
1066 C += ldc;
1067
1068 AO = A;
1069
1070 i = m;
1071 while (i >= 16) {
1072 float *BO;
1073 __m512 zmm0, zmm2, row0;
1074 // L8_11
1075 BO = B;
1076 int kloop = K;
1077
1078 INIT16x1()
1079 // L8_16
1080 while (kloop > 0) {
1081 // L12_17
1082 KERNEL16x1_SUB()
1083 kloop--;
1084 }
1085 // L8_19
1086 SAVE16x1(alpha)
1087 CO1 += 16;
1088
1089 i-= 16;
1090 }
1091 while (i >= 8) {
1092 float *BO;
1093 __m256 ymm0, ymm2, ymm4;
1094 // L8_11
1095 BO = B;
1096 int kloop = K;
1097
1098 INIT8x1()
1099 // L8_16
1100 while (kloop > 0) {
1101 // L12_17
1102 KERNEL8x1_SUB()
1103 kloop--;
1104 }
1105 // L8_19
1106 SAVE8x1(alpha)
1107 CO1 += 8;
1108
1109 i-= 8;
1110 }
1111 while (i >= 4) {
1112 float *BO;
1113 __m128 xmm0, xmm2, row0;
1114 // L8_11
1115 BO = B;
1116 int kloop = K;
1117
1118 INIT4x1()
1119 // L8_16
1120 while (kloop > 0) {
1121 // L12_17
1122 KERNEL4x1_SUB()
1123 kloop--;
1124 }
1125 // L8_19
1126 SAVE4x1(alpha)
1127 CO1 += 4;
1128
1129 i-= 4;
1130 }
1131
1132 /**************************************************************************
1133 * Rest of M
1134 ***************************************************************************/
1135
1136 while (i >= 2) {
1137 float *BO;
1138 float xmm0, xmm1, xmm2, row0, row0b;
1139 int kloop = K;
1140 BO = B;
1141
1142 INIT2x1()
1143
1144 while (kloop > 0) {
1145 KERNEL2x1_SUB()
1146 kloop--;
1147 }
1148 SAVE2x1(alpha)
1149 CO1 += 2;
1150 i -= 2;
1151 }
1152 // L13_40
1153 while (i >= 1) {
1154 float *BO;
1155 float xmm0, xmm2, row0;
1156 int kloop = K;
1157
1158 BO = B;
1159 INIT1x1()
1160
1161
1162 while (kloop > 0) {
1163 KERNEL1x1_SUB()
1164 kloop--;
1165 }
1166 SAVE1x1(alpha)
1167 CO1 += 1;
1168 i -= 1;
1169 }
1170
1171 B += K * 1;
1172 N -= 1;
1173 }
1174
1175
1176 return 0;
1177 }
1178
1179 #include "sgemm_direct_skylakex.c"
1180