1 /*
2
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
6
7 Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc.
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33 */
34
35 #include "immintrin.h"
36 #include "xmmintrin.h"
37 #include "blis.h"
38
39 #ifdef BLIS_ENABLE_SMALL_MATRIX
40
41 #define MR 32
42 #define D_MR (MR >> 1)
43 #define NR 3
44
45 #define BLIS_ENABLE_PREFETCH
46 #define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)
47 static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
48 static float C_pack[F_SCRATCH_DIM] __attribute__((aligned(64)));
49 #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 )
50 #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2)
51 #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2)
52 #define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)
53 static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
54 static double D_C_pack[D_SCRATCH_DIM] __attribute__((aligned(64)));
55 #define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called.
56 #define AT_MR 4 // The kernel dimension of the A transpose SYRK kernel.(AT_MR * NR).
57 static err_t bli_ssyrk_small
58 (
59 obj_t* alpha,
60 obj_t* a,
61 obj_t* b,
62 obj_t* beta,
63 obj_t* c,
64 cntx_t* cntx,
65 cntl_t* cntl
66 );
67
68 static err_t bli_dsyrk_small
69 (
70 obj_t* alpha,
71 obj_t* a,
72 obj_t* b,
73 obj_t* beta,
74 obj_t* c,
75 cntx_t* cntx,
76 cntl_t* cntl
77 );
78
79 static err_t bli_ssyrk_small_atbn
80 (
81 obj_t* alpha,
82 obj_t* a,
83 obj_t* b,
84 obj_t* beta,
85 obj_t* c,
86 cntx_t* cntx,
87 cntl_t* cntl
88 );
89
90 static err_t bli_dsyrk_small_atbn
91 (
92 obj_t* alpha,
93 obj_t* a,
94 obj_t* b,
95 obj_t* beta,
96 obj_t* c,
97 cntx_t* cntx,
98 cntl_t* cntl
99 );
100 /*
101 * The bli_syrk_small function will use the
102 * custom MRxNR kernels, to perform the computation.
103 * The custom kernels are used if the [M * N] < 240 * 240
104 */
bli_syrk_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)105 err_t bli_syrk_small
106 (
107 obj_t* alpha,
108 obj_t* a,
109 obj_t* b,
110 obj_t* beta,
111 obj_t* c,
112 cntx_t* cntx,
113 cntl_t* cntl
114 )
115 {
116 // FGVZ: This code was originally in bli_syrk_front(). However, it really
117 // fits more naturally here within the bli_syrk_small() function. This
118 // becomes a bit more obvious now that the code is here, as it contains
119 // cpp macros such as BLIS_SMALL_MATRIX_A_THRES_M_SYRK, which are specific
120 // to this implementation.
121 if ( bli_obj_has_trans( a ) )
122 {
123 // Continue with small implementation.
124 ;
125 }
126 else if ( ( bli_obj_length( a ) <= BLIS_SMALL_MATRIX_A_THRES_M_SYRK &&
127 bli_obj_width( a ) < BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) ||
128 ( bli_obj_length( a ) < BLIS_SMALL_MATRIX_A_THRES_M_SYRK &&
129 bli_obj_width( a ) <= BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) )
130 {
131 // Continue with small implementation.
132 ;
133 }
134 else
135 {
136 // Reject the problem and return to large code path.
137 return BLIS_FAILURE;
138 }
139
140 #ifdef BLIS_ENABLE_MULTITHREADING
141 return BLIS_NOT_YET_IMPLEMENTED;
142 #endif
143 // If alpha is zero, scale by beta and return.
144 if (bli_obj_equals(alpha, &BLIS_ZERO))
145 {
146 return BLIS_NOT_YET_IMPLEMENTED;
147 }
148
149 // if row major format return.
150 if ((bli_obj_row_stride( a ) != 1) ||
151 (bli_obj_row_stride( b ) != 1) ||
152 (bli_obj_row_stride( c ) != 1))
153 {
154 return BLIS_INVALID_ROW_STRIDE;
155 }
156
157 num_t dt = ((*c).info & (0x7 << 0));
158
159 if (bli_obj_has_trans( a ))
160 {
161 if (bli_obj_has_notrans( b ))
162 {
163 if (dt == BLIS_FLOAT)
164 {
165 return bli_ssyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl);
166 }
167 else if (dt == BLIS_DOUBLE)
168 {
169 return bli_dsyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl);
170 }
171 }
172
173 return BLIS_NOT_YET_IMPLEMENTED;
174 }
175
176 if (dt == BLIS_DOUBLE)
177 {
178 return bli_dsyrk_small(alpha, a, b, beta, c, cntx, cntl);
179 }
180
181 if (dt == BLIS_FLOAT)
182 {
183 return bli_ssyrk_small(alpha, a, b, beta, c, cntx, cntl);
184 }
185
186 return BLIS_NOT_YET_IMPLEMENTED;
187 };
188
189
bli_ssyrk_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)190 static err_t bli_ssyrk_small
191 (
192 obj_t* alpha,
193 obj_t* a,
194 obj_t* b,
195 obj_t* beta,
196 obj_t* c,
197 cntx_t* cntx,
198 cntl_t* cntl
199 )
200 {
201
202 int M = bli_obj_length( c ); // number of rows of Matrix C
203 int N = bli_obj_width( c ); // number of columns of Matrix C
204 int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
205 int L = M * N;
206
207 if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
208 || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
209 {
210
211 int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
212 int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
213 int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
214 int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
215 int row_idx, col_idx, k;
216 int rs_matC = bli_obj_row_stride( c );
217 int rsc = 1;
218 float *A = a->buffer; // pointer to elements of Matrix A
219 float *B = b->buffer; // pointer to elements of Matrix B
220 float *C = C_pack; // pointer to elements of Matrix C
221 float *matCbuf = c->buffer;
222
223 float *tA = A, *tB = B, *tC = C;//, *tA_pack;
224 float *tA_packed; // temprorary pointer to hold packed A memory pointer
225 int row_idx_packed; //packed A memory row index
226 int lda_packed; //lda of packed A
227 int col_idx_start; //starting index after A matrix is packed.
228 dim_t tb_inc_row = 1; // row stride of matrix B
229 dim_t tb_inc_col = ldb; // column stride of matrix B
230 __m256 ymm4, ymm5, ymm6, ymm7;
231 __m256 ymm8, ymm9, ymm10, ymm11;
232 __m256 ymm12, ymm13, ymm14, ymm15;
233 __m256 ymm0, ymm1, ymm2, ymm3;
234
235 int n_remainder; // If the N is non multiple of 3.(N%3)
236 int m_remainder; // If the M is non multiple of 32.(M%32)
237
238 float *alpha_cast, *beta_cast; // alpha, beta multiples
239 alpha_cast = (alpha->buffer);
240 beta_cast = (beta->buffer);
241 int required_packing_A = 1;
242
243 // when N is equal to 1 call GEMV instead of SYRK
244 if (N == 1)
245 {
246 bli_gemv
247 (
248 alpha,
249 a,
250 b,
251 beta,
252 c
253 );
254 return BLIS_SUCCESS;
255 }
256
257 //update the pointer math if matrix B needs to be transposed.
258 if (bli_obj_has_trans( b ))
259 {
260 tb_inc_col = 1; //switch row and column strides
261 tb_inc_row = ldb;
262 }
263
264 if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM))
265 {
266 required_packing_A = 0;
267 }
268 /*
269 * The computation loop runs for MRxN columns of C matrix, thus
270 * accessing the MRxK A matrix data and KxNR B matrix data.
271 * The computation is organized as inner loops of dimension MRxNR.
272 */
273 // Process MR rows of C matrix at a time.
274 for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
275 {
276
277 col_idx_start = 0;
278 tA_packed = A;
279 row_idx_packed = row_idx;
280 lda_packed = lda;
281
282 // This is the part of the pack and compute optimization.
283 // During the first column iteration, we store the accessed A matrix into
284 // contiguous static memory. This helps to keep te A matrix in Cache and
285 // aviods the TLB misses.
286 if (required_packing_A)
287 {
288 col_idx = 0;
289
290 //pointer math to point to proper memory
291 tC = C + ldc * col_idx + row_idx;
292 tB = B + tb_inc_col * col_idx;
293 tA = A + row_idx;
294 tA_packed = A_pack;
295
296 #if 0//def BLIS_ENABLE_PREFETCH
297 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
298 _mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
299 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
300 _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
301 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
302 _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
303 #endif
304 // clear scratch registers.
305 ymm4 = _mm256_setzero_ps();
306 ymm5 = _mm256_setzero_ps();
307 ymm6 = _mm256_setzero_ps();
308 ymm7 = _mm256_setzero_ps();
309 ymm8 = _mm256_setzero_ps();
310 ymm9 = _mm256_setzero_ps();
311 ymm10 = _mm256_setzero_ps();
312 ymm11 = _mm256_setzero_ps();
313 ymm12 = _mm256_setzero_ps();
314 ymm13 = _mm256_setzero_ps();
315 ymm14 = _mm256_setzero_ps();
316 ymm15 = _mm256_setzero_ps();
317
318 for (k = 0; k < K; ++k)
319 {
320 // The inner loop broadcasts the B matrix data and
321 // multiplies it with the A matrix.
322 // This loop is processing MR x K
323 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
324 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
325 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
326 tB += tb_inc_row;
327
328 //broadcasted matrix B elements are multiplied
329 //with matrix A columns.
330 ymm3 = _mm256_loadu_ps(tA);
331 _mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A
332 // ymm4 += ymm0 * ymm3;
333 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
334 // ymm8 += ymm1 * ymm3;
335 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
336 // ymm12 += ymm2 * ymm3;
337 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
338
339 ymm3 = _mm256_loadu_ps(tA + 8);
340 _mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A
341 // ymm5 += ymm0 * ymm3;
342 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
343 // ymm9 += ymm1 * ymm3;
344 ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
345 // ymm13 += ymm2 * ymm3;
346 ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
347
348 ymm3 = _mm256_loadu_ps(tA + 16);
349 _mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A
350 // ymm6 += ymm0 * ymm3;
351 ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
352 // ymm10 += ymm1 * ymm3;
353 ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
354 // ymm14 += ymm2 * ymm3;
355 ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
356
357 ymm3 = _mm256_loadu_ps(tA + 24);
358 _mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A
359 // ymm7 += ymm0 * ymm3;
360 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
361 // ymm11 += ymm1 * ymm3;
362 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
363 // ymm15 += ymm2 * ymm3;
364 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
365
366 tA += lda;
367 tA_packed += MR;
368 }
369 // alpha, beta multiplication.
370 ymm0 = _mm256_broadcast_ss(alpha_cast);
371 //ymm1 = _mm256_broadcast_ss(beta_cast);
372
373 //multiply A*B by alpha.
374 ymm4 = _mm256_mul_ps(ymm4, ymm0);
375 ymm5 = _mm256_mul_ps(ymm5, ymm0);
376 ymm6 = _mm256_mul_ps(ymm6, ymm0);
377 ymm7 = _mm256_mul_ps(ymm7, ymm0);
378 ymm8 = _mm256_mul_ps(ymm8, ymm0);
379 ymm9 = _mm256_mul_ps(ymm9, ymm0);
380 ymm10 = _mm256_mul_ps(ymm10, ymm0);
381 ymm11 = _mm256_mul_ps(ymm11, ymm0);
382 ymm12 = _mm256_mul_ps(ymm12, ymm0);
383 ymm13 = _mm256_mul_ps(ymm13, ymm0);
384 ymm14 = _mm256_mul_ps(ymm14, ymm0);
385 ymm15 = _mm256_mul_ps(ymm15, ymm0);
386
387 // multiply C by beta and accumulate col 1.
388 /*ymm2 = _mm256_loadu_ps(tC);
389 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
390 ymm2 = _mm256_loadu_ps(tC + 8);
391 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
392 ymm2 = _mm256_loadu_ps(tC + 16);
393 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
394 ymm2 = _mm256_loadu_ps(tC + 24);
395 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
396 _mm256_storeu_ps(tC, ymm4);
397 _mm256_storeu_ps(tC + 8, ymm5);
398 _mm256_storeu_ps(tC + 16, ymm6);
399 _mm256_storeu_ps(tC + 24, ymm7);
400
401 // multiply C by beta and accumulate, col 2.
402 tC += ldc;
403 /*ymm2 = _mm256_loadu_ps(tC);
404 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
405 ymm2 = _mm256_loadu_ps(tC + 8);
406 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
407 ymm2 = _mm256_loadu_ps(tC + 16);
408 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
409 ymm2 = _mm256_loadu_ps(tC + 24);
410 ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
411 _mm256_storeu_ps(tC, ymm8);
412 _mm256_storeu_ps(tC + 8, ymm9);
413 _mm256_storeu_ps(tC + 16, ymm10);
414 _mm256_storeu_ps(tC + 24, ymm11);
415
416 // multiply C by beta and accumulate, col 3.
417 tC += ldc;
418 /*ymm2 = _mm256_loadu_ps(tC);
419 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
420 ymm2 = _mm256_loadu_ps(tC + 8);
421 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
422 ymm2 = _mm256_loadu_ps(tC + 16);
423 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
424 ymm2 = _mm256_loadu_ps(tC + 24);
425 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
426 _mm256_storeu_ps(tC, ymm12);
427 _mm256_storeu_ps(tC + 8, ymm13);
428 _mm256_storeu_ps(tC + 16, ymm14);
429 _mm256_storeu_ps(tC + 24, ymm15);
430
431 // modify the pointer arithematic to use packed A matrix.
432 col_idx_start = NR;
433 tA_packed = A_pack;
434 row_idx_packed = 0;
435 lda_packed = MR;
436 }
437 // Process NR columns of C matrix at a time.
438 for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
439 {
440 //pointer math to point to proper memory
441 tC = C + ldc * col_idx + row_idx;
442 tB = B + tb_inc_col * col_idx;
443 tA = tA_packed + row_idx_packed;
444
445 #if 0//def BLIS_ENABLE_PREFETCH
446 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
447 _mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
448 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
449 _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
450 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
451 _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
452 #endif
453 // clear scratch registers.
454 ymm4 = _mm256_setzero_ps();
455 ymm5 = _mm256_setzero_ps();
456 ymm6 = _mm256_setzero_ps();
457 ymm7 = _mm256_setzero_ps();
458 ymm8 = _mm256_setzero_ps();
459 ymm9 = _mm256_setzero_ps();
460 ymm10 = _mm256_setzero_ps();
461 ymm11 = _mm256_setzero_ps();
462 ymm12 = _mm256_setzero_ps();
463 ymm13 = _mm256_setzero_ps();
464 ymm14 = _mm256_setzero_ps();
465 ymm15 = _mm256_setzero_ps();
466
467 for (k = 0; k < K; ++k)
468 {
469 // The inner loop broadcasts the B matrix data and
470 // multiplies it with the A matrix.
471 // This loop is processing MR x K
472 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
473 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
474 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
475 tB += tb_inc_row;
476
477 //broadcasted matrix B elements are multiplied
478 //with matrix A columns.
479 ymm3 = _mm256_loadu_ps(tA);
480 // ymm4 += ymm0 * ymm3;
481 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
482 // ymm8 += ymm1 * ymm3;
483 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
484 // ymm12 += ymm2 * ymm3;
485 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
486
487 ymm3 = _mm256_loadu_ps(tA + 8);
488 // ymm5 += ymm0 * ymm3;
489 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
490 // ymm9 += ymm1 * ymm3;
491 ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
492 // ymm13 += ymm2 * ymm3;
493 ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
494
495 ymm3 = _mm256_loadu_ps(tA + 16);
496 // ymm6 += ymm0 * ymm3;
497 ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
498 // ymm10 += ymm1 * ymm3;
499 ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
500 // ymm14 += ymm2 * ymm3;
501 ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
502
503 ymm3 = _mm256_loadu_ps(tA + 24);
504 // ymm7 += ymm0 * ymm3;
505 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
506 // ymm11 += ymm1 * ymm3;
507 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
508 // ymm15 += ymm2 * ymm3;
509 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
510
511 tA += lda_packed;
512 }
513 // alpha, beta multiplication.
514 ymm0 = _mm256_broadcast_ss(alpha_cast);
515 //ymm1 = _mm256_broadcast_ss(beta_cast);
516
517 //multiply A*B by alpha.
518 ymm4 = _mm256_mul_ps(ymm4, ymm0);
519 ymm5 = _mm256_mul_ps(ymm5, ymm0);
520 ymm6 = _mm256_mul_ps(ymm6, ymm0);
521 ymm7 = _mm256_mul_ps(ymm7, ymm0);
522 ymm8 = _mm256_mul_ps(ymm8, ymm0);
523 ymm9 = _mm256_mul_ps(ymm9, ymm0);
524 ymm10 = _mm256_mul_ps(ymm10, ymm0);
525 ymm11 = _mm256_mul_ps(ymm11, ymm0);
526 ymm12 = _mm256_mul_ps(ymm12, ymm0);
527 ymm13 = _mm256_mul_ps(ymm13, ymm0);
528 ymm14 = _mm256_mul_ps(ymm14, ymm0);
529 ymm15 = _mm256_mul_ps(ymm15, ymm0);
530
531 // multiply C by beta and accumulate col 1.
532 /*ymm2 = _mm256_loadu_ps(tC);
533 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
534 ymm2 = _mm256_loadu_ps(tC + 8);
535 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
536 ymm2 = _mm256_loadu_ps(tC + 16);
537 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
538 ymm2 = _mm256_loadu_ps(tC + 24);
539 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
540 _mm256_storeu_ps(tC, ymm4);
541 _mm256_storeu_ps(tC + 8, ymm5);
542 _mm256_storeu_ps(tC + 16, ymm6);
543 _mm256_storeu_ps(tC + 24, ymm7);
544
545 // multiply C by beta and accumulate, col 2.
546 tC += ldc;
547 /*ymm2 = _mm256_loadu_ps(tC);
548 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
549 ymm2 = _mm256_loadu_ps(tC + 8);
550 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
551 ymm2 = _mm256_loadu_ps(tC + 16);
552 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
553 ymm2 = _mm256_loadu_ps(tC + 24);
554 ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
555 _mm256_storeu_ps(tC, ymm8);
556 _mm256_storeu_ps(tC + 8, ymm9);
557 _mm256_storeu_ps(tC + 16, ymm10);
558 _mm256_storeu_ps(tC + 24, ymm11);
559
560 // multiply C by beta and accumulate, col 3.
561 tC += ldc;
562 /*ymm2 = _mm256_loadu_ps(tC);
563 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
564 ymm2 = _mm256_loadu_ps(tC + 8);
565 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
566 ymm2 = _mm256_loadu_ps(tC + 16);
567 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
568 ymm2 = _mm256_loadu_ps(tC + 24);
569 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
570 _mm256_storeu_ps(tC, ymm12);
571 _mm256_storeu_ps(tC + 8, ymm13);
572 _mm256_storeu_ps(tC + 16, ymm14);
573 _mm256_storeu_ps(tC + 24, ymm15);
574
575 }
576 n_remainder = N - col_idx;
577
578 // if the N is not multiple of 3.
579 // handling edge case.
580 if (n_remainder == 2)
581 {
582 //pointer math to point to proper memory
583 tC = C + ldc * col_idx + row_idx;
584 tB = B + tb_inc_col * col_idx;
585 tA = A + row_idx;
586
587 // clear scratch registers.
588 ymm8 = _mm256_setzero_ps();
589 ymm9 = _mm256_setzero_ps();
590 ymm10 = _mm256_setzero_ps();
591 ymm11 = _mm256_setzero_ps();
592 ymm12 = _mm256_setzero_ps();
593 ymm13 = _mm256_setzero_ps();
594 ymm14 = _mm256_setzero_ps();
595 ymm15 = _mm256_setzero_ps();
596
597 for (k = 0; k < K; ++k)
598 {
599 // The inner loop broadcasts the B matrix data and
600 // multiplies it with the A matrix.
601 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
602 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
603 tB += tb_inc_row;
604
605 //broadcasted matrix B elements are multiplied
606 //with matrix A columns.
607 ymm3 = _mm256_loadu_ps(tA);
608 ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
609 ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
610
611 ymm3 = _mm256_loadu_ps(tA + 8);
612 ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
613 ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
614
615 ymm3 = _mm256_loadu_ps(tA + 16);
616 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
617 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
618
619 ymm3 = _mm256_loadu_ps(tA + 24);
620 ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);
621 ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
622
623 tA += lda;
624
625 }
626 // alpha, beta multiplication.
627 ymm0 = _mm256_broadcast_ss(alpha_cast);
628 //ymm1 = _mm256_broadcast_ss(beta_cast);
629
630 //multiply A*B by alpha.
631 ymm8 = _mm256_mul_ps(ymm8, ymm0);
632 ymm9 = _mm256_mul_ps(ymm9, ymm0);
633 ymm10 = _mm256_mul_ps(ymm10, ymm0);
634 ymm11 = _mm256_mul_ps(ymm11, ymm0);
635 ymm12 = _mm256_mul_ps(ymm12, ymm0);
636 ymm13 = _mm256_mul_ps(ymm13, ymm0);
637 ymm14 = _mm256_mul_ps(ymm14, ymm0);
638 ymm15 = _mm256_mul_ps(ymm15, ymm0);
639
640 // multiply C by beta and accumulate, col 1.
641 /*ymm2 = _mm256_loadu_ps(tC + 0);
642 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
643 ymm2 = _mm256_loadu_ps(tC + 8);
644 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
645 ymm2 = _mm256_loadu_ps(tC + 16);
646 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
647 ymm2 = _mm256_loadu_ps(tC + 24);
648 ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/
649 _mm256_storeu_ps(tC + 0, ymm8);
650 _mm256_storeu_ps(tC + 8, ymm9);
651 _mm256_storeu_ps(tC + 16, ymm10);
652 _mm256_storeu_ps(tC + 24, ymm11);
653
654 // multiply C by beta and accumulate, col 2.
655 tC += ldc;
656 /*ymm2 = _mm256_loadu_ps(tC);
657 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
658 ymm2 = _mm256_loadu_ps(tC + 8);
659 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
660 ymm2 = _mm256_loadu_ps(tC + 16);
661 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
662 ymm2 = _mm256_loadu_ps(tC + 24);
663 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
664 _mm256_storeu_ps(tC, ymm12);
665 _mm256_storeu_ps(tC + 8, ymm13);
666 _mm256_storeu_ps(tC + 16, ymm14);
667 _mm256_storeu_ps(tC + 24, ymm15);
668
669 col_idx += 2;
670 }
671 // if the N is not multiple of 3.
672 // handling edge case.
673 if (n_remainder == 1)
674 {
675 //pointer math to point to proper memory
676 tC = C + ldc * col_idx + row_idx;
677 tB = B + tb_inc_col * col_idx;
678 tA = A + row_idx;
679
680 // clear scratch registers.
681 ymm12 = _mm256_setzero_ps();
682 ymm13 = _mm256_setzero_ps();
683 ymm14 = _mm256_setzero_ps();
684 ymm15 = _mm256_setzero_ps();
685
686 for (k = 0; k < K; ++k)
687 {
688 // The inner loop broadcasts the B matrix data and
689 // multiplies it with the A matrix.
690 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
691 tB += tb_inc_row;
692
693 //broadcasted matrix B elements are multiplied
694 //with matrix A columns.
695 ymm3 = _mm256_loadu_ps(tA);
696 ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
697
698 ymm3 = _mm256_loadu_ps(tA + 8);
699 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
700
701 ymm3 = _mm256_loadu_ps(tA + 16);
702 ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
703
704 ymm3 = _mm256_loadu_ps(tA + 24);
705 ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15);
706
707 tA += lda;
708
709 }
710 // alpha, beta multiplication.
711 ymm0 = _mm256_broadcast_ss(alpha_cast);
712 //ymm1 = _mm256_broadcast_ss(beta_cast);
713
714 //multiply A*B by alpha.
715 ymm12 = _mm256_mul_ps(ymm12, ymm0);
716 ymm13 = _mm256_mul_ps(ymm13, ymm0);
717 ymm14 = _mm256_mul_ps(ymm14, ymm0);
718 ymm15 = _mm256_mul_ps(ymm15, ymm0);
719
720 // multiply C by beta and accumulate.
721 /*ymm2 = _mm256_loadu_ps(tC + 0);
722 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
723 ymm2 = _mm256_loadu_ps(tC + 8);
724 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
725 ymm2 = _mm256_loadu_ps(tC + 16);
726 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
727 ymm2 = _mm256_loadu_ps(tC + 24);
728 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/
729
730 _mm256_storeu_ps(tC + 0, ymm12);
731 _mm256_storeu_ps(tC + 8, ymm13);
732 _mm256_storeu_ps(tC + 16, ymm14);
733 _mm256_storeu_ps(tC + 24, ymm15);
734 }
735 }
736
737 m_remainder = M - row_idx;
738
739 if (m_remainder >= 24)
740 {
741 m_remainder -= 24;
742
743 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
744 {
745 //pointer math to point to proper memory
746 tC = C + ldc * col_idx + row_idx;
747 tB = B + tb_inc_col * col_idx;
748 tA = A + row_idx;
749
750 // clear scratch registers.
751 ymm4 = _mm256_setzero_ps();
752 ymm5 = _mm256_setzero_ps();
753 ymm6 = _mm256_setzero_ps();
754 ymm8 = _mm256_setzero_ps();
755 ymm9 = _mm256_setzero_ps();
756 ymm10 = _mm256_setzero_ps();
757 ymm12 = _mm256_setzero_ps();
758 ymm13 = _mm256_setzero_ps();
759 ymm14 = _mm256_setzero_ps();
760
761 for (k = 0; k < K; ++k)
762 {
763 // The inner loop broadcasts the B matrix data and
764 // multiplies it with the A matrix.
765 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
766 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
767 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
768 tB += tb_inc_row;
769
770 //broadcasted matrix B elements are multiplied
771 //with matrix A columns.
772 ymm3 = _mm256_loadu_ps(tA);
773 // ymm4 += ymm0 * ymm3;
774 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
775 // ymm8 += ymm1 * ymm3;
776 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
777 // ymm12 += ymm2 * ymm3;
778 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
779
780 ymm3 = _mm256_loadu_ps(tA + 8);
781 // ymm5 += ymm0 * ymm3;
782 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
783 // ymm9 += ymm1 * ymm3;
784 ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
785 // ymm13 += ymm2 * ymm3;
786 ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
787
788 ymm3 = _mm256_loadu_ps(tA + 16);
789 // ymm6 += ymm0 * ymm3;
790 ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
791 // ymm10 += ymm1 * ymm3;
792 ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
793 // ymm14 += ymm2 * ymm3;
794 ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
795
796 tA += lda;
797 }
798 // alpha, beta multiplication.
799 ymm0 = _mm256_broadcast_ss(alpha_cast);
800 //ymm1 = _mm256_broadcast_ss(beta_cast);
801
802 //multiply A*B by alpha.
803 ymm4 = _mm256_mul_ps(ymm4, ymm0);
804 ymm5 = _mm256_mul_ps(ymm5, ymm0);
805 ymm6 = _mm256_mul_ps(ymm6, ymm0);
806 ymm8 = _mm256_mul_ps(ymm8, ymm0);
807 ymm9 = _mm256_mul_ps(ymm9, ymm0);
808 ymm10 = _mm256_mul_ps(ymm10, ymm0);
809 ymm12 = _mm256_mul_ps(ymm12, ymm0);
810 ymm13 = _mm256_mul_ps(ymm13, ymm0);
811 ymm14 = _mm256_mul_ps(ymm14, ymm0);
812
813 // multiply C by beta and accumulate.
814 /*ymm2 = _mm256_loadu_ps(tC);
815 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
816 ymm2 = _mm256_loadu_ps(tC + 8);
817 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
818 ymm2 = _mm256_loadu_ps(tC + 16);
819 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/
820 _mm256_storeu_ps(tC, ymm4);
821 _mm256_storeu_ps(tC + 8, ymm5);
822 _mm256_storeu_ps(tC + 16, ymm6);
823
824 // multiply C by beta and accumulate.
825 tC += ldc;
826 /*ymm2 = _mm256_loadu_ps(tC);
827 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
828 ymm2 = _mm256_loadu_ps(tC + 8);
829 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
830 ymm2 = _mm256_loadu_ps(tC + 16);
831 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/
832 _mm256_storeu_ps(tC, ymm8);
833 _mm256_storeu_ps(tC + 8, ymm9);
834 _mm256_storeu_ps(tC + 16, ymm10);
835
836 // multiply C by beta and accumulate.
837 tC += ldc;
838 /*ymm2 = _mm256_loadu_ps(tC);
839 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
840 ymm2 = _mm256_loadu_ps(tC + 8);
841 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
842 ymm2 = _mm256_loadu_ps(tC + 16);
843 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
844 _mm256_storeu_ps(tC, ymm12);
845 _mm256_storeu_ps(tC + 8, ymm13);
846 _mm256_storeu_ps(tC + 16, ymm14);
847
848 }
849 n_remainder = N - col_idx;
850 // if the N is not multiple of 3.
851 // handling edge case.
852 if (n_remainder == 2)
853 {
854 //pointer math to point to proper memory
855 tC = C + ldc * col_idx + row_idx;
856 tB = B + tb_inc_col * col_idx;
857 tA = A + row_idx;
858
859 // clear scratch registers.
860 ymm8 = _mm256_setzero_ps();
861 ymm9 = _mm256_setzero_ps();
862 ymm10 = _mm256_setzero_ps();
863 ymm12 = _mm256_setzero_ps();
864 ymm13 = _mm256_setzero_ps();
865 ymm14 = _mm256_setzero_ps();
866
867 for (k = 0; k < K; ++k)
868 {
869 // The inner loop broadcasts the B matrix data and
870 // multiplies it with the A matrix.
871 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
872 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
873 tB += tb_inc_row;
874
875 //broadcasted matrix B elements are multiplied
876 //with matrix A columns.
877 ymm3 = _mm256_loadu_ps(tA);
878 ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
879 ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
880
881 ymm3 = _mm256_loadu_ps(tA + 8);
882 ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
883 ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
884
885 ymm3 = _mm256_loadu_ps(tA + 16);
886 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
887 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
888
889 tA += lda;
890
891 }
892 // alpha, beta multiplication.
893 ymm0 = _mm256_broadcast_ss(alpha_cast);
894 //ymm1 = _mm256_broadcast_ss(beta_cast);
895
896 //multiply A*B by alpha.
897 ymm8 = _mm256_mul_ps(ymm8, ymm0);
898 ymm9 = _mm256_mul_ps(ymm9, ymm0);
899 ymm10 = _mm256_mul_ps(ymm10, ymm0);
900 ymm12 = _mm256_mul_ps(ymm12, ymm0);
901 ymm13 = _mm256_mul_ps(ymm13, ymm0);
902 ymm14 = _mm256_mul_ps(ymm14, ymm0);
903
904 // multiply C by beta and accumulate.
905 /*ymm2 = _mm256_loadu_ps(tC + 0);
906 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
907 ymm2 = _mm256_loadu_ps(tC + 8);
908 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
909 ymm2 = _mm256_loadu_ps(tC + 16);
910 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/
911 _mm256_storeu_ps(tC + 0, ymm8);
912 _mm256_storeu_ps(tC + 8, ymm9);
913 _mm256_storeu_ps(tC + 16, ymm10);
914
915 // multiply C by beta and accumulate.
916 tC += ldc;
917 /*ymm2 = _mm256_loadu_ps(tC);
918 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
919 ymm2 = _mm256_loadu_ps(tC + 8);
920 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
921 ymm2 = _mm256_loadu_ps(tC + 16);
922 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
923 _mm256_storeu_ps(tC, ymm12);
924 _mm256_storeu_ps(tC + 8, ymm13);
925 _mm256_storeu_ps(tC + 16, ymm14);
926
927 col_idx += 2;
928 }
929 // if the N is not multiple of 3.
930 // handling edge case.
931 if (n_remainder == 1)
932 {
933 //pointer math to point to proper memory
934 tC = C + ldc * col_idx + row_idx;
935 tB = B + tb_inc_col * col_idx;
936 tA = A + row_idx;
937
938 // clear scratch registers.
939 ymm12 = _mm256_setzero_ps();
940 ymm13 = _mm256_setzero_ps();
941 ymm14 = _mm256_setzero_ps();
942
943 for (k = 0; k < K; ++k)
944 {
945 // The inner loop broadcasts the B matrix data and
946 // multiplies it with the A matrix.
947 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
948 tB += tb_inc_row;
949
950 //broadcasted matrix B elements are multiplied
951 //with matrix A columns.
952 ymm3 = _mm256_loadu_ps(tA);
953 ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
954
955 ymm3 = _mm256_loadu_ps(tA + 8);
956 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
957
958 ymm3 = _mm256_loadu_ps(tA + 16);
959 ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
960
961 tA += lda;
962
963 }
964 // alpha, beta multiplication.
965 ymm0 = _mm256_broadcast_ss(alpha_cast);
966 //ymm1 = _mm256_broadcast_ss(beta_cast);
967
968 //multiply A*B by alpha.
969 ymm12 = _mm256_mul_ps(ymm12, ymm0);
970 ymm13 = _mm256_mul_ps(ymm13, ymm0);
971 ymm14 = _mm256_mul_ps(ymm14, ymm0);
972
973 // multiply C by beta and accumulate.
974 /*ymm2 = _mm256_loadu_ps(tC + 0);
975 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
976 ymm2 = _mm256_loadu_ps(tC + 8);
977 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
978 ymm2 = _mm256_loadu_ps(tC + 16);
979 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/
980
981 _mm256_storeu_ps(tC + 0, ymm12);
982 _mm256_storeu_ps(tC + 8, ymm13);
983 _mm256_storeu_ps(tC + 16, ymm14);
984 }
985
986 row_idx += 24;
987 }
988
989 if (m_remainder >= 16)
990 {
991 m_remainder -= 16;
992
993 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
994 {
995 //pointer math to point to proper memory
996 tC = C + ldc * col_idx + row_idx;
997 tB = B + tb_inc_col * col_idx;
998 tA = A + row_idx;
999
1000 // clear scratch registers.
1001 ymm4 = _mm256_setzero_ps();
1002 ymm5 = _mm256_setzero_ps();
1003 ymm6 = _mm256_setzero_ps();
1004 ymm7 = _mm256_setzero_ps();
1005 ymm8 = _mm256_setzero_ps();
1006 ymm9 = _mm256_setzero_ps();
1007
1008 for (k = 0; k < K; ++k)
1009 {
1010 // The inner loop broadcasts the B matrix data and
1011 // multiplies it with the A matrix.
1012 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1013 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1014 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1015 tB += tb_inc_row;
1016
1017 //broadcasted matrix B elements are multiplied
1018 //with matrix A columns.
1019 ymm3 = _mm256_loadu_ps(tA);
1020 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1021 ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
1022 ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8);
1023
1024 ymm3 = _mm256_loadu_ps(tA + 8);
1025 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1026 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1027 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1028
1029 tA += lda;
1030 }
1031 // alpha, beta multiplication.
1032 ymm0 = _mm256_broadcast_ss(alpha_cast);
1033 //ymm1 = _mm256_broadcast_ss(beta_cast);
1034
1035 //multiply A*B by alpha.
1036 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1037 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1038 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1039 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1040 ymm8 = _mm256_mul_ps(ymm8, ymm0);
1041 ymm9 = _mm256_mul_ps(ymm9, ymm0);
1042
1043 // multiply C by beta and accumulate.
1044 /*ymm2 = _mm256_loadu_ps(tC);
1045 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1046 ymm2 = _mm256_loadu_ps(tC + 8);
1047 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1048 _mm256_storeu_ps(tC, ymm4);
1049 _mm256_storeu_ps(tC + 8, ymm5);
1050
1051 // multiply C by beta and accumulate.
1052 tC += ldc;
1053 /*ymm2 = _mm256_loadu_ps(tC);
1054 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1055 ymm2 = _mm256_loadu_ps(tC + 8);
1056 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
1057 _mm256_storeu_ps(tC, ymm6);
1058 _mm256_storeu_ps(tC + 8, ymm7);
1059
1060 // multiply C by beta and accumulate.
1061 tC += ldc;
1062 /*ymm2 = _mm256_loadu_ps(tC);
1063 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
1064 ymm2 = _mm256_loadu_ps(tC + 8);
1065 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/
1066 _mm256_storeu_ps(tC, ymm8);
1067 _mm256_storeu_ps(tC + 8, ymm9);
1068
1069 }
1070 n_remainder = N - col_idx;
1071 // if the N is not multiple of 3.
1072 // handling edge case.
1073 if (n_remainder == 2)
1074 {
1075 //pointer math to point to proper memory
1076 tC = C + ldc * col_idx + row_idx;
1077 tB = B + tb_inc_col * col_idx;
1078 tA = A + row_idx;
1079
1080 // clear scratch registers.
1081 ymm4 = _mm256_setzero_ps();
1082 ymm5 = _mm256_setzero_ps();
1083 ymm6 = _mm256_setzero_ps();
1084 ymm7 = _mm256_setzero_ps();
1085
1086 for (k = 0; k < K; ++k)
1087 {
1088 // The inner loop broadcasts the B matrix data and
1089 // multiplies it with the A matrix.
1090 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1091 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1092 tB += tb_inc_row;
1093
1094 //broadcasted matrix B elements are multiplied
1095 //with matrix A columns.
1096 ymm3 = _mm256_loadu_ps(tA);
1097 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1098 ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
1099
1100 ymm3 = _mm256_loadu_ps(tA + 8);
1101 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1102 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1103
1104 tA += lda;
1105 }
1106 // alpha, beta multiplication.
1107 ymm0 = _mm256_broadcast_ss(alpha_cast);
1108 //ymm1 = _mm256_broadcast_ss(beta_cast);
1109
1110 //multiply A*B by alpha.
1111 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1112 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1113 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1114 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1115
1116 // multiply C by beta and accumulate.
1117 /*ymm2 = _mm256_loadu_ps(tC);
1118 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1119 ymm2 = _mm256_loadu_ps(tC + 8);
1120 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1121 _mm256_storeu_ps(tC, ymm4);
1122 _mm256_storeu_ps(tC + 8, ymm5);
1123
1124 // multiply C by beta and accumulate.
1125 tC += ldc;
1126 /*ymm2 = _mm256_loadu_ps(tC);
1127 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1128 ymm2 = _mm256_loadu_ps(tC + 8);
1129 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
1130 _mm256_storeu_ps(tC, ymm6);
1131 _mm256_storeu_ps(tC + 8, ymm7);
1132
1133 col_idx += 2;
1134
1135 }
1136 // if the N is not multiple of 3.
1137 // handling edge case.
1138 if (n_remainder == 1)
1139 {
1140 //pointer math to point to proper memory
1141 tC = C + ldc * col_idx + row_idx;
1142 tB = B + tb_inc_col * col_idx;
1143 tA = A + row_idx;
1144
1145 ymm4 = _mm256_setzero_ps();
1146 ymm5 = _mm256_setzero_ps();
1147
1148 for (k = 0; k < K; ++k)
1149 {
1150 // The inner loop broadcasts the B matrix data and
1151 // multiplies it with the A matrix.
1152 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1153 tB += tb_inc_row;
1154
1155 //broadcasted matrix B elements are multiplied
1156 //with matrix A columns.
1157 ymm3 = _mm256_loadu_ps(tA);
1158 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1159
1160 ymm3 = _mm256_loadu_ps(tA + 8);
1161 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1162
1163 tA += lda;
1164 }
1165 // alpha, beta multiplication.
1166 ymm0 = _mm256_broadcast_ss(alpha_cast);
1167 //ymm1 = _mm256_broadcast_ss(beta_cast);
1168
1169 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1170 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1171
1172 // multiply C by beta and accumulate.
1173 /*ymm2 = _mm256_loadu_ps(tC);
1174 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1175 ymm2 = _mm256_loadu_ps(tC + 8);
1176 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1177 _mm256_storeu_ps(tC, ymm4);
1178 _mm256_storeu_ps(tC + 8, ymm5);
1179
1180 }
1181
1182 row_idx += 16;
1183 }
1184
1185 if (m_remainder >= 8)
1186 {
1187 m_remainder -= 8;
1188
1189 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1190 {
1191 //pointer math to point to proper memory
1192 tC = C + ldc * col_idx + row_idx;
1193 tB = B + tb_inc_col * col_idx;
1194 tA = A + row_idx;
1195
1196 // clear scratch registers.
1197 ymm4 = _mm256_setzero_ps();
1198 ymm5 = _mm256_setzero_ps();
1199 ymm6 = _mm256_setzero_ps();
1200
1201 for (k = 0; k < K; ++k)
1202 {
1203 // The inner loop broadcasts the B matrix data and
1204 // multiplies it with the A matrix.
1205 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1206 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1207 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1208 tB += tb_inc_row;
1209
1210 //broadcasted matrix B elements are multiplied
1211 //with matrix A columns.
1212 ymm3 = _mm256_loadu_ps(tA);
1213 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1214 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
1215 ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
1216
1217 tA += lda;
1218 }
1219 // alpha, beta multiplication.
1220 ymm0 = _mm256_broadcast_ss(alpha_cast);
1221 //ymm1 = _mm256_broadcast_ss(beta_cast);
1222
1223 //multiply A*B by alpha.
1224 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1225 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1226 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1227
1228 // multiply C by beta and accumulate.
1229 /*ymm2 = _mm256_loadu_ps(tC);
1230 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
1231 _mm256_storeu_ps(tC, ymm4);
1232
1233 // multiply C by beta and accumulate.
1234 tC += ldc;
1235 /*ymm2 = _mm256_loadu_ps(tC);
1236 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1237 _mm256_storeu_ps(tC, ymm5);
1238
1239 // multiply C by beta and accumulate.
1240 tC += ldc;
1241 /*ymm2 = _mm256_loadu_ps(tC);
1242 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/
1243 _mm256_storeu_ps(tC, ymm6);
1244 }
1245 n_remainder = N - col_idx;
1246 // if the N is not multiple of 3.
1247 // handling edge case.
1248 if (n_remainder == 2)
1249 {
1250 //pointer math to point to proper memory
1251 tC = C + ldc * col_idx + row_idx;
1252 tB = B + tb_inc_col * col_idx;
1253 tA = A + row_idx;
1254
1255 ymm4 = _mm256_setzero_ps();
1256 ymm5 = _mm256_setzero_ps();
1257
1258 for (k = 0; k < K; ++k)
1259 {
1260 // The inner loop broadcasts the B matrix data and
1261 // multiplies it with the A matrix.
1262 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1263 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1264 tB += tb_inc_row;
1265
1266 //broadcasted matrix B elements are multiplied
1267 //with matrix A columns.
1268 ymm3 = _mm256_loadu_ps(tA);
1269 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1270 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
1271
1272 tA += lda;
1273 }
1274 // alpha, beta multiplication.
1275 ymm0 = _mm256_broadcast_ss(alpha_cast);
1276 //ymm1 = _mm256_broadcast_ss(beta_cast);
1277
1278 //multiply A*B by alpha.
1279 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1280 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1281
1282 // multiply C by beta and accumulate.
1283 /*ymm2 = _mm256_loadu_ps(tC);
1284 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
1285 _mm256_storeu_ps(tC, ymm4);
1286
1287 // multiply C by beta and accumulate.
1288 tC += ldc;
1289 /*ymm2 = _mm256_loadu_ps(tC);
1290 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1291 _mm256_storeu_ps(tC, ymm5);
1292
1293 col_idx += 2;
1294
1295 }
1296 // if the N is not multiple of 3.
1297 // handling edge case.
1298 if (n_remainder == 1)
1299 {
1300 //pointer math to point to proper memory
1301 tC = C + ldc * col_idx + row_idx;
1302 tB = B + tb_inc_col * col_idx;
1303 tA = A + row_idx;
1304
1305 ymm4 = _mm256_setzero_ps();
1306
1307 for (k = 0; k < K; ++k)
1308 {
1309 // The inner loop broadcasts the B matrix data and
1310 // multiplies it with the A matrix.
1311 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1312 tB += tb_inc_row;
1313
1314 //broadcasted matrix B elements are multiplied
1315 //with matrix A columns.
1316 ymm3 = _mm256_loadu_ps(tA);
1317 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1318
1319 tA += lda;
1320 }
1321 // alpha, beta multiplication.
1322 ymm0 = _mm256_broadcast_ss(alpha_cast);
1323 //ymm1 = _mm256_broadcast_ss(beta_cast);
1324
1325 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1326
1327 // multiply C by beta and accumulate.
1328 /*ymm2 = _mm256_loadu_ps(tC);
1329 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/
1330 _mm256_storeu_ps(tC, ymm4);
1331
1332 }
1333
1334 row_idx += 8;
1335 }
1336 // M is not a multiple of 32.
1337 // The handling of edge case where the remainder
1338 // dimension is less than 8. The padding takes place
1339 // to handle this case.
1340 if ((m_remainder) && (lda > 7))
1341 {
1342 float f_temp[8];
1343
1344 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1345 {
1346 //pointer math to point to proper memory
1347 tC = C + ldc * col_idx + row_idx;
1348 tB = B + tb_inc_col * col_idx;
1349 tA = A + row_idx;
1350
1351 // clear scratch registers.
1352 ymm5 = _mm256_setzero_ps();
1353 ymm7 = _mm256_setzero_ps();
1354 ymm9 = _mm256_setzero_ps();
1355
1356 for (k = 0; k < (K - 1); ++k)
1357 {
1358 // The inner loop broadcasts the B matrix data and
1359 // multiplies it with the A matrix.
1360 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1361 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1362 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1363 tB += tb_inc_row;
1364
1365 //broadcasted matrix B elements are multiplied
1366 //with matrix A columns.
1367 ymm3 = _mm256_loadu_ps(tA);
1368 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1369 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1370 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1371
1372 tA += lda;
1373 }
1374 // alpha, beta multiplication.
1375 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1376 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1377 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1378 tB += tb_inc_row;
1379
1380 for (int i = 0; i < m_remainder; i++)
1381 {
1382 f_temp[i] = tA[i];
1383 }
1384 ymm3 = _mm256_loadu_ps(f_temp);
1385 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1386 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1387 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1388
1389 ymm0 = _mm256_broadcast_ss(alpha_cast);
1390 //ymm1 = _mm256_broadcast_ss(beta_cast);
1391
1392 //multiply A*B by alpha.
1393 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1394 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1395 ymm9 = _mm256_mul_ps(ymm9, ymm0);
1396
1397
1398 /*for (int i = 0; i < m_remainder; i++)
1399 {
1400 f_temp[i] = tC[i];
1401 }
1402 ymm2 = _mm256_loadu_ps(f_temp);
1403 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1404 _mm256_storeu_ps(f_temp, ymm5);
1405 for (int i = 0; i < m_remainder; i++)
1406 {
1407 tC[i] = f_temp[i];
1408 }
1409
1410 tC += ldc;
1411 /*for (int i = 0; i < m_remainder; i++)
1412 {
1413 f_temp[i] = tC[i];
1414 }
1415 ymm2 = _mm256_loadu_ps(f_temp);
1416 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
1417 _mm256_storeu_ps(f_temp, ymm7);
1418 for (int i = 0; i < m_remainder; i++)
1419 {
1420 tC[i] = f_temp[i];
1421 }
1422
1423 tC += ldc;
1424 /*for (int i = 0; i < m_remainder; i++)
1425 {
1426 f_temp[i] = tC[i];
1427 }
1428 ymm2 = _mm256_loadu_ps(f_temp);
1429 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/
1430 _mm256_storeu_ps(f_temp, ymm9);
1431 for (int i = 0; i < m_remainder; i++)
1432 {
1433 tC[i] = f_temp[i];
1434 }
1435 }
1436 n_remainder = N - col_idx;
1437 // if the N is not multiple of 3.
1438 // handling edge case.
1439 if (n_remainder == 2)
1440 {
1441 //pointer math to point to proper memory
1442 tC = C + ldc * col_idx + row_idx;
1443 tB = B + tb_inc_col * col_idx;
1444 tA = A + row_idx;
1445
1446 ymm5 = _mm256_setzero_ps();
1447 ymm7 = _mm256_setzero_ps();
1448
1449 for (k = 0; k < (K - 1); ++k)
1450 {
1451 // The inner loop broadcasts the B matrix data and
1452 // multiplies it with the A matrix.
1453 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1454 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1455 tB += tb_inc_row;
1456
1457 ymm3 = _mm256_loadu_ps(tA);
1458 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1459 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1460
1461 tA += lda;
1462 }
1463
1464 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1465 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1466 tB += tb_inc_row;
1467
1468 for (int i = 0; i < m_remainder; i++)
1469 {
1470 f_temp[i] = tA[i];
1471 }
1472 ymm3 = _mm256_loadu_ps(f_temp);
1473 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1474 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1475
1476 ymm0 = _mm256_broadcast_ss(alpha_cast);
1477 //ymm1 = _mm256_broadcast_ss(beta_cast);
1478
1479 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1480 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1481
1482 /*for (int i = 0; i < m_remainder; i++)
1483 {
1484 f_temp[i] = tC[i];
1485 }
1486 ymm2 = _mm256_loadu_ps(f_temp);
1487 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1488 _mm256_storeu_ps(f_temp, ymm5);
1489 for (int i = 0; i < m_remainder; i++)
1490 {
1491 tC[i] = f_temp[i];
1492 }
1493
1494 tC += ldc;
1495 /*for (int i = 0; i < m_remainder; i++)
1496 {
1497 f_temp[i] = tC[i];
1498 }
1499 ymm2 = _mm256_loadu_ps(f_temp);
1500 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/
1501 _mm256_storeu_ps(f_temp, ymm7);
1502 for (int i = 0; i < m_remainder; i++)
1503 {
1504 tC[i] = f_temp[i];
1505 }
1506 }
1507 // if the N is not multiple of 3.
1508 // handling edge case.
1509 if (n_remainder == 1)
1510 {
1511 //pointer math to point to proper memory
1512 tC = C + ldc * col_idx + row_idx;
1513 tB = B + tb_inc_col * col_idx;
1514 tA = A + row_idx;
1515
1516 ymm5 = _mm256_setzero_ps();
1517
1518 for (k = 0; k < (K - 1); ++k)
1519 {
1520 // The inner loop broadcasts the B matrix data and
1521 // multiplies it with the A matrix.
1522 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1523 tB += tb_inc_row;
1524
1525 ymm3 = _mm256_loadu_ps(tA);
1526 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1527
1528 tA += lda;
1529 }
1530
1531 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1532 tB += tb_inc_row;
1533
1534 for (int i = 0; i < m_remainder; i++)
1535 {
1536 f_temp[i] = tA[i];
1537 }
1538 ymm3 = _mm256_loadu_ps(f_temp);
1539 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1540
1541 ymm0 = _mm256_broadcast_ss(alpha_cast);
1542 //ymm1 = _mm256_broadcast_ss(beta_cast);
1543
1544 // multiply C by beta and accumulate.
1545 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1546
1547 /*for (int i = 0; i < m_remainder; i++)
1548 {
1549 f_temp[i] = tC[i];
1550 }
1551 ymm2 = _mm256_loadu_ps(f_temp);
1552 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/
1553 _mm256_storeu_ps(f_temp, ymm5);
1554 for (int i = 0; i < m_remainder; i++)
1555 {
1556 tC[i] = f_temp[i];
1557 }
1558 }
1559 m_remainder = 0;
1560 }
1561
1562 if (m_remainder)
1563 {
1564 float result;
1565 for (; row_idx < M; row_idx += 1)
1566 {
1567 for (col_idx = 0; col_idx < N; col_idx += 1)
1568 {
1569 //pointer math to point to proper memory
1570 tC = C + ldc * col_idx + row_idx;
1571 tB = B + tb_inc_col * col_idx;
1572 tA = A + row_idx;
1573
1574 result = 0;
1575 for (k = 0; k < K; ++k)
1576 {
1577 result += (*tA) * (*tB);
1578 tA += lda;
1579 tB += tb_inc_row;
1580 }
1581
1582 result *= (*alpha_cast);
1583 (*tC) = /*(*tC) * (*beta_cast) + */result;
1584 }
1585 }
1586 }
1587
1588 //copy/compute sryk values back to C using SIMD
1589 if ( bli_seq0( *beta_cast ) )
1590 {//just copy in case of beta = 0
1591 dim_t _i, _j, k, _l;
1592 if(bli_obj_is_lower(c)) // c is lower
1593 {
1594 //first column
1595 _j = 0;
1596 k = M >> 3;
1597 _i = 0;
1598 for ( _l = 0; _l < k; _l++ )
1599 {
1600 ymm0 = _mm256_loadu_ps((C + _i*rsc));
1601 _mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0);
1602 _i += 8;
1603 }
1604 while (_i < M )
1605 {
1606 bli_sscopys( *(C + _i*rsc + _j*ldc),
1607 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1608 _i++;
1609 }
1610 _j++;
1611 while ( _j < N ) //next column
1612 {
1613 //k = (_j + (8 - (_j & 7)));
1614 _l = _j & 7;
1615 k = (_l != 0) ? (_j + (8 - _l)) : _j;
1616 k = (k <= M) ? k : M;
1617 for ( _i = _j; _i < k; ++_i )
1618 {
1619 bli_sscopys( *(C + _i*rsc + _j*ldc),
1620 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1621 }
1622 k = (M - _i) >> 3;
1623 _l = 0;
1624 while ( _l < k )
1625 {
1626 ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
1627 _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
1628
1629 _i += 8;
1630 _l++;
1631 }
1632 while (_i < M )
1633 {
1634 bli_sscopys( *(C + _i*rsc + _j*ldc),
1635 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1636 _i++;
1637 }
1638 _j++;
1639 }
1640 }
1641 else //c is upper
1642 {
1643 for ( _j = 0; _j < N; ++_j )
1644 {
1645 k = (_j + 1) >> 3;
1646 _i = 0;
1647 _l = 0;
1648 while ( _l < k )
1649 {
1650 ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
1651 _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
1652 _i += 8;
1653 _l++;
1654 }
1655 while (_i <= _j )
1656 {
1657 bli_sscopys( *(C + _i*rsc + _j*ldc),
1658 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1659 ++_i;
1660 }
1661 }
1662 }
1663 }
1664 else
1665 {//when beta is non-zero, fmadd and store the results
1666 dim_t _i, _j, k, _l;
1667 ymm1 = _mm256_broadcast_ss(beta_cast);
1668 if(bli_obj_is_lower(c)) //c is lower
1669 {
1670 //first column
1671 _j = 0;
1672 k = M >> 3;
1673 _i = 0;
1674 for ( _l = 0; _l < k; _l++ )
1675 {
1676 ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC));
1677 ymm0 = _mm256_loadu_ps((C + _i*rsc));
1678 ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
1679 _mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0);
1680 _i += 8;
1681 }
1682 while (_i < M )
1683 {
1684 bli_sssxpbys( *(C + _i*rsc + _j*ldc),
1685 *(beta_cast),
1686 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1687 _i++;
1688 }
1689 _j++;
1690 while ( _j < N ) //next column
1691 {
1692 //k = (_j + (8 - (_j & 7)));
1693 _l = _j & 7;
1694 k = (_l != 0) ? (_j + (8 - _l)) : _j;
1695 k = (k <= M) ? k : M;
1696 for ( _i = _j; _i < k; ++_i )
1697 {
1698 bli_sssxpbys( *(C + _i*rsc + _j*ldc),
1699 *(beta_cast),
1700 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1701 }
1702 k = (M - _i) >> 3;
1703 _l = 0;
1704 while ( _l < k )
1705 {
1706 ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC));
1707 ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
1708 ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
1709 _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
1710
1711 _i += 8;
1712 _l++;
1713 }
1714 while (_i < M )
1715 {
1716 bli_sssxpbys( *(C + _i*rsc + _j*ldc),
1717 *(beta_cast),
1718 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1719 _i++;
1720 }
1721 _j++;
1722 }
1723 }
1724 else //c is upper
1725 {
1726 for ( _j = 0; _j < N; ++_j )
1727 {
1728 k = (_j + 1) >> 3;
1729 _i = 0;
1730 _l = 0;
1731 while ( _l < k )
1732 {
1733 ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC));
1734 ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc));
1735 ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0);
1736 _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
1737 _i += 8;
1738 _l++;
1739 }
1740 while (_i <= _j )
1741 {
1742 bli_sssxpbys( *(C + _i*rsc + _j*ldc),
1743 *(beta_cast),
1744 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
1745 ++_i;
1746 }
1747 }
1748 }
1749 }
1750
1751 return BLIS_SUCCESS;
1752 }
1753 else
1754 return BLIS_NONCONFORMAL_DIMENSIONS;
1755
1756
1757 };
1758
bli_dsyrk_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)1759 static err_t bli_dsyrk_small
1760 (
1761 obj_t* alpha,
1762 obj_t* a,
1763 obj_t* b,
1764 obj_t* beta,
1765 obj_t* c,
1766 cntx_t* cntx,
1767 cntl_t* cntl
1768 )
1769 {
1770
1771 int M = bli_obj_length( c ); // number of rows of Matrix C
1772 int N = bli_obj_width( c ); // number of columns of Matrix C
1773 int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
1774 int L = M * N;
1775
1776 // If alpha is zero, scale by beta and return.
1777 if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
1778 || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
1779 {
1780
1781 int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
1782 int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
1783 int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
1784 int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
1785 int row_idx, col_idx, k;
1786 int rs_matC = bli_obj_row_stride( c );
1787 int rsc = 1;
1788 double *A = a->buffer; // pointer to elements of Matrix A
1789 double *B = b->buffer; // pointer to elements of Matrix B
1790 double *C = D_C_pack; // pointer to elements of Matrix C
1791 double *matCbuf = c->buffer;
1792
1793 double *tA = A, *tB = B, *tC = C;//, *tA_pack;
1794 double *tA_packed; // temprorary pointer to hold packed A memory pointer
1795 int row_idx_packed; //packed A memory row index
1796 int lda_packed; //lda of packed A
1797 int col_idx_start; //starting index after A matrix is packed.
1798 dim_t tb_inc_row = 1; // row stride of matrix B
1799 dim_t tb_inc_col = ldb; // column stride of matrix B
1800 __m256d ymm4, ymm5, ymm6, ymm7;
1801 __m256d ymm8, ymm9, ymm10, ymm11;
1802 __m256d ymm12, ymm13, ymm14, ymm15;
1803 __m256d ymm0, ymm1, ymm2, ymm3;
1804
1805 int n_remainder; // If the N is non multiple of 3.(N%3)
1806 int m_remainder; // If the M is non multiple of 16.(M%16)
1807
1808 double *alpha_cast, *beta_cast; // alpha, beta multiples
1809 alpha_cast = (alpha->buffer);
1810 beta_cast = (beta->buffer);
1811 int required_packing_A = 1;
1812
1813 // when N is equal to 1 call GEMV instead of SYRK
1814 if (N == 1)
1815 {
1816 bli_gemv
1817 (
1818 alpha,
1819 a,
1820 b,
1821 beta,
1822 c
1823 );
1824 return BLIS_SUCCESS;
1825 }
1826
1827 //update the pointer math if matrix B needs to be transposed.
1828 if (bli_obj_has_trans( b ))
1829 {
1830 tb_inc_col = 1; //switch row and column strides
1831 tb_inc_row = ldb;
1832 }
1833
1834 if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM))
1835 {
1836 required_packing_A = 0;
1837 }
1838 /*
1839 * The computation loop runs for D_MRxN columns of C matrix, thus
1840 * accessing the D_MRxK A matrix data and KxNR B matrix data.
1841 * The computation is organized as inner loops of dimension D_MRxNR.
1842 */
1843 // Process D_MR rows of C matrix at a time.
1844 for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR)
1845 {
1846
1847 col_idx_start = 0;
1848 tA_packed = A;
1849 row_idx_packed = row_idx;
1850 lda_packed = lda;
1851
1852 // This is the part of the pack and compute optimization.
1853 // During the first column iteration, we store the accessed A matrix into
1854 // contiguous static memory. This helps to keep te A matrix in Cache and
1855 // aviods the TLB misses.
1856 if (required_packing_A)
1857 {
1858 col_idx = 0;
1859
1860 //pointer math to point to proper memory
1861 tC = C + ldc * col_idx + row_idx;
1862 tB = B + tb_inc_col * col_idx;
1863 tA = A + row_idx;
1864 tA_packed = D_A_pack;
1865
1866 #if 0//def BLIS_ENABLE_PREFETCH
1867 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
1868 _mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
1869 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
1870 _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
1871 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
1872 _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
1873 #endif
1874 // clear scratch registers.
1875 ymm4 = _mm256_setzero_pd();
1876 ymm5 = _mm256_setzero_pd();
1877 ymm6 = _mm256_setzero_pd();
1878 ymm7 = _mm256_setzero_pd();
1879 ymm8 = _mm256_setzero_pd();
1880 ymm9 = _mm256_setzero_pd();
1881 ymm10 = _mm256_setzero_pd();
1882 ymm11 = _mm256_setzero_pd();
1883 ymm12 = _mm256_setzero_pd();
1884 ymm13 = _mm256_setzero_pd();
1885 ymm14 = _mm256_setzero_pd();
1886 ymm15 = _mm256_setzero_pd();
1887
1888 for (k = 0; k < K; ++k)
1889 {
1890 // The inner loop broadcasts the B matrix data and
1891 // multiplies it with the A matrix.
1892 // This loop is processing D_MR x K
1893 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
1894 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
1895 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
1896 tB += tb_inc_row;
1897
1898 //broadcasted matrix B elements are multiplied
1899 //with matrix A columns.
1900 ymm3 = _mm256_loadu_pd(tA);
1901 _mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A
1902 // ymm4 += ymm0 * ymm3;
1903 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
1904 // ymm8 += ymm1 * ymm3;
1905 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
1906 // ymm12 += ymm2 * ymm3;
1907 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
1908
1909 ymm3 = _mm256_loadu_pd(tA + 4);
1910 _mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A
1911 // ymm5 += ymm0 * ymm3;
1912 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
1913 // ymm9 += ymm1 * ymm3;
1914 ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
1915 // ymm13 += ymm2 * ymm3;
1916 ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
1917
1918 ymm3 = _mm256_loadu_pd(tA + 8);
1919 _mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A
1920 // ymm6 += ymm0 * ymm3;
1921 ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
1922 // ymm10 += ymm1 * ymm3;
1923 ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
1924 // ymm14 += ymm2 * ymm3;
1925 ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
1926
1927 ymm3 = _mm256_loadu_pd(tA + 12);
1928 _mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A
1929 // ymm7 += ymm0 * ymm3;
1930 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
1931 // ymm11 += ymm1 * ymm3;
1932 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
1933 // ymm15 += ymm2 * ymm3;
1934 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
1935
1936 tA += lda;
1937 tA_packed += D_MR;
1938 }
1939 // alpha, beta multiplication.
1940 ymm0 = _mm256_broadcast_sd(alpha_cast);
1941 //ymm1 = _mm256_broadcast_sd(beta_cast);
1942
1943 //multiply A*B by alpha.
1944 ymm4 = _mm256_mul_pd(ymm4, ymm0);
1945 ymm5 = _mm256_mul_pd(ymm5, ymm0);
1946 ymm6 = _mm256_mul_pd(ymm6, ymm0);
1947 ymm7 = _mm256_mul_pd(ymm7, ymm0);
1948 ymm8 = _mm256_mul_pd(ymm8, ymm0);
1949 ymm9 = _mm256_mul_pd(ymm9, ymm0);
1950 ymm10 = _mm256_mul_pd(ymm10, ymm0);
1951 ymm11 = _mm256_mul_pd(ymm11, ymm0);
1952 ymm12 = _mm256_mul_pd(ymm12, ymm0);
1953 ymm13 = _mm256_mul_pd(ymm13, ymm0);
1954 ymm14 = _mm256_mul_pd(ymm14, ymm0);
1955 ymm15 = _mm256_mul_pd(ymm15, ymm0);
1956
1957 // multiply C by beta and accumulate col 1.
1958 /*ymm2 = _mm256_loadu_pd(tC);
1959 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
1960 ymm2 = _mm256_loadu_pd(tC + 4);
1961 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
1962 ymm2 = _mm256_loadu_pd(tC + 8);
1963 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
1964 ymm2 = _mm256_loadu_pd(tC + 12);
1965 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
1966 _mm256_storeu_pd(tC, ymm4);
1967 _mm256_storeu_pd(tC + 4, ymm5);
1968 _mm256_storeu_pd(tC + 8, ymm6);
1969 _mm256_storeu_pd(tC + 12, ymm7);
1970
1971 // multiply C by beta and accumulate, col 2.
1972 tC += ldc;
1973 /*ymm2 = _mm256_loadu_pd(tC);
1974 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
1975 ymm2 = _mm256_loadu_pd(tC + 4);
1976 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
1977 ymm2 = _mm256_loadu_pd(tC + 8);
1978 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
1979 ymm2 = _mm256_loadu_pd(tC + 12);
1980 ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
1981 _mm256_storeu_pd(tC, ymm8);
1982 _mm256_storeu_pd(tC + 4, ymm9);
1983 _mm256_storeu_pd(tC + 8, ymm10);
1984 _mm256_storeu_pd(tC + 12, ymm11);
1985
1986 // multiply C by beta and accumulate, col 3.
1987 tC += ldc;
1988 /*ymm2 = _mm256_loadu_pd(tC);
1989 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
1990 ymm2 = _mm256_loadu_pd(tC + 4);
1991 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
1992 ymm2 = _mm256_loadu_pd(tC + 8);
1993 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
1994 ymm2 = _mm256_loadu_pd(tC + 12);
1995 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
1996 _mm256_storeu_pd(tC, ymm12);
1997 _mm256_storeu_pd(tC + 4, ymm13);
1998 _mm256_storeu_pd(tC + 8, ymm14);
1999 _mm256_storeu_pd(tC + 12, ymm15);
2000
2001 // modify the pointer arithematic to use packed A matrix.
2002 col_idx_start = NR;
2003 tA_packed = D_A_pack;
2004 row_idx_packed = 0;
2005 lda_packed = D_MR;
2006 }
2007 // Process NR columns of C matrix at a time.
2008 for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
2009 {
2010 //pointer math to point to proper memory
2011 tC = C + ldc * col_idx + row_idx;
2012 tB = B + tb_inc_col * col_idx;
2013 tA = tA_packed + row_idx_packed;
2014
2015 #if 0//def BLIS_ENABLE_PREFETCH
2016 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
2017 _mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
2018 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
2019 _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
2020 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
2021 _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
2022 #endif
2023 // clear scratch registers.
2024 ymm4 = _mm256_setzero_pd();
2025 ymm5 = _mm256_setzero_pd();
2026 ymm6 = _mm256_setzero_pd();
2027 ymm7 = _mm256_setzero_pd();
2028 ymm8 = _mm256_setzero_pd();
2029 ymm9 = _mm256_setzero_pd();
2030 ymm10 = _mm256_setzero_pd();
2031 ymm11 = _mm256_setzero_pd();
2032 ymm12 = _mm256_setzero_pd();
2033 ymm13 = _mm256_setzero_pd();
2034 ymm14 = _mm256_setzero_pd();
2035 ymm15 = _mm256_setzero_pd();
2036
2037 for (k = 0; k < K; ++k)
2038 {
2039 // The inner loop broadcasts the B matrix data and
2040 // multiplies it with the A matrix.
2041 // This loop is processing D_MR x K
2042 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2043 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2044 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2045 tB += tb_inc_row;
2046
2047 //broadcasted matrix B elements are multiplied
2048 //with matrix A columns.
2049 ymm3 = _mm256_loadu_pd(tA);
2050 // ymm4 += ymm0 * ymm3;
2051 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2052 // ymm8 += ymm1 * ymm3;
2053 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
2054 // ymm12 += ymm2 * ymm3;
2055 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
2056
2057 ymm3 = _mm256_loadu_pd(tA + 4);
2058 // ymm5 += ymm0 * ymm3;
2059 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2060 // ymm9 += ymm1 * ymm3;
2061 ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
2062 // ymm13 += ymm2 * ymm3;
2063 ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
2064
2065 ymm3 = _mm256_loadu_pd(tA + 8);
2066 // ymm6 += ymm0 * ymm3;
2067 ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
2068 // ymm10 += ymm1 * ymm3;
2069 ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
2070 // ymm14 += ymm2 * ymm3;
2071 ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
2072
2073 ymm3 = _mm256_loadu_pd(tA + 12);
2074 // ymm7 += ymm0 * ymm3;
2075 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
2076 // ymm11 += ymm1 * ymm3;
2077 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
2078 // ymm15 += ymm2 * ymm3;
2079 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
2080
2081 tA += lda_packed;
2082 }
2083 // alpha, beta multiplication.
2084 ymm0 = _mm256_broadcast_sd(alpha_cast);
2085 //ymm1 = _mm256_broadcast_sd(beta_cast);
2086
2087 //multiply A*B by alpha.
2088 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2089 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2090 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2091 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2092 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2093 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2094 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2095 ymm11 = _mm256_mul_pd(ymm11, ymm0);
2096 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2097 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2098 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2099 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2100
2101 // multiply C by beta and accumulate col 1.
2102 /*ymm2 = _mm256_loadu_pd(tC);
2103 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2104 ymm2 = _mm256_loadu_pd(tC + 4);
2105 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2106 ymm2 = _mm256_loadu_pd(tC + 8);
2107 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2108 ymm2 = _mm256_loadu_pd(tC + 12);
2109 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
2110 _mm256_storeu_pd(tC, ymm4);
2111 _mm256_storeu_pd(tC + 4, ymm5);
2112 _mm256_storeu_pd(tC + 8, ymm6);
2113 _mm256_storeu_pd(tC + 12, ymm7);
2114
2115 // multiply C by beta and accumulate, col 2.
2116 tC += ldc;
2117 /*ymm2 = _mm256_loadu_pd(tC);
2118 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2119 ymm2 = _mm256_loadu_pd(tC + 4);
2120 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2121 ymm2 = _mm256_loadu_pd(tC + 8);
2122 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2123 ymm2 = _mm256_loadu_pd(tC + 12);
2124 ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
2125 _mm256_storeu_pd(tC, ymm8);
2126 _mm256_storeu_pd(tC + 4, ymm9);
2127 _mm256_storeu_pd(tC + 8, ymm10);
2128 _mm256_storeu_pd(tC + 12, ymm11);
2129
2130 // multiply C by beta and accumulate, col 3.
2131 tC += ldc;
2132 /*ymm2 = _mm256_loadu_pd(tC);
2133 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2134 ymm2 = _mm256_loadu_pd(tC + 4);
2135 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2136 ymm2 = _mm256_loadu_pd(tC + 8);
2137 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2138 ymm2 = _mm256_loadu_pd(tC + 12);
2139 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
2140 _mm256_storeu_pd(tC, ymm12);
2141 _mm256_storeu_pd(tC + 4, ymm13);
2142 _mm256_storeu_pd(tC + 8, ymm14);
2143 _mm256_storeu_pd(tC + 12, ymm15);
2144
2145 }
2146 n_remainder = N - col_idx;
2147
2148 // if the N is not multiple of 3.
2149 // handling edge case.
2150 if (n_remainder == 2)
2151 {
2152 //pointer math to point to proper memory
2153 tC = C + ldc * col_idx + row_idx;
2154 tB = B + tb_inc_col * col_idx;
2155 tA = A + row_idx;
2156
2157 // clear scratch registers.
2158 ymm8 = _mm256_setzero_pd();
2159 ymm9 = _mm256_setzero_pd();
2160 ymm10 = _mm256_setzero_pd();
2161 ymm11 = _mm256_setzero_pd();
2162 ymm12 = _mm256_setzero_pd();
2163 ymm13 = _mm256_setzero_pd();
2164 ymm14 = _mm256_setzero_pd();
2165 ymm15 = _mm256_setzero_pd();
2166
2167 for (k = 0; k < K; ++k)
2168 {
2169 // The inner loop broadcasts the B matrix data and
2170 // multiplies it with the A matrix.
2171 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2172 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2173 tB += tb_inc_row;
2174
2175 //broadcasted matrix B elements are multiplied
2176 //with matrix A columns.
2177 ymm3 = _mm256_loadu_pd(tA);
2178 ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
2179 ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
2180
2181 ymm3 = _mm256_loadu_pd(tA + 4);
2182 ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
2183 ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
2184
2185 ymm3 = _mm256_loadu_pd(tA + 8);
2186 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
2187 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
2188
2189 ymm3 = _mm256_loadu_pd(tA + 12);
2190 ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11);
2191 ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);
2192
2193 tA += lda;
2194
2195 }
2196 // alpha, beta multiplication.
2197 ymm0 = _mm256_broadcast_sd(alpha_cast);
2198 //ymm1 = _mm256_broadcast_sd(beta_cast);
2199
2200 //multiply A*B by alpha.
2201 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2202 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2203 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2204 ymm11 = _mm256_mul_pd(ymm11, ymm0);
2205 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2206 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2207 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2208 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2209
2210 // multiply C by beta and accumulate, col 1.
2211 /*ymm2 = _mm256_loadu_pd(tC + 0);
2212 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2213 ymm2 = _mm256_loadu_pd(tC + 4);
2214 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2215 ymm2 = _mm256_loadu_pd(tC + 8);
2216 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2217 ymm2 = _mm256_loadu_pd(tC + 12);
2218 ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/
2219 _mm256_storeu_pd(tC + 0, ymm8);
2220 _mm256_storeu_pd(tC + 4, ymm9);
2221 _mm256_storeu_pd(tC + 8, ymm10);
2222 _mm256_storeu_pd(tC + 12, ymm11);
2223
2224 // multiply C by beta and accumulate, col 2.
2225 tC += ldc;
2226 /*ymm2 = _mm256_loadu_pd(tC);
2227 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2228 ymm2 = _mm256_loadu_pd(tC + 4);
2229 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2230 ymm2 = _mm256_loadu_pd(tC + 8);
2231 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2232 ymm2 = _mm256_loadu_pd(tC + 12);
2233 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
2234 _mm256_storeu_pd(tC, ymm12);
2235 _mm256_storeu_pd(tC + 4, ymm13);
2236 _mm256_storeu_pd(tC + 8, ymm14);
2237 _mm256_storeu_pd(tC + 12, ymm15);
2238
2239 col_idx += 2;
2240 }
2241 // if the N is not multiple of 3.
2242 // handling edge case.
2243 if (n_remainder == 1)
2244 {
2245 //pointer math to point to proper memory
2246 tC = C + ldc * col_idx + row_idx;
2247 tB = B + tb_inc_col * col_idx;
2248 tA = A + row_idx;
2249
2250 // clear scratch registers.
2251 ymm12 = _mm256_setzero_pd();
2252 ymm13 = _mm256_setzero_pd();
2253 ymm14 = _mm256_setzero_pd();
2254 ymm15 = _mm256_setzero_pd();
2255
2256 for (k = 0; k < K; ++k)
2257 {
2258 // The inner loop broadcasts the B matrix data and
2259 // multiplies it with the A matrix.
2260 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2261 tB += tb_inc_row;
2262
2263 //broadcasted matrix B elements are multiplied
2264 //with matrix A columns.
2265 ymm3 = _mm256_loadu_pd(tA);
2266 ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
2267
2268 ymm3 = _mm256_loadu_pd(tA + 4);
2269 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
2270
2271 ymm3 = _mm256_loadu_pd(tA + 8);
2272 ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
2273
2274 ymm3 = _mm256_loadu_pd(tA + 12);
2275 ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15);
2276
2277 tA += lda;
2278
2279 }
2280 // alpha, beta multiplication.
2281 ymm0 = _mm256_broadcast_sd(alpha_cast);
2282 //ymm1 = _mm256_broadcast_sd(beta_cast);
2283
2284 //multiply A*B by alpha.
2285 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2286 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2287 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2288 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2289
2290 // multiply C by beta and accumulate.
2291 /*ymm2 = _mm256_loadu_pd(tC + 0);
2292 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2293 ymm2 = _mm256_loadu_pd(tC + 4);
2294 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2295 ymm2 = _mm256_loadu_pd(tC + 8);
2296 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2297 ymm2 = _mm256_loadu_pd(tC + 12);
2298 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/
2299
2300 _mm256_storeu_pd(tC + 0, ymm12);
2301 _mm256_storeu_pd(tC + 4, ymm13);
2302 _mm256_storeu_pd(tC + 8, ymm14);
2303 _mm256_storeu_pd(tC + 12, ymm15);
2304 }
2305 }
2306
2307 m_remainder = M - row_idx;
2308
2309 if (m_remainder >= 12)
2310 {
2311 m_remainder -= 12;
2312
2313 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2314 {
2315 //pointer math to point to proper memory
2316 tC = C + ldc * col_idx + row_idx;
2317 tB = B + tb_inc_col * col_idx;
2318 tA = A + row_idx;
2319
2320 // clear scratch registers.
2321 ymm4 = _mm256_setzero_pd();
2322 ymm5 = _mm256_setzero_pd();
2323 ymm6 = _mm256_setzero_pd();
2324 ymm8 = _mm256_setzero_pd();
2325 ymm9 = _mm256_setzero_pd();
2326 ymm10 = _mm256_setzero_pd();
2327 ymm12 = _mm256_setzero_pd();
2328 ymm13 = _mm256_setzero_pd();
2329 ymm14 = _mm256_setzero_pd();
2330
2331 for (k = 0; k < K; ++k)
2332 {
2333 // The inner loop broadcasts the B matrix data and
2334 // multiplies it with the A matrix.
2335 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2336 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2337 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2338 tB += tb_inc_row;
2339
2340 //broadcasted matrix B elements are multiplied
2341 //with matrix A columns.
2342 ymm3 = _mm256_loadu_pd(tA);
2343 // ymm4 += ymm0 * ymm3;
2344 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2345 // ymm8 += ymm1 * ymm3;
2346 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
2347 // ymm12 += ymm2 * ymm3;
2348 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
2349
2350 ymm3 = _mm256_loadu_pd(tA + 4);
2351 // ymm5 += ymm0 * ymm3;
2352 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2353 // ymm9 += ymm1 * ymm3;
2354 ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
2355 // ymm13 += ymm2 * ymm3;
2356 ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
2357
2358 ymm3 = _mm256_loadu_pd(tA + 8);
2359 // ymm6 += ymm0 * ymm3;
2360 ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
2361 // ymm10 += ymm1 * ymm3;
2362 ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
2363 // ymm14 += ymm2 * ymm3;
2364 ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
2365
2366 tA += lda;
2367 }
2368 // alpha, beta multiplication.
2369 ymm0 = _mm256_broadcast_sd(alpha_cast);
2370 //ymm1 = _mm256_broadcast_sd(beta_cast);
2371
2372 //multiply A*B by alpha.
2373 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2374 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2375 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2376 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2377 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2378 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2379 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2380 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2381 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2382
2383 // multiply C by beta and accumulate.
2384 /*ymm2 = _mm256_loadu_pd(tC);
2385 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2386 ymm2 = _mm256_loadu_pd(tC + 4);
2387 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2388 ymm2 = _mm256_loadu_pd(tC + 8);
2389 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/
2390 _mm256_storeu_pd(tC, ymm4);
2391 _mm256_storeu_pd(tC + 4, ymm5);
2392 _mm256_storeu_pd(tC + 8, ymm6);
2393
2394 // multiply C by beta and accumulate.
2395 tC += ldc;
2396 /*ymm2 = _mm256_loadu_pd(tC);
2397 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2398 ymm2 = _mm256_loadu_pd(tC + 4);
2399 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2400 ymm2 = _mm256_loadu_pd(tC + 8);
2401 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/
2402 _mm256_storeu_pd(tC, ymm8);
2403 _mm256_storeu_pd(tC + 4, ymm9);
2404 _mm256_storeu_pd(tC + 8, ymm10);
2405
2406 // multiply C by beta and accumulate.
2407 tC += ldc;
2408 /*ymm2 = _mm256_loadu_pd(tC);
2409 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2410 ymm2 = _mm256_loadu_pd(tC + 4);
2411 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2412 ymm2 = _mm256_loadu_pd(tC + 8);
2413 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
2414 _mm256_storeu_pd(tC, ymm12);
2415 _mm256_storeu_pd(tC + 4, ymm13);
2416 _mm256_storeu_pd(tC + 8, ymm14);
2417
2418 }
2419 n_remainder = N - col_idx;
2420 // if the N is not multiple of 3.
2421 // handling edge case.
2422 if (n_remainder == 2)
2423 {
2424 //pointer math to point to proper memory
2425 tC = C + ldc * col_idx + row_idx;
2426 tB = B + tb_inc_col * col_idx;
2427 tA = A + row_idx;
2428
2429 // clear scratch registers.
2430 ymm8 = _mm256_setzero_pd();
2431 ymm9 = _mm256_setzero_pd();
2432 ymm10 = _mm256_setzero_pd();
2433 ymm12 = _mm256_setzero_pd();
2434 ymm13 = _mm256_setzero_pd();
2435 ymm14 = _mm256_setzero_pd();
2436
2437 for (k = 0; k < K; ++k)
2438 {
2439 // The inner loop broadcasts the B matrix data and
2440 // multiplies it with the A matrix.
2441 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2442 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2443 tB += tb_inc_row;
2444
2445 //broadcasted matrix B elements are multiplied
2446 //with matrix A columns.
2447 ymm3 = _mm256_loadu_pd(tA);
2448 ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
2449 ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
2450
2451 ymm3 = _mm256_loadu_pd(tA + 4);
2452 ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
2453 ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
2454
2455 ymm3 = _mm256_loadu_pd(tA + 8);
2456 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
2457 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
2458
2459 tA += lda;
2460
2461 }
2462 // alpha, beta multiplication.
2463 ymm0 = _mm256_broadcast_sd(alpha_cast);
2464 //ymm1 = _mm256_broadcast_sd(beta_cast);
2465
2466 //multiply A*B by alpha.
2467 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2468 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2469 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2470 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2471 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2472 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2473
2474 // multiply C by beta and accumulate.
2475 /*ymm2 = _mm256_loadu_pd(tC + 0);
2476 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2477 ymm2 = _mm256_loadu_pd(tC + 4);
2478 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2479 ymm2 = _mm256_loadu_pd(tC + 8);
2480 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/
2481 _mm256_storeu_pd(tC + 0, ymm8);
2482 _mm256_storeu_pd(tC + 4, ymm9);
2483 _mm256_storeu_pd(tC + 8, ymm10);
2484
2485 // multiply C by beta and accumulate.
2486 tC += ldc;
2487 /*ymm2 = _mm256_loadu_pd(tC);
2488 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2489 ymm2 = _mm256_loadu_pd(tC + 4);
2490 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2491 ymm2 = _mm256_loadu_pd(tC + 8);
2492 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
2493 _mm256_storeu_pd(tC, ymm12);
2494 _mm256_storeu_pd(tC + 4, ymm13);
2495 _mm256_storeu_pd(tC + 8, ymm14);
2496
2497 col_idx += 2;
2498 }
2499 // if the N is not multiple of 3.
2500 // handling edge case.
2501 if (n_remainder == 1)
2502 {
2503 //pointer math to point to proper memory
2504 tC = C + ldc * col_idx + row_idx;
2505 tB = B + tb_inc_col * col_idx;
2506 tA = A + row_idx;
2507
2508 // clear scratch registers.
2509 ymm12 = _mm256_setzero_pd();
2510 ymm13 = _mm256_setzero_pd();
2511 ymm14 = _mm256_setzero_pd();
2512
2513 for (k = 0; k < K; ++k)
2514 {
2515 // The inner loop broadcasts the B matrix data and
2516 // multiplies it with the A matrix.
2517 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2518 tB += tb_inc_row;
2519
2520 //broadcasted matrix B elements are multiplied
2521 //with matrix A columns.
2522 ymm3 = _mm256_loadu_pd(tA);
2523 ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
2524
2525 ymm3 = _mm256_loadu_pd(tA + 4);
2526 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
2527
2528 ymm3 = _mm256_loadu_pd(tA + 8);
2529 ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
2530
2531 tA += lda;
2532
2533 }
2534 // alpha, beta multiplication.
2535 ymm0 = _mm256_broadcast_sd(alpha_cast);
2536 //ymm1 = _mm256_broadcast_sd(beta_cast);
2537
2538 //multiply A*B by alpha.
2539 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2540 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2541 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2542
2543 // multiply C by beta and accumulate.
2544 /*ymm2 = _mm256_loadu_pd(tC + 0);
2545 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2546 ymm2 = _mm256_loadu_pd(tC + 4);
2547 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2548 ymm2 = _mm256_loadu_pd(tC + 8);
2549 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/
2550
2551 _mm256_storeu_pd(tC + 0, ymm12);
2552 _mm256_storeu_pd(tC + 4, ymm13);
2553 _mm256_storeu_pd(tC + 8, ymm14);
2554 }
2555
2556 row_idx += 12;
2557 }
2558
2559 if (m_remainder >= 8)
2560 {
2561 m_remainder -= 8;
2562
2563 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2564 {
2565 //pointer math to point to proper memory
2566 tC = C + ldc * col_idx + row_idx;
2567 tB = B + tb_inc_col * col_idx;
2568 tA = A + row_idx;
2569
2570 // clear scratch registers.
2571 ymm4 = _mm256_setzero_pd();
2572 ymm5 = _mm256_setzero_pd();
2573 ymm6 = _mm256_setzero_pd();
2574 ymm7 = _mm256_setzero_pd();
2575 ymm8 = _mm256_setzero_pd();
2576 ymm9 = _mm256_setzero_pd();
2577
2578 for (k = 0; k < K; ++k)
2579 {
2580 // The inner loop broadcasts the B matrix data and
2581 // multiplies it with the A matrix.
2582 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2583 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2584 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2585 tB += tb_inc_row;
2586
2587 //broadcasted matrix B elements are multiplied
2588 //with matrix A columns.
2589 ymm3 = _mm256_loadu_pd(tA);
2590 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2591 ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
2592 ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8);
2593
2594 ymm3 = _mm256_loadu_pd(tA + 4);
2595 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2596 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2597 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
2598
2599 tA += lda;
2600 }
2601 // alpha, beta multiplication.
2602 ymm0 = _mm256_broadcast_sd(alpha_cast);
2603 //ymm1 = _mm256_broadcast_sd(beta_cast);
2604
2605 //multiply A*B by alpha.
2606 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2607 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2608 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2609 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2610 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2611 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2612
2613 // multiply C by beta and accumulate.
2614 /*ymm2 = _mm256_loadu_pd(tC);
2615 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2616 ymm2 = _mm256_loadu_pd(tC + 4);
2617 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
2618 _mm256_storeu_pd(tC, ymm4);
2619 _mm256_storeu_pd(tC + 4, ymm5);
2620
2621 // multiply C by beta and accumulate.
2622 tC += ldc;
2623 /*ymm2 = _mm256_loadu_pd(tC);
2624 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2625 ymm2 = _mm256_loadu_pd(tC + 4);
2626 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
2627 _mm256_storeu_pd(tC, ymm6);
2628 _mm256_storeu_pd(tC + 4, ymm7);
2629
2630 // multiply C by beta and accumulate.
2631 tC += ldc;
2632 /*ymm2 = _mm256_loadu_pd(tC);
2633 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2634 ymm2 = _mm256_loadu_pd(tC + 4);
2635 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/
2636 _mm256_storeu_pd(tC, ymm8);
2637 _mm256_storeu_pd(tC + 4, ymm9);
2638
2639 }
2640 n_remainder = N - col_idx;
2641 // if the N is not multiple of 3.
2642 // handling edge case.
2643 if (n_remainder == 2)
2644 {
2645 //pointer math to point to proper memory
2646 tC = C + ldc * col_idx + row_idx;
2647 tB = B + tb_inc_col * col_idx;
2648 tA = A + row_idx;
2649
2650 // clear scratch registers.
2651 ymm4 = _mm256_setzero_pd();
2652 ymm5 = _mm256_setzero_pd();
2653 ymm6 = _mm256_setzero_pd();
2654 ymm7 = _mm256_setzero_pd();
2655
2656 for (k = 0; k < K; ++k)
2657 {
2658 // The inner loop broadcasts the B matrix data and
2659 // multiplies it with the A matrix.
2660 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2661 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2662 tB += tb_inc_row;
2663
2664 //broadcasted matrix B elements are multiplied
2665 //with matrix A columns.
2666 ymm3 = _mm256_loadu_pd(tA);
2667 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2668 ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
2669
2670 ymm3 = _mm256_loadu_pd(tA + 4);
2671 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2672 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2673
2674 tA += lda;
2675 }
2676 // alpha, beta multiplication.
2677 ymm0 = _mm256_broadcast_sd(alpha_cast);
2678 //ymm1 = _mm256_broadcast_sd(beta_cast);
2679
2680 //multiply A*B by alpha.
2681 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2682 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2683 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2684 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2685
2686 // multiply C by beta and accumulate.
2687 /*ymm2 = _mm256_loadu_pd(tC);
2688 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2689 ymm2 = _mm256_loadu_pd(tC + 4);
2690 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
2691 _mm256_storeu_pd(tC, ymm4);
2692 _mm256_storeu_pd(tC + 4, ymm5);
2693
2694 // multiply C by beta and accumulate.
2695 tC += ldc;
2696 /*ymm2 = _mm256_loadu_pd(tC);
2697 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2698 ymm2 = _mm256_loadu_pd(tC + 4);
2699 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
2700 _mm256_storeu_pd(tC, ymm6);
2701 _mm256_storeu_pd(tC + 4, ymm7);
2702
2703 col_idx += 2;
2704
2705 }
2706 // if the N is not multiple of 3.
2707 // handling edge case.
2708 if (n_remainder == 1)
2709 {
2710 //pointer math to point to proper memory
2711 tC = C + ldc * col_idx + row_idx;
2712 tB = B + tb_inc_col * col_idx;
2713 tA = A + row_idx;
2714
2715 ymm4 = _mm256_setzero_pd();
2716 ymm5 = _mm256_setzero_pd();
2717
2718 for (k = 0; k < K; ++k)
2719 {
2720 // The inner loop broadcasts the B matrix data and
2721 // multiplies it with the A matrix.
2722 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2723 tB += tb_inc_row;
2724
2725 //broadcasted matrix B elements are multiplied
2726 //with matrix A columns.
2727 ymm3 = _mm256_loadu_pd(tA);
2728 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2729
2730 ymm3 = _mm256_loadu_pd(tA + 4);
2731 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2732
2733 tA += lda;
2734 }
2735 // alpha, beta multiplication.
2736 ymm0 = _mm256_broadcast_sd(alpha_cast);
2737 //ymm1 = _mm256_broadcast_sd(beta_cast);
2738
2739 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2740 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2741
2742 // multiply C by beta and accumulate.
2743 /*ymm2 = _mm256_loadu_pd(tC);
2744 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2745 ymm2 = _mm256_loadu_pd(tC + 4);
2746 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
2747 _mm256_storeu_pd(tC, ymm4);
2748 _mm256_storeu_pd(tC + 4, ymm5);
2749
2750 }
2751
2752 row_idx += 8;
2753 }
2754
2755 if (m_remainder >= 4)
2756 {
2757 m_remainder -= 4;
2758
2759 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2760 {
2761 //pointer math to point to proper memory
2762 tC = C + ldc * col_idx + row_idx;
2763 tB = B + tb_inc_col * col_idx;
2764 tA = A + row_idx;
2765
2766 // clear scratch registers.
2767 ymm4 = _mm256_setzero_pd();
2768 ymm5 = _mm256_setzero_pd();
2769 ymm6 = _mm256_setzero_pd();
2770
2771 for (k = 0; k < K; ++k)
2772 {
2773 // The inner loop broadcasts the B matrix data and
2774 // multiplies it with the A matrix.
2775 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2776 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2777 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2778 tB += tb_inc_row;
2779
2780 //broadcasted matrix B elements are multiplied
2781 //with matrix A columns.
2782 ymm3 = _mm256_loadu_pd(tA);
2783 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2784 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
2785 ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
2786
2787 tA += lda;
2788 }
2789 // alpha, beta multiplication.
2790 ymm0 = _mm256_broadcast_sd(alpha_cast);
2791 //ymm1 = _mm256_broadcast_sd(beta_cast);
2792
2793 //multiply A*B by alpha.
2794 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2795 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2796 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2797
2798 // multiply C by beta and accumulate.
2799 /*ymm2 = _mm256_loadu_pd(tC);
2800 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
2801 _mm256_storeu_pd(tC, ymm4);
2802
2803 // multiply C by beta and accumulate.
2804 tC += ldc;
2805 /*ymm2 = _mm256_loadu_pd(tC);
2806 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
2807 _mm256_storeu_pd(tC, ymm5);
2808
2809 // multiply C by beta and accumulate.
2810 tC += ldc;
2811 /*ymm2 = _mm256_loadu_pd(tC);
2812 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/
2813 _mm256_storeu_pd(tC, ymm6);
2814 }
2815 n_remainder = N - col_idx;
2816 // if the N is not multiple of 3.
2817 // handling edge case.
2818 if (n_remainder == 2)
2819 {
2820 //pointer math to point to proper memory
2821 tC = C + ldc * col_idx + row_idx;
2822 tB = B + tb_inc_col * col_idx;
2823 tA = A + row_idx;
2824
2825 ymm4 = _mm256_setzero_pd();
2826 ymm5 = _mm256_setzero_pd();
2827
2828 for (k = 0; k < K; ++k)
2829 {
2830 // The inner loop broadcasts the B matrix data and
2831 // multiplies it with the A matrix.
2832 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2833 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2834 tB += tb_inc_row;
2835
2836 //broadcasted matrix B elements are multiplied
2837 //with matrix A columns.
2838 ymm3 = _mm256_loadu_pd(tA);
2839 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2840 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
2841
2842 tA += lda;
2843 }
2844 // alpha, beta multiplication.
2845 ymm0 = _mm256_broadcast_sd(alpha_cast);
2846 //ymm1 = _mm256_broadcast_sd(beta_cast);
2847
2848 //multiply A*B by alpha.
2849 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2850 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2851
2852 // multiply C by beta and accumulate.
2853 /*ymm2 = _mm256_loadu_pd(tC);
2854 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
2855 _mm256_storeu_pd(tC, ymm4);
2856
2857 // multiply C by beta and accumulate.
2858 tC += ldc;
2859 /*ymm2 = _mm256_loadu_pd(tC);
2860 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
2861 _mm256_storeu_pd(tC, ymm5);
2862
2863 col_idx += 2;
2864
2865 }
2866 // if the N is not multiple of 3.
2867 // handling edge case.
2868 if (n_remainder == 1)
2869 {
2870 //pointer math to point to proper memory
2871 tC = C + ldc * col_idx + row_idx;
2872 tB = B + tb_inc_col * col_idx;
2873 tA = A + row_idx;
2874
2875 ymm4 = _mm256_setzero_pd();
2876
2877 for (k = 0; k < K; ++k)
2878 {
2879 // The inner loop broadcasts the B matrix data and
2880 // multiplies it with the A matrix.
2881 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2882 tB += tb_inc_row;
2883
2884 //broadcasted matrix B elements are multiplied
2885 //with matrix A columns.
2886 ymm3 = _mm256_loadu_pd(tA);
2887 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2888
2889 tA += lda;
2890 }
2891 // alpha, beta multiplication.
2892 ymm0 = _mm256_broadcast_sd(alpha_cast);
2893 //ymm1 = _mm256_broadcast_sd(beta_cast);
2894
2895 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2896
2897 // multiply C by beta and accumulate.
2898 /*ymm2 = _mm256_loadu_pd(tC);
2899 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/
2900 _mm256_storeu_pd(tC, ymm4);
2901
2902 }
2903
2904 row_idx += 4;
2905 }
2906 // M is not a multiple of 32.
2907 // The handling of edge case where the remainder
2908 // dimension is less than 8. The padding takes place
2909 // to handle this case.
2910 if ((m_remainder) && (lda > 3))
2911 {
2912 double f_temp[8];
2913
2914 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2915 {
2916 //pointer math to point to proper memory
2917 tC = C + ldc * col_idx + row_idx;
2918 tB = B + tb_inc_col * col_idx;
2919 tA = A + row_idx;
2920
2921 // clear scratch registers.
2922 ymm5 = _mm256_setzero_pd();
2923 ymm7 = _mm256_setzero_pd();
2924 ymm9 = _mm256_setzero_pd();
2925
2926 for (k = 0; k < (K - 1); ++k)
2927 {
2928 // The inner loop broadcasts the B matrix data and
2929 // multiplies it with the A matrix.
2930 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2931 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2932 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2933 tB += tb_inc_row;
2934
2935 //broadcasted matrix B elements are multiplied
2936 //with matrix A columns.
2937 ymm3 = _mm256_loadu_pd(tA);
2938 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2939 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2940 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
2941
2942 tA += lda;
2943 }
2944 // alpha, beta multiplication.
2945 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2946 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2947 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2948 tB += tb_inc_row;
2949
2950 for (int i = 0; i < m_remainder; i++)
2951 {
2952 f_temp[i] = tA[i];
2953 }
2954 ymm3 = _mm256_loadu_pd(f_temp);
2955 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2956 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2957 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
2958
2959 ymm0 = _mm256_broadcast_sd(alpha_cast);
2960 //ymm1 = _mm256_broadcast_sd(beta_cast);
2961
2962 //multiply A*B by alpha.
2963 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2964 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2965 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2966
2967
2968 /*for (int i = 0; i < m_remainder; i++)
2969 {
2970 f_temp[i] = tC[i];
2971 }
2972 ymm2 = _mm256_loadu_pd(f_temp);
2973 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
2974 _mm256_storeu_pd(f_temp, ymm5);
2975 for (int i = 0; i < m_remainder; i++)
2976 {
2977 tC[i] = f_temp[i];
2978 }
2979
2980 tC += ldc;
2981 /*for (int i = 0; i < m_remainder; i++)
2982 {
2983 f_temp[i] = tC[i];
2984 }
2985 ymm2 = _mm256_loadu_pd(f_temp);
2986 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
2987 _mm256_storeu_pd(f_temp, ymm7);
2988 for (int i = 0; i < m_remainder; i++)
2989 {
2990 tC[i] = f_temp[i];
2991 }
2992
2993 tC += ldc;
2994 /*for (int i = 0; i < m_remainder; i++)
2995 {
2996 f_temp[i] = tC[i];
2997 }
2998 ymm2 = _mm256_loadu_pd(f_temp);
2999 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/
3000 _mm256_storeu_pd(f_temp, ymm9);
3001 for (int i = 0; i < m_remainder; i++)
3002 {
3003 tC[i] = f_temp[i];
3004 }
3005 }
3006 n_remainder = N - col_idx;
3007 // if the N is not multiple of 3.
3008 // handling edge case.
3009 if (n_remainder == 2)
3010 {
3011 //pointer math to point to proper memory
3012 tC = C + ldc * col_idx + row_idx;
3013 tB = B + tb_inc_col * col_idx;
3014 tA = A + row_idx;
3015
3016 ymm5 = _mm256_setzero_pd();
3017 ymm7 = _mm256_setzero_pd();
3018
3019 for (k = 0; k < (K - 1); ++k)
3020 {
3021 // The inner loop broadcasts the B matrix data and
3022 // multiplies it with the A matrix.
3023 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3024 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3025 tB += tb_inc_row;
3026
3027 ymm3 = _mm256_loadu_pd(tA);
3028 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3029 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3030
3031 tA += lda;
3032 }
3033
3034 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3035 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3036 tB += tb_inc_row;
3037
3038 for (int i = 0; i < m_remainder; i++)
3039 {
3040 f_temp[i] = tA[i];
3041 }
3042 ymm3 = _mm256_loadu_pd(f_temp);
3043 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3044 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3045
3046 ymm0 = _mm256_broadcast_sd(alpha_cast);
3047 //ymm1 = _mm256_broadcast_sd(beta_cast);
3048
3049 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3050 ymm7 = _mm256_mul_pd(ymm7, ymm0);
3051
3052 /*for (int i = 0; i < m_remainder; i++)
3053 {
3054 f_temp[i] = tC[i];
3055 }
3056 ymm2 = _mm256_loadu_pd(f_temp);
3057 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
3058 _mm256_storeu_pd(f_temp, ymm5);
3059 for (int i = 0; i < m_remainder; i++)
3060 {
3061 tC[i] = f_temp[i];
3062 }
3063
3064 tC += ldc;
3065 /*for (int i = 0; i < m_remainder; i++)
3066 {
3067 f_temp[i] = tC[i];
3068 }
3069 ymm2 = _mm256_loadu_pd(f_temp);
3070 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/
3071 _mm256_storeu_pd(f_temp, ymm7);
3072 for (int i = 0; i < m_remainder; i++)
3073 {
3074 tC[i] = f_temp[i];
3075 }
3076 }
3077 // if the N is not multiple of 3.
3078 // handling edge case.
3079 if (n_remainder == 1)
3080 {
3081 //pointer math to point to proper memory
3082 tC = C + ldc * col_idx + row_idx;
3083 tB = B + tb_inc_col * col_idx;
3084 tA = A + row_idx;
3085
3086 ymm5 = _mm256_setzero_pd();
3087
3088 for (k = 0; k < (K - 1); ++k)
3089 {
3090 // The inner loop broadcasts the B matrix data and
3091 // multiplies it with the A matrix.
3092 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3093 tB += tb_inc_row;
3094
3095 ymm3 = _mm256_loadu_pd(tA);
3096 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3097
3098 tA += lda;
3099 }
3100
3101 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3102 tB += tb_inc_row;
3103
3104 for (int i = 0; i < m_remainder; i++)
3105 {
3106 f_temp[i] = tA[i];
3107 }
3108 ymm3 = _mm256_loadu_pd(f_temp);
3109 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3110
3111 ymm0 = _mm256_broadcast_sd(alpha_cast);
3112 //ymm1 = _mm256_broadcast_sd(beta_cast);
3113
3114 // multiply C by beta and accumulate.
3115 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3116
3117 /*for (int i = 0; i < m_remainder; i++)
3118 {
3119 f_temp[i] = tC[i];
3120 }
3121 ymm2 = _mm256_loadu_pd(f_temp);
3122 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/
3123 _mm256_storeu_pd(f_temp, ymm5);
3124 for (int i = 0; i < m_remainder; i++)
3125 {
3126 tC[i] = f_temp[i];
3127 }
3128 }
3129 m_remainder = 0;
3130 }
3131
3132 if (m_remainder)
3133 {
3134 double result;
3135 for (; row_idx < M; row_idx += 1)
3136 {
3137 for (col_idx = 0; col_idx < N; col_idx += 1)
3138 {
3139 //pointer math to point to proper memory
3140 tC = C + ldc * col_idx + row_idx;
3141 tB = B + tb_inc_col * col_idx;
3142 tA = A + row_idx;
3143
3144 result = 0;
3145 for (k = 0; k < K; ++k)
3146 {
3147 result += (*tA) * (*tB);
3148 tA += lda;
3149 tB += tb_inc_row;
3150 }
3151
3152 result *= (*alpha_cast);
3153 (*tC) = /*(*tC) * (*beta_cast) + */result;
3154 }
3155 }
3156 }
3157
3158 //copy/compute sryk values back to C using SIMD
3159 if ( bli_seq0( *beta_cast ) )
3160 {//just copy for beta = 0
3161 dim_t _i, _j, k, _l;
3162 if(bli_obj_is_lower(c)) //c is lower
3163 {
3164 //first column
3165 _j = 0;
3166 k = M >> 2;
3167 _i = 0;
3168 for ( _l = 0; _l < k; _l++ )
3169 {
3170 ymm0 = _mm256_loadu_pd((C + _i*rsc));
3171 _mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0);
3172 _i += 4;
3173 }
3174 while (_i < M )
3175 {
3176 bli_ddcopys( *(C + _i*rsc + _j*ldc),
3177 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3178 _i++;
3179 }
3180 _j++;
3181 while ( _j < N ) //next column
3182 {
3183 //k = (_j + (4 - (_j & 3)));
3184 _l = _j & 3;
3185 k = (_l != 0) ? (_j + (4 - _l)) : _j;
3186 k = (k <= M) ? k : M;
3187 for ( _i = _j; _i < k; ++_i )
3188 {
3189 bli_ddcopys( *(C + _i*rsc + _j*ldc),
3190 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3191 }
3192 k = (M - _i) >> 2;
3193 _l = 0;
3194 while ( _l < k )
3195 {
3196 ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
3197 _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
3198
3199 _i += 4;
3200 _l++;
3201 }
3202 while (_i < M )
3203 {
3204 bli_ddcopys( *(C + _i*rsc + _j*ldc),
3205 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3206 _i++;
3207 }
3208 _j++;
3209 }
3210 }
3211 else //c is upper
3212 {
3213 for ( _j = 0; _j < N; ++_j )
3214 {
3215 k = (_j + 1) >> 2;
3216 _i = 0;
3217 _l = 0;
3218 while ( _l < k )
3219 {
3220 ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
3221 _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
3222 _i += 4;
3223 _l++;
3224 }
3225 while (_i <= _j )
3226 {
3227 bli_ddcopys( *(C + _i*rsc + _j*ldc),
3228 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3229 ++_i;
3230 }
3231 }
3232 }
3233 }
3234 else
3235 {//when beta is non-zero, fmadd and store the results
3236 dim_t _i, _j, k, _l;
3237 ymm1 = _mm256_broadcast_sd(beta_cast);
3238 if(bli_obj_is_lower(c)) //c is lower
3239 {
3240 //first column
3241 _j = 0;
3242 k = M >> 2;
3243 _i = 0;
3244 for ( _l = 0; _l < k; _l++ )
3245 {
3246 ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC));
3247 ymm0 = _mm256_loadu_pd((C + _i*rsc));
3248 ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
3249 _mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0);
3250 _i += 4;
3251 }
3252 while (_i < M )
3253 {
3254 bli_dddxpbys( *(C + _i*rsc + _j*ldc),
3255 *(beta_cast),
3256 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3257 _i++;
3258 }
3259 _j++;
3260 while ( _j < N ) //next column
3261 {
3262 //k = (_j + (4 - (_j & 3)));
3263 _l = _j & 3;
3264 k = (_l != 0) ? (_j + (4 - _l)) : _j;
3265 k = (k <= M) ? k : M;
3266 for ( _i = _j; _i < k; ++_i )
3267 {
3268 bli_dddxpbys( *(C + _i*rsc + _j*ldc),
3269 *(beta_cast),
3270 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3271 }
3272 k = (M - _i) >> 2;
3273 _l = 0;
3274 while ( _l < k )
3275 {
3276 ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC));
3277 ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
3278 ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
3279 _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
3280
3281 _i += 4;
3282 _l++;
3283 }
3284 while (_i < M )
3285 {
3286 bli_dddxpbys( *(C + _i*rsc + _j*ldc),
3287 *(beta_cast),
3288 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3289 _i++;
3290 }
3291 _j++;
3292 }
3293 }
3294 else //c is upper
3295 {
3296 for ( _j = 0; _j < N; ++_j )
3297 {
3298 k = (_j + 1) >> 2;
3299 _i = 0;
3300 _l = 0;
3301 while ( _l < k )
3302 {
3303 ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC));
3304 ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc));
3305 ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0);
3306 _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0);
3307 _i += 4;
3308 _l++;
3309 }
3310 while (_i <= _j )
3311 {
3312 bli_dddxpbys( *(C + _i*rsc + _j*ldc),
3313 *(beta_cast),
3314 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3315 ++_i;
3316 }
3317 }
3318 }
3319 }
3320
3321 return BLIS_SUCCESS;
3322 }
3323 else
3324 return BLIS_NONCONFORMAL_DIMENSIONS;
3325
3326
3327 };
3328
bli_ssyrk_small_atbn(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)3329 static err_t bli_ssyrk_small_atbn
3330 (
3331 obj_t* alpha,
3332 obj_t* a,
3333 obj_t* b,
3334 obj_t* beta,
3335 obj_t* c,
3336 cntx_t* cntx,
3337 cntl_t* cntl
3338 )
3339 {
3340 int M = bli_obj_length(c); // number of rows of Matrix C
3341 int N = bli_obj_width(c); // number of columns of Matrix C
3342 int K = bli_obj_length(b); // number of rows of Matrix B
3343 int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
3344 int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
3345 int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
3346 int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
3347 int row_idx = 0, col_idx = 0, k;
3348 int rs_matC = bli_obj_row_stride( c );
3349 int rsc = 1;
3350 float *A = a->buffer; // pointer to matrix A elements, stored in row major format
3351 float *B = b->buffer; // pointer to matrix B elements, stored in column major format
3352 float *C = C_pack; // pointer to matrix C elements, stored in column major format
3353 float *matCbuf = c->buffer;
3354
3355 float *tA = A, *tB = B, *tC = C;
3356
3357 __m256 ymm4, ymm5, ymm6, ymm7;
3358 __m256 ymm8, ymm9, ymm10, ymm11;
3359 __m256 ymm12, ymm13, ymm14, ymm15;
3360 __m256 ymm0, ymm1, ymm2, ymm3;
3361
3362 float result, scratch[8];
3363 float *alpha_cast, *beta_cast; // alpha, beta multiples
3364 alpha_cast = (alpha->buffer);
3365 beta_cast = (beta->buffer);
3366
3367 // The non-copy version of the A^T SYRK gives better performance for the small M cases.
3368 // The threshold is controlled by BLIS_ATBN_M_THRES
3369 if (M <= BLIS_ATBN_M_THRES)
3370 {
3371 for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
3372 {
3373 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3374 {
3375 tA = A + row_idx * lda;
3376 tB = B + col_idx * ldb;
3377 tC = C + col_idx * ldc + row_idx;
3378 // clear scratch registers.
3379 ymm4 = _mm256_setzero_ps();
3380 ymm5 = _mm256_setzero_ps();
3381 ymm6 = _mm256_setzero_ps();
3382 ymm7 = _mm256_setzero_ps();
3383 ymm8 = _mm256_setzero_ps();
3384 ymm9 = _mm256_setzero_ps();
3385 ymm10 = _mm256_setzero_ps();
3386 ymm11 = _mm256_setzero_ps();
3387 ymm12 = _mm256_setzero_ps();
3388 ymm13 = _mm256_setzero_ps();
3389 ymm14 = _mm256_setzero_ps();
3390 ymm15 = _mm256_setzero_ps();
3391
3392 //The inner loop computes the 4x3 values of the matrix.
3393 //The computation pattern is:
3394 // ymm4 ymm5 ymm6
3395 // ymm7 ymm8 ymm9
3396 // ymm10 ymm11 ymm12
3397 // ymm13 ymm14 ymm15
3398
3399 //The Dot operation is performed in the inner loop, 8 float elements fit
3400 //in the YMM register hence loop count incremented by 8
3401 for (k = 0; (k + 7) < K; k += 8)
3402 {
3403 ymm0 = _mm256_loadu_ps(tB + 0);
3404 ymm1 = _mm256_loadu_ps(tB + ldb);
3405 ymm2 = _mm256_loadu_ps(tB + 2 * ldb);
3406
3407 ymm3 = _mm256_loadu_ps(tA);
3408 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3409 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
3410 ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
3411
3412 ymm3 = _mm256_loadu_ps(tA + lda);
3413 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3414 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
3415 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
3416
3417 ymm3 = _mm256_loadu_ps(tA + 2 * lda);
3418 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3419 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
3420 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
3421
3422 ymm3 = _mm256_loadu_ps(tA + 3 * lda);
3423 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3424 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
3425 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
3426
3427 tA += 8;
3428 tB += 8;
3429
3430 }
3431
3432 // if K is not a multiple of 8, padding is done before load using temproary array.
3433 if (k < K)
3434 {
3435 int iter;
3436 float data_feeder[8] = { 0.0 };
3437
3438 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3439 ymm0 = _mm256_loadu_ps(data_feeder);
3440 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
3441 ymm1 = _mm256_loadu_ps(data_feeder);
3442 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
3443 ymm2 = _mm256_loadu_ps(data_feeder);
3444
3445 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3446 ymm3 = _mm256_loadu_ps(data_feeder);
3447 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3448 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
3449 ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
3450
3451 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3452 ymm3 = _mm256_loadu_ps(data_feeder);
3453 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3454 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
3455 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
3456
3457 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3458 ymm3 = _mm256_loadu_ps(data_feeder);
3459 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3460 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
3461 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
3462
3463 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3464 ymm3 = _mm256_loadu_ps(data_feeder);
3465 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3466 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
3467 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
3468
3469 }
3470
3471 //horizontal addition and storage of the data.
3472 //Results for 4x3 blocks of C is stored here
3473 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3474 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3475 _mm256_storeu_ps(scratch, ymm4);
3476 result = scratch[0] + scratch[4];
3477 result *= (*alpha_cast);
3478 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3479
3480 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3481 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3482 _mm256_storeu_ps(scratch, ymm7);
3483 result = scratch[0] + scratch[4];
3484 result *= (*alpha_cast);
3485 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3486
3487 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3488 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3489 _mm256_storeu_ps(scratch, ymm10);
3490 result = scratch[0] + scratch[4];
3491 result *= (*alpha_cast);
3492 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3493
3494 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3495 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3496 _mm256_storeu_ps(scratch, ymm13);
3497 result = scratch[0] + scratch[4];
3498 result *= (*alpha_cast);
3499 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3500
3501 tC += ldc;
3502 ymm5 = _mm256_hadd_ps(ymm5, ymm5);
3503 ymm5 = _mm256_hadd_ps(ymm5, ymm5);
3504 _mm256_storeu_ps(scratch, ymm5);
3505 result = scratch[0] + scratch[4];
3506 result *= (*alpha_cast);
3507 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3508
3509 ymm8 = _mm256_hadd_ps(ymm8, ymm8);
3510 ymm8 = _mm256_hadd_ps(ymm8, ymm8);
3511 _mm256_storeu_ps(scratch, ymm8);
3512 result = scratch[0] + scratch[4];
3513 result *= (*alpha_cast);
3514 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3515
3516 ymm11 = _mm256_hadd_ps(ymm11, ymm11);
3517 ymm11 = _mm256_hadd_ps(ymm11, ymm11);
3518 _mm256_storeu_ps(scratch, ymm11);
3519 result = scratch[0] + scratch[4];
3520 result *= (*alpha_cast);
3521 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3522
3523 ymm14 = _mm256_hadd_ps(ymm14, ymm14);
3524 ymm14 = _mm256_hadd_ps(ymm14, ymm14);
3525 _mm256_storeu_ps(scratch, ymm14);
3526 result = scratch[0] + scratch[4];
3527 result *= (*alpha_cast);
3528 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3529
3530 tC += ldc;
3531 ymm6 = _mm256_hadd_ps(ymm6, ymm6);
3532 ymm6 = _mm256_hadd_ps(ymm6, ymm6);
3533 _mm256_storeu_ps(scratch, ymm6);
3534 result = scratch[0] + scratch[4];
3535 result *= (*alpha_cast);
3536 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3537
3538 ymm9 = _mm256_hadd_ps(ymm9, ymm9);
3539 ymm9 = _mm256_hadd_ps(ymm9, ymm9);
3540 _mm256_storeu_ps(scratch, ymm9);
3541 result = scratch[0] + scratch[4];
3542 result *= (*alpha_cast);
3543 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3544
3545 ymm12 = _mm256_hadd_ps(ymm12, ymm12);
3546 ymm12 = _mm256_hadd_ps(ymm12, ymm12);
3547 _mm256_storeu_ps(scratch, ymm12);
3548 result = scratch[0] + scratch[4];
3549 result *= (*alpha_cast);
3550 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3551
3552 ymm15 = _mm256_hadd_ps(ymm15, ymm15);
3553 ymm15 = _mm256_hadd_ps(ymm15, ymm15);
3554 _mm256_storeu_ps(scratch, ymm15);
3555 result = scratch[0] + scratch[4];
3556 result *= (*alpha_cast);
3557 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3558 }
3559 }
3560
3561 int processed_col = col_idx;
3562 int processed_row = row_idx;
3563
3564 //The edge case handling where N is not a multiple of 3
3565 if (processed_col < N)
3566 {
3567 for (col_idx = processed_col; col_idx < N; col_idx += 1)
3568 {
3569 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3570 {
3571 tA = A + row_idx * lda;
3572 tB = B + col_idx * ldb;
3573 tC = C + col_idx * ldc + row_idx;
3574 // clear scratch registers.
3575 ymm4 = _mm256_setzero_ps();
3576 ymm7 = _mm256_setzero_ps();
3577 ymm10 = _mm256_setzero_ps();
3578 ymm13 = _mm256_setzero_ps();
3579
3580 //The inner loop computes the 4x1 values of the matrix.
3581 //The computation pattern is:
3582 // ymm4
3583 // ymm7
3584 // ymm10
3585 // ymm13
3586
3587 for (k = 0; (k + 7) < K; k += 8)
3588 {
3589 ymm0 = _mm256_loadu_ps(tB + 0);
3590
3591 ymm3 = _mm256_loadu_ps(tA);
3592 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3593
3594 ymm3 = _mm256_loadu_ps(tA + lda);
3595 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3596
3597 ymm3 = _mm256_loadu_ps(tA + 2 * lda);
3598 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3599
3600 ymm3 = _mm256_loadu_ps(tA + 3 * lda);
3601 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3602
3603 tA += 8;
3604 tB += 8;
3605 }
3606
3607 // if K is not a multiple of 8, padding is done before load using temproary array.
3608 if (k < K)
3609 {
3610 int iter;
3611 float data_feeder[8] = { 0.0 };
3612
3613 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3614 ymm0 = _mm256_loadu_ps(data_feeder);
3615
3616 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3617 ymm3 = _mm256_loadu_ps(data_feeder);
3618 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3619
3620 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3621 ymm3 = _mm256_loadu_ps(data_feeder);
3622 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3623
3624 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3625 ymm3 = _mm256_loadu_ps(data_feeder);
3626 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3627
3628 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3629 ymm3 = _mm256_loadu_ps(data_feeder);
3630 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3631
3632 }
3633
3634 //horizontal addition and storage of the data.
3635 //Results for 4x1 blocks of C is stored here
3636 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3637 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3638 _mm256_storeu_ps(scratch, ymm4);
3639 result = scratch[0] + scratch[4];
3640 result *= (*alpha_cast);
3641 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3642
3643 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3644 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3645 _mm256_storeu_ps(scratch, ymm7);
3646 result = scratch[0] + scratch[4];
3647 result *= (*alpha_cast);
3648 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3649
3650 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3651 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3652 _mm256_storeu_ps(scratch, ymm10);
3653 result = scratch[0] + scratch[4];
3654 result *= (*alpha_cast);
3655 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3656
3657 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3658 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3659 _mm256_storeu_ps(scratch, ymm13);
3660 result = scratch[0] + scratch[4];
3661 result *= (*alpha_cast);
3662 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3663
3664 }
3665 }
3666 processed_row = row_idx;
3667 }
3668
3669 //The edge case handling where M is not a multiple of 4
3670 if (processed_row < M)
3671 {
3672 for (row_idx = processed_row; row_idx < M; row_idx += 1)
3673 {
3674 for (col_idx = 0; col_idx < N; col_idx += 1)
3675 {
3676 tA = A + row_idx * lda;
3677 tB = B + col_idx * ldb;
3678 tC = C + col_idx * ldc + row_idx;
3679 // clear scratch registers.
3680 ymm4 = _mm256_setzero_ps();
3681
3682 for (k = 0; (k + 7) < K; k += 8)
3683 {
3684 ymm0 = _mm256_loadu_ps(tB + 0);
3685 ymm3 = _mm256_loadu_ps(tA);
3686 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3687
3688 tA += 8;
3689 tB += 8;
3690 }
3691
3692 // if K is not a multiple of 8, padding is done before load using temproary array.
3693 if (k < K)
3694 {
3695 int iter;
3696 float data_feeder[8] = { 0.0 };
3697
3698 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3699 ymm0 = _mm256_loadu_ps(data_feeder);
3700
3701 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3702 ymm3 = _mm256_loadu_ps(data_feeder);
3703 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3704
3705 }
3706
3707 //horizontal addition and storage of the data.
3708 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3709 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3710 _mm256_storeu_ps(scratch, ymm4);
3711 result = scratch[0] + scratch[4];
3712 result *= (*alpha_cast);
3713 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3714
3715 }
3716 }
3717 }
3718
3719 //copy/compute sryk values back to C
3720 if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C
3721 {
3722 dim_t _i, _j;
3723 if(bli_obj_is_lower(c)) //c is lower
3724 {
3725 for ( _j = 0; _j < N; ++_j )
3726 for ( _i = 0; _i < M; ++_i )
3727 if ( (doff_t)_j - (doff_t)_i <= 0 )
3728 {
3729 bli_sscopys( *(C + _i*rsc + _j*ldc),
3730 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3731 }
3732 }
3733 else //c is upper
3734 {
3735 for ( _j = 0; _j < N; ++_j )
3736 for ( _i = 0; _i < M; ++_i )
3737 if ( (doff_t)_j - (doff_t)_i >= 0 )
3738 {
3739 bli_sscopys( *(C + _i*rsc + _j*ldc),
3740 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3741 }
3742 }
3743 }
3744 else //when beta is non-zero, multiply and store result to C
3745 {
3746 dim_t _i, _j;
3747 if(bli_obj_is_lower(c)) //c is lower
3748 {
3749 for ( _j = 0; _j < N; ++_j )
3750 for ( _i = 0; _i < M; ++_i )
3751 if ( (doff_t)_j - (doff_t)_i <= 0 )
3752 {
3753 bli_sssxpbys( *(C + _i*rsc + _j*ldc),
3754 *(beta_cast),
3755 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3756 }
3757 }
3758 else //c is upper
3759 {
3760 for ( _j = 0; _j < N; ++_j )
3761 for ( _i = 0; _i < M; ++_i )
3762 if ( (doff_t)_j - (doff_t)_i >= 0 )
3763 {
3764 bli_sssxpbys( *(C + _i*rsc + _j*ldc),
3765 *(beta_cast),
3766 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
3767 }
3768 }
3769 }
3770
3771 return BLIS_SUCCESS;
3772 }
3773 else
3774 return BLIS_NONCONFORMAL_DIMENSIONS;
3775 }
3776
bli_dsyrk_small_atbn(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)3777 static err_t bli_dsyrk_small_atbn
3778 (
3779 obj_t* alpha,
3780 obj_t* a,
3781 obj_t* b,
3782 obj_t* beta,
3783 obj_t* c,
3784 cntx_t* cntx,
3785 cntl_t* cntl
3786 )
3787 {
3788 int M = bli_obj_length( c ); // number of rows of Matrix C
3789 int N = bli_obj_width( c ); // number of columns of Matrix C
3790 int K = bli_obj_length( b ); // number of rows of Matrix B
3791 int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
3792 int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
3793 int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C
3794 int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C
3795 int row_idx = 0, col_idx = 0, k;
3796 int rs_matC = bli_obj_row_stride( c );
3797 int rsc = 1;
3798 double *A = a->buffer; // pointer to matrix A elements, stored in row major format
3799 double *B = b->buffer; // pointer to matrix B elements, stored in column major format
3800 double *C = D_C_pack; // pointer to matrix C elements, stored in column major format
3801 double *matCbuf = c->buffer;
3802
3803 double *tA = A, *tB = B, *tC = C;
3804
3805 __m256d ymm4, ymm5, ymm6, ymm7;
3806 __m256d ymm8, ymm9, ymm10, ymm11;
3807 __m256d ymm12, ymm13, ymm14, ymm15;
3808 __m256d ymm0, ymm1, ymm2, ymm3;
3809
3810 double result, scratch[8];
3811 double *alpha_cast, *beta_cast; // alpha, beta multiples
3812 alpha_cast = (alpha->buffer);
3813 beta_cast = (beta->buffer);
3814
3815 // The non-copy version of the A^T SYRK gives better performance for the small M cases.
3816 // The threshold is controlled by BLIS_ATBN_M_THRES
3817 if (M <= BLIS_ATBN_M_THRES)
3818 {
3819 for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
3820 {
3821 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3822 {
3823 tA = A + row_idx * lda;
3824 tB = B + col_idx * ldb;
3825 tC = C + col_idx * ldc + row_idx;
3826 // clear scratch registers.
3827 ymm4 = _mm256_setzero_pd();
3828 ymm5 = _mm256_setzero_pd();
3829 ymm6 = _mm256_setzero_pd();
3830 ymm7 = _mm256_setzero_pd();
3831 ymm8 = _mm256_setzero_pd();
3832 ymm9 = _mm256_setzero_pd();
3833 ymm10 = _mm256_setzero_pd();
3834 ymm11 = _mm256_setzero_pd();
3835 ymm12 = _mm256_setzero_pd();
3836 ymm13 = _mm256_setzero_pd();
3837 ymm14 = _mm256_setzero_pd();
3838 ymm15 = _mm256_setzero_pd();
3839
3840 //The inner loop computes the 4x3 values of the matrix.
3841 //The computation pattern is:
3842 // ymm4 ymm5 ymm6
3843 // ymm7 ymm8 ymm9
3844 // ymm10 ymm11 ymm12
3845 // ymm13 ymm14 ymm15
3846
3847 //The Dot operation is performed in the inner loop, 4 double elements fit
3848 //in the YMM register hence loop count incremented by 4
3849 for (k = 0; (k + 3) < K; k += 4)
3850 {
3851 ymm0 = _mm256_loadu_pd(tB + 0);
3852 ymm1 = _mm256_loadu_pd(tB + ldb);
3853 ymm2 = _mm256_loadu_pd(tB + 2 * ldb);
3854
3855 ymm3 = _mm256_loadu_pd(tA);
3856 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3857 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
3858 ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
3859
3860 ymm3 = _mm256_loadu_pd(tA + lda);
3861 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
3862 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
3863 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3864
3865 ymm3 = _mm256_loadu_pd(tA + 2 * lda);
3866 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
3867 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
3868 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
3869
3870 ymm3 = _mm256_loadu_pd(tA + 3 * lda);
3871 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
3872 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
3873 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
3874
3875 tA += 4;
3876 tB += 4;
3877
3878 }
3879
3880 // if K is not a multiple of 4, padding is done before load using temproary array.
3881 if (k < K)
3882 {
3883 int iter;
3884 double data_feeder[4] = { 0.0 };
3885
3886 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3887 ymm0 = _mm256_loadu_pd(data_feeder);
3888 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
3889 ymm1 = _mm256_loadu_pd(data_feeder);
3890 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
3891 ymm2 = _mm256_loadu_pd(data_feeder);
3892
3893 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3894 ymm3 = _mm256_loadu_pd(data_feeder);
3895 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3896 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
3897 ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
3898
3899 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3900 ymm3 = _mm256_loadu_pd(data_feeder);
3901 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
3902 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
3903 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3904
3905 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3906 ymm3 = _mm256_loadu_pd(data_feeder);
3907 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
3908 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
3909 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
3910
3911 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3912 ymm3 = _mm256_loadu_pd(data_feeder);
3913 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
3914 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
3915 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
3916
3917 }
3918
3919 //horizontal addition and storage of the data.
3920 //Results for 4x3 blocks of C is stored here
3921 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
3922 _mm256_storeu_pd(scratch, ymm4);
3923 result = scratch[0] + scratch[2];
3924 result *= (*alpha_cast);
3925 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3926
3927 ymm7 = _mm256_hadd_pd(ymm7, ymm7);
3928 _mm256_storeu_pd(scratch, ymm7);
3929 result = scratch[0] + scratch[2];
3930 result *= (*alpha_cast);
3931 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3932
3933 ymm10 = _mm256_hadd_pd(ymm10, ymm10);
3934 _mm256_storeu_pd(scratch, ymm10);
3935 result = scratch[0] + scratch[2];
3936 result *= (*alpha_cast);
3937 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3938
3939 ymm13 = _mm256_hadd_pd(ymm13, ymm13);
3940 _mm256_storeu_pd(scratch, ymm13);
3941 result = scratch[0] + scratch[2];
3942 result *= (*alpha_cast);
3943 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3944
3945
3946 tC += ldc;
3947 ymm5 = _mm256_hadd_pd(ymm5, ymm5);
3948 _mm256_storeu_pd(scratch, ymm5);
3949 result = scratch[0] + scratch[2];
3950 result *= (*alpha_cast);
3951 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3952
3953 ymm8 = _mm256_hadd_pd(ymm8, ymm8);
3954 _mm256_storeu_pd(scratch, ymm8);
3955 result = scratch[0] + scratch[2];
3956 result *= (*alpha_cast);
3957 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3958
3959 ymm11 = _mm256_hadd_pd(ymm11, ymm11);
3960 _mm256_storeu_pd(scratch, ymm11);
3961 result = scratch[0] + scratch[2];
3962 result *= (*alpha_cast);
3963 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3964
3965 ymm14 = _mm256_hadd_pd(ymm14, ymm14);
3966 _mm256_storeu_pd(scratch, ymm14);
3967 result = scratch[0] + scratch[2];
3968 result *= (*alpha_cast);
3969 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3970
3971
3972 tC += ldc;
3973 ymm6 = _mm256_hadd_pd(ymm6, ymm6);
3974 _mm256_storeu_pd(scratch, ymm6);
3975 result = scratch[0] + scratch[2];
3976 result *= (*alpha_cast);
3977 tC[0] = result/* + tC[0] * (*beta_cast)*/;
3978
3979 ymm9 = _mm256_hadd_pd(ymm9, ymm9);
3980 _mm256_storeu_pd(scratch, ymm9);
3981 result = scratch[0] + scratch[2];
3982 result *= (*alpha_cast);
3983 tC[1] = result/* + tC[1] * (*beta_cast)*/;
3984
3985 ymm12 = _mm256_hadd_pd(ymm12, ymm12);
3986 _mm256_storeu_pd(scratch, ymm12);
3987 result = scratch[0] + scratch[2];
3988 result *= (*alpha_cast);
3989 tC[2] = result/* + tC[2] * (*beta_cast)*/;
3990
3991 ymm15 = _mm256_hadd_pd(ymm15, ymm15);
3992 _mm256_storeu_pd(scratch, ymm15);
3993 result = scratch[0] + scratch[2];
3994 result *= (*alpha_cast);
3995 tC[3] = result/* + tC[3] * (*beta_cast)*/;
3996 }
3997 }
3998
3999 int processed_col = col_idx;
4000 int processed_row = row_idx;
4001
4002 //The edge case handling where N is not a multiple of 3
4003 if (processed_col < N)
4004 {
4005 for (col_idx = processed_col; col_idx < N; col_idx += 1)
4006 {
4007 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
4008 {
4009 tA = A + row_idx * lda;
4010 tB = B + col_idx * ldb;
4011 tC = C + col_idx * ldc + row_idx;
4012 // clear scratch registers.
4013 ymm4 = _mm256_setzero_pd();
4014 ymm7 = _mm256_setzero_pd();
4015 ymm10 = _mm256_setzero_pd();
4016 ymm13 = _mm256_setzero_pd();
4017
4018 //The inner loop computes the 4x1 values of the matrix.
4019 //The computation pattern is:
4020 // ymm4
4021 // ymm7
4022 // ymm10
4023 // ymm13
4024
4025 for (k = 0; (k + 3) < K; k += 4)
4026 {
4027 ymm0 = _mm256_loadu_pd(tB + 0);
4028
4029 ymm3 = _mm256_loadu_pd(tA);
4030 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4031
4032 ymm3 = _mm256_loadu_pd(tA + lda);
4033 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
4034
4035 ymm3 = _mm256_loadu_pd(tA + 2 * lda);
4036 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
4037
4038 ymm3 = _mm256_loadu_pd(tA + 3 * lda);
4039 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
4040
4041 tA += 4;
4042 tB += 4;
4043 }
4044 // if K is not a multiple of 4, padding is done before load using temproary array.
4045 if (k < K)
4046 {
4047 int iter;
4048 double data_feeder[4] = { 0.0 };
4049
4050 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
4051 ymm0 = _mm256_loadu_pd(data_feeder);
4052
4053 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
4054 ymm3 = _mm256_loadu_pd(data_feeder);
4055 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4056
4057 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
4058 ymm3 = _mm256_loadu_pd(data_feeder);
4059 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
4060
4061 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
4062 ymm3 = _mm256_loadu_pd(data_feeder);
4063 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
4064
4065 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
4066 ymm3 = _mm256_loadu_pd(data_feeder);
4067 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
4068
4069 }
4070
4071 //horizontal addition and storage of the data.
4072 //Results for 4x1 blocks of C is stored here
4073 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
4074 _mm256_storeu_pd(scratch, ymm4);
4075 result = scratch[0] + scratch[2];
4076 result *= (*alpha_cast);
4077 tC[0] = result/* + tC[0] * (*beta_cast)*/;
4078
4079 ymm7 = _mm256_hadd_pd(ymm7, ymm7);
4080 _mm256_storeu_pd(scratch, ymm7);
4081 result = scratch[0] + scratch[2];
4082 result *= (*alpha_cast);
4083 tC[1] = result/* + tC[1] * (*beta_cast)*/;
4084
4085 ymm10 = _mm256_hadd_pd(ymm10, ymm10);
4086 _mm256_storeu_pd(scratch, ymm10);
4087 result = scratch[0] + scratch[2];
4088 result *= (*alpha_cast);
4089 tC[2] = result/* + tC[2] * (*beta_cast)*/;
4090
4091 ymm13 = _mm256_hadd_pd(ymm13, ymm13);
4092 _mm256_storeu_pd(scratch, ymm13);
4093 result = scratch[0] + scratch[2];
4094 result *= (*alpha_cast);
4095 tC[3] = result/* + tC[3] * (*beta_cast)*/;
4096
4097 }
4098 }
4099 processed_row = row_idx;
4100 }
4101
4102 // The edge case handling where M is not a multiple of 4
4103 if (processed_row < M)
4104 {
4105 for (row_idx = processed_row; row_idx < M; row_idx += 1)
4106 {
4107 for (col_idx = 0; col_idx < N; col_idx += 1)
4108 {
4109 tA = A + row_idx * lda;
4110 tB = B + col_idx * ldb;
4111 tC = C + col_idx * ldc + row_idx;
4112 // clear scratch registers.
4113 ymm4 = _mm256_setzero_pd();
4114
4115 for (k = 0; (k + 3) < K; k += 4)
4116 {
4117 ymm0 = _mm256_loadu_pd(tB + 0);
4118 ymm3 = _mm256_loadu_pd(tA);
4119 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4120
4121 tA += 4;
4122 tB += 4;
4123 }
4124
4125 // if K is not a multiple of 4, padding is done before load using temproary array.
4126 if (k < K)
4127 {
4128 int iter;
4129 double data_feeder[4] = { 0.0 };
4130
4131 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
4132 ymm0 = _mm256_loadu_pd(data_feeder);
4133
4134 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
4135 ymm3 = _mm256_loadu_pd(data_feeder);
4136 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4137
4138 }
4139
4140 //horizontal addition and storage of the data.
4141 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
4142 _mm256_storeu_pd(scratch, ymm4);
4143 result = scratch[0] + scratch[2];
4144 result *= (*alpha_cast);
4145 tC[0] = result/* + tC[0] * (*beta_cast)*/;
4146
4147 }
4148 }
4149 }
4150
4151 //copy/compute sryk values back to C
4152 if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C
4153 {
4154 dim_t _i, _j;
4155 if(bli_obj_is_lower(c)) //c is lower
4156 {
4157 for ( _j = 0; _j < N; ++_j )
4158 for ( _i = 0; _i < M; ++_i )
4159 if ( (doff_t)_j - (doff_t)_i <= 0 )
4160 {
4161 bli_ddcopys( *(C + _i*rsc + _j*ldc),
4162 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
4163 }
4164 }
4165 else //c is upper
4166 {
4167 for ( _j = 0; _j < N; ++_j )
4168 for ( _i = 0; _i < M; ++_i )
4169 if ( (doff_t)_j - (doff_t)_i >= 0 )
4170 {
4171 bli_ddcopys( *(C + _i*rsc + _j*ldc),
4172 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
4173 }
4174 }
4175 }
4176 else //when beta is non-zero, multiply and store result to C
4177 {
4178 dim_t _i, _j;
4179 if(bli_obj_is_lower(c)) //c is lower
4180 {
4181 for ( _j = 0; _j < N; ++_j )
4182 for ( _i = 0; _i < M; ++_i )
4183 if ( (doff_t)_j - (doff_t)_i <= 0 )
4184 {
4185 bli_dddxpbys( *(C + _i*rsc + _j*ldc),
4186 *(beta_cast),
4187 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
4188 }
4189 }
4190 else //c is upper
4191 {
4192 for ( _j = 0; _j < N; ++_j )
4193 for ( _i = 0; _i < M; ++_i )
4194 if ( (doff_t)_j - (doff_t)_i >= 0 )
4195 {
4196 bli_dddxpbys( *(C + _i*rsc + _j*ldc),
4197 *(beta_cast),
4198 *(matCbuf + _i*rs_matC + _j*ldc_matC) );
4199 }
4200 }
4201 }
4202
4203 return BLIS_SUCCESS;
4204 }
4205 else
4206 return BLIS_NONCONFORMAL_DIMENSIONS;
4207 }
4208
4209 #endif
4210
4211