1 /*
2
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
6
7 Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc.
8
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name(s) of the copyright holder(s) nor the names of its
18 contributors may be used to endorse or promote products derived
19 from this software without specific prior written permission.
20
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33 */
34
35 #include "immintrin.h"
36 #include "xmmintrin.h"
37 #include "blis.h"
38
39 #define AOCL_DTL_TRACE_ENTRY(x) ;
40 #define AOCL_DTL_TRACE_EXIT(x) ;
41 #define AOCL_DTL_TRACE_EXIT_ERR(x,y) ;
42
43 #ifdef BLIS_ENABLE_SMALL_MATRIX
44
45 #define MR 32
46 #define D_MR (MR >> 1)
47 #define NR 3
48 #define D_BLIS_SMALL_MATRIX_K_THRES_ROME 256
49
50 #define BLIS_ENABLE_PREFETCH
51 #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 )
52 #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2)
53 #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2)
54 #define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called.
55 #define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR).
56 static err_t bli_sgemm_small
57 (
58 obj_t* alpha,
59 obj_t* a,
60 obj_t* b,
61 obj_t* beta,
62 obj_t* c,
63 cntx_t* cntx,
64 cntl_t* cntl
65 );
66
67 static err_t bli_dgemm_small
68 (
69 obj_t* alpha,
70 obj_t* a,
71 obj_t* b,
72 obj_t* beta,
73 obj_t* c,
74 cntx_t* cntx,
75 cntl_t* cntl
76 );
77
78 static err_t bli_sgemm_small_atbn
79 (
80 obj_t* alpha,
81 obj_t* a,
82 obj_t* b,
83 obj_t* beta,
84 obj_t* c,
85 cntx_t* cntx,
86 cntl_t* cntl
87 );
88
89 static err_t bli_dgemm_small_atbn
90 (
91 obj_t* alpha,
92 obj_t* a,
93 obj_t* b,
94 obj_t* beta,
95 obj_t* c,
96 cntx_t* cntx,
97 cntl_t* cntl
98 );
99 /*
100 * The bli_gemm_small function will use the
101 * custom MRxNR kernels, to perform the computation.
102 * The custom kernels are used if the [M * N] < 240 * 240
103 */
bli_gemm_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)104 err_t bli_gemm_small
105 (
106 obj_t* alpha,
107 obj_t* a,
108 obj_t* b,
109 obj_t* beta,
110 obj_t* c,
111 cntx_t* cntx,
112 cntl_t* cntl
113 )
114 {
115 AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
116
117 #ifdef BLIS_ENABLE_MULTITHREADING
118 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
119 return BLIS_NOT_YET_IMPLEMENTED;
120 #endif
121 // If alpha is zero, scale by beta and return.
122 if (bli_obj_equals(alpha, &BLIS_ZERO))
123 {
124 return BLIS_NOT_YET_IMPLEMENTED;
125 }
126
127 // if row major format return.
128 if ((bli_obj_row_stride( a ) != 1) ||
129 (bli_obj_row_stride( b ) != 1) ||
130 (bli_obj_row_stride( c ) != 1))
131 {
132 return BLIS_INVALID_ROW_STRIDE;
133 }
134
135 num_t dt = bli_obj_dt(c);
136
137 if (bli_obj_has_trans( a ))
138 {
139 if (bli_obj_has_notrans( b ))
140 {
141 if (dt == BLIS_FLOAT)
142 {
143 return bli_sgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl);
144 }
145 else if (dt == BLIS_DOUBLE)
146 {
147 return bli_dgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl);
148 }
149 }
150
151 return BLIS_NOT_YET_IMPLEMENTED;
152 }
153
154 if (dt == BLIS_DOUBLE)
155 {
156 return bli_dgemm_small(alpha, a, b, beta, c, cntx, cntl);
157 }
158
159 if (dt == BLIS_FLOAT)
160 {
161 return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl);
162 }
163
164 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
165 return BLIS_NOT_YET_IMPLEMENTED;
166 };
167
168
bli_sgemm_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)169 static err_t bli_sgemm_small
170 (
171 obj_t* alpha,
172 obj_t* a,
173 obj_t* b,
174 obj_t* beta,
175 obj_t* c,
176 cntx_t* cntx,
177 cntl_t* cntl
178 )
179 {
180 AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
181 gint_t M = bli_obj_length( c ); // number of rows of Matrix C
182 gint_t N = bli_obj_width( c ); // number of columns of Matrix C
183 gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
184 gint_t L = M * N;
185
186 // when N is equal to 1 call GEMV instead of GEMM
187 if (N == 1)
188 {
189 bli_gemv
190 (
191 alpha,
192 a,
193 b,
194 beta,
195 c
196 );
197 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
198 return BLIS_SUCCESS;
199 }
200
201
202 if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
203 || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
204 {
205 guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
206 guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
207 guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
208 guint_t row_idx, col_idx, k;
209
210 float *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A
211 float *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B
212 float *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C
213
214 float *tA = A, *tB = B, *tC = C;//, *tA_pack;
215 float *tA_packed; // temporary pointer to hold packed A memory pointer
216
217 guint_t row_idx_packed; //packed A memory row index
218 guint_t lda_packed; //lda of packed A
219 guint_t col_idx_start; //starting index after A matrix is packed.
220 dim_t tb_inc_row = 1; // row stride of matrix B
221 dim_t tb_inc_col = ldb; // column stride of matrix B
222
223 __m256 ymm4, ymm5, ymm6, ymm7;
224 __m256 ymm8, ymm9, ymm10, ymm11;
225 __m256 ymm12, ymm13, ymm14, ymm15;
226 __m256 ymm0, ymm1, ymm2, ymm3;
227
228 gint_t n_remainder; // If the N is non multiple of 3.(N%3)
229 gint_t m_remainder; // If the M is non multiple of 32.(M%32)
230 gint_t required_packing_A = 1;
231 mem_t local_mem_buf_A_s;
232 float *A_pack = NULL;
233 rntm_t rntm;
234
235 const num_t dt_exec = bli_obj_dt( c );
236 float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha );
237 float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta );
238
239 /*Beta Zero Check*/
240 bool is_beta_non_zero=0;
241 if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
242 is_beta_non_zero = 1;
243 }
244
245 //update the pointer math if matrix B needs to be transposed.
246 if (bli_obj_has_trans( b )) {
247 tb_inc_col = 1; //switch row and column strides
248 tb_inc_row = ldb;
249 }
250
251 /*
252 * This function was using global array to pack part of A input when needed.
253 * However, using this global array make the function non-reentrant.
254 * Instead of using a global array we should allocate buffer for each invocation.
255 * Since the buffer size is too big or stack and doing malloc every time will be too expensive,
256 * better approach is to get the buffer from the pre-allocated pool and return
257 * it the pool once we are doing.
258 *
259 * In order to get the buffer from pool, we need access to memory broker,
260 * currently this function is not invoked in such a way that it can receive
261 * the memory broker (via rntm). Following hack will get the global memory
262 * broker that can be use it to access the pool.
263 *
264 * Note there will be memory allocation at least on first innovation
265 * as there will not be any pool created for this size.
266 * Subsequent invocations will just reuse the buffer from the pool.
267 */
268
269 bli_rntm_init_from_global( &rntm );
270 bli_rntm_set_num_threads_only( 1, &rntm );
271 bli_membrk_rntm_set_membrk( &rntm );
272
273 // Get the current size of the buffer pool for A block packing.
274 // We will use the same size to avoid pool re-initialization
275 siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
276 bli_rntm_membrk(&rntm)));
277
278 // Based on the available memory in the buffer we will decide if
279 // we want to do packing or not.
280 //
281 // This kernel assumes that "A" will be un-packged if N <= 3.
282 // Usually this range (N <= 3) is handled by SUP, however,
283 // if SUP is disabled or for any other condition if we do
284 // enter this kernel with N <= 3, we want to make sure that
285 // "A" remains unpacked.
286 //
287 // If this check is removed it will result in the crash as
288 // reported in CPUPL-587.
289 //
290
291 if ((N <= 3) || (((MR * K) << 2) > buffer_size))
292 {
293 required_packing_A = 0;
294 }
295 else
296 {
297 #ifdef BLIS_ENABLE_MEM_TRACING
298 printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size);
299 #endif
300 // Get the buffer from the pool, if there is no pool with
301 // required size, it will be created.
302 bli_membrk_acquire_m(&rntm,
303 buffer_size,
304 BLIS_BITVAL_BUFFER_FOR_A_BLOCK,
305 &local_mem_buf_A_s);
306
307 A_pack = bli_mem_buffer(&local_mem_buf_A_s);
308 }
309
310 /*
311 * The computation loop runs for MRxN columns of C matrix, thus
312 * accessing the MRxK A matrix data and KxNR B matrix data.
313 * The computation is organized as inner loops of dimension MRxNR.
314 */
315 // Process MR rows of C matrix at a time.
316 for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
317 {
318 col_idx_start = 0;
319 tA_packed = A;
320 row_idx_packed = row_idx;
321 lda_packed = lda;
322
323 // This is the part of the pack and compute optimization.
324 // During the first column iteration, we store the accessed A matrix into
325 // contiguous static memory. This helps to keep te A matrix in Cache and
326 // aviods the TLB misses.
327 if (required_packing_A)
328 {
329 col_idx = 0;
330
331 //pointer math to point to proper memory
332 tC = C + ldc * col_idx + row_idx;
333 tB = B + tb_inc_col * col_idx;
334 tA = A + row_idx;
335 tA_packed = A_pack;
336
337 #ifdef BLIS_ENABLE_PREFETCH
338 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
339 _mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
340 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
341 _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
342 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
343 _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
344 #endif
345 // clear scratch registers.
346 ymm4 = _mm256_setzero_ps();
347 ymm5 = _mm256_setzero_ps();
348 ymm6 = _mm256_setzero_ps();
349 ymm7 = _mm256_setzero_ps();
350 ymm8 = _mm256_setzero_ps();
351 ymm9 = _mm256_setzero_ps();
352 ymm10 = _mm256_setzero_ps();
353 ymm11 = _mm256_setzero_ps();
354 ymm12 = _mm256_setzero_ps();
355 ymm13 = _mm256_setzero_ps();
356 ymm14 = _mm256_setzero_ps();
357 ymm15 = _mm256_setzero_ps();
358
359 for (k = 0; k < K; ++k)
360 {
361 // The inner loop broadcasts the B matrix data and
362 // multiplies it with the A matrix.
363 // This loop is processing MR x K
364 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
365 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
366 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
367 tB += tb_inc_row;
368
369 //broadcasted matrix B elements are multiplied
370 //with matrix A columns.
371 ymm3 = _mm256_loadu_ps(tA);
372 _mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A
373 // ymm4 += ymm0 * ymm3;
374 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
375 // ymm8 += ymm1 * ymm3;
376 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
377 // ymm12 += ymm2 * ymm3;
378 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
379
380 ymm3 = _mm256_loadu_ps(tA + 8);
381 _mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A
382 // ymm5 += ymm0 * ymm3;
383 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
384 // ymm9 += ymm1 * ymm3;
385 ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
386 // ymm13 += ymm2 * ymm3;
387 ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
388
389 ymm3 = _mm256_loadu_ps(tA + 16);
390 _mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A
391 // ymm6 += ymm0 * ymm3;
392 ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
393 // ymm10 += ymm1 * ymm3;
394 ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
395 // ymm14 += ymm2 * ymm3;
396 ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
397
398 ymm3 = _mm256_loadu_ps(tA + 24);
399 _mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A
400 // ymm7 += ymm0 * ymm3;
401 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
402 // ymm11 += ymm1 * ymm3;
403 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
404 // ymm15 += ymm2 * ymm3;
405 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
406
407 tA += lda;
408 tA_packed += MR;
409 }
410 // alpha, beta multiplication.
411 ymm0 = _mm256_broadcast_ss(alpha_cast);
412
413 //multiply A*B by alpha.
414 ymm4 = _mm256_mul_ps(ymm4, ymm0);
415 ymm5 = _mm256_mul_ps(ymm5, ymm0);
416 ymm6 = _mm256_mul_ps(ymm6, ymm0);
417 ymm7 = _mm256_mul_ps(ymm7, ymm0);
418 ymm8 = _mm256_mul_ps(ymm8, ymm0);
419 ymm9 = _mm256_mul_ps(ymm9, ymm0);
420 ymm10 = _mm256_mul_ps(ymm10, ymm0);
421 ymm11 = _mm256_mul_ps(ymm11, ymm0);
422 ymm12 = _mm256_mul_ps(ymm12, ymm0);
423 ymm13 = _mm256_mul_ps(ymm13, ymm0);
424 ymm14 = _mm256_mul_ps(ymm14, ymm0);
425 ymm15 = _mm256_mul_ps(ymm15, ymm0);
426
427 if(is_beta_non_zero)
428 {
429 ymm1 = _mm256_broadcast_ss(beta_cast);
430 // multiply C by beta and accumulate col 1.
431 ymm2 = _mm256_loadu_ps(tC);
432 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
433 ymm2 = _mm256_loadu_ps(tC + 8);
434 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
435 ymm2 = _mm256_loadu_ps(tC + 16);
436 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
437 ymm2 = _mm256_loadu_ps(tC + 24);
438 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
439
440 float* ttC = tC +ldc;
441 ymm2 = _mm256_loadu_ps(ttC);
442 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
443 ymm2 = _mm256_loadu_ps(ttC + 8);
444 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
445 ymm2 = _mm256_loadu_ps(ttC + 16);
446 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
447 ymm2 = _mm256_loadu_ps(ttC + 24);
448 ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
449
450 ttC += ldc;
451 ymm2 = _mm256_loadu_ps(ttC);
452 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
453 ymm2 = _mm256_loadu_ps(ttC + 8);
454 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
455 ymm2 = _mm256_loadu_ps(ttC + 16);
456 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
457 ymm2 = _mm256_loadu_ps(ttC + 24);
458 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
459 }
460 _mm256_storeu_ps(tC, ymm4);
461 _mm256_storeu_ps(tC + 8, ymm5);
462 _mm256_storeu_ps(tC + 16, ymm6);
463 _mm256_storeu_ps(tC + 24, ymm7);
464
465 // multiply C by beta and accumulate, col 2.
466 tC += ldc;
467 _mm256_storeu_ps(tC, ymm8);
468 _mm256_storeu_ps(tC + 8, ymm9);
469 _mm256_storeu_ps(tC + 16, ymm10);
470 _mm256_storeu_ps(tC + 24, ymm11);
471
472 // multiply C by beta and accumulate, col 3.
473 tC += ldc;
474 _mm256_storeu_ps(tC, ymm12);
475 _mm256_storeu_ps(tC + 8, ymm13);
476 _mm256_storeu_ps(tC + 16, ymm14);
477 _mm256_storeu_ps(tC + 24, ymm15);
478
479 // modify the pointer arithematic to use packed A matrix.
480 col_idx_start = NR;
481 tA_packed = A_pack;
482 row_idx_packed = 0;
483 lda_packed = MR;
484 }
485 // Process NR columns of C matrix at a time.
486 for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
487 {
488 //pointer math to point to proper memory
489 tC = C + ldc * col_idx + row_idx;
490 tB = B + tb_inc_col * col_idx;
491 tA = tA_packed + row_idx_packed;
492
493 #ifdef BLIS_ENABLE_PREFETCH
494 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
495 _mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
496 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
497 _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
498 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
499 _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
500 #endif
501 // clear scratch registers.
502 ymm4 = _mm256_setzero_ps();
503 ymm5 = _mm256_setzero_ps();
504 ymm6 = _mm256_setzero_ps();
505 ymm7 = _mm256_setzero_ps();
506 ymm8 = _mm256_setzero_ps();
507 ymm9 = _mm256_setzero_ps();
508 ymm10 = _mm256_setzero_ps();
509 ymm11 = _mm256_setzero_ps();
510 ymm12 = _mm256_setzero_ps();
511 ymm13 = _mm256_setzero_ps();
512 ymm14 = _mm256_setzero_ps();
513 ymm15 = _mm256_setzero_ps();
514
515 for (k = 0; k < K; ++k)
516 {
517 // The inner loop broadcasts the B matrix data and
518 // multiplies it with the A matrix.
519 // This loop is processing MR x K
520 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
521 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
522 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
523 tB += tb_inc_row;
524
525 //broadcasted matrix B elements are multiplied
526 //with matrix A columns.
527 ymm3 = _mm256_loadu_ps(tA);
528 // ymm4 += ymm0 * ymm3;
529 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
530 // ymm8 += ymm1 * ymm3;
531 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
532 // ymm12 += ymm2 * ymm3;
533 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
534
535 ymm3 = _mm256_loadu_ps(tA + 8);
536 // ymm5 += ymm0 * ymm3;
537 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
538 // ymm9 += ymm1 * ymm3;
539 ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
540 // ymm13 += ymm2 * ymm3;
541 ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
542
543 ymm3 = _mm256_loadu_ps(tA + 16);
544 // ymm6 += ymm0 * ymm3;
545 ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
546 // ymm10 += ymm1 * ymm3;
547 ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
548 // ymm14 += ymm2 * ymm3;
549 ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
550
551 ymm3 = _mm256_loadu_ps(tA + 24);
552 // ymm7 += ymm0 * ymm3;
553 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
554 // ymm11 += ymm1 * ymm3;
555 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
556 // ymm15 += ymm2 * ymm3;
557 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
558
559 tA += lda_packed;
560 }
561 // alpha, beta multiplication.
562 ymm0 = _mm256_broadcast_ss(alpha_cast);
563
564 //multiply A*B by alpha.
565 ymm4 = _mm256_mul_ps(ymm4, ymm0);
566 ymm5 = _mm256_mul_ps(ymm5, ymm0);
567 ymm6 = _mm256_mul_ps(ymm6, ymm0);
568 ymm7 = _mm256_mul_ps(ymm7, ymm0);
569 ymm8 = _mm256_mul_ps(ymm8, ymm0);
570 ymm9 = _mm256_mul_ps(ymm9, ymm0);
571 ymm10 = _mm256_mul_ps(ymm10, ymm0);
572 ymm11 = _mm256_mul_ps(ymm11, ymm0);
573 ymm12 = _mm256_mul_ps(ymm12, ymm0);
574 ymm13 = _mm256_mul_ps(ymm13, ymm0);
575 ymm14 = _mm256_mul_ps(ymm14, ymm0);
576 ymm15 = _mm256_mul_ps(ymm15, ymm0);
577
578 if(is_beta_non_zero)
579 {
580 ymm1 = _mm256_broadcast_ss(beta_cast);
581 // multiply C by beta and accumulate col 1.
582 ymm2 = _mm256_loadu_ps(tC);
583 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
584 ymm2 = _mm256_loadu_ps(tC + 8);
585 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
586 ymm2 = _mm256_loadu_ps(tC + 16);
587 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
588 ymm2 = _mm256_loadu_ps(tC + 24);
589 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
590 float* ttC = tC +ldc;
591 ymm2 = _mm256_loadu_ps(ttC);
592 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
593 ymm2 = _mm256_loadu_ps(ttC + 8);
594 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
595 ymm2 = _mm256_loadu_ps(ttC + 16);
596 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
597 ymm2 = _mm256_loadu_ps(ttC + 24);
598 ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
599 ttC = ttC +ldc;
600 ymm2 = _mm256_loadu_ps(ttC);
601 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
602 ymm2 = _mm256_loadu_ps(ttC + 8);
603 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
604 ymm2 = _mm256_loadu_ps(ttC + 16);
605 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
606 ymm2 = _mm256_loadu_ps(ttC + 24);
607 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
608 }
609 _mm256_storeu_ps(tC, ymm4);
610 _mm256_storeu_ps(tC + 8, ymm5);
611 _mm256_storeu_ps(tC + 16, ymm6);
612 _mm256_storeu_ps(tC + 24, ymm7);
613
614 // multiply C by beta and accumulate, col 2.
615 tC += ldc;
616 _mm256_storeu_ps(tC, ymm8);
617 _mm256_storeu_ps(tC + 8, ymm9);
618 _mm256_storeu_ps(tC + 16, ymm10);
619 _mm256_storeu_ps(tC + 24, ymm11);
620
621 // multiply C by beta and accumulate, col 3.
622 tC += ldc;
623 _mm256_storeu_ps(tC, ymm12);
624 _mm256_storeu_ps(tC + 8, ymm13);
625 _mm256_storeu_ps(tC + 16, ymm14);
626 _mm256_storeu_ps(tC + 24, ymm15);
627
628 }
629 n_remainder = N - col_idx;
630
631 // if the N is not multiple of 3.
632 // handling edge case.
633 if (n_remainder == 2)
634 {
635 //pointer math to point to proper memory
636 tC = C + ldc * col_idx + row_idx;
637 tB = B + tb_inc_col * col_idx;
638 tA = A + row_idx;
639
640 // clear scratch registers.
641 ymm8 = _mm256_setzero_ps();
642 ymm9 = _mm256_setzero_ps();
643 ymm10 = _mm256_setzero_ps();
644 ymm11 = _mm256_setzero_ps();
645 ymm12 = _mm256_setzero_ps();
646 ymm13 = _mm256_setzero_ps();
647 ymm14 = _mm256_setzero_ps();
648 ymm15 = _mm256_setzero_ps();
649
650 for (k = 0; k < K; ++k)
651 {
652 // The inner loop broadcasts the B matrix data and
653 // multiplies it with the A matrix.
654 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
655 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
656 tB += tb_inc_row;
657
658 //broadcasted matrix B elements are multiplied
659 //with matrix A columns.
660 ymm3 = _mm256_loadu_ps(tA);
661 ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
662 ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
663
664 ymm3 = _mm256_loadu_ps(tA + 8);
665 ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
666 ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
667
668 ymm3 = _mm256_loadu_ps(tA + 16);
669 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
670 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
671
672 ymm3 = _mm256_loadu_ps(tA + 24);
673 ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);
674 ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
675
676 tA += lda;
677
678 }
679 // alpha, beta multiplication.
680 ymm0 = _mm256_broadcast_ss(alpha_cast);
681
682 //multiply A*B by alpha.
683 ymm8 = _mm256_mul_ps(ymm8, ymm0);
684 ymm9 = _mm256_mul_ps(ymm9, ymm0);
685 ymm10 = _mm256_mul_ps(ymm10, ymm0);
686 ymm11 = _mm256_mul_ps(ymm11, ymm0);
687 ymm12 = _mm256_mul_ps(ymm12, ymm0);
688 ymm13 = _mm256_mul_ps(ymm13, ymm0);
689 ymm14 = _mm256_mul_ps(ymm14, ymm0);
690 ymm15 = _mm256_mul_ps(ymm15, ymm0);
691
692 // multiply C by beta and accumulate, col 1.
693 if(is_beta_non_zero)
694 {
695 ymm1 = _mm256_broadcast_ss(beta_cast);
696 ymm2 = _mm256_loadu_ps(tC);
697 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
698 ymm2 = _mm256_loadu_ps(tC + 8);
699 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
700 ymm2 = _mm256_loadu_ps(tC + 16);
701 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
702 ymm2 = _mm256_loadu_ps(tC + 24);
703 ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
704
705 float* ttC = tC +ldc;
706 // multiply C by beta and accumulate, col 2.
707 ymm2 = _mm256_loadu_ps(ttC);
708 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
709 ymm2 = _mm256_loadu_ps(ttC + 8);
710 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
711 ymm2 = _mm256_loadu_ps(ttC + 16);
712 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
713 ymm2 = _mm256_loadu_ps(ttC + 24);
714 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
715 }
716 _mm256_storeu_ps(tC, ymm8);
717 _mm256_storeu_ps(tC + 8, ymm9);
718 _mm256_storeu_ps(tC + 16, ymm10);
719 _mm256_storeu_ps(tC + 24, ymm11);
720 tC += ldc;
721 _mm256_storeu_ps(tC, ymm12);
722 _mm256_storeu_ps(tC + 8, ymm13);
723 _mm256_storeu_ps(tC + 16, ymm14);
724 _mm256_storeu_ps(tC + 24, ymm15);
725
726 col_idx += 2;
727 }
728 // if the N is not multiple of 3.
729 // handling edge case.
730 if (n_remainder == 1)
731 {
732 //pointer math to point to proper memory
733 tC = C + ldc * col_idx + row_idx;
734 tB = B + tb_inc_col * col_idx;
735 tA = A + row_idx;
736
737 // clear scratch registers.
738 ymm12 = _mm256_setzero_ps();
739 ymm13 = _mm256_setzero_ps();
740 ymm14 = _mm256_setzero_ps();
741 ymm15 = _mm256_setzero_ps();
742
743 for (k = 0; k < K; ++k)
744 {
745 // The inner loop broadcasts the B matrix data and
746 // multiplies it with the A matrix.
747 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
748 tB += tb_inc_row;
749
750 //broadcasted matrix B elements are multiplied
751 //with matrix A columns.
752 ymm3 = _mm256_loadu_ps(tA);
753 ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
754
755 ymm3 = _mm256_loadu_ps(tA + 8);
756 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
757
758 ymm3 = _mm256_loadu_ps(tA + 16);
759 ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
760
761 ymm3 = _mm256_loadu_ps(tA + 24);
762 ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15);
763
764 tA += lda;
765
766 }
767 // alpha, beta multiplication.
768 ymm0 = _mm256_broadcast_ss(alpha_cast);
769
770 //multiply A*B by alpha.
771 ymm12 = _mm256_mul_ps(ymm12, ymm0);
772 ymm13 = _mm256_mul_ps(ymm13, ymm0);
773 ymm14 = _mm256_mul_ps(ymm14, ymm0);
774 ymm15 = _mm256_mul_ps(ymm15, ymm0);
775
776 if(is_beta_non_zero)
777 {
778 ymm1 = _mm256_broadcast_ss(beta_cast);
779 // multiply C by beta and accumulate.
780 ymm2 = _mm256_loadu_ps(tC + 0);
781 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
782 ymm2 = _mm256_loadu_ps(tC + 8);
783 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
784 ymm2 = _mm256_loadu_ps(tC + 16);
785 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
786 ymm2 = _mm256_loadu_ps(tC + 24);
787 ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
788 }
789
790 _mm256_storeu_ps(tC + 0, ymm12);
791 _mm256_storeu_ps(tC + 8, ymm13);
792 _mm256_storeu_ps(tC + 16, ymm14);
793 _mm256_storeu_ps(tC + 24, ymm15);
794 }
795 }
796
797 m_remainder = M - row_idx;
798
799 if (m_remainder >= 24)
800 {
801 m_remainder -= 24;
802
803 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
804 {
805 //pointer math to point to proper memory
806 tC = C + ldc * col_idx + row_idx;
807 tB = B + tb_inc_col * col_idx;
808 tA = A + row_idx;
809
810 // clear scratch registers.
811 ymm4 = _mm256_setzero_ps();
812 ymm5 = _mm256_setzero_ps();
813 ymm6 = _mm256_setzero_ps();
814 ymm8 = _mm256_setzero_ps();
815 ymm9 = _mm256_setzero_ps();
816 ymm10 = _mm256_setzero_ps();
817 ymm12 = _mm256_setzero_ps();
818 ymm13 = _mm256_setzero_ps();
819 ymm14 = _mm256_setzero_ps();
820
821 for (k = 0; k < K; ++k)
822 {
823 // The inner loop broadcasts the B matrix data and
824 // multiplies it with the A matrix.
825 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
826 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
827 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
828 tB += tb_inc_row;
829
830 //broadcasted matrix B elements are multiplied
831 //with matrix A columns.
832 ymm3 = _mm256_loadu_ps(tA);
833 // ymm4 += ymm0 * ymm3;
834 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
835 // ymm8 += ymm1 * ymm3;
836 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
837 // ymm12 += ymm2 * ymm3;
838 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
839
840 ymm3 = _mm256_loadu_ps(tA + 8);
841 // ymm5 += ymm0 * ymm3;
842 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
843 // ymm9 += ymm1 * ymm3;
844 ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
845 // ymm13 += ymm2 * ymm3;
846 ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
847
848 ymm3 = _mm256_loadu_ps(tA + 16);
849 // ymm6 += ymm0 * ymm3;
850 ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
851 // ymm10 += ymm1 * ymm3;
852 ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
853 // ymm14 += ymm2 * ymm3;
854 ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
855
856 tA += lda;
857 }
858 // alpha, beta multiplication.
859 ymm0 = _mm256_broadcast_ss(alpha_cast);
860
861 //multiply A*B by alpha.
862 ymm4 = _mm256_mul_ps(ymm4, ymm0);
863 ymm5 = _mm256_mul_ps(ymm5, ymm0);
864 ymm6 = _mm256_mul_ps(ymm6, ymm0);
865 ymm8 = _mm256_mul_ps(ymm8, ymm0);
866 ymm9 = _mm256_mul_ps(ymm9, ymm0);
867 ymm10 = _mm256_mul_ps(ymm10, ymm0);
868 ymm12 = _mm256_mul_ps(ymm12, ymm0);
869 ymm13 = _mm256_mul_ps(ymm13, ymm0);
870 ymm14 = _mm256_mul_ps(ymm14, ymm0);
871
872 if(is_beta_non_zero)
873 {
874 ymm1 = _mm256_broadcast_ss(beta_cast);
875 // multiply C by beta and accumulate.
876 ymm2 = _mm256_loadu_ps(tC);
877 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
878 ymm2 = _mm256_loadu_ps(tC + 8);
879 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
880 ymm2 = _mm256_loadu_ps(tC + 16);
881 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
882 float* ttC = tC +ldc;
883 ymm2 = _mm256_loadu_ps(ttC);
884 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
885 ymm2 = _mm256_loadu_ps(ttC + 8);
886 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
887 ymm2 = _mm256_loadu_ps(ttC + 16);
888 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
889 ttC += ldc;
890 ymm2 = _mm256_loadu_ps(ttC);
891 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
892 ymm2 = _mm256_loadu_ps(ttC + 8);
893 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
894 ymm2 = _mm256_loadu_ps(ttC + 16);
895 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
896 }
897 _mm256_storeu_ps(tC, ymm4);
898 _mm256_storeu_ps(tC + 8, ymm5);
899 _mm256_storeu_ps(tC + 16, ymm6);
900
901 // multiply C by beta and accumulate.
902 tC += ldc;
903 _mm256_storeu_ps(tC, ymm8);
904 _mm256_storeu_ps(tC + 8, ymm9);
905 _mm256_storeu_ps(tC + 16, ymm10);
906
907 // multiply C by beta and accumulate.
908 tC += ldc;
909 _mm256_storeu_ps(tC, ymm12);
910 _mm256_storeu_ps(tC + 8, ymm13);
911 _mm256_storeu_ps(tC + 16, ymm14);
912
913 }
914 n_remainder = N - col_idx;
915 // if the N is not multiple of 3.
916 // handling edge case.
917 if (n_remainder == 2)
918 {
919 //pointer math to point to proper memory
920 tC = C + ldc * col_idx + row_idx;
921 tB = B + tb_inc_col * col_idx;
922 tA = A + row_idx;
923
924 // clear scratch registers.
925 ymm8 = _mm256_setzero_ps();
926 ymm9 = _mm256_setzero_ps();
927 ymm10 = _mm256_setzero_ps();
928 ymm12 = _mm256_setzero_ps();
929 ymm13 = _mm256_setzero_ps();
930 ymm14 = _mm256_setzero_ps();
931
932 for (k = 0; k < K; ++k)
933 {
934 // The inner loop broadcasts the B matrix data and
935 // multiplies it with the A matrix.
936 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
937 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
938 tB += tb_inc_row;
939
940 //broadcasted matrix B elements are multiplied
941 //with matrix A columns.
942 ymm3 = _mm256_loadu_ps(tA);
943 ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
944 ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
945
946 ymm3 = _mm256_loadu_ps(tA + 8);
947 ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
948 ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
949
950 ymm3 = _mm256_loadu_ps(tA + 16);
951 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
952 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
953
954 tA += lda;
955
956 }
957 // alpha, beta multiplication.
958 ymm0 = _mm256_broadcast_ss(alpha_cast);
959
960 //multiply A*B by alpha.
961 ymm8 = _mm256_mul_ps(ymm8, ymm0);
962 ymm9 = _mm256_mul_ps(ymm9, ymm0);
963 ymm10 = _mm256_mul_ps(ymm10, ymm0);
964 ymm12 = _mm256_mul_ps(ymm12, ymm0);
965 ymm13 = _mm256_mul_ps(ymm13, ymm0);
966 ymm14 = _mm256_mul_ps(ymm14, ymm0);
967
968 if(is_beta_non_zero)
969 {
970 ymm1 = _mm256_broadcast_ss(beta_cast);
971 // multiply C by beta and accumulate.
972 ymm2 = _mm256_loadu_ps(tC);
973 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
974 ymm2 = _mm256_loadu_ps(tC + 8);
975 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
976 ymm2 = _mm256_loadu_ps(tC + 16);
977 ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
978
979 float* ttC = tC +ldc;
980 // multiply C by beta and accumulate.
981 ymm2 = _mm256_loadu_ps(ttC);
982 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
983 ymm2 = _mm256_loadu_ps(ttC + 8);
984 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
985 ymm2 = _mm256_loadu_ps(ttC + 16);
986 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
987 }
988
989 _mm256_storeu_ps(tC, ymm8);
990 _mm256_storeu_ps(tC + 8, ymm9);
991 _mm256_storeu_ps(tC + 16, ymm10);
992
993 tC += ldc;
994
995 _mm256_storeu_ps(tC, ymm12);
996 _mm256_storeu_ps(tC + 8, ymm13);
997 _mm256_storeu_ps(tC + 16, ymm14);
998
999 col_idx += 2;
1000 }
1001 // if the N is not multiple of 3.
1002 // handling edge case.
1003 if (n_remainder == 1)
1004 {
1005 //pointer math to point to proper memory
1006 tC = C + ldc * col_idx + row_idx;
1007 tB = B + tb_inc_col * col_idx;
1008 tA = A + row_idx;
1009
1010 // clear scratch registers.
1011 ymm12 = _mm256_setzero_ps();
1012 ymm13 = _mm256_setzero_ps();
1013 ymm14 = _mm256_setzero_ps();
1014
1015 for (k = 0; k < K; ++k)
1016 {
1017 // The inner loop broadcasts the B matrix data and
1018 // multiplies it with the A matrix.
1019 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1020 tB += tb_inc_row;
1021
1022 //broadcasted matrix B elements are multiplied
1023 //with matrix A columns.
1024 ymm3 = _mm256_loadu_ps(tA);
1025 ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
1026
1027 ymm3 = _mm256_loadu_ps(tA + 8);
1028 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
1029
1030 ymm3 = _mm256_loadu_ps(tA + 16);
1031 ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
1032
1033 tA += lda;
1034
1035 }
1036 // alpha, beta multiplication.
1037 ymm0 = _mm256_broadcast_ss(alpha_cast);
1038
1039 //multiply A*B by alpha.
1040 ymm12 = _mm256_mul_ps(ymm12, ymm0);
1041 ymm13 = _mm256_mul_ps(ymm13, ymm0);
1042 ymm14 = _mm256_mul_ps(ymm14, ymm0);
1043
1044 if(is_beta_non_zero)
1045 {
1046 ymm1 = _mm256_broadcast_ss(beta_cast);
1047 // multiply C by beta and accumulate.
1048 ymm2 = _mm256_loadu_ps(tC + 0);
1049 ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
1050 ymm2 = _mm256_loadu_ps(tC + 8);
1051 ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
1052 ymm2 = _mm256_loadu_ps(tC + 16);
1053 ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
1054 }
1055 _mm256_storeu_ps(tC + 0, ymm12);
1056 _mm256_storeu_ps(tC + 8, ymm13);
1057 _mm256_storeu_ps(tC + 16, ymm14);
1058 }
1059
1060 row_idx += 24;
1061 }
1062
1063 if (m_remainder >= 16)
1064 {
1065 m_remainder -= 16;
1066
1067 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1068 {
1069 //pointer math to point to proper memory
1070 tC = C + ldc * col_idx + row_idx;
1071 tB = B + tb_inc_col * col_idx;
1072 tA = A + row_idx;
1073
1074 // clear scratch registers.
1075 ymm4 = _mm256_setzero_ps();
1076 ymm5 = _mm256_setzero_ps();
1077 ymm6 = _mm256_setzero_ps();
1078 ymm7 = _mm256_setzero_ps();
1079 ymm8 = _mm256_setzero_ps();
1080 ymm9 = _mm256_setzero_ps();
1081
1082 for (k = 0; k < K; ++k)
1083 {
1084 // The inner loop broadcasts the B matrix data and
1085 // multiplies it with the A matrix.
1086 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1087 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1088 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1089 tB += tb_inc_row;
1090
1091 //broadcasted matrix B elements are multiplied
1092 //with matrix A columns.
1093 ymm3 = _mm256_loadu_ps(tA);
1094 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1095 ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
1096 ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8);
1097
1098 ymm3 = _mm256_loadu_ps(tA + 8);
1099 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1100 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1101 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1102
1103 tA += lda;
1104 }
1105 // alpha, beta multiplication.
1106 ymm0 = _mm256_broadcast_ss(alpha_cast);
1107
1108 //multiply A*B by alpha.
1109 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1110 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1111 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1112 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1113 ymm8 = _mm256_mul_ps(ymm8, ymm0);
1114 ymm9 = _mm256_mul_ps(ymm9, ymm0);
1115
1116 if(is_beta_non_zero)
1117 {
1118 ymm1 = _mm256_broadcast_ss(beta_cast);
1119 // multiply C by beta and accumulate.
1120 ymm2 = _mm256_loadu_ps(tC);
1121 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1122 ymm2 = _mm256_loadu_ps(tC + 8);
1123 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1124 float* ttC = tC + ldc;
1125 ymm2 = _mm256_loadu_ps(ttC);
1126 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1127 ymm2 = _mm256_loadu_ps(ttC + 8);
1128 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1129 ttC += ldc;
1130 ymm2 = _mm256_loadu_ps(ttC);
1131 ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
1132 ymm2 = _mm256_loadu_ps(ttC + 8);
1133 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
1134 }
1135 _mm256_storeu_ps(tC, ymm4);
1136 _mm256_storeu_ps(tC + 8, ymm5);
1137
1138 // multiply C by beta and accumulate.
1139 tC += ldc;
1140 _mm256_storeu_ps(tC, ymm6);
1141 _mm256_storeu_ps(tC + 8, ymm7);
1142
1143 // multiply C by beta and accumulate.
1144 tC += ldc;
1145 _mm256_storeu_ps(tC, ymm8);
1146 _mm256_storeu_ps(tC + 8, ymm9);
1147
1148 }
1149 n_remainder = N - col_idx;
1150 // if the N is not multiple of 3.
1151 // handling edge case.
1152 if (n_remainder == 2)
1153 {
1154 //pointer math to point to proper memory
1155 tC = C + ldc * col_idx + row_idx;
1156 tB = B + tb_inc_col * col_idx;
1157 tA = A + row_idx;
1158
1159 // clear scratch registers.
1160 ymm4 = _mm256_setzero_ps();
1161 ymm5 = _mm256_setzero_ps();
1162 ymm6 = _mm256_setzero_ps();
1163 ymm7 = _mm256_setzero_ps();
1164
1165 for (k = 0; k < K; ++k)
1166 {
1167 // The inner loop broadcasts the B matrix data and
1168 // multiplies it with the A matrix.
1169 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1170 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1171 tB += tb_inc_row;
1172
1173 //broadcasted matrix B elements are multiplied
1174 //with matrix A columns.
1175 ymm3 = _mm256_loadu_ps(tA);
1176 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1177 ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
1178
1179 ymm3 = _mm256_loadu_ps(tA + 8);
1180 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1181 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1182
1183 tA += lda;
1184 }
1185 // alpha, beta multiplication.
1186 ymm0 = _mm256_broadcast_ss(alpha_cast);
1187
1188 //multiply A*B by alpha.
1189 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1190 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1191 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1192 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1193
1194 if(is_beta_non_zero)
1195 {
1196 ymm1 = _mm256_broadcast_ss(beta_cast);
1197 // multiply C by beta and accumulate.
1198 ymm2 = _mm256_loadu_ps(tC);
1199 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1200 ymm2 = _mm256_loadu_ps(tC + 8);
1201 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1202 float* ttC = tC + ldc;
1203 ymm2 = _mm256_loadu_ps(ttC);
1204 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1205 ymm2 = _mm256_loadu_ps(ttC + 8);
1206 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1207 }
1208 _mm256_storeu_ps(tC, ymm4);
1209 _mm256_storeu_ps(tC + 8, ymm5);
1210
1211 // multiply C by beta and accumulate.
1212 tC += ldc;
1213 _mm256_storeu_ps(tC, ymm6);
1214 _mm256_storeu_ps(tC + 8, ymm7);
1215
1216 col_idx += 2;
1217
1218 }
1219 // if the N is not multiple of 3.
1220 // handling edge case.
1221 if (n_remainder == 1)
1222 {
1223 //pointer math to point to proper memory
1224 tC = C + ldc * col_idx + row_idx;
1225 tB = B + tb_inc_col * col_idx;
1226 tA = A + row_idx;
1227
1228 ymm4 = _mm256_setzero_ps();
1229 ymm5 = _mm256_setzero_ps();
1230
1231 for (k = 0; k < K; ++k)
1232 {
1233 // The inner loop broadcasts the B matrix data and
1234 // multiplies it with the A matrix.
1235 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1236 tB += tb_inc_row;
1237
1238 //broadcasted matrix B elements are multiplied
1239 //with matrix A columns.
1240 ymm3 = _mm256_loadu_ps(tA);
1241 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1242
1243 ymm3 = _mm256_loadu_ps(tA + 8);
1244 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1245
1246 tA += lda;
1247 }
1248 // alpha, beta multiplication.
1249 ymm0 = _mm256_broadcast_ss(alpha_cast);
1250
1251 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1252 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1253
1254 // multiply C by beta and accumulate.
1255 if(is_beta_non_zero)
1256 {
1257 ymm1 = _mm256_broadcast_ss(beta_cast);
1258 ymm2 = _mm256_loadu_ps(tC);
1259 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1260 ymm2 = _mm256_loadu_ps(tC + 8);
1261 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1262 }
1263 _mm256_storeu_ps(tC, ymm4);
1264 _mm256_storeu_ps(tC + 8, ymm5);
1265
1266 }
1267
1268 row_idx += 16;
1269 }
1270
1271 if (m_remainder >= 8)
1272 {
1273 m_remainder -= 8;
1274
1275 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1276 {
1277 //pointer math to point to proper memory
1278 tC = C + ldc * col_idx + row_idx;
1279 tB = B + tb_inc_col * col_idx;
1280 tA = A + row_idx;
1281
1282 // clear scratch registers.
1283 ymm4 = _mm256_setzero_ps();
1284 ymm5 = _mm256_setzero_ps();
1285 ymm6 = _mm256_setzero_ps();
1286
1287 for (k = 0; k < K; ++k)
1288 {
1289 // The inner loop broadcasts the B matrix data and
1290 // multiplies it with the A matrix.
1291 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1292 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1293 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1294 tB += tb_inc_row;
1295
1296 //broadcasted matrix B elements are multiplied
1297 //with matrix A columns.
1298 ymm3 = _mm256_loadu_ps(tA);
1299 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1300 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
1301 ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
1302
1303 tA += lda;
1304 }
1305 // alpha, beta multiplication.
1306 ymm0 = _mm256_broadcast_ss(alpha_cast);
1307
1308 //multiply A*B by alpha.
1309 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1310 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1311 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1312
1313 if(is_beta_non_zero)
1314 {
1315 ymm1 = _mm256_broadcast_ss(beta_cast);
1316 ymm2 = _mm256_loadu_ps(tC);
1317 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1318 ymm2 = _mm256_loadu_ps(tC + ldc);
1319 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1320 ymm2 = _mm256_loadu_ps(tC + 2*ldc);
1321 ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1322 }
1323 _mm256_storeu_ps(tC, ymm4);
1324
1325 // multiply C by beta and accumulate.
1326 tC += ldc;
1327 _mm256_storeu_ps(tC, ymm5);
1328
1329 // multiply C by beta and accumulate.
1330 tC += ldc;
1331 _mm256_storeu_ps(tC, ymm6);
1332 }
1333 n_remainder = N - col_idx;
1334 // if the N is not multiple of 3.
1335 // handling edge case.
1336 if (n_remainder == 2)
1337 {
1338 //pointer math to point to proper memory
1339 tC = C + ldc * col_idx + row_idx;
1340 tB = B + tb_inc_col * col_idx;
1341 tA = A + row_idx;
1342
1343 ymm4 = _mm256_setzero_ps();
1344 ymm5 = _mm256_setzero_ps();
1345
1346 for (k = 0; k < K; ++k)
1347 {
1348 // The inner loop broadcasts the B matrix data and
1349 // multiplies it with the A matrix.
1350 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1351 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1352 tB += tb_inc_row;
1353
1354 //broadcasted matrix B elements are multiplied
1355 //with matrix A columns.
1356 ymm3 = _mm256_loadu_ps(tA);
1357 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1358 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
1359
1360 tA += lda;
1361 }
1362 // alpha, beta multiplication.
1363 ymm0 = _mm256_broadcast_ss(alpha_cast);
1364
1365 //multiply A*B by alpha.
1366 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1367 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1368
1369 if(is_beta_non_zero)
1370 {
1371 ymm1 = _mm256_broadcast_ss(beta_cast);
1372 // multiply C by beta and accumulate.
1373 ymm2 = _mm256_loadu_ps(tC);
1374 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1375 ymm2 = _mm256_loadu_ps(tC + ldc);
1376 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1377 }
1378 _mm256_storeu_ps(tC, ymm4);
1379 // multiply C by beta and accumulate.
1380 tC += ldc;
1381 _mm256_storeu_ps(tC, ymm5);
1382
1383 col_idx += 2;
1384
1385 }
1386 // if the N is not multiple of 3.
1387 // handling edge case.
1388 if (n_remainder == 1)
1389 {
1390 //pointer math to point to proper memory
1391 tC = C + ldc * col_idx + row_idx;
1392 tB = B + tb_inc_col * col_idx;
1393 tA = A + row_idx;
1394
1395 ymm4 = _mm256_setzero_ps();
1396
1397 for (k = 0; k < K; ++k)
1398 {
1399 // The inner loop broadcasts the B matrix data and
1400 // multiplies it with the A matrix.
1401 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1402 tB += tb_inc_row;
1403
1404 //broadcasted matrix B elements are multiplied
1405 //with matrix A columns.
1406 ymm3 = _mm256_loadu_ps(tA);
1407 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1408
1409 tA += lda;
1410 }
1411 // alpha, beta multiplication.
1412 ymm0 = _mm256_broadcast_ss(alpha_cast);
1413 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1414
1415 if(is_beta_non_zero)
1416 {
1417 ymm1 = _mm256_broadcast_ss(beta_cast);
1418 // multiply C by beta and accumulate.
1419 ymm2 = _mm256_loadu_ps(tC);
1420 ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1421 }
1422 _mm256_storeu_ps(tC, ymm4);
1423
1424 }
1425
1426 row_idx += 8;
1427 }
1428 // M is not a multiple of 32.
1429 // The handling of edge case where the remainder
1430 // dimension is less than 8. The padding takes place
1431 // to handle this case.
1432 if ((m_remainder) && (lda > 7))
1433 {
1434 float f_temp[8] = {0.0};
1435
1436 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1437 {
1438 //pointer math to point to proper memory
1439 tC = C + ldc * col_idx + row_idx;
1440 tB = B + tb_inc_col * col_idx;
1441 tA = A + row_idx;
1442
1443 // clear scratch registers.
1444 ymm5 = _mm256_setzero_ps();
1445 ymm7 = _mm256_setzero_ps();
1446 ymm9 = _mm256_setzero_ps();
1447
1448 for (k = 0; k < (K - 1); ++k)
1449 {
1450 // The inner loop broadcasts the B matrix data and
1451 // multiplies it with the A matrix.
1452 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1453 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1454 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1455 tB += tb_inc_row;
1456
1457 //broadcasted matrix B elements are multiplied
1458 //with matrix A columns.
1459 ymm3 = _mm256_loadu_ps(tA);
1460 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1461 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1462 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1463
1464 tA += lda;
1465 }
1466 // alpha, beta multiplication.
1467 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1468 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1469 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1470 tB += tb_inc_row;
1471
1472 for (int i = 0; i < m_remainder; i++)
1473 {
1474 f_temp[i] = tA[i];
1475 }
1476 ymm3 = _mm256_loadu_ps(f_temp);
1477 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1478 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1479 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1480
1481 ymm0 = _mm256_broadcast_ss(alpha_cast);
1482 ymm1 = _mm256_broadcast_ss(beta_cast);
1483
1484 //multiply A*B by alpha.
1485 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1486 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1487 ymm9 = _mm256_mul_ps(ymm9, ymm0);
1488
1489
1490 for (int i = 0; i < m_remainder; i++)
1491 {
1492 f_temp[i] = tC[i];
1493 }
1494 ymm2 = _mm256_loadu_ps(f_temp);
1495 if(is_beta_non_zero){
1496 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1497 }
1498 _mm256_storeu_ps(f_temp, ymm5);
1499 for (int i = 0; i < m_remainder; i++)
1500 {
1501 tC[i] = f_temp[i];
1502 }
1503
1504 tC += ldc;
1505 for (int i = 0; i < m_remainder; i++)
1506 {
1507 f_temp[i] = tC[i];
1508 }
1509 ymm2 = _mm256_loadu_ps(f_temp);
1510 if(is_beta_non_zero){
1511 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1512 }
1513 _mm256_storeu_ps(f_temp, ymm7);
1514 for (int i = 0; i < m_remainder; i++)
1515 {
1516 tC[i] = f_temp[i];
1517 }
1518
1519 tC += ldc;
1520 for (int i = 0; i < m_remainder; i++)
1521 {
1522 f_temp[i] = tC[i];
1523 }
1524 ymm2 = _mm256_loadu_ps(f_temp);
1525 if(is_beta_non_zero){
1526 ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
1527 }
1528 _mm256_storeu_ps(f_temp, ymm9);
1529 for (int i = 0; i < m_remainder; i++)
1530 {
1531 tC[i] = f_temp[i];
1532 }
1533 }
1534 n_remainder = N - col_idx;
1535 // if the N is not multiple of 3.
1536 // handling edge case.
1537 if (n_remainder == 2)
1538 {
1539 //pointer math to point to proper memory
1540 tC = C + ldc * col_idx + row_idx;
1541 tB = B + tb_inc_col * col_idx;
1542 tA = A + row_idx;
1543
1544 ymm5 = _mm256_setzero_ps();
1545 ymm7 = _mm256_setzero_ps();
1546
1547 for (k = 0; k < (K - 1); ++k)
1548 {
1549 // The inner loop broadcasts the B matrix data and
1550 // multiplies it with the A matrix.
1551 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1552 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1553 tB += tb_inc_row;
1554
1555 ymm3 = _mm256_loadu_ps(tA);
1556 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1557 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1558
1559 tA += lda;
1560 }
1561
1562 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1563 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1564 tB += tb_inc_row;
1565
1566 for (int i = 0; i < m_remainder; i++)
1567 {
1568 f_temp[i] = tA[i];
1569 }
1570 ymm3 = _mm256_loadu_ps(f_temp);
1571 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1572 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1573
1574 ymm0 = _mm256_broadcast_ss(alpha_cast);
1575 ymm1 = _mm256_broadcast_ss(beta_cast);
1576
1577 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1578 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1579
1580 for (int i = 0; i < m_remainder; i++)
1581 {
1582 f_temp[i] = tC[i];
1583 }
1584 ymm2 = _mm256_loadu_ps(f_temp);
1585 if(is_beta_non_zero){
1586 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1587 }
1588 _mm256_storeu_ps(f_temp, ymm5);
1589 for (int i = 0; i < m_remainder; i++)
1590 {
1591 tC[i] = f_temp[i];
1592 }
1593
1594 tC += ldc;
1595 for (int i = 0; i < m_remainder; i++)
1596 {
1597 f_temp[i] = tC[i];
1598 }
1599 ymm2 = _mm256_loadu_ps(f_temp);
1600 if(is_beta_non_zero){
1601 ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1602 }
1603 _mm256_storeu_ps(f_temp, ymm7);
1604 for (int i = 0; i < m_remainder; i++)
1605 {
1606 tC[i] = f_temp[i];
1607 }
1608 }
1609 // if the N is not multiple of 3.
1610 // handling edge case.
1611 if (n_remainder == 1)
1612 {
1613 //pointer math to point to proper memory
1614 tC = C + ldc * col_idx + row_idx;
1615 tB = B + tb_inc_col * col_idx;
1616 tA = A + row_idx;
1617
1618 ymm5 = _mm256_setzero_ps();
1619
1620 for (k = 0; k < (K - 1); ++k)
1621 {
1622 // The inner loop broadcasts the B matrix data and
1623 // multiplies it with the A matrix.
1624 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1625 tB += tb_inc_row;
1626
1627 ymm3 = _mm256_loadu_ps(tA);
1628 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1629
1630 tA += lda;
1631 }
1632
1633 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1634 tB += tb_inc_row;
1635
1636 for (int i = 0; i < m_remainder; i++)
1637 {
1638 f_temp[i] = tA[i];
1639 }
1640 ymm3 = _mm256_loadu_ps(f_temp);
1641 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1642
1643 ymm0 = _mm256_broadcast_ss(alpha_cast);
1644
1645 // multiply C by beta and accumulate.
1646 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1647
1648 for (int i = 0; i < m_remainder; i++)
1649 {
1650 f_temp[i] = tC[i];
1651 }
1652 ymm2 = _mm256_loadu_ps(f_temp);
1653 if(is_beta_non_zero){
1654 ymm1 = _mm256_broadcast_ss(beta_cast);
1655 ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1656 }
1657 _mm256_storeu_ps(f_temp, ymm5);
1658 for (int i = 0; i < m_remainder; i++)
1659 {
1660 tC[i] = f_temp[i];
1661 }
1662 }
1663 m_remainder = 0;
1664 }
1665
1666 if (m_remainder)
1667 {
1668 float result;
1669 for (; row_idx < M; row_idx += 1)
1670 {
1671 for (col_idx = 0; col_idx < N; col_idx += 1)
1672 {
1673 //pointer math to point to proper memory
1674 tC = C + ldc * col_idx + row_idx;
1675 tB = B + tb_inc_col * col_idx;
1676 tA = A + row_idx;
1677
1678 result = 0;
1679 for (k = 0; k < K; ++k)
1680 {
1681 result += (*tA) * (*tB);
1682 tA += lda;
1683 tB += tb_inc_row;
1684 }
1685
1686 result *= (*alpha_cast);
1687 if(is_beta_non_zero){
1688 (*tC) = (*tC) * (*beta_cast) + result;
1689 }else{
1690 (*tC) = result;
1691 }
1692 }
1693 }
1694 }
1695
1696 // Return the buffer to pool
1697 if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s) ) {
1698
1699 #ifdef BLIS_ENABLE_MEM_TRACING
1700 printf( "bli_sgemm_small(): releasing mem pool block\n" );
1701 #endif
1702 bli_membrk_release(&rntm,
1703 &local_mem_buf_A_s);
1704 }
1705
1706 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
1707 return BLIS_SUCCESS;
1708 }
1709 else
1710 {
1711 AOCL_DTL_TRACE_EXIT_ERR(
1712 AOCL_DTL_LEVEL_INFO,
1713 "Invalid dimesions for small gemm."
1714 );
1715 return BLIS_NONCONFORMAL_DIMENSIONS;
1716 }
1717
1718 };
1719
bli_dgemm_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)1720 static err_t bli_dgemm_small
1721 (
1722 obj_t* alpha,
1723 obj_t* a,
1724 obj_t* b,
1725 obj_t* beta,
1726 obj_t* c,
1727 cntx_t* cntx,
1728 cntl_t* cntl
1729 )
1730 {
1731
1732 AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
1733
1734 gint_t M = bli_obj_length( c ); // number of rows of Matrix C
1735 gint_t N = bli_obj_width( c ); // number of columns of Matrix C
1736 gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
1737 gint_t L = M * N;
1738
1739 // when N is equal to 1 call GEMV instead of GEMM
1740 if (N == 1)
1741 {
1742 bli_gemv
1743 (
1744 alpha,
1745 a,
1746 b,
1747 beta,
1748 c
1749 );
1750 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
1751 return BLIS_SUCCESS;
1752 }
1753
1754 if (N<3) //Implemenation assumes that N is atleast 3.
1755 {
1756 AOCL_DTL_TRACE_EXIT_ERR(
1757 AOCL_DTL_LEVEL_INFO,
1758 "N < 3, cannot be processed by small gemm"
1759 );
1760 return BLIS_NOT_YET_IMPLEMENTED;
1761 }
1762
1763 #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
1764 if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME))))
1765 #else
1766 if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
1767 || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
1768 #endif
1769 {
1770 guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
1771 guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
1772 guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
1773 guint_t row_idx, col_idx, k;
1774 double *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A
1775 double *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B
1776 double *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C
1777
1778 double *tA = A, *tB = B, *tC = C;//, *tA_pack;
1779 double *tA_packed; // temprorary pointer to hold packed A memory pointer
1780 guint_t row_idx_packed; //packed A memory row index
1781 guint_t lda_packed; //lda of packed A
1782 guint_t col_idx_start; //starting index after A matrix is packed.
1783 dim_t tb_inc_row = 1; // row stride of matrix B
1784 dim_t tb_inc_col = ldb; // column stride of matrix B
1785 __m256d ymm4, ymm5, ymm6, ymm7;
1786 __m256d ymm8, ymm9, ymm10, ymm11;
1787 __m256d ymm12, ymm13, ymm14, ymm15;
1788 __m256d ymm0, ymm1, ymm2, ymm3;
1789
1790 gint_t n_remainder; // If the N is non multiple of 3.(N%3)
1791 gint_t m_remainder; // If the M is non multiple of 16.(M%16)
1792
1793 double *alpha_cast, *beta_cast; // alpha, beta multiples
1794 alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha);
1795 beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta);
1796
1797 gint_t required_packing_A = 1;
1798 mem_t local_mem_buf_A_s;
1799 double *D_A_pack = NULL;
1800 rntm_t rntm;
1801
1802 //update the pointer math if matrix B needs to be transposed.
1803 if (bli_obj_has_trans( b ))
1804 {
1805 tb_inc_col = 1; //switch row and column strides
1806 tb_inc_row = ldb;
1807 }
1808
1809 //checking whether beta value is zero.
1810 //if true, we should perform C=alpha * A*B operation
1811 //instead of C = beta * C + alpha * (A * B)
1812 bool is_beta_non_zero = 0;
1813 if(!bli_obj_equals(beta, &BLIS_ZERO))
1814 is_beta_non_zero = 1;
1815
1816 /*
1817 * This function was using global array to pack part of A input when needed.
1818 * However, using this global array make the function non-reentrant.
1819 * Instead of using a global array we should allocate buffer for each invocation.
1820 * Since the buffer size is too big or stack and doing malloc every time will be too expensive,
1821 * better approach is to get the buffer from the pre-allocated pool and return
1822 * it the pool once we are doing.
1823 *
1824 * In order to get the buffer from pool, we need access to memory broker,
1825 * currently this function is not invoked in such a way that it can receive
1826 * the memory broker (via rntm). Following hack will get the global memory
1827 * broker that can be use it to access the pool.
1828 *
1829 * Note there will be memory allocation at least on first innovation
1830 * as there will not be any pool created for this size.
1831 * Subsequent invocations will just reuse the buffer from the pool.
1832 */
1833
1834 bli_rntm_init_from_global( &rntm );
1835 bli_rntm_set_num_threads_only( 1, &rntm );
1836 bli_membrk_rntm_set_membrk( &rntm );
1837
1838 // Get the current size of the buffer pool for A block packing.
1839 // We will use the same size to avoid pool re-initliazaton
1840 siz_t buffer_size = bli_pool_block_size(
1841 bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
1842 bli_rntm_membrk(&rntm)));
1843
1844 //
1845 // This kernel assumes that "A" will be unpackged if N <= 3.
1846 // Usually this range (N <= 3) is handled by SUP, however,
1847 // if SUP is disabled or for any other condition if we do
1848 // enter this kernel with N <= 3, we want to make sure that
1849 // "A" remains unpacked.
1850 //
1851 // If this check is removed it will result in the crash as
1852 // reported in CPUPL-587.
1853 //
1854
1855 if ((N <= 3) || ((D_MR * K) << 3) > buffer_size)
1856 {
1857 required_packing_A = 0;
1858 }
1859
1860 if (required_packing_A == 1)
1861 {
1862 #ifdef BLIS_ENABLE_MEM_TRACING
1863 printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", buffer_size);
1864 #endif
1865 // Get the buffer from the pool.
1866 bli_membrk_acquire_m(&rntm,
1867 buffer_size,
1868 BLIS_BITVAL_BUFFER_FOR_A_BLOCK,
1869 &local_mem_buf_A_s);
1870
1871 D_A_pack = bli_mem_buffer(&local_mem_buf_A_s);
1872 }
1873
1874 /*
1875 * The computation loop runs for D_MRxN columns of C matrix, thus
1876 * accessing the D_MRxK A matrix data and KxNR B matrix data.
1877 * The computation is organized as inner loops of dimension D_MRxNR.
1878 */
1879 // Process D_MR rows of C matrix at a time.
1880 for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR)
1881 {
1882
1883 col_idx_start = 0;
1884 tA_packed = A;
1885 row_idx_packed = row_idx;
1886 lda_packed = lda;
1887
1888 // This is the part of the pack and compute optimization.
1889 // During the first column iteration, we store the accessed A matrix into
1890 // contiguous static memory. This helps to keep te A matrix in Cache and
1891 // aviods the TLB misses.
1892 if (required_packing_A)
1893 {
1894 col_idx = 0;
1895
1896 //pointer math to point to proper memory
1897 tC = C + ldc * col_idx + row_idx;
1898 tB = B + tb_inc_col * col_idx;
1899 tA = A + row_idx;
1900 tA_packed = D_A_pack;
1901
1902 #ifdef BLIS_ENABLE_PREFETCH
1903 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
1904 _mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
1905 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
1906 _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
1907 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
1908 _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
1909 #endif
1910 // clear scratch registers.
1911 ymm4 = _mm256_setzero_pd();
1912 ymm5 = _mm256_setzero_pd();
1913 ymm6 = _mm256_setzero_pd();
1914 ymm7 = _mm256_setzero_pd();
1915 ymm8 = _mm256_setzero_pd();
1916 ymm9 = _mm256_setzero_pd();
1917 ymm10 = _mm256_setzero_pd();
1918 ymm11 = _mm256_setzero_pd();
1919 ymm12 = _mm256_setzero_pd();
1920 ymm13 = _mm256_setzero_pd();
1921 ymm14 = _mm256_setzero_pd();
1922 ymm15 = _mm256_setzero_pd();
1923
1924 for (k = 0; k < K; ++k)
1925 {
1926 // The inner loop broadcasts the B matrix data and
1927 // multiplies it with the A matrix.
1928 // This loop is processing D_MR x K
1929 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
1930 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
1931 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
1932 tB += tb_inc_row;
1933
1934 //broadcasted matrix B elements are multiplied
1935 //with matrix A columns.
1936 ymm3 = _mm256_loadu_pd(tA);
1937 _mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A
1938 // ymm4 += ymm0 * ymm3;
1939 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
1940 // ymm8 += ymm1 * ymm3;
1941 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
1942 // ymm12 += ymm2 * ymm3;
1943 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
1944
1945 ymm3 = _mm256_loadu_pd(tA + 4);
1946 _mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A
1947 // ymm5 += ymm0 * ymm3;
1948 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
1949 // ymm9 += ymm1 * ymm3;
1950 ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
1951 // ymm13 += ymm2 * ymm3;
1952 ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
1953
1954 ymm3 = _mm256_loadu_pd(tA + 8);
1955 _mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A
1956 // ymm6 += ymm0 * ymm3;
1957 ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
1958 // ymm10 += ymm1 * ymm3;
1959 ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
1960 // ymm14 += ymm2 * ymm3;
1961 ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
1962
1963 ymm3 = _mm256_loadu_pd(tA + 12);
1964 _mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A
1965 // ymm7 += ymm0 * ymm3;
1966 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
1967 // ymm11 += ymm1 * ymm3;
1968 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
1969 // ymm15 += ymm2 * ymm3;
1970 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
1971
1972 tA += lda;
1973 tA_packed += D_MR;
1974 }
1975 // alpha, beta multiplication.
1976 ymm0 = _mm256_broadcast_sd(alpha_cast);
1977 ymm1 = _mm256_broadcast_sd(beta_cast);
1978
1979 //multiply A*B by alpha.
1980 ymm4 = _mm256_mul_pd(ymm4, ymm0);
1981 ymm5 = _mm256_mul_pd(ymm5, ymm0);
1982 ymm6 = _mm256_mul_pd(ymm6, ymm0);
1983 ymm7 = _mm256_mul_pd(ymm7, ymm0);
1984 ymm8 = _mm256_mul_pd(ymm8, ymm0);
1985 ymm9 = _mm256_mul_pd(ymm9, ymm0);
1986 ymm10 = _mm256_mul_pd(ymm10, ymm0);
1987 ymm11 = _mm256_mul_pd(ymm11, ymm0);
1988 ymm12 = _mm256_mul_pd(ymm12, ymm0);
1989 ymm13 = _mm256_mul_pd(ymm13, ymm0);
1990 ymm14 = _mm256_mul_pd(ymm14, ymm0);
1991 ymm15 = _mm256_mul_pd(ymm15, ymm0);
1992
1993 if(is_beta_non_zero)
1994 {
1995 // multiply C by beta and accumulate col 1.
1996 ymm2 = _mm256_loadu_pd(tC);
1997 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
1998 ymm2 = _mm256_loadu_pd(tC + 4);
1999 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2000 ymm2 = _mm256_loadu_pd(tC + 8);
2001 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2002 ymm2 = _mm256_loadu_pd(tC + 12);
2003 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2004
2005 double* ttC = tC + ldc;
2006
2007 // multiply C by beta and accumulate, col 2.
2008 ymm2 = _mm256_loadu_pd(ttC);
2009 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2010 ymm2 = _mm256_loadu_pd(ttC + 4);
2011 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2012 ymm2 = _mm256_loadu_pd(ttC + 8);
2013 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2014 ymm2 = _mm256_loadu_pd(ttC + 12);
2015 ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
2016
2017 ttC += ldc;
2018
2019 // multiply C by beta and accumulate, col 3.
2020 ymm2 = _mm256_loadu_pd(ttC);
2021 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2022 ymm2 = _mm256_loadu_pd(ttC + 4);
2023 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2024 ymm2 = _mm256_loadu_pd(ttC + 8);
2025 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2026 ymm2 = _mm256_loadu_pd(ttC + 12);
2027 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2028 }
2029 _mm256_storeu_pd(tC, ymm4);
2030 _mm256_storeu_pd(tC + 4, ymm5);
2031 _mm256_storeu_pd(tC + 8, ymm6);
2032 _mm256_storeu_pd(tC + 12, ymm7);
2033
2034 tC += ldc;
2035
2036 _mm256_storeu_pd(tC, ymm8);
2037 _mm256_storeu_pd(tC + 4, ymm9);
2038 _mm256_storeu_pd(tC + 8, ymm10);
2039 _mm256_storeu_pd(tC + 12, ymm11);
2040
2041 tC += ldc;
2042
2043 _mm256_storeu_pd(tC, ymm12);
2044 _mm256_storeu_pd(tC + 4, ymm13);
2045 _mm256_storeu_pd(tC + 8, ymm14);
2046 _mm256_storeu_pd(tC + 12, ymm15);
2047
2048 // modify the pointer arithematic to use packed A matrix.
2049 col_idx_start = NR;
2050 tA_packed = D_A_pack;
2051 row_idx_packed = 0;
2052 lda_packed = D_MR;
2053 }
2054 // Process NR columns of C matrix at a time.
2055 for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
2056 {
2057 //pointer math to point to proper memory
2058 tC = C + ldc * col_idx + row_idx;
2059 tB = B + tb_inc_col * col_idx;
2060 tA = tA_packed + row_idx_packed;
2061
2062 #ifdef BLIS_ENABLE_PREFETCH
2063 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
2064 _mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
2065 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
2066 _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
2067 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
2068 _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
2069 #endif
2070 // clear scratch registers.
2071 ymm4 = _mm256_setzero_pd();
2072 ymm5 = _mm256_setzero_pd();
2073 ymm6 = _mm256_setzero_pd();
2074 ymm7 = _mm256_setzero_pd();
2075 ymm8 = _mm256_setzero_pd();
2076 ymm9 = _mm256_setzero_pd();
2077 ymm10 = _mm256_setzero_pd();
2078 ymm11 = _mm256_setzero_pd();
2079 ymm12 = _mm256_setzero_pd();
2080 ymm13 = _mm256_setzero_pd();
2081 ymm14 = _mm256_setzero_pd();
2082 ymm15 = _mm256_setzero_pd();
2083
2084 for (k = 0; k < K; ++k)
2085 {
2086 // The inner loop broadcasts the B matrix data and
2087 // multiplies it with the A matrix.
2088 // This loop is processing D_MR x K
2089 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2090 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2091 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2092 tB += tb_inc_row;
2093
2094 //broadcasted matrix B elements are multiplied
2095 //with matrix A columns.
2096 ymm3 = _mm256_loadu_pd(tA);
2097 // ymm4 += ymm0 * ymm3;
2098 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2099 // ymm8 += ymm1 * ymm3;
2100 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
2101 // ymm12 += ymm2 * ymm3;
2102 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
2103
2104 ymm3 = _mm256_loadu_pd(tA + 4);
2105 // ymm5 += ymm0 * ymm3;
2106 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2107 // ymm9 += ymm1 * ymm3;
2108 ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
2109 // ymm13 += ymm2 * ymm3;
2110 ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
2111
2112 ymm3 = _mm256_loadu_pd(tA + 8);
2113 // ymm6 += ymm0 * ymm3;
2114 ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
2115 // ymm10 += ymm1 * ymm3;
2116 ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
2117 // ymm14 += ymm2 * ymm3;
2118 ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
2119
2120 ymm3 = _mm256_loadu_pd(tA + 12);
2121 // ymm7 += ymm0 * ymm3;
2122 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
2123 // ymm11 += ymm1 * ymm3;
2124 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
2125 // ymm15 += ymm2 * ymm3;
2126 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
2127
2128 tA += lda_packed;
2129 }
2130 // alpha, beta multiplication.
2131 ymm0 = _mm256_broadcast_sd(alpha_cast);
2132 ymm1 = _mm256_broadcast_sd(beta_cast);
2133
2134 //multiply A*B by alpha.
2135 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2136 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2137 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2138 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2139 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2140 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2141 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2142 ymm11 = _mm256_mul_pd(ymm11, ymm0);
2143 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2144 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2145 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2146 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2147
2148 if(is_beta_non_zero)
2149 {
2150 // multiply C by beta and accumulate col 1.
2151 ymm2 = _mm256_loadu_pd(tC);
2152 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2153 ymm2 = _mm256_loadu_pd(tC + 4);
2154 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2155 ymm2 = _mm256_loadu_pd(tC + 8);
2156 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2157 ymm2 = _mm256_loadu_pd(tC + 12);
2158 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2159
2160 // multiply C by beta and accumulate, col 2.
2161 double* ttC = tC + ldc;
2162 ymm2 = _mm256_loadu_pd(ttC);
2163 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2164 ymm2 = _mm256_loadu_pd(ttC + 4);
2165 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2166 ymm2 = _mm256_loadu_pd(ttC + 8);
2167 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2168 ymm2 = _mm256_loadu_pd(ttC + 12);
2169 ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
2170
2171 // multiply C by beta and accumulate, col 3.
2172 ttC += ldc;
2173 ymm2 = _mm256_loadu_pd(ttC);
2174 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2175 ymm2 = _mm256_loadu_pd(ttC + 4);
2176 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2177 ymm2 = _mm256_loadu_pd(ttC + 8);
2178 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2179 ymm2 = _mm256_loadu_pd(ttC + 12);
2180 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2181 }
2182 _mm256_storeu_pd(tC, ymm4);
2183 _mm256_storeu_pd(tC + 4, ymm5);
2184 _mm256_storeu_pd(tC + 8, ymm6);
2185 _mm256_storeu_pd(tC + 12, ymm7);
2186
2187 tC += ldc;
2188
2189 _mm256_storeu_pd(tC, ymm8);
2190 _mm256_storeu_pd(tC + 4, ymm9);
2191 _mm256_storeu_pd(tC + 8, ymm10);
2192 _mm256_storeu_pd(tC + 12, ymm11);
2193
2194 tC += ldc;
2195
2196 _mm256_storeu_pd(tC, ymm12);
2197 _mm256_storeu_pd(tC + 4, ymm13);
2198 _mm256_storeu_pd(tC + 8, ymm14);
2199 _mm256_storeu_pd(tC + 12, ymm15);
2200
2201 }
2202 n_remainder = N - col_idx;
2203
2204 // if the N is not multiple of 3.
2205 // handling edge case.
2206 if (n_remainder == 2)
2207 {
2208 //pointer math to point to proper memory
2209 tC = C + ldc * col_idx + row_idx;
2210 tB = B + tb_inc_col * col_idx;
2211 tA = A + row_idx;
2212
2213 // clear scratch registers.
2214 ymm8 = _mm256_setzero_pd();
2215 ymm9 = _mm256_setzero_pd();
2216 ymm10 = _mm256_setzero_pd();
2217 ymm11 = _mm256_setzero_pd();
2218 ymm12 = _mm256_setzero_pd();
2219 ymm13 = _mm256_setzero_pd();
2220 ymm14 = _mm256_setzero_pd();
2221 ymm15 = _mm256_setzero_pd();
2222
2223 for (k = 0; k < K; ++k)
2224 {
2225 // The inner loop broadcasts the B matrix data and
2226 // multiplies it with the A matrix.
2227 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2228 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2229 tB += tb_inc_row;
2230
2231 //broadcasted matrix B elements are multiplied
2232 //with matrix A columns.
2233 ymm3 = _mm256_loadu_pd(tA);
2234 ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
2235 ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
2236
2237 ymm3 = _mm256_loadu_pd(tA + 4);
2238 ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
2239 ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
2240
2241 ymm3 = _mm256_loadu_pd(tA + 8);
2242 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
2243 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
2244
2245 ymm3 = _mm256_loadu_pd(tA + 12);
2246 ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11);
2247 ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);
2248
2249 tA += lda;
2250
2251 }
2252 // alpha, beta multiplication.
2253 ymm0 = _mm256_broadcast_sd(alpha_cast);
2254 ymm1 = _mm256_broadcast_sd(beta_cast);
2255
2256 //multiply A*B by alpha.
2257 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2258 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2259 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2260 ymm11 = _mm256_mul_pd(ymm11, ymm0);
2261 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2262 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2263 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2264 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2265
2266 if(is_beta_non_zero)
2267 {
2268 // multiply C by beta and accumulate, col 1.
2269 ymm2 = _mm256_loadu_pd(tC + 0);
2270 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2271 ymm2 = _mm256_loadu_pd(tC + 4);
2272 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2273 ymm2 = _mm256_loadu_pd(tC + 8);
2274 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2275 ymm2 = _mm256_loadu_pd(tC + 12);
2276 ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
2277
2278 // multiply C by beta and accumulate, col 2.
2279 double *ttC = tC + ldc;
2280
2281 ymm2 = _mm256_loadu_pd(ttC);
2282 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2283 ymm2 = _mm256_loadu_pd(ttC + 4);
2284 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2285 ymm2 = _mm256_loadu_pd(ttC + 8);
2286 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2287 ymm2 = _mm256_loadu_pd(ttC + 12);
2288 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2289 }
2290
2291 _mm256_storeu_pd(tC + 0, ymm8);
2292 _mm256_storeu_pd(tC + 4, ymm9);
2293 _mm256_storeu_pd(tC + 8, ymm10);
2294 _mm256_storeu_pd(tC + 12, ymm11);
2295
2296 tC += ldc;
2297
2298 _mm256_storeu_pd(tC, ymm12);
2299 _mm256_storeu_pd(tC + 4, ymm13);
2300 _mm256_storeu_pd(tC + 8, ymm14);
2301 _mm256_storeu_pd(tC + 12, ymm15);
2302 col_idx += 2;
2303 }
2304 // if the N is not multiple of 3.
2305 // handling edge case.
2306 if (n_remainder == 1)
2307 {
2308 //pointer math to point to proper memory
2309 tC = C + ldc * col_idx + row_idx;
2310 tB = B + tb_inc_col * col_idx;
2311 tA = A + row_idx;
2312
2313 // clear scratch registers.
2314 ymm12 = _mm256_setzero_pd();
2315 ymm13 = _mm256_setzero_pd();
2316 ymm14 = _mm256_setzero_pd();
2317 ymm15 = _mm256_setzero_pd();
2318
2319 for (k = 0; k < K; ++k)
2320 {
2321 // The inner loop broadcasts the B matrix data and
2322 // multiplies it with the A matrix.
2323 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2324 tB += tb_inc_row;
2325
2326 //broadcasted matrix B elements are multiplied
2327 //with matrix A columns.
2328 ymm3 = _mm256_loadu_pd(tA);
2329 ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
2330
2331 ymm3 = _mm256_loadu_pd(tA + 4);
2332 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
2333
2334 ymm3 = _mm256_loadu_pd(tA + 8);
2335 ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
2336
2337 ymm3 = _mm256_loadu_pd(tA + 12);
2338 ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15);
2339
2340 tA += lda;
2341
2342 }
2343 // alpha, beta multiplication.
2344 ymm0 = _mm256_broadcast_sd(alpha_cast);
2345 ymm1 = _mm256_broadcast_sd(beta_cast);
2346
2347 //multiply A*B by alpha.
2348 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2349 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2350 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2351 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2352
2353 if(is_beta_non_zero)
2354 {
2355 // multiply C by beta and accumulate.
2356 ymm2 = _mm256_loadu_pd(tC + 0);
2357 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2358 ymm2 = _mm256_loadu_pd(tC + 4);
2359 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2360 ymm2 = _mm256_loadu_pd(tC + 8);
2361 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2362 ymm2 = _mm256_loadu_pd(tC + 12);
2363 ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2364 }
2365
2366 _mm256_storeu_pd(tC + 0, ymm12);
2367 _mm256_storeu_pd(tC + 4, ymm13);
2368 _mm256_storeu_pd(tC + 8, ymm14);
2369 _mm256_storeu_pd(tC + 12, ymm15);
2370 }
2371 }
2372
2373 m_remainder = M - row_idx;
2374
2375 if (m_remainder >= 12)
2376 {
2377 m_remainder -= 12;
2378
2379 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2380 {
2381 //pointer math to point to proper memory
2382 tC = C + ldc * col_idx + row_idx;
2383 tB = B + tb_inc_col * col_idx;
2384 tA = A + row_idx;
2385
2386 // clear scratch registers.
2387 ymm4 = _mm256_setzero_pd();
2388 ymm5 = _mm256_setzero_pd();
2389 ymm6 = _mm256_setzero_pd();
2390 ymm8 = _mm256_setzero_pd();
2391 ymm9 = _mm256_setzero_pd();
2392 ymm10 = _mm256_setzero_pd();
2393 ymm12 = _mm256_setzero_pd();
2394 ymm13 = _mm256_setzero_pd();
2395 ymm14 = _mm256_setzero_pd();
2396
2397 for (k = 0; k < K; ++k)
2398 {
2399 // The inner loop broadcasts the B matrix data and
2400 // multiplies it with the A matrix.
2401 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2402 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2403 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2404 tB += tb_inc_row;
2405
2406 //broadcasted matrix B elements are multiplied
2407 //with matrix A columns.
2408 ymm3 = _mm256_loadu_pd(tA);
2409 // ymm4 += ymm0 * ymm3;
2410 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2411 // ymm8 += ymm1 * ymm3;
2412 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
2413 // ymm12 += ymm2 * ymm3;
2414 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
2415
2416 ymm3 = _mm256_loadu_pd(tA + 4);
2417 // ymm5 += ymm0 * ymm3;
2418 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2419 // ymm9 += ymm1 * ymm3;
2420 ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
2421 // ymm13 += ymm2 * ymm3;
2422 ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
2423
2424 ymm3 = _mm256_loadu_pd(tA + 8);
2425 // ymm6 += ymm0 * ymm3;
2426 ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
2427 // ymm10 += ymm1 * ymm3;
2428 ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
2429 // ymm14 += ymm2 * ymm3;
2430 ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
2431
2432 tA += lda;
2433 }
2434 // alpha, beta multiplication.
2435 ymm0 = _mm256_broadcast_sd(alpha_cast);
2436 ymm1 = _mm256_broadcast_sd(beta_cast);
2437
2438 //multiply A*B by alpha.
2439 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2440 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2441 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2442 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2443 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2444 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2445 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2446 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2447 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2448
2449 if(is_beta_non_zero)
2450 {
2451 // multiply C by beta and accumulate.
2452 ymm2 = _mm256_loadu_pd(tC);
2453 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2454 ymm2 = _mm256_loadu_pd(tC + 4);
2455 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2456 ymm2 = _mm256_loadu_pd(tC + 8);
2457 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2458
2459 // multiply C by beta and accumulate.
2460 double *ttC = tC +ldc;
2461 ymm2 = _mm256_loadu_pd(ttC);
2462 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2463 ymm2 = _mm256_loadu_pd(ttC + 4);
2464 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2465 ymm2 = _mm256_loadu_pd(ttC + 8);
2466 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2467
2468 // multiply C by beta and accumulate.
2469 ttC += ldc;
2470 ymm2 = _mm256_loadu_pd(ttC);
2471 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2472 ymm2 = _mm256_loadu_pd(ttC + 4);
2473 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2474 ymm2 = _mm256_loadu_pd(ttC + 8);
2475 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2476
2477 }
2478 _mm256_storeu_pd(tC, ymm4);
2479 _mm256_storeu_pd(tC + 4, ymm5);
2480 _mm256_storeu_pd(tC + 8, ymm6);
2481
2482 tC += ldc;
2483
2484 _mm256_storeu_pd(tC, ymm8);
2485 _mm256_storeu_pd(tC + 4, ymm9);
2486 _mm256_storeu_pd(tC + 8, ymm10);
2487
2488 tC += ldc;
2489
2490 _mm256_storeu_pd(tC, ymm12);
2491 _mm256_storeu_pd(tC + 4, ymm13);
2492 _mm256_storeu_pd(tC + 8, ymm14);
2493 }
2494 n_remainder = N - col_idx;
2495 // if the N is not multiple of 3.
2496 // handling edge case.
2497 if (n_remainder == 2)
2498 {
2499 //pointer math to point to proper memory
2500 tC = C + ldc * col_idx + row_idx;
2501 tB = B + tb_inc_col * col_idx;
2502 tA = A + row_idx;
2503
2504 // clear scratch registers.
2505 ymm8 = _mm256_setzero_pd();
2506 ymm9 = _mm256_setzero_pd();
2507 ymm10 = _mm256_setzero_pd();
2508 ymm12 = _mm256_setzero_pd();
2509 ymm13 = _mm256_setzero_pd();
2510 ymm14 = _mm256_setzero_pd();
2511
2512 for (k = 0; k < K; ++k)
2513 {
2514 // The inner loop broadcasts the B matrix data and
2515 // multiplies it with the A matrix.
2516 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2517 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2518 tB += tb_inc_row;
2519
2520 //broadcasted matrix B elements are multiplied
2521 //with matrix A columns.
2522 ymm3 = _mm256_loadu_pd(tA);
2523 ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
2524 ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
2525
2526 ymm3 = _mm256_loadu_pd(tA + 4);
2527 ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
2528 ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
2529
2530 ymm3 = _mm256_loadu_pd(tA + 8);
2531 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
2532 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
2533
2534 tA += lda;
2535
2536 }
2537 // alpha, beta multiplication.
2538 ymm0 = _mm256_broadcast_sd(alpha_cast);
2539 ymm1 = _mm256_broadcast_sd(beta_cast);
2540
2541 //multiply A*B by alpha.
2542 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2543 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2544 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2545 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2546 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2547 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2548
2549
2550 if(is_beta_non_zero)
2551 {
2552 // multiply C by beta and accumulate.
2553 ymm2 = _mm256_loadu_pd(tC + 0);
2554 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2555 ymm2 = _mm256_loadu_pd(tC + 4);
2556 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2557 ymm2 = _mm256_loadu_pd(tC + 8);
2558 ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2559
2560 double *ttC = tC + ldc;
2561
2562 // multiply C by beta and accumulate.
2563 ymm2 = _mm256_loadu_pd(ttC);
2564 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2565 ymm2 = _mm256_loadu_pd(ttC + 4);
2566 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2567 ymm2 = _mm256_loadu_pd(ttC + 8);
2568 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2569
2570 }
2571 _mm256_storeu_pd(tC + 0, ymm8);
2572 _mm256_storeu_pd(tC + 4, ymm9);
2573 _mm256_storeu_pd(tC + 8, ymm10);
2574
2575 tC += ldc;
2576
2577 _mm256_storeu_pd(tC, ymm12);
2578 _mm256_storeu_pd(tC + 4, ymm13);
2579 _mm256_storeu_pd(tC + 8, ymm14);
2580
2581 col_idx += 2;
2582 }
2583 // if the N is not multiple of 3.
2584 // handling edge case.
2585 if (n_remainder == 1)
2586 {
2587 //pointer math to point to proper memory
2588 tC = C + ldc * col_idx + row_idx;
2589 tB = B + tb_inc_col * col_idx;
2590 tA = A + row_idx;
2591
2592 // clear scratch registers.
2593 ymm12 = _mm256_setzero_pd();
2594 ymm13 = _mm256_setzero_pd();
2595 ymm14 = _mm256_setzero_pd();
2596
2597 for (k = 0; k < K; ++k)
2598 {
2599 // The inner loop broadcasts the B matrix data and
2600 // multiplies it with the A matrix.
2601 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2602 tB += tb_inc_row;
2603
2604 //broadcasted matrix B elements are multiplied
2605 //with matrix A columns.
2606 ymm3 = _mm256_loadu_pd(tA);
2607 ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
2608
2609 ymm3 = _mm256_loadu_pd(tA + 4);
2610 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
2611
2612 ymm3 = _mm256_loadu_pd(tA + 8);
2613 ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
2614
2615 tA += lda;
2616
2617 }
2618 // alpha, beta multiplication.
2619 ymm0 = _mm256_broadcast_sd(alpha_cast);
2620 ymm1 = _mm256_broadcast_sd(beta_cast);
2621
2622 //multiply A*B by alpha.
2623 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2624 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2625 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2626
2627
2628 if(is_beta_non_zero)
2629 {
2630 // multiply C by beta and accumulate.
2631 ymm2 = _mm256_loadu_pd(tC + 0);
2632 ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2633 ymm2 = _mm256_loadu_pd(tC + 4);
2634 ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2635 ymm2 = _mm256_loadu_pd(tC + 8);
2636 ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2637
2638 }
2639 _mm256_storeu_pd(tC + 0, ymm12);
2640 _mm256_storeu_pd(tC + 4, ymm13);
2641 _mm256_storeu_pd(tC + 8, ymm14);
2642 }
2643
2644 row_idx += 12;
2645 }
2646
2647 if (m_remainder >= 8)
2648 {
2649 m_remainder -= 8;
2650
2651 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2652 {
2653 //pointer math to point to proper memory
2654 tC = C + ldc * col_idx + row_idx;
2655 tB = B + tb_inc_col * col_idx;
2656 tA = A + row_idx;
2657
2658 // clear scratch registers.
2659 ymm4 = _mm256_setzero_pd();
2660 ymm5 = _mm256_setzero_pd();
2661 ymm6 = _mm256_setzero_pd();
2662 ymm7 = _mm256_setzero_pd();
2663 ymm8 = _mm256_setzero_pd();
2664 ymm9 = _mm256_setzero_pd();
2665
2666 for (k = 0; k < K; ++k)
2667 {
2668 // The inner loop broadcasts the B matrix data and
2669 // multiplies it with the A matrix.
2670 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2671 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2672 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2673 tB += tb_inc_row;
2674
2675 //broadcasted matrix B elements are multiplied
2676 //with matrix A columns.
2677 ymm3 = _mm256_loadu_pd(tA);
2678 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2679 ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
2680 ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8);
2681
2682 ymm3 = _mm256_loadu_pd(tA + 4);
2683 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2684 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2685 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
2686
2687 tA += lda;
2688 }
2689 // alpha, beta multiplication.
2690 ymm0 = _mm256_broadcast_sd(alpha_cast);
2691 ymm1 = _mm256_broadcast_sd(beta_cast);
2692
2693 //multiply A*B by alpha.
2694 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2695 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2696 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2697 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2698 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2699 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2700
2701 if(is_beta_non_zero)
2702 {
2703 // multiply C by beta and accumulate.
2704 ymm2 = _mm256_loadu_pd(tC);
2705 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2706 ymm2 = _mm256_loadu_pd(tC + 4);
2707 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2708
2709 double* ttC = tC + ldc;
2710
2711 // multiply C by beta and accumulate.
2712 ymm2 = _mm256_loadu_pd(ttC);
2713 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2714 ymm2 = _mm256_loadu_pd(ttC + 4);
2715 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2716
2717 ttC += ldc;
2718
2719 // multiply C by beta and accumulate.
2720 ymm2 = _mm256_loadu_pd(ttC);
2721 ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2722 ymm2 = _mm256_loadu_pd(ttC + 4);
2723 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2724 }
2725
2726 _mm256_storeu_pd(tC, ymm4);
2727 _mm256_storeu_pd(tC + 4, ymm5);
2728
2729 tC += ldc;
2730 _mm256_storeu_pd(tC, ymm6);
2731 _mm256_storeu_pd(tC + 4, ymm7);
2732
2733 tC += ldc;
2734 _mm256_storeu_pd(tC, ymm8);
2735 _mm256_storeu_pd(tC + 4, ymm9);
2736
2737 }
2738 n_remainder = N - col_idx;
2739 // if the N is not multiple of 3.
2740 // handling edge case.
2741 if (n_remainder == 2)
2742 {
2743 //pointer math to point to proper memory
2744 tC = C + ldc * col_idx + row_idx;
2745 tB = B + tb_inc_col * col_idx;
2746 tA = A + row_idx;
2747
2748 // clear scratch registers.
2749 ymm4 = _mm256_setzero_pd();
2750 ymm5 = _mm256_setzero_pd();
2751 ymm6 = _mm256_setzero_pd();
2752 ymm7 = _mm256_setzero_pd();
2753
2754 for (k = 0; k < K; ++k)
2755 {
2756 // The inner loop broadcasts the B matrix data and
2757 // multiplies it with the A matrix.
2758 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2759 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2760 tB += tb_inc_row;
2761
2762 //broadcasted matrix B elements are multiplied
2763 //with matrix A columns.
2764 ymm3 = _mm256_loadu_pd(tA);
2765 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2766 ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
2767
2768 ymm3 = _mm256_loadu_pd(tA + 4);
2769 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2770 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2771
2772 tA += lda;
2773 }
2774 // alpha, beta multiplication.
2775 ymm0 = _mm256_broadcast_sd(alpha_cast);
2776 ymm1 = _mm256_broadcast_sd(beta_cast);
2777
2778 //multiply A*B by alpha.
2779 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2780 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2781 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2782 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2783
2784 if(is_beta_non_zero)
2785 {
2786 // multiply C by beta and accumulate.
2787 ymm2 = _mm256_loadu_pd(tC);
2788 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2789 ymm2 = _mm256_loadu_pd(tC + 4);
2790 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2791
2792 double* ttC = tC + ldc;
2793
2794 // multiply C by beta and accumulate.
2795 ymm2 = _mm256_loadu_pd(ttC);
2796 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2797 ymm2 = _mm256_loadu_pd(ttC + 4);
2798 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2799 }
2800 _mm256_storeu_pd(tC, ymm4);
2801 _mm256_storeu_pd(tC + 4, ymm5);
2802
2803 tC += ldc;
2804 _mm256_storeu_pd(tC, ymm6);
2805 _mm256_storeu_pd(tC + 4, ymm7);
2806
2807 col_idx += 2;
2808
2809 }
2810 // if the N is not multiple of 3.
2811 // handling edge case.
2812 if (n_remainder == 1)
2813 {
2814 //pointer math to point to proper memory
2815 tC = C + ldc * col_idx + row_idx;
2816 tB = B + tb_inc_col * col_idx;
2817 tA = A + row_idx;
2818
2819 ymm4 = _mm256_setzero_pd();
2820 ymm5 = _mm256_setzero_pd();
2821
2822 for (k = 0; k < K; ++k)
2823 {
2824 // The inner loop broadcasts the B matrix data and
2825 // multiplies it with the A matrix.
2826 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2827 tB += tb_inc_row;
2828
2829 //broadcasted matrix B elements are multiplied
2830 //with matrix A columns.
2831 ymm3 = _mm256_loadu_pd(tA);
2832 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2833
2834 ymm3 = _mm256_loadu_pd(tA + 4);
2835 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2836
2837 tA += lda;
2838 }
2839 // alpha, beta multiplication.
2840 ymm0 = _mm256_broadcast_sd(alpha_cast);
2841 ymm1 = _mm256_broadcast_sd(beta_cast);
2842
2843 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2844 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2845
2846 if(is_beta_non_zero)
2847 {
2848 // multiply C by beta and accumulate.
2849 ymm2 = _mm256_loadu_pd(tC);
2850 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2851 ymm2 = _mm256_loadu_pd(tC + 4);
2852 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2853 }
2854 _mm256_storeu_pd(tC, ymm4);
2855 _mm256_storeu_pd(tC + 4, ymm5);
2856
2857 }
2858
2859 row_idx += 8;
2860 }
2861
2862 if (m_remainder >= 4)
2863 {
2864 //printf("HERE\n");
2865 m_remainder -= 4;
2866
2867 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2868 {
2869 //pointer math to point to proper memory
2870 tC = C + ldc * col_idx + row_idx;
2871 tB = B + tb_inc_col * col_idx;
2872 tA = A + row_idx;
2873
2874 // clear scratch registers.
2875 ymm4 = _mm256_setzero_pd();
2876 ymm5 = _mm256_setzero_pd();
2877 ymm6 = _mm256_setzero_pd();
2878
2879 for (k = 0; k < K; ++k)
2880 {
2881 // The inner loop broadcasts the B matrix data and
2882 // multiplies it with the A matrix.
2883 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2884 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2885 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2886 tB += tb_inc_row;
2887
2888 //broadcasted matrix B elements are multiplied
2889 //with matrix A columns.
2890 ymm3 = _mm256_loadu_pd(tA);
2891 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2892 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
2893 ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
2894
2895 tA += lda;
2896 }
2897 // alpha, beta multiplication.
2898 ymm0 = _mm256_broadcast_sd(alpha_cast);
2899 ymm1 = _mm256_broadcast_sd(beta_cast);
2900
2901 //multiply A*B by alpha.
2902 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2903 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2904 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2905
2906 if(is_beta_non_zero)
2907 {
2908 // multiply C by beta and accumulate.
2909 ymm2 = _mm256_loadu_pd(tC);
2910 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2911
2912 double* ttC = tC + ldc;
2913
2914 // multiply C by beta and accumulate.
2915 ymm2 = _mm256_loadu_pd(ttC);
2916 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2917
2918 ttC += ldc;
2919
2920 // multiply C by beta and accumulate.
2921 ymm2 = _mm256_loadu_pd(ttC);
2922 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2923 }
2924 _mm256_storeu_pd(tC, ymm4);
2925
2926 tC += ldc;
2927 _mm256_storeu_pd(tC, ymm5);
2928
2929 tC += ldc;
2930 _mm256_storeu_pd(tC, ymm6);
2931 }
2932 n_remainder = N - col_idx;
2933 // if the N is not multiple of 3.
2934 // handling edge case.
2935 if (n_remainder == 2)
2936 {
2937 //pointer math to point to proper memory
2938 tC = C + ldc * col_idx + row_idx;
2939 tB = B + tb_inc_col * col_idx;
2940 tA = A + row_idx;
2941
2942 ymm4 = _mm256_setzero_pd();
2943 ymm5 = _mm256_setzero_pd();
2944
2945 for (k = 0; k < K; ++k)
2946 {
2947 // The inner loop broadcasts the B matrix data and
2948 // multiplies it with the A matrix.
2949 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2950 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2951 tB += tb_inc_row;
2952
2953 //broadcasted matrix B elements are multiplied
2954 //with matrix A columns.
2955 ymm3 = _mm256_loadu_pd(tA);
2956 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2957 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
2958
2959 tA += lda;
2960 }
2961 // alpha, beta multiplication.
2962 ymm0 = _mm256_broadcast_sd(alpha_cast);
2963 ymm1 = _mm256_broadcast_sd(beta_cast);
2964
2965 //multiply A*B by alpha.
2966 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2967 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2968
2969 if(is_beta_non_zero)
2970 {
2971 // multiply C by beta and accumulate.
2972 ymm2 = _mm256_loadu_pd(tC);
2973 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2974
2975 double* ttC = tC + ldc;
2976
2977 // multiply C by beta and accumulate.
2978 ymm2 = _mm256_loadu_pd(ttC);
2979 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2980 }
2981 _mm256_storeu_pd(tC, ymm4);
2982
2983 tC += ldc;
2984 _mm256_storeu_pd(tC, ymm5);
2985
2986 col_idx += 2;
2987
2988 }
2989 // if the N is not multiple of 3.
2990 // handling edge case.
2991 if (n_remainder == 1)
2992 {
2993 //pointer math to point to proper memory
2994 tC = C + ldc * col_idx + row_idx;
2995 tB = B + tb_inc_col * col_idx;
2996 tA = A + row_idx;
2997
2998 ymm4 = _mm256_setzero_pd();
2999
3000 for (k = 0; k < K; ++k)
3001 {
3002 // The inner loop broadcasts the B matrix data and
3003 // multiplies it with the A matrix.
3004 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3005 tB += tb_inc_row;
3006
3007 //broadcasted matrix B elements are multiplied
3008 //with matrix A columns.
3009 ymm3 = _mm256_loadu_pd(tA);
3010 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3011
3012 tA += lda;
3013 }
3014 // alpha, beta multiplication.
3015 ymm0 = _mm256_broadcast_sd(alpha_cast);
3016 ymm1 = _mm256_broadcast_sd(beta_cast);
3017
3018 ymm4 = _mm256_mul_pd(ymm4, ymm0);
3019
3020 if(is_beta_non_zero)
3021 {
3022 // multiply C by beta and accumulate.
3023 ymm2 = _mm256_loadu_pd(tC);
3024 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
3025
3026 }
3027 _mm256_storeu_pd(tC, ymm4);
3028
3029 }
3030
3031 row_idx += 4;
3032 }
3033 // M is not a multiple of 32.
3034 // The handling of edge case where the remainder
3035 // dimension is less than 8. The padding takes place
3036 // to handle this case.
3037 if ((m_remainder) && (lda > 3))
3038 {
3039 double f_temp[8] = {0.0};
3040
3041 for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
3042 {
3043 //pointer math to point to proper memory
3044 tC = C + ldc * col_idx + row_idx;
3045 tB = B + tb_inc_col * col_idx;
3046 tA = A + row_idx;
3047
3048 // clear scratch registers.
3049 ymm5 = _mm256_setzero_pd();
3050 ymm7 = _mm256_setzero_pd();
3051 ymm9 = _mm256_setzero_pd();
3052
3053 for (k = 0; k < (K - 1); ++k)
3054 {
3055 // The inner loop broadcasts the B matrix data and
3056 // multiplies it with the A matrix.
3057 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3058 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3059 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
3060 tB += tb_inc_row;
3061
3062 //broadcasted matrix B elements are multiplied
3063 //with matrix A columns.
3064 ymm3 = _mm256_loadu_pd(tA);
3065 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3066 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3067 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3068
3069 tA += lda;
3070 }
3071 // alpha, beta multiplication.
3072 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3073 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3074 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
3075 tB += tb_inc_row;
3076
3077 for (int i = 0; i < m_remainder; i++)
3078 {
3079 f_temp[i] = tA[i];
3080 }
3081 ymm3 = _mm256_loadu_pd(f_temp);
3082 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3083 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3084 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3085
3086 ymm0 = _mm256_broadcast_sd(alpha_cast);
3087 ymm1 = _mm256_broadcast_sd(beta_cast);
3088
3089 //multiply A*B by alpha.
3090 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3091 ymm7 = _mm256_mul_pd(ymm7, ymm0);
3092 ymm9 = _mm256_mul_pd(ymm9, ymm0);
3093
3094 if(is_beta_non_zero)
3095 {
3096 for (int i = 0; i < m_remainder; i++)
3097 {
3098 f_temp[i] = tC[i];
3099 }
3100 ymm2 = _mm256_loadu_pd(f_temp);
3101 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
3102
3103
3104 double* ttC = tC + ldc;
3105
3106 for (int i = 0; i < m_remainder; i++)
3107 {
3108 f_temp[i] = ttC[i];
3109 }
3110 ymm2 = _mm256_loadu_pd(f_temp);
3111 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
3112
3113 ttC += ldc;
3114 for (int i = 0; i < m_remainder; i++)
3115 {
3116 f_temp[i] = ttC[i];
3117 }
3118 ymm2 = _mm256_loadu_pd(f_temp);
3119 ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
3120 }
3121 _mm256_storeu_pd(f_temp, ymm5);
3122 for (int i = 0; i < m_remainder; i++)
3123 {
3124 tC[i] = f_temp[i];
3125 }
3126
3127 tC += ldc;
3128 _mm256_storeu_pd(f_temp, ymm7);
3129 for (int i = 0; i < m_remainder; i++)
3130 {
3131 tC[i] = f_temp[i];
3132 }
3133
3134 tC += ldc;
3135 _mm256_storeu_pd(f_temp, ymm9);
3136 for (int i = 0; i < m_remainder; i++)
3137 {
3138 tC[i] = f_temp[i];
3139 }
3140 }
3141 n_remainder = N - col_idx;
3142 // if the N is not multiple of 3.
3143 // handling edge case.
3144 if (n_remainder == 2)
3145 {
3146 //pointer math to point to proper memory
3147 tC = C + ldc * col_idx + row_idx;
3148 tB = B + tb_inc_col * col_idx;
3149 tA = A + row_idx;
3150
3151 ymm5 = _mm256_setzero_pd();
3152 ymm7 = _mm256_setzero_pd();
3153
3154 for (k = 0; k < (K - 1); ++k)
3155 {
3156 // The inner loop broadcasts the B matrix data and
3157 // multiplies it with the A matrix.
3158 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3159 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3160 tB += tb_inc_row;
3161
3162 ymm3 = _mm256_loadu_pd(tA);
3163 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3164 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3165
3166 tA += lda;
3167 }
3168
3169 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3170 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3171 tB += tb_inc_row;
3172
3173 for (int i = 0; i < m_remainder; i++)
3174 {
3175 f_temp[i] = tA[i];
3176 }
3177 ymm3 = _mm256_loadu_pd(f_temp);
3178 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3179 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3180
3181 ymm0 = _mm256_broadcast_sd(alpha_cast);
3182 ymm1 = _mm256_broadcast_sd(beta_cast);
3183
3184 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3185 ymm7 = _mm256_mul_pd(ymm7, ymm0);
3186
3187 if(is_beta_non_zero)
3188 {
3189 for (int i = 0; i < m_remainder; i++)
3190 {
3191 f_temp[i] = tC[i];
3192 }
3193 ymm2 = _mm256_loadu_pd(f_temp);
3194 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
3195
3196 double* ttC = tC + ldc;
3197
3198 for (int i = 0; i < m_remainder; i++)
3199 {
3200 f_temp[i] = ttC[i];
3201 }
3202 ymm2 = _mm256_loadu_pd(f_temp);
3203 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
3204
3205 }
3206 _mm256_storeu_pd(f_temp, ymm5);
3207 for (int i = 0; i < m_remainder; i++)
3208 {
3209 tC[i] = f_temp[i];
3210 }
3211
3212 tC += ldc;
3213 _mm256_storeu_pd(f_temp, ymm7);
3214 for (int i = 0; i < m_remainder; i++)
3215 {
3216 tC[i] = f_temp[i];
3217 }
3218 }
3219 // if the N is not multiple of 3.
3220 // handling edge case.
3221 if (n_remainder == 1)
3222 {
3223 //pointer math to point to proper memory
3224 tC = C + ldc * col_idx + row_idx;
3225 tB = B + tb_inc_col * col_idx;
3226 tA = A + row_idx;
3227
3228 ymm5 = _mm256_setzero_pd();
3229
3230 for (k = 0; k < (K - 1); ++k)
3231 {
3232 // The inner loop broadcasts the B matrix data and
3233 // multiplies it with the A matrix.
3234 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3235 tB += tb_inc_row;
3236
3237 ymm3 = _mm256_loadu_pd(tA);
3238 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3239
3240 tA += lda;
3241 }
3242
3243 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3244 tB += tb_inc_row;
3245
3246 for (int i = 0; i < m_remainder; i++)
3247 {
3248 f_temp[i] = tA[i];
3249 }
3250 ymm3 = _mm256_loadu_pd(f_temp);
3251 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3252
3253 ymm0 = _mm256_broadcast_sd(alpha_cast);
3254 ymm1 = _mm256_broadcast_sd(beta_cast);
3255
3256 // multiply C by beta and accumulate.
3257 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3258
3259 if(is_beta_non_zero)
3260 {
3261
3262 for (int i = 0; i < m_remainder; i++)
3263 {
3264 f_temp[i] = tC[i];
3265 }
3266 ymm2 = _mm256_loadu_pd(f_temp);
3267 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
3268 }
3269 _mm256_storeu_pd(f_temp, ymm5);
3270 for (int i = 0; i < m_remainder; i++)
3271 {
3272 tC[i] = f_temp[i];
3273 }
3274 }
3275 m_remainder = 0;
3276 }
3277
3278 if (m_remainder)
3279 {
3280 double result;
3281 for (; row_idx < M; row_idx += 1)
3282 {
3283 for (col_idx = 0; col_idx < N; col_idx += 1)
3284 {
3285 //pointer math to point to proper memory
3286 tC = C + ldc * col_idx + row_idx;
3287 tB = B + tb_inc_col * col_idx;
3288 tA = A + row_idx;
3289
3290 result = 0;
3291 for (k = 0; k < K; ++k)
3292 {
3293 result += (*tA) * (*tB);
3294 tA += lda;
3295 tB += tb_inc_row;
3296 }
3297
3298 result *= (*alpha_cast);
3299 if(is_beta_non_zero)
3300 (*tC) = (*tC) * (*beta_cast) + result;
3301 else
3302 (*tC) = result;
3303 }
3304 }
3305 }
3306
3307 // Return the buffer to pool
3308 if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) {
3309 #ifdef BLIS_ENABLE_MEM_TRACING
3310 printf( "bli_dgemm_small(): releasing mem pool block\n" );
3311 #endif
3312 bli_membrk_release(&rntm,
3313 &local_mem_buf_A_s);
3314 }
3315 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
3316 return BLIS_SUCCESS;
3317 }
3318 else
3319 {
3320 AOCL_DTL_TRACE_EXIT_ERR(
3321 AOCL_DTL_LEVEL_INFO,
3322 "Invalid dimesions for small gemm."
3323 );
3324 return BLIS_NONCONFORMAL_DIMENSIONS;
3325 }
3326 };
3327
bli_sgemm_small_atbn(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)3328 static err_t bli_sgemm_small_atbn
3329 (
3330 obj_t* alpha,
3331 obj_t* a,
3332 obj_t* b,
3333 obj_t* beta,
3334 obj_t* c,
3335 cntx_t* cntx,
3336 cntl_t* cntl
3337 )
3338 {
3339 AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
3340
3341 gint_t M = bli_obj_length( c ); // number of rows of Matrix C
3342 gint_t N = bli_obj_width( c ); // number of columns of Matrix C
3343 gint_t K = bli_obj_length( b ); // number of rows of Matrix B
3344
3345 guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
3346 guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
3347 guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
3348
3349 int row_idx = 0, col_idx = 0, k;
3350
3351 float *A = bli_obj_buffer_at_off(a); // pointer to matrix A elements, stored in row major format
3352 float *B = bli_obj_buffer_at_off(b); // pointer to matrix B elements, stored in column major format
3353 float *C = bli_obj_buffer_at_off(c); // pointer to matrix C elements, stored in column major format
3354
3355 float *tA = A, *tB = B, *tC = C;
3356
3357 __m256 ymm4, ymm5, ymm6, ymm7;
3358 __m256 ymm8, ymm9, ymm10, ymm11;
3359 __m256 ymm12, ymm13, ymm14, ymm15;
3360 __m256 ymm0, ymm1, ymm2, ymm3;
3361
3362 float result;
3363 float scratch[8] = {0.0};
3364 const num_t dt_exec = bli_obj_dt( c );
3365 float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha );
3366 float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta );
3367
3368 /*Beta Zero Check*/
3369 bool is_beta_non_zero=0;
3370 if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
3371 is_beta_non_zero = 1;
3372 }
3373
3374 // The non-copy version of the A^T GEMM gives better performance for the small M cases.
3375 // The threshold is controlled by BLIS_ATBN_M_THRES
3376 if (M <= BLIS_ATBN_M_THRES)
3377 {
3378 for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
3379 {
3380 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3381 {
3382 tA = A + row_idx * lda;
3383 tB = B + col_idx * ldb;
3384 tC = C + col_idx * ldc + row_idx;
3385 // clear scratch registers.
3386 ymm4 = _mm256_setzero_ps();
3387 ymm5 = _mm256_setzero_ps();
3388 ymm6 = _mm256_setzero_ps();
3389 ymm7 = _mm256_setzero_ps();
3390 ymm8 = _mm256_setzero_ps();
3391 ymm9 = _mm256_setzero_ps();
3392 ymm10 = _mm256_setzero_ps();
3393 ymm11 = _mm256_setzero_ps();
3394 ymm12 = _mm256_setzero_ps();
3395 ymm13 = _mm256_setzero_ps();
3396 ymm14 = _mm256_setzero_ps();
3397 ymm15 = _mm256_setzero_ps();
3398
3399 //The inner loop computes the 4x3 values of the matrix.
3400 //The computation pattern is:
3401 // ymm4 ymm5 ymm6
3402 // ymm7 ymm8 ymm9
3403 // ymm10 ymm11 ymm12
3404 // ymm13 ymm14 ymm15
3405
3406 //The Dot operation is performed in the inner loop, 8 float elements fit
3407 //in the YMM register hence loop count incremented by 8
3408 for (k = 0; (k + 7) < K; k += 8)
3409 {
3410 ymm0 = _mm256_loadu_ps(tB + 0);
3411 ymm1 = _mm256_loadu_ps(tB + ldb);
3412 ymm2 = _mm256_loadu_ps(tB + 2 * ldb);
3413
3414 ymm3 = _mm256_loadu_ps(tA);
3415 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3416 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
3417 ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
3418
3419 ymm3 = _mm256_loadu_ps(tA + lda);
3420 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3421 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
3422 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
3423
3424 ymm3 = _mm256_loadu_ps(tA + 2 * lda);
3425 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3426 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
3427 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
3428
3429 ymm3 = _mm256_loadu_ps(tA + 3 * lda);
3430 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3431 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
3432 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
3433
3434 tA += 8;
3435 tB += 8;
3436
3437 }
3438
3439 // if K is not a multiple of 8, padding is done before load using temproary array.
3440 if (k < K)
3441 {
3442 int iter;
3443 float data_feeder[8] = { 0.0 };
3444
3445 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3446 ymm0 = _mm256_loadu_ps(data_feeder);
3447 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
3448 ymm1 = _mm256_loadu_ps(data_feeder);
3449 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
3450 ymm2 = _mm256_loadu_ps(data_feeder);
3451
3452 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3453 ymm3 = _mm256_loadu_ps(data_feeder);
3454 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3455 ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
3456 ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
3457
3458 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3459 ymm3 = _mm256_loadu_ps(data_feeder);
3460 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3461 ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
3462 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
3463
3464 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3465 ymm3 = _mm256_loadu_ps(data_feeder);
3466 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3467 ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
3468 ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
3469
3470 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3471 ymm3 = _mm256_loadu_ps(data_feeder);
3472 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3473 ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
3474 ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
3475
3476 }
3477
3478 //horizontal addition and storage of the data.
3479 //Results for 4x3 blocks of C is stored here
3480 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3481 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3482 _mm256_storeu_ps(scratch, ymm4);
3483 result = scratch[0] + scratch[4];
3484 result *= (*alpha_cast);
3485 if(is_beta_non_zero){
3486 tC[0] = result + tC[0] * (*beta_cast);
3487 }else{
3488 tC[0] = result;
3489 }
3490
3491 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3492 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3493 _mm256_storeu_ps(scratch, ymm7);
3494 result = scratch[0] + scratch[4];
3495 result *= (*alpha_cast);
3496 if(is_beta_non_zero){
3497 tC[1] = result + tC[1] * (*beta_cast);
3498 }else{
3499 tC[1] = result;
3500 }
3501
3502 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3503 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3504 _mm256_storeu_ps(scratch, ymm10);
3505 result = scratch[0] + scratch[4];
3506 result *= (*alpha_cast);
3507 if(is_beta_non_zero){
3508 tC[2] = result + tC[2] * (*beta_cast);
3509 }else{
3510 tC[2] = result;
3511 }
3512
3513 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3514 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3515 _mm256_storeu_ps(scratch, ymm13);
3516 result = scratch[0] + scratch[4];
3517 result *= (*alpha_cast);
3518 if(is_beta_non_zero){
3519 tC[3] = result + tC[3] * (*beta_cast);
3520 }else{
3521 tC[3] = result;
3522 }
3523
3524 tC += ldc;
3525 ymm5 = _mm256_hadd_ps(ymm5, ymm5);
3526 ymm5 = _mm256_hadd_ps(ymm5, ymm5);
3527 _mm256_storeu_ps(scratch, ymm5);
3528 result = scratch[0] + scratch[4];
3529 result *= (*alpha_cast);
3530 if(is_beta_non_zero){
3531 tC[0] = result + tC[0] * (*beta_cast);
3532 }else{
3533 tC[0] = result;
3534 }
3535
3536 ymm8 = _mm256_hadd_ps(ymm8, ymm8);
3537 ymm8 = _mm256_hadd_ps(ymm8, ymm8);
3538 _mm256_storeu_ps(scratch, ymm8);
3539 result = scratch[0] + scratch[4];
3540 result *= (*alpha_cast);
3541 if(is_beta_non_zero){
3542 tC[1] = result + tC[1] * (*beta_cast);
3543 }else{
3544 tC[1] = result;
3545 }
3546
3547 ymm11 = _mm256_hadd_ps(ymm11, ymm11);
3548 ymm11 = _mm256_hadd_ps(ymm11, ymm11);
3549 _mm256_storeu_ps(scratch, ymm11);
3550 result = scratch[0] + scratch[4];
3551 result *= (*alpha_cast);
3552 if(is_beta_non_zero){
3553 tC[2] = result + tC[2] * (*beta_cast);
3554 }else{
3555 tC[2] = result;
3556 }
3557
3558 ymm14 = _mm256_hadd_ps(ymm14, ymm14);
3559 ymm14 = _mm256_hadd_ps(ymm14, ymm14);
3560 _mm256_storeu_ps(scratch, ymm14);
3561 result = scratch[0] + scratch[4];
3562 result *= (*alpha_cast);
3563 if(is_beta_non_zero){
3564 tC[3] = result + tC[3] * (*beta_cast);
3565 }else{
3566 tC[3] = result;
3567 }
3568
3569 tC += ldc;
3570 ymm6 = _mm256_hadd_ps(ymm6, ymm6);
3571 ymm6 = _mm256_hadd_ps(ymm6, ymm6);
3572 _mm256_storeu_ps(scratch, ymm6);
3573 result = scratch[0] + scratch[4];
3574 result *= (*alpha_cast);
3575 if(is_beta_non_zero){
3576 tC[0] = result + tC[0] * (*beta_cast);
3577 }else{
3578 tC[0] = result;
3579 }
3580
3581 ymm9 = _mm256_hadd_ps(ymm9, ymm9);
3582 ymm9 = _mm256_hadd_ps(ymm9, ymm9);
3583 _mm256_storeu_ps(scratch, ymm9);
3584 result = scratch[0] + scratch[4];
3585 result *= (*alpha_cast);
3586 if(is_beta_non_zero){
3587 tC[1] = result + tC[1] * (*beta_cast);
3588 }else{
3589 tC[1] = result;
3590 }
3591
3592 ymm12 = _mm256_hadd_ps(ymm12, ymm12);
3593 ymm12 = _mm256_hadd_ps(ymm12, ymm12);
3594 _mm256_storeu_ps(scratch, ymm12);
3595 result = scratch[0] + scratch[4];
3596 result *= (*alpha_cast);
3597 if(is_beta_non_zero){
3598 tC[2] = result + tC[2] * (*beta_cast);
3599 }else{
3600 tC[2] = result;
3601 }
3602
3603 ymm15 = _mm256_hadd_ps(ymm15, ymm15);
3604 ymm15 = _mm256_hadd_ps(ymm15, ymm15);
3605 _mm256_storeu_ps(scratch, ymm15);
3606 result = scratch[0] + scratch[4];
3607 result *= (*alpha_cast);
3608 if(is_beta_non_zero){
3609 tC[3] = result + tC[3] * (*beta_cast);
3610 }else{
3611 tC[3] = result;
3612 }
3613 }
3614 }
3615
3616 int processed_col = col_idx;
3617 int processed_row = row_idx;
3618
3619 //The edge case handling where N is not a multiple of 3
3620 if (processed_col < N)
3621 {
3622 for (col_idx = processed_col; col_idx < N; col_idx += 1)
3623 {
3624 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3625 {
3626 tA = A + row_idx * lda;
3627 tB = B + col_idx * ldb;
3628 tC = C + col_idx * ldc + row_idx;
3629 // clear scratch registers.
3630 ymm4 = _mm256_setzero_ps();
3631 ymm7 = _mm256_setzero_ps();
3632 ymm10 = _mm256_setzero_ps();
3633 ymm13 = _mm256_setzero_ps();
3634
3635 //The inner loop computes the 4x1 values of the matrix.
3636 //The computation pattern is:
3637 // ymm4
3638 // ymm7
3639 // ymm10
3640 // ymm13
3641
3642 for (k = 0; (k + 7) < K; k += 8)
3643 {
3644 ymm0 = _mm256_loadu_ps(tB + 0);
3645
3646 ymm3 = _mm256_loadu_ps(tA);
3647 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3648
3649 ymm3 = _mm256_loadu_ps(tA + lda);
3650 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3651
3652 ymm3 = _mm256_loadu_ps(tA + 2 * lda);
3653 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3654
3655 ymm3 = _mm256_loadu_ps(tA + 3 * lda);
3656 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3657
3658 tA += 8;
3659 tB += 8;
3660 }
3661
3662 // if K is not a multiple of 8, padding is done before load using temproary array.
3663 if (k < K)
3664 {
3665 int iter;
3666 float data_feeder[8] = { 0.0 };
3667
3668 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3669 ymm0 = _mm256_loadu_ps(data_feeder);
3670
3671 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3672 ymm3 = _mm256_loadu_ps(data_feeder);
3673 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3674
3675 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3676 ymm3 = _mm256_loadu_ps(data_feeder);
3677 ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3678
3679 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3680 ymm3 = _mm256_loadu_ps(data_feeder);
3681 ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3682
3683 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3684 ymm3 = _mm256_loadu_ps(data_feeder);
3685 ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3686
3687 }
3688
3689 //horizontal addition and storage of the data.
3690 //Results for 4x1 blocks of C is stored here
3691 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3692 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3693 _mm256_storeu_ps(scratch, ymm4);
3694 result = scratch[0] + scratch[4];
3695 result *= (*alpha_cast);
3696 if(is_beta_non_zero){
3697 tC[0] = result + tC[0] * (*beta_cast);
3698 }else{
3699 tC[0] = result;
3700 }
3701
3702 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3703 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3704 _mm256_storeu_ps(scratch, ymm7);
3705 result = scratch[0] + scratch[4];
3706 result *= (*alpha_cast);
3707 if(is_beta_non_zero){
3708 tC[1] = result + tC[1] * (*beta_cast);
3709 }else{
3710 tC[1] = result;
3711 }
3712
3713 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3714 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3715 _mm256_storeu_ps(scratch, ymm10);
3716 result = scratch[0] + scratch[4];
3717 result *= (*alpha_cast);
3718 if(is_beta_non_zero){
3719 tC[2] = result + tC[2] * (*beta_cast);
3720 }else{
3721 tC[2] = result;
3722 }
3723
3724 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3725 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3726 _mm256_storeu_ps(scratch, ymm13);
3727 result = scratch[0] + scratch[4];
3728 result *= (*alpha_cast);
3729 if(is_beta_non_zero){
3730 tC[3] = result + tC[3] * (*beta_cast);
3731 }else{
3732 tC[3] = result;
3733 }
3734 }
3735 }
3736 processed_row = row_idx;
3737 }
3738
3739 //The edge case handling where M is not a multiple of 4
3740 if (processed_row < M)
3741 {
3742 for (row_idx = processed_row; row_idx < M; row_idx += 1)
3743 {
3744 for (col_idx = 0; col_idx < N; col_idx += 1)
3745 {
3746 tA = A + row_idx * lda;
3747 tB = B + col_idx * ldb;
3748 tC = C + col_idx * ldc + row_idx;
3749 // clear scratch registers.
3750 ymm4 = _mm256_setzero_ps();
3751
3752 for (k = 0; (k + 7) < K; k += 8)
3753 {
3754 ymm0 = _mm256_loadu_ps(tB + 0);
3755 ymm3 = _mm256_loadu_ps(tA);
3756 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3757
3758 tA += 8;
3759 tB += 8;
3760 }
3761
3762 // if K is not a multiple of 8, padding is done before load using temproary array.
3763 if (k < K)
3764 {
3765 int iter;
3766 float data_feeder[8] = { 0.0 };
3767
3768 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3769 ymm0 = _mm256_loadu_ps(data_feeder);
3770
3771 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3772 ymm3 = _mm256_loadu_ps(data_feeder);
3773 ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3774
3775 }
3776
3777 //horizontal addition and storage of the data.
3778 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3779 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3780 _mm256_storeu_ps(scratch, ymm4);
3781 result = scratch[0] + scratch[4];
3782 result *= (*alpha_cast);
3783 if(is_beta_non_zero){
3784 tC[0] = result + tC[0] * (*beta_cast);
3785 }else{
3786 tC[0] = result;
3787 }
3788
3789 }
3790 }
3791 }
3792 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
3793 return BLIS_SUCCESS;
3794 }
3795 else
3796 {
3797 AOCL_DTL_TRACE_EXIT_ERR(
3798 AOCL_DTL_LEVEL_INFO,
3799 "Invalid dimesions for small gemm."
3800 );
3801 return BLIS_NONCONFORMAL_DIMENSIONS;
3802 }
3803 }
3804
bli_dgemm_small_atbn(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)3805 static err_t bli_dgemm_small_atbn
3806 (
3807 obj_t* alpha,
3808 obj_t* a,
3809 obj_t* b,
3810 obj_t* beta,
3811 obj_t* c,
3812 cntx_t* cntx,
3813 cntl_t* cntl
3814 )
3815 {
3816 AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
3817
3818 gint_t M = bli_obj_length( c ); // number of rows of Matrix C
3819 gint_t N = bli_obj_width( c ); // number of columns of Matrix C
3820 gint_t K = bli_obj_length( b ); // number of rows of Matrix B
3821
3822 // The non-copy version of the A^T GEMM gives better performance for the small M cases.
3823 // The threshold is controlled by BLIS_ATBN_M_THRES
3824 if (M <= BLIS_ATBN_M_THRES)
3825 {
3826 guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
3827 guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
3828 guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
3829 guint_t row_idx = 0, col_idx = 0, k;
3830 double *A = bli_obj_buffer_at_off(a); // pointer to matrix A elements, stored in row major format
3831 double *B = bli_obj_buffer_at_off(b); // pointer to matrix B elements, stored in column major format
3832 double *C = bli_obj_buffer_at_off(c); // pointer to matrix C elements, stored in column major format
3833
3834 double *tA = A, *tB = B, *tC = C;
3835
3836 __m256d ymm4, ymm5, ymm6, ymm7;
3837 __m256d ymm8, ymm9, ymm10, ymm11;
3838 __m256d ymm12, ymm13, ymm14, ymm15;
3839 __m256d ymm0, ymm1, ymm2, ymm3;
3840
3841 double result;
3842 double scratch[8] = {0.0};
3843 double *alpha_cast, *beta_cast; // alpha, beta multiples
3844 alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha);
3845 beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta);
3846
3847 //check if beta is zero
3848 //if true, we need to perform C = alpha * (A * B)
3849 //instead of C = beta * C + alpha * (A * B)
3850 bool is_beta_non_zero = 0;
3851 if(!bli_obj_equals(beta,&BLIS_ZERO))
3852 is_beta_non_zero = 1;
3853
3854 for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
3855 {
3856 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3857 {
3858 tA = A + row_idx * lda;
3859 tB = B + col_idx * ldb;
3860 tC = C + col_idx * ldc + row_idx;
3861 // clear scratch registers.
3862 ymm4 = _mm256_setzero_pd();
3863 ymm5 = _mm256_setzero_pd();
3864 ymm6 = _mm256_setzero_pd();
3865 ymm7 = _mm256_setzero_pd();
3866 ymm8 = _mm256_setzero_pd();
3867 ymm9 = _mm256_setzero_pd();
3868 ymm10 = _mm256_setzero_pd();
3869 ymm11 = _mm256_setzero_pd();
3870 ymm12 = _mm256_setzero_pd();
3871 ymm13 = _mm256_setzero_pd();
3872 ymm14 = _mm256_setzero_pd();
3873 ymm15 = _mm256_setzero_pd();
3874
3875 //The inner loop computes the 4x3 values of the matrix.
3876 //The computation pattern is:
3877 // ymm4 ymm5 ymm6
3878 // ymm7 ymm8 ymm9
3879 // ymm10 ymm11 ymm12
3880 // ymm13 ymm14 ymm15
3881
3882 //The Dot operation is performed in the inner loop, 4 double elements fit
3883 //in the YMM register hence loop count incremented by 4
3884 for (k = 0; (k + 3) < K; k += 4)
3885 {
3886 ymm0 = _mm256_loadu_pd(tB + 0);
3887 ymm1 = _mm256_loadu_pd(tB + ldb);
3888 ymm2 = _mm256_loadu_pd(tB + 2 * ldb);
3889
3890 ymm3 = _mm256_loadu_pd(tA);
3891 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3892 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
3893 ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
3894
3895 ymm3 = _mm256_loadu_pd(tA + lda);
3896 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
3897 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
3898 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3899
3900 ymm3 = _mm256_loadu_pd(tA + 2 * lda);
3901 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
3902 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
3903 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
3904
3905 ymm3 = _mm256_loadu_pd(tA + 3 * lda);
3906 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
3907 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
3908 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
3909
3910 tA += 4;
3911 tB += 4;
3912
3913 }
3914
3915 // if K is not a multiple of 4, padding is done before load using temproary array.
3916 if (k < K)
3917 {
3918 int iter;
3919 double data_feeder[4] = { 0.0 };
3920
3921 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3922 ymm0 = _mm256_loadu_pd(data_feeder);
3923 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
3924 ymm1 = _mm256_loadu_pd(data_feeder);
3925 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
3926 ymm2 = _mm256_loadu_pd(data_feeder);
3927
3928 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3929 ymm3 = _mm256_loadu_pd(data_feeder);
3930 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3931 ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
3932 ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
3933
3934 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3935 ymm3 = _mm256_loadu_pd(data_feeder);
3936 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
3937 ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
3938 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3939
3940 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3941 ymm3 = _mm256_loadu_pd(data_feeder);
3942 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
3943 ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
3944 ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
3945
3946 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3947 ymm3 = _mm256_loadu_pd(data_feeder);
3948 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
3949 ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
3950 ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
3951
3952 }
3953
3954 //horizontal addition and storage of the data.
3955 //Results for 4x3 blocks of C is stored here
3956 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
3957 _mm256_storeu_pd(scratch, ymm4);
3958 result = scratch[0] + scratch[2];
3959 result *= (*alpha_cast);
3960 if(is_beta_non_zero)
3961 tC[0] = result + tC[0] * (*beta_cast);
3962 else
3963 tC[0] = result;
3964
3965 ymm7 = _mm256_hadd_pd(ymm7, ymm7);
3966 _mm256_storeu_pd(scratch, ymm7);
3967 result = scratch[0] + scratch[2];
3968 result *= (*alpha_cast);
3969 if(is_beta_non_zero)
3970 tC[1] = result + tC[1] * (*beta_cast);
3971 else
3972 tC[1] = result;
3973
3974 ymm10 = _mm256_hadd_pd(ymm10, ymm10);
3975 _mm256_storeu_pd(scratch, ymm10);
3976 result = scratch[0] + scratch[2];
3977 result *= (*alpha_cast);
3978 if(is_beta_non_zero)
3979 tC[2] = result + tC[2] * (*beta_cast);
3980 else
3981 tC[2] = result;
3982
3983 ymm13 = _mm256_hadd_pd(ymm13, ymm13);
3984 _mm256_storeu_pd(scratch, ymm13);
3985 result = scratch[0] + scratch[2];
3986 result *= (*alpha_cast);
3987 if(is_beta_non_zero)
3988 tC[3] = result + tC[3] * (*beta_cast);
3989 else
3990 tC[3] = result;
3991
3992 tC += ldc;
3993 ymm5 = _mm256_hadd_pd(ymm5, ymm5);
3994 _mm256_storeu_pd(scratch, ymm5);
3995 result = scratch[0] + scratch[2];
3996 result *= (*alpha_cast);
3997 if(is_beta_non_zero)
3998 tC[0] = result + tC[0] * (*beta_cast);
3999 else
4000 tC[0] = result;
4001
4002 ymm8 = _mm256_hadd_pd(ymm8, ymm8);
4003 _mm256_storeu_pd(scratch, ymm8);
4004 result = scratch[0] + scratch[2];
4005 result *= (*alpha_cast);
4006 if(is_beta_non_zero)
4007 tC[1] = result + tC[1] * (*beta_cast);
4008 else
4009 tC[1] = result;
4010
4011 ymm11 = _mm256_hadd_pd(ymm11, ymm11);
4012 _mm256_storeu_pd(scratch, ymm11);
4013 result = scratch[0] + scratch[2];
4014 result *= (*alpha_cast);
4015 if(is_beta_non_zero)
4016 tC[2] = result + tC[2] * (*beta_cast);
4017 else
4018 tC[2] = result;
4019
4020 ymm14 = _mm256_hadd_pd(ymm14, ymm14);
4021 _mm256_storeu_pd(scratch, ymm14);
4022 result = scratch[0] + scratch[2];
4023 result *= (*alpha_cast);
4024 if(is_beta_non_zero)
4025 tC[3] = result + tC[3] * (*beta_cast);
4026 else
4027 tC[3] = result;
4028
4029 tC += ldc;
4030 ymm6 = _mm256_hadd_pd(ymm6, ymm6);
4031 _mm256_storeu_pd(scratch, ymm6);
4032 result = scratch[0] + scratch[2];
4033 result *= (*alpha_cast);
4034 if(is_beta_non_zero)
4035 tC[0] = result + tC[0] * (*beta_cast);
4036 else
4037 tC[0] = result;
4038
4039 ymm9 = _mm256_hadd_pd(ymm9, ymm9);
4040 _mm256_storeu_pd(scratch, ymm9);
4041 result = scratch[0] + scratch[2];
4042 result *= (*alpha_cast);
4043 if(is_beta_non_zero)
4044 tC[1] = result + tC[1] * (*beta_cast);
4045 else
4046 tC[1] = result;
4047
4048 ymm12 = _mm256_hadd_pd(ymm12, ymm12);
4049 _mm256_storeu_pd(scratch, ymm12);
4050 result = scratch[0] + scratch[2];
4051 result *= (*alpha_cast);
4052 if(is_beta_non_zero)
4053 tC[2] = result + tC[2] * (*beta_cast);
4054 else
4055 tC[2] = result;
4056
4057 ymm15 = _mm256_hadd_pd(ymm15, ymm15);
4058 _mm256_storeu_pd(scratch, ymm15);
4059 result = scratch[0] + scratch[2];
4060 result *= (*alpha_cast);
4061 if(is_beta_non_zero)
4062 tC[3] = result + tC[3] * (*beta_cast);
4063 else
4064 tC[3] = result;
4065 }
4066 }
4067
4068 int processed_col = col_idx;
4069 int processed_row = row_idx;
4070
4071 //The edge case handling where N is not a multiple of 3
4072 if (processed_col < N)
4073 {
4074 for (col_idx = processed_col; col_idx < N; col_idx += 1)
4075 {
4076 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
4077 {
4078 tA = A + row_idx * lda;
4079 tB = B + col_idx * ldb;
4080 tC = C + col_idx * ldc + row_idx;
4081 // clear scratch registers.
4082 ymm4 = _mm256_setzero_pd();
4083 ymm7 = _mm256_setzero_pd();
4084 ymm10 = _mm256_setzero_pd();
4085 ymm13 = _mm256_setzero_pd();
4086
4087 //The inner loop computes the 4x1 values of the matrix.
4088 //The computation pattern is:
4089 // ymm4
4090 // ymm7
4091 // ymm10
4092 // ymm13
4093
4094 for (k = 0; (k + 3) < K; k += 4)
4095 {
4096 ymm0 = _mm256_loadu_pd(tB + 0);
4097
4098 ymm3 = _mm256_loadu_pd(tA);
4099 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4100
4101 ymm3 = _mm256_loadu_pd(tA + lda);
4102 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
4103
4104 ymm3 = _mm256_loadu_pd(tA + 2 * lda);
4105 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
4106
4107 ymm3 = _mm256_loadu_pd(tA + 3 * lda);
4108 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
4109
4110 tA += 4;
4111 tB += 4;
4112 }
4113 // if K is not a multiple of 4, padding is done before load using temproary array.
4114 if (k < K)
4115 {
4116 int iter;
4117 double data_feeder[4] = { 0.0 };
4118
4119 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
4120 ymm0 = _mm256_loadu_pd(data_feeder);
4121
4122 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
4123 ymm3 = _mm256_loadu_pd(data_feeder);
4124 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4125
4126 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
4127 ymm3 = _mm256_loadu_pd(data_feeder);
4128 ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
4129
4130 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
4131 ymm3 = _mm256_loadu_pd(data_feeder);
4132 ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
4133
4134 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
4135 ymm3 = _mm256_loadu_pd(data_feeder);
4136 ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
4137
4138 }
4139
4140 //horizontal addition and storage of the data.
4141 //Results for 4x1 blocks of C is stored here
4142 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
4143 _mm256_storeu_pd(scratch, ymm4);
4144 result = scratch[0] + scratch[2];
4145 result *= (*alpha_cast);
4146 if(is_beta_non_zero)
4147 tC[0] = result + tC[0] * (*beta_cast);
4148 else
4149 tC[0] = result;
4150
4151 ymm7 = _mm256_hadd_pd(ymm7, ymm7);
4152 _mm256_storeu_pd(scratch, ymm7);
4153 result = scratch[0] + scratch[2];
4154 result *= (*alpha_cast);
4155 if(is_beta_non_zero)
4156 tC[1] = result + tC[1] * (*beta_cast);
4157 else
4158 tC[1] = result;
4159
4160 ymm10 = _mm256_hadd_pd(ymm10, ymm10);
4161 _mm256_storeu_pd(scratch, ymm10);
4162 result = scratch[0] + scratch[2];
4163 result *= (*alpha_cast);
4164 if(is_beta_non_zero)
4165 tC[2] = result + tC[2] * (*beta_cast);
4166 else
4167 tC[2] = result;
4168
4169 ymm13 = _mm256_hadd_pd(ymm13, ymm13);
4170 _mm256_storeu_pd(scratch, ymm13);
4171 result = scratch[0] + scratch[2];
4172 result *= (*alpha_cast);
4173 if(is_beta_non_zero)
4174 tC[3] = result + tC[3] * (*beta_cast);
4175 else
4176 tC[3] = result;
4177 }
4178 }
4179 processed_row = row_idx;
4180 }
4181
4182 // The edge case handling where M is not a multiple of 4
4183 if (processed_row < M)
4184 {
4185 for (row_idx = processed_row; row_idx < M; row_idx += 1)
4186 {
4187 for (col_idx = 0; col_idx < N; col_idx += 1)
4188 {
4189 tA = A + row_idx * lda;
4190 tB = B + col_idx * ldb;
4191 tC = C + col_idx * ldc + row_idx;
4192 // clear scratch registers.
4193 ymm4 = _mm256_setzero_pd();
4194
4195 for (k = 0; (k + 3) < K; k += 4)
4196 {
4197 ymm0 = _mm256_loadu_pd(tB + 0);
4198 ymm3 = _mm256_loadu_pd(tA);
4199 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4200
4201 tA += 4;
4202 tB += 4;
4203 }
4204
4205 // if K is not a multiple of 4, padding is done before load using temproary array.
4206 if (k < K)
4207 {
4208 int iter;
4209 double data_feeder[4] = { 0.0 };
4210
4211 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
4212 ymm0 = _mm256_loadu_pd(data_feeder);
4213
4214 for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
4215 ymm3 = _mm256_loadu_pd(data_feeder);
4216 ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4217
4218 }
4219
4220 //horizontal addition and storage of the data.
4221 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
4222 _mm256_storeu_pd(scratch, ymm4);
4223 result = scratch[0] + scratch[2];
4224 result *= (*alpha_cast);
4225 if(is_beta_non_zero)
4226 tC[0] = result + tC[0] * (*beta_cast);
4227 else
4228 tC[0] = result;
4229 }
4230 }
4231 }
4232 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
4233 return BLIS_SUCCESS;
4234 }
4235 else
4236 {
4237 AOCL_DTL_TRACE_EXIT_ERR(
4238 AOCL_DTL_LEVEL_INFO,
4239 "Invalid dimesions for small gemm."
4240 );
4241 return BLIS_NONCONFORMAL_DIMENSIONS;
4242 }
4243 }
4244 #endif
4245
4246