1 /*
2 
3    BLIS
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
6 
7    Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc.
8 
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions are
11    met:
12     - Redistributions of source code must retain the above copyright
13       notice, this list of conditions and the following disclaimer.
14     - Redistributions in binary form must reproduce the above copyright
15       notice, this list of conditions and the following disclaimer in the
16       documentation and/or other materials provided with the distribution.
17     - Neither the name(s) of the copyright holder(s) nor the names of its
18       contributors may be used to endorse or promote products derived
19       from this software without specific prior written permission.
20 
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 
33 */
34 
35 #include "immintrin.h"
36 #include "xmmintrin.h"
37 #include "blis.h"
38 
39 #define AOCL_DTL_TRACE_ENTRY(x)      ;
40 #define AOCL_DTL_TRACE_EXIT(x)       ;
41 #define AOCL_DTL_TRACE_EXIT_ERR(x,y) ;
42 
43 #ifdef BLIS_ENABLE_SMALL_MATRIX
44 
45 #define MR 32
46 #define D_MR (MR >> 1)
47 #define NR 3
48 #define D_BLIS_SMALL_MATRIX_K_THRES_ROME    256
49 
50 #define BLIS_ENABLE_PREFETCH
51 #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 )
52 #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2)
53 #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2)
54 #define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called.
55 #define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR).
56 static err_t bli_sgemm_small
57      (
58        obj_t*  alpha,
59        obj_t*  a,
60        obj_t*  b,
61        obj_t*  beta,
62        obj_t*  c,
63        cntx_t* cntx,
64        cntl_t* cntl
65      );
66 
67 static err_t bli_dgemm_small
68      (
69        obj_t*  alpha,
70        obj_t*  a,
71        obj_t*  b,
72        obj_t*  beta,
73        obj_t*  c,
74        cntx_t* cntx,
75        cntl_t* cntl
76      );
77 
78 static err_t bli_sgemm_small_atbn
79      (
80        obj_t*  alpha,
81        obj_t*  a,
82        obj_t*  b,
83        obj_t*  beta,
84        obj_t*  c,
85        cntx_t* cntx,
86        cntl_t* cntl
87      );
88 
89 static err_t bli_dgemm_small_atbn
90      (
91        obj_t*  alpha,
92        obj_t*  a,
93        obj_t*  b,
94        obj_t*  beta,
95        obj_t*  c,
96        cntx_t* cntx,
97        cntl_t* cntl
98      );
99 /*
100 * The bli_gemm_small function will use the
101 * custom MRxNR kernels, to perform the computation.
102 * The custom kernels are used if the [M * N] < 240 * 240
103 */
bli_gemm_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)104 err_t bli_gemm_small
105      (
106        obj_t*  alpha,
107        obj_t*  a,
108        obj_t*  b,
109        obj_t*  beta,
110        obj_t*  c,
111        cntx_t* cntx,
112        cntl_t* cntl
113      )
114 {
115 	AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
116 
117 #ifdef BLIS_ENABLE_MULTITHREADING
118 	AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
119     return BLIS_NOT_YET_IMPLEMENTED;
120 #endif
121     // If alpha is zero, scale by beta and return.
122     if (bli_obj_equals(alpha, &BLIS_ZERO))
123     {
124         return BLIS_NOT_YET_IMPLEMENTED;
125     }
126 
127     // if row major format return.
128     if ((bli_obj_row_stride( a ) != 1) ||
129         (bli_obj_row_stride( b ) != 1) ||
130         (bli_obj_row_stride( c ) != 1))
131     {
132         return BLIS_INVALID_ROW_STRIDE;
133     }
134 
135     num_t dt = bli_obj_dt(c);
136 
137     if (bli_obj_has_trans( a ))
138     {
139         if (bli_obj_has_notrans( b ))
140         {
141             if (dt == BLIS_FLOAT)
142             {
143                 return bli_sgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl);
144             }
145             else if (dt == BLIS_DOUBLE)
146             {
147                 return bli_dgemm_small_atbn(alpha, a, b, beta, c, cntx, cntl);
148             }
149         }
150 
151         return BLIS_NOT_YET_IMPLEMENTED;
152     }
153 
154     if (dt == BLIS_DOUBLE)
155     {
156         return bli_dgemm_small(alpha, a, b, beta, c, cntx, cntl);
157     }
158 
159     if (dt == BLIS_FLOAT)
160     {
161         return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl);
162     }
163 
164 	AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
165     return BLIS_NOT_YET_IMPLEMENTED;
166 };
167 
168 
bli_sgemm_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)169 static err_t bli_sgemm_small
170      (
171        obj_t*  alpha,
172        obj_t*  a,
173        obj_t*  b,
174        obj_t*  beta,
175        obj_t*  c,
176        cntx_t* cntx,
177        cntl_t* cntl
178      )
179 {
180 	AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7);
181     gint_t M = bli_obj_length( c ); // number of rows of Matrix C
182     gint_t N = bli_obj_width( c );  // number of columns of Matrix C
183     gint_t K = bli_obj_width( a );  // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
184     gint_t L = M * N;
185 
186 	// when N is equal to 1 call GEMV instead of GEMM
187     if (N == 1)
188     {
189         bli_gemv
190         (
191             alpha,
192             a,
193             b,
194             beta,
195             c
196         );
197 		AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
198         return BLIS_SUCCESS;
199     }
200 
201 
202     if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES))
203         || ((M  < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
204     {
205         guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
206         guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
207         guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
208         guint_t row_idx, col_idx, k;
209 
210         float *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A
211         float *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B
212         float *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C
213 
214         float *tA = A, *tB = B, *tC = C;//, *tA_pack;
215         float *tA_packed; // temporary pointer to hold packed A memory pointer
216 
217         guint_t row_idx_packed; //packed A memory row index
218         guint_t lda_packed; //lda of packed A
219         guint_t col_idx_start; //starting index after A matrix is packed.
220         dim_t tb_inc_row = 1; // row stride of matrix B
221         dim_t tb_inc_col = ldb; // column stride of matrix B
222 
223 	    __m256 ymm4, ymm5, ymm6, ymm7;
224         __m256 ymm8, ymm9, ymm10, ymm11;
225         __m256 ymm12, ymm13, ymm14, ymm15;
226         __m256 ymm0, ymm1, ymm2, ymm3;
227 
228         gint_t n_remainder; // If the N is non multiple of 3.(N%3)
229         gint_t m_remainder; // If the M is non multiple of 32.(M%32)
230         gint_t required_packing_A = 1;
231         mem_t local_mem_buf_A_s;
232         float *A_pack = NULL;
233         rntm_t rntm;
234 
235         const num_t    dt_exec   = bli_obj_dt( c );
236         float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha );
237         float* restrict beta_cast  = bli_obj_buffer_for_1x1( dt_exec, beta );
238 
239         /*Beta Zero Check*/
240         bool is_beta_non_zero=0;
241         if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
242             is_beta_non_zero = 1;
243         }
244 
245 	//update the pointer math if matrix B needs to be transposed.
246         if (bli_obj_has_trans( b )) {
247             tb_inc_col = 1; //switch row and column strides
248             tb_inc_row = ldb;
249         }
250 
251         /*
252          * This function was using global array to pack part of A input when needed.
253          * However, using this global array make the function non-reentrant.
254          * Instead of using a global array we should allocate buffer for each invocation.
255          * Since the buffer size is too big or stack and doing malloc every time will be too expensive,
256          * better approach is to get the buffer from the pre-allocated pool and return
257          * it the pool once we are doing.
258          *
259          * In order to get the buffer from pool, we need access to memory broker,
260          * currently this function is not invoked in such a way that it can receive
261          * the memory broker (via rntm). Following hack will get the global memory
262          * broker that can be use it to access the pool.
263          *
264          * Note there will be memory allocation at least on first innovation
265          * as there will not be any pool created for this size.
266          * Subsequent invocations will just reuse the buffer from the pool.
267          */
268 
269         bli_rntm_init_from_global( &rntm );
270         bli_rntm_set_num_threads_only( 1, &rntm );
271         bli_membrk_rntm_set_membrk( &rntm );
272 
273         // Get the current size of the buffer pool for A block packing.
274         // We will use the same size to avoid pool re-initialization
275         siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
276                                                 bli_rntm_membrk(&rntm)));
277 
278         // Based on the available memory in the buffer we will decide if
279         // we want to do packing or not.
280         //
281         // This kernel assumes that "A" will be un-packged if N <= 3.
282         // Usually this range (N <= 3) is handled by SUP, however,
283         // if SUP is disabled or for any other condition if we do
284         // enter this kernel with N <= 3, we want to make sure that
285         // "A" remains unpacked.
286         //
287         // If this check is removed it will result in the crash as
288         // reported in CPUPL-587.
289         //
290 
291         if ((N <= 3) || (((MR * K) << 2) > buffer_size))
292         {
293             required_packing_A = 0;
294         }
295         else
296         {
297 #ifdef BLIS_ENABLE_MEM_TRACING
298             printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size);
299 #endif
300             // Get the buffer from the pool, if there is no pool with
301             // required size, it will be created.
302             bli_membrk_acquire_m(&rntm,
303                                  buffer_size,
304                                  BLIS_BITVAL_BUFFER_FOR_A_BLOCK,
305                                  &local_mem_buf_A_s);
306 
307             A_pack = bli_mem_buffer(&local_mem_buf_A_s);
308         }
309 
310         /*
311         * The computation loop runs for MRxN columns of C matrix, thus
312         * accessing the MRxK A matrix data and KxNR B matrix data.
313         * The computation is organized as inner loops of dimension MRxNR.
314         */
315         // Process MR rows of C matrix at a time.
316         for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR)
317         {
318             col_idx_start = 0;
319             tA_packed = A;
320             row_idx_packed = row_idx;
321             lda_packed = lda;
322 
323             // This is the part of the pack and compute optimization.
324             // During the first column iteration, we store the accessed A matrix into
325             // contiguous static memory. This helps to keep te A matrix in Cache and
326             // aviods the TLB misses.
327             if (required_packing_A)
328             {
329                 col_idx = 0;
330 
331                 //pointer math to point to proper memory
332                 tC = C + ldc * col_idx + row_idx;
333                 tB = B + tb_inc_col * col_idx;
334                 tA = A + row_idx;
335                 tA_packed = A_pack;
336 
337 #ifdef BLIS_ENABLE_PREFETCH
338                 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
339                 _mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
340                 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
341                 _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
342                 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
343                 _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
344 #endif
345                 // clear scratch registers.
346                 ymm4 = _mm256_setzero_ps();
347                 ymm5 = _mm256_setzero_ps();
348                 ymm6 = _mm256_setzero_ps();
349                 ymm7 = _mm256_setzero_ps();
350                 ymm8 = _mm256_setzero_ps();
351                 ymm9 = _mm256_setzero_ps();
352                 ymm10 = _mm256_setzero_ps();
353                 ymm11 = _mm256_setzero_ps();
354                 ymm12 = _mm256_setzero_ps();
355                 ymm13 = _mm256_setzero_ps();
356                 ymm14 = _mm256_setzero_ps();
357                 ymm15 = _mm256_setzero_ps();
358 
359                 for (k = 0; k < K; ++k)
360                 {
361                     // The inner loop broadcasts the B matrix data and
362                     // multiplies it with the A matrix.
363                     // This loop is processing MR x K
364                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
365                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
366                     ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
367                     tB += tb_inc_row;
368 
369                     //broadcasted matrix B elements are multiplied
370                     //with matrix A columns.
371                     ymm3 = _mm256_loadu_ps(tA);
372                     _mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A
373                                                        //                   ymm4 += ymm0 * ymm3;
374                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
375                     //                    ymm8 += ymm1 * ymm3;
376                     ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
377                     //                    ymm12 += ymm2 * ymm3;
378                     ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
379 
380                     ymm3 = _mm256_loadu_ps(tA + 8);
381                     _mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A
382                                                            //                    ymm5 += ymm0 * ymm3;
383                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
384                     //                    ymm9 += ymm1 * ymm3;
385                     ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
386                     //                    ymm13 += ymm2 * ymm3;
387                     ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
388 
389                     ymm3 = _mm256_loadu_ps(tA + 16);
390                     _mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A
391                                                             //                   ymm6 += ymm0 * ymm3;
392                     ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
393                     //                    ymm10 += ymm1 * ymm3;
394                     ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
395                     //                    ymm14 += ymm2 * ymm3;
396                     ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
397 
398                     ymm3 = _mm256_loadu_ps(tA + 24);
399                     _mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A
400                                                             //                    ymm7 += ymm0 * ymm3;
401                     ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
402                     //                    ymm11 += ymm1 * ymm3;
403                     ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
404                     //                   ymm15 += ymm2 * ymm3;
405                     ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
406 
407                     tA += lda;
408                     tA_packed += MR;
409                 }
410                 // alpha, beta multiplication.
411                 ymm0 = _mm256_broadcast_ss(alpha_cast);
412 
413                 //multiply A*B by alpha.
414                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
415                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
416                 ymm6 = _mm256_mul_ps(ymm6, ymm0);
417                 ymm7 = _mm256_mul_ps(ymm7, ymm0);
418                 ymm8 = _mm256_mul_ps(ymm8, ymm0);
419                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
420                 ymm10 = _mm256_mul_ps(ymm10, ymm0);
421                 ymm11 = _mm256_mul_ps(ymm11, ymm0);
422                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
423                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
424                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
425                 ymm15 = _mm256_mul_ps(ymm15, ymm0);
426 
427                 if(is_beta_non_zero)
428                 {
429                     ymm1 = _mm256_broadcast_ss(beta_cast);
430                     // multiply C by beta and accumulate col 1.
431                     ymm2 = _mm256_loadu_ps(tC);
432                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
433                     ymm2 = _mm256_loadu_ps(tC + 8);
434                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
435                     ymm2 = _mm256_loadu_ps(tC + 16);
436                     ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
437                     ymm2 = _mm256_loadu_ps(tC + 24);
438                     ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
439 
440                     float* ttC = tC +ldc;
441                     ymm2 = _mm256_loadu_ps(ttC);
442                     ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
443                     ymm2 = _mm256_loadu_ps(ttC + 8);
444                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
445                     ymm2 = _mm256_loadu_ps(ttC + 16);
446                     ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
447                     ymm2 = _mm256_loadu_ps(ttC + 24);
448                     ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
449 
450                     ttC += ldc;
451                     ymm2 = _mm256_loadu_ps(ttC);
452                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
453                     ymm2 = _mm256_loadu_ps(ttC + 8);
454                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
455                     ymm2 = _mm256_loadu_ps(ttC + 16);
456                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
457                     ymm2 = _mm256_loadu_ps(ttC + 24);
458                     ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
459                 }
460                 _mm256_storeu_ps(tC, ymm4);
461                 _mm256_storeu_ps(tC + 8, ymm5);
462                 _mm256_storeu_ps(tC + 16, ymm6);
463                 _mm256_storeu_ps(tC + 24, ymm7);
464 
465                 // multiply C by beta and accumulate, col 2.
466                 tC += ldc;
467                 _mm256_storeu_ps(tC, ymm8);
468                 _mm256_storeu_ps(tC + 8, ymm9);
469                 _mm256_storeu_ps(tC + 16, ymm10);
470                 _mm256_storeu_ps(tC + 24, ymm11);
471 
472                 // multiply C by beta and accumulate, col 3.
473                 tC += ldc;
474                 _mm256_storeu_ps(tC, ymm12);
475                 _mm256_storeu_ps(tC + 8, ymm13);
476                 _mm256_storeu_ps(tC + 16, ymm14);
477                 _mm256_storeu_ps(tC + 24, ymm15);
478 
479                 // modify the pointer arithematic to use packed A matrix.
480                 col_idx_start = NR;
481                 tA_packed = A_pack;
482                 row_idx_packed = 0;
483                 lda_packed = MR;
484             }
485             // Process NR columns of C matrix at a time.
486             for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
487             {
488                 //pointer math to point to proper memory
489                 tC = C + ldc * col_idx + row_idx;
490                 tB = B + tb_inc_col * col_idx;
491                 tA = tA_packed + row_idx_packed;
492 
493 #ifdef BLIS_ENABLE_PREFETCH
494                 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
495                 _mm_prefetch((char*)(tC + 16), _MM_HINT_T0);
496                 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
497                 _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0);
498                 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
499                 _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0);
500 #endif
501                 // clear scratch registers.
502                 ymm4 = _mm256_setzero_ps();
503                 ymm5 = _mm256_setzero_ps();
504                 ymm6 = _mm256_setzero_ps();
505                 ymm7 = _mm256_setzero_ps();
506                 ymm8 = _mm256_setzero_ps();
507                 ymm9 = _mm256_setzero_ps();
508                 ymm10 = _mm256_setzero_ps();
509                 ymm11 = _mm256_setzero_ps();
510                 ymm12 = _mm256_setzero_ps();
511                 ymm13 = _mm256_setzero_ps();
512                 ymm14 = _mm256_setzero_ps();
513                 ymm15 = _mm256_setzero_ps();
514 
515                 for (k = 0; k < K; ++k)
516                 {
517                     // The inner loop broadcasts the B matrix data and
518                     // multiplies it with the A matrix.
519                     // This loop is processing MR x K
520                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
521                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
522                     ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
523                     tB += tb_inc_row;
524 
525                     //broadcasted matrix B elements are multiplied
526                     //with matrix A columns.
527                     ymm3 = _mm256_loadu_ps(tA);
528                     //                   ymm4 += ymm0 * ymm3;
529                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
530                     //                    ymm8 += ymm1 * ymm3;
531                     ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
532                     //                    ymm12 += ymm2 * ymm3;
533                     ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
534 
535                     ymm3 = _mm256_loadu_ps(tA + 8);
536                     //                    ymm5 += ymm0 * ymm3;
537                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
538                     //                    ymm9 += ymm1 * ymm3;
539                     ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
540                     //                    ymm13 += ymm2 * ymm3;
541                     ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
542 
543                     ymm3 = _mm256_loadu_ps(tA + 16);
544                     //                   ymm6 += ymm0 * ymm3;
545                     ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
546                     //                    ymm10 += ymm1 * ymm3;
547                     ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
548                     //                    ymm14 += ymm2 * ymm3;
549                     ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
550 
551                     ymm3 = _mm256_loadu_ps(tA + 24);
552                     //                    ymm7 += ymm0 * ymm3;
553                     ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
554                     //                    ymm11 += ymm1 * ymm3;
555                     ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
556                     //                   ymm15 += ymm2 * ymm3;
557                     ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
558 
559                     tA += lda_packed;
560                 }
561                 // alpha, beta multiplication.
562                 ymm0 = _mm256_broadcast_ss(alpha_cast);
563 
564                 //multiply A*B by alpha.
565                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
566                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
567                 ymm6 = _mm256_mul_ps(ymm6, ymm0);
568                 ymm7 = _mm256_mul_ps(ymm7, ymm0);
569                 ymm8 = _mm256_mul_ps(ymm8, ymm0);
570                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
571                 ymm10 = _mm256_mul_ps(ymm10, ymm0);
572                 ymm11 = _mm256_mul_ps(ymm11, ymm0);
573                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
574                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
575                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
576                 ymm15 = _mm256_mul_ps(ymm15, ymm0);
577 
578                 if(is_beta_non_zero)
579                 {
580                     ymm1 = _mm256_broadcast_ss(beta_cast);
581                     // multiply C by beta and accumulate col 1.
582                     ymm2 = _mm256_loadu_ps(tC);
583                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
584                     ymm2 = _mm256_loadu_ps(tC + 8);
585                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
586                     ymm2 = _mm256_loadu_ps(tC + 16);
587                     ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
588                     ymm2 = _mm256_loadu_ps(tC + 24);
589                     ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
590                     float* ttC = tC +ldc;
591                     ymm2 = _mm256_loadu_ps(ttC);
592                     ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
593                     ymm2 = _mm256_loadu_ps(ttC + 8);
594                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
595                     ymm2 = _mm256_loadu_ps(ttC + 16);
596                     ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
597                     ymm2 = _mm256_loadu_ps(ttC + 24);
598                     ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
599                     ttC = ttC +ldc;
600                     ymm2 = _mm256_loadu_ps(ttC);
601                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
602                     ymm2 = _mm256_loadu_ps(ttC + 8);
603                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
604                     ymm2 = _mm256_loadu_ps(ttC + 16);
605                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
606                     ymm2 = _mm256_loadu_ps(ttC + 24);
607                     ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
608                 }
609                 _mm256_storeu_ps(tC, ymm4);
610                 _mm256_storeu_ps(tC + 8, ymm5);
611                 _mm256_storeu_ps(tC + 16, ymm6);
612                 _mm256_storeu_ps(tC + 24, ymm7);
613 
614                 // multiply C by beta and accumulate, col 2.
615                 tC += ldc;
616                 _mm256_storeu_ps(tC, ymm8);
617                 _mm256_storeu_ps(tC + 8, ymm9);
618                 _mm256_storeu_ps(tC + 16, ymm10);
619                 _mm256_storeu_ps(tC + 24, ymm11);
620 
621                 // multiply C by beta and accumulate, col 3.
622                 tC += ldc;
623                 _mm256_storeu_ps(tC, ymm12);
624                 _mm256_storeu_ps(tC + 8, ymm13);
625                 _mm256_storeu_ps(tC + 16, ymm14);
626                 _mm256_storeu_ps(tC + 24, ymm15);
627 
628             }
629             n_remainder = N - col_idx;
630 
631             // if the N is not multiple of 3.
632             // handling edge case.
633             if (n_remainder == 2)
634             {
635                 //pointer math to point to proper memory
636                 tC = C + ldc * col_idx + row_idx;
637                 tB = B + tb_inc_col * col_idx;
638                 tA = A + row_idx;
639 
640                 // clear scratch registers.
641                 ymm8 = _mm256_setzero_ps();
642                 ymm9 = _mm256_setzero_ps();
643                 ymm10 = _mm256_setzero_ps();
644                 ymm11 = _mm256_setzero_ps();
645                 ymm12 = _mm256_setzero_ps();
646                 ymm13 = _mm256_setzero_ps();
647                 ymm14 = _mm256_setzero_ps();
648                 ymm15 = _mm256_setzero_ps();
649 
650                 for (k = 0; k < K; ++k)
651                 {
652                     // The inner loop broadcasts the B matrix data and
653                     // multiplies it with the A matrix.
654                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
655                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
656                     tB += tb_inc_row;
657 
658                     //broadcasted matrix B elements are multiplied
659                     //with matrix A columns.
660                     ymm3 = _mm256_loadu_ps(tA);
661                     ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
662                     ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
663 
664                     ymm3 = _mm256_loadu_ps(tA + 8);
665                     ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
666                     ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
667 
668                     ymm3 = _mm256_loadu_ps(tA + 16);
669                     ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
670                     ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
671 
672                     ymm3 = _mm256_loadu_ps(tA + 24);
673                     ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);
674                     ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15);
675 
676                     tA += lda;
677 
678                 }
679                 // alpha, beta multiplication.
680                 ymm0 = _mm256_broadcast_ss(alpha_cast);
681 
682                 //multiply A*B by alpha.
683                 ymm8 = _mm256_mul_ps(ymm8, ymm0);
684                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
685                 ymm10 = _mm256_mul_ps(ymm10, ymm0);
686                 ymm11 = _mm256_mul_ps(ymm11, ymm0);
687                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
688                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
689                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
690                 ymm15 = _mm256_mul_ps(ymm15, ymm0);
691 
692                 // multiply C by beta and accumulate, col 1.
693                 if(is_beta_non_zero)
694                 {
695                     ymm1 = _mm256_broadcast_ss(beta_cast);
696                     ymm2 = _mm256_loadu_ps(tC);
697                     ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
698                     ymm2 = _mm256_loadu_ps(tC + 8);
699                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
700                     ymm2 = _mm256_loadu_ps(tC + 16);
701                     ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
702                     ymm2 = _mm256_loadu_ps(tC + 24);
703                     ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);
704 
705                     float* ttC = tC +ldc;
706                     // multiply C by beta and accumulate, col 2.
707                     ymm2 = _mm256_loadu_ps(ttC);
708                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
709                     ymm2 = _mm256_loadu_ps(ttC + 8);
710                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
711                     ymm2 = _mm256_loadu_ps(ttC + 16);
712                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
713                     ymm2 = _mm256_loadu_ps(ttC + 24);
714                     ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
715                 }
716                 _mm256_storeu_ps(tC, ymm8);
717                 _mm256_storeu_ps(tC + 8, ymm9);
718                 _mm256_storeu_ps(tC + 16, ymm10);
719                 _mm256_storeu_ps(tC + 24, ymm11);
720                 tC += ldc;
721                 _mm256_storeu_ps(tC, ymm12);
722                 _mm256_storeu_ps(tC + 8, ymm13);
723                 _mm256_storeu_ps(tC + 16, ymm14);
724                 _mm256_storeu_ps(tC + 24, ymm15);
725 
726                 col_idx += 2;
727             }
728             // if the N is not multiple of 3.
729             // handling edge case.
730             if (n_remainder == 1)
731             {
732                 //pointer math to point to proper memory
733                 tC = C + ldc * col_idx + row_idx;
734                 tB = B + tb_inc_col * col_idx;
735                 tA = A + row_idx;
736 
737                 // clear scratch registers.
738                 ymm12 = _mm256_setzero_ps();
739                 ymm13 = _mm256_setzero_ps();
740                 ymm14 = _mm256_setzero_ps();
741                 ymm15 = _mm256_setzero_ps();
742 
743                 for (k = 0; k < K; ++k)
744                 {
745                     // The inner loop broadcasts the B matrix data and
746                     // multiplies it with the A matrix.
747                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
748                     tB += tb_inc_row;
749 
750                     //broadcasted matrix B elements are multiplied
751                     //with matrix A columns.
752                     ymm3 = _mm256_loadu_ps(tA);
753                     ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
754 
755                     ymm3 = _mm256_loadu_ps(tA + 8);
756                     ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
757 
758                     ymm3 = _mm256_loadu_ps(tA + 16);
759                     ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
760 
761                     ymm3 = _mm256_loadu_ps(tA + 24);
762                     ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15);
763 
764                     tA += lda;
765 
766                 }
767                 // alpha, beta multiplication.
768                 ymm0 = _mm256_broadcast_ss(alpha_cast);
769 
770                 //multiply A*B by alpha.
771                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
772                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
773                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
774                 ymm15 = _mm256_mul_ps(ymm15, ymm0);
775 
776                 if(is_beta_non_zero)
777                 {
778                     ymm1 = _mm256_broadcast_ss(beta_cast);
779                     // multiply C by beta and accumulate.
780                     ymm2 = _mm256_loadu_ps(tC + 0);
781                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
782                     ymm2 = _mm256_loadu_ps(tC + 8);
783                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
784                     ymm2 = _mm256_loadu_ps(tC + 16);
785                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
786                     ymm2 = _mm256_loadu_ps(tC + 24);
787                     ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);
788                 }
789 
790                 _mm256_storeu_ps(tC + 0, ymm12);
791                 _mm256_storeu_ps(tC + 8, ymm13);
792                 _mm256_storeu_ps(tC + 16, ymm14);
793                 _mm256_storeu_ps(tC + 24, ymm15);
794             }
795         }
796 
797         m_remainder = M - row_idx;
798 
799         if (m_remainder >= 24)
800         {
801             m_remainder -= 24;
802 
803             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
804             {
805                 //pointer math to point to proper memory
806                 tC = C + ldc * col_idx + row_idx;
807                 tB = B + tb_inc_col * col_idx;
808                 tA = A + row_idx;
809 
810                 // clear scratch registers.
811                 ymm4 = _mm256_setzero_ps();
812                 ymm5 = _mm256_setzero_ps();
813                 ymm6 = _mm256_setzero_ps();
814                 ymm8 = _mm256_setzero_ps();
815                 ymm9 = _mm256_setzero_ps();
816                 ymm10 = _mm256_setzero_ps();
817                 ymm12 = _mm256_setzero_ps();
818                 ymm13 = _mm256_setzero_ps();
819                 ymm14 = _mm256_setzero_ps();
820 
821                 for (k = 0; k < K; ++k)
822                 {
823                     // The inner loop broadcasts the B matrix data and
824                     // multiplies it with the A matrix.
825                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
826                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
827                     ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
828                     tB += tb_inc_row;
829 
830                     //broadcasted matrix B elements are multiplied
831                     //with matrix A columns.
832                     ymm3 = _mm256_loadu_ps(tA);
833                     //                   ymm4 += ymm0 * ymm3;
834                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
835                     //                    ymm8 += ymm1 * ymm3;
836                     ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
837                     //                    ymm12 += ymm2 * ymm3;
838                     ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
839 
840                     ymm3 = _mm256_loadu_ps(tA + 8);
841                     //                    ymm5 += ymm0 * ymm3;
842                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
843                     //                    ymm9 += ymm1 * ymm3;
844                     ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9);
845                     //                    ymm13 += ymm2 * ymm3;
846                     ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13);
847 
848                     ymm3 = _mm256_loadu_ps(tA + 16);
849                     //                   ymm6 += ymm0 * ymm3;
850                     ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);
851                     //                    ymm10 += ymm1 * ymm3;
852                     ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10);
853                     //                    ymm14 += ymm2 * ymm3;
854                     ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14);
855 
856                     tA += lda;
857                 }
858                 // alpha, beta multiplication.
859                 ymm0 = _mm256_broadcast_ss(alpha_cast);
860 
861                 //multiply A*B by alpha.
862                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
863                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
864                 ymm6 = _mm256_mul_ps(ymm6, ymm0);
865                 ymm8 = _mm256_mul_ps(ymm8, ymm0);
866                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
867                 ymm10 = _mm256_mul_ps(ymm10, ymm0);
868                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
869                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
870                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
871 
872                 if(is_beta_non_zero)
873                 {
874                     ymm1 = _mm256_broadcast_ss(beta_cast);
875                     // multiply C by beta and accumulate.
876                     ymm2 = _mm256_loadu_ps(tC);
877                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
878                     ymm2 = _mm256_loadu_ps(tC + 8);
879                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
880                     ymm2 = _mm256_loadu_ps(tC + 16);
881                     ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
882                     float* ttC = tC +ldc;
883                     ymm2 = _mm256_loadu_ps(ttC);
884                     ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
885                     ymm2 = _mm256_loadu_ps(ttC + 8);
886                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
887                     ymm2 = _mm256_loadu_ps(ttC + 16);
888                     ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
889                     ttC += ldc;
890                     ymm2 = _mm256_loadu_ps(ttC);
891                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
892                     ymm2 = _mm256_loadu_ps(ttC + 8);
893                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
894                     ymm2 = _mm256_loadu_ps(ttC + 16);
895                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
896                 }
897                 _mm256_storeu_ps(tC, ymm4);
898                 _mm256_storeu_ps(tC + 8, ymm5);
899                 _mm256_storeu_ps(tC + 16, ymm6);
900 
901                 // multiply C by beta and accumulate.
902                 tC += ldc;
903                 _mm256_storeu_ps(tC, ymm8);
904                 _mm256_storeu_ps(tC + 8, ymm9);
905                 _mm256_storeu_ps(tC + 16, ymm10);
906 
907                 // multiply C by beta and accumulate.
908                 tC += ldc;
909                 _mm256_storeu_ps(tC, ymm12);
910                 _mm256_storeu_ps(tC + 8, ymm13);
911                 _mm256_storeu_ps(tC + 16, ymm14);
912 
913             }
914             n_remainder = N - col_idx;
915             // if the N is not multiple of 3.
916             // handling edge case.
917             if (n_remainder == 2)
918             {
919                 //pointer math to point to proper memory
920                 tC = C + ldc * col_idx + row_idx;
921                 tB = B + tb_inc_col * col_idx;
922                 tA = A + row_idx;
923 
924                 // clear scratch registers.
925                 ymm8 = _mm256_setzero_ps();
926                 ymm9 = _mm256_setzero_ps();
927                 ymm10 = _mm256_setzero_ps();
928                 ymm12 = _mm256_setzero_ps();
929                 ymm13 = _mm256_setzero_ps();
930                 ymm14 = _mm256_setzero_ps();
931 
932                 for (k = 0; k < K; ++k)
933                 {
934                     // The inner loop broadcasts the B matrix data and
935                     // multiplies it with the A matrix.
936                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
937                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
938                     tB += tb_inc_row;
939 
940                     //broadcasted matrix B elements are multiplied
941                     //with matrix A columns.
942                     ymm3 = _mm256_loadu_ps(tA);
943                     ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8);
944                     ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12);
945 
946                     ymm3 = _mm256_loadu_ps(tA + 8);
947                     ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9);
948                     ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13);
949 
950                     ymm3 = _mm256_loadu_ps(tA + 16);
951                     ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
952                     ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
953 
954                     tA += lda;
955 
956                 }
957                 // alpha, beta multiplication.
958                 ymm0 = _mm256_broadcast_ss(alpha_cast);
959 
960                 //multiply A*B by alpha.
961                 ymm8 = _mm256_mul_ps(ymm8, ymm0);
962                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
963                 ymm10 = _mm256_mul_ps(ymm10, ymm0);
964                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
965                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
966                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
967 
968                 if(is_beta_non_zero)
969                 {
970                     ymm1 = _mm256_broadcast_ss(beta_cast);
971                     // multiply C by beta and accumulate.
972                     ymm2 = _mm256_loadu_ps(tC);
973                     ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
974                     ymm2 = _mm256_loadu_ps(tC + 8);
975                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
976                     ymm2 = _mm256_loadu_ps(tC + 16);
977                     ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);
978 
979                     float* ttC = tC +ldc;
980                     // multiply C by beta and accumulate.
981                     ymm2 = _mm256_loadu_ps(ttC);
982                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
983                     ymm2 = _mm256_loadu_ps(ttC + 8);
984                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
985                     ymm2 = _mm256_loadu_ps(ttC + 16);
986                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
987                 }
988 
989                 _mm256_storeu_ps(tC, ymm8);
990                 _mm256_storeu_ps(tC + 8, ymm9);
991                 _mm256_storeu_ps(tC + 16, ymm10);
992 
993                 tC += ldc;
994 
995                 _mm256_storeu_ps(tC, ymm12);
996                 _mm256_storeu_ps(tC + 8, ymm13);
997                 _mm256_storeu_ps(tC + 16, ymm14);
998 
999                 col_idx += 2;
1000             }
1001             // if the N is not multiple of 3.
1002             // handling edge case.
1003             if (n_remainder == 1)
1004             {
1005                 //pointer math to point to proper memory
1006                 tC = C + ldc * col_idx + row_idx;
1007                 tB = B + tb_inc_col * col_idx;
1008                 tA = A + row_idx;
1009 
1010                 // clear scratch registers.
1011                 ymm12 = _mm256_setzero_ps();
1012                 ymm13 = _mm256_setzero_ps();
1013                 ymm14 = _mm256_setzero_ps();
1014 
1015                 for (k = 0; k < K; ++k)
1016                 {
1017                     // The inner loop broadcasts the B matrix data and
1018                     // multiplies it with the A matrix.
1019                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1020                     tB += tb_inc_row;
1021 
1022                     //broadcasted matrix B elements are multiplied
1023                     //with matrix A columns.
1024                     ymm3 = _mm256_loadu_ps(tA);
1025                     ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12);
1026 
1027                     ymm3 = _mm256_loadu_ps(tA + 8);
1028                     ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
1029 
1030                     ymm3 = _mm256_loadu_ps(tA + 16);
1031                     ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);
1032 
1033                     tA += lda;
1034 
1035                 }
1036                 // alpha, beta multiplication.
1037                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1038 
1039                 //multiply A*B by alpha.
1040                 ymm12 = _mm256_mul_ps(ymm12, ymm0);
1041                 ymm13 = _mm256_mul_ps(ymm13, ymm0);
1042                 ymm14 = _mm256_mul_ps(ymm14, ymm0);
1043 
1044                 if(is_beta_non_zero)
1045                 {
1046                     ymm1 = _mm256_broadcast_ss(beta_cast);
1047                     // multiply C by beta and accumulate.
1048                     ymm2 = _mm256_loadu_ps(tC + 0);
1049                     ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);
1050                     ymm2 = _mm256_loadu_ps(tC + 8);
1051                     ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);
1052                     ymm2 = _mm256_loadu_ps(tC + 16);
1053                     ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);
1054                 }
1055                 _mm256_storeu_ps(tC + 0, ymm12);
1056                 _mm256_storeu_ps(tC + 8, ymm13);
1057                 _mm256_storeu_ps(tC + 16, ymm14);
1058             }
1059 
1060             row_idx += 24;
1061         }
1062 
1063         if (m_remainder >= 16)
1064         {
1065             m_remainder -= 16;
1066 
1067             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1068             {
1069                 //pointer math to point to proper memory
1070                 tC = C + ldc * col_idx + row_idx;
1071                 tB = B + tb_inc_col * col_idx;
1072                 tA = A + row_idx;
1073 
1074                 // clear scratch registers.
1075                 ymm4 = _mm256_setzero_ps();
1076                 ymm5 = _mm256_setzero_ps();
1077                 ymm6 = _mm256_setzero_ps();
1078                 ymm7 = _mm256_setzero_ps();
1079                 ymm8 = _mm256_setzero_ps();
1080                 ymm9 = _mm256_setzero_ps();
1081 
1082                 for (k = 0; k < K; ++k)
1083                 {
1084                     // The inner loop broadcasts the B matrix data and
1085                     // multiplies it with the A matrix.
1086                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1087                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1088                     ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1089                     tB += tb_inc_row;
1090 
1091                     //broadcasted matrix B elements are multiplied
1092                     //with matrix A columns.
1093                     ymm3 = _mm256_loadu_ps(tA);
1094                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1095                     ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
1096                     ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8);
1097 
1098                     ymm3 = _mm256_loadu_ps(tA + 8);
1099                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1100                     ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1101                     ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1102 
1103                     tA += lda;
1104                 }
1105                 // alpha, beta multiplication.
1106                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1107 
1108                 //multiply A*B by alpha.
1109                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1110                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1111                 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1112                 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1113                 ymm8 = _mm256_mul_ps(ymm8, ymm0);
1114                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
1115 
1116                 if(is_beta_non_zero)
1117                 {
1118                     ymm1 = _mm256_broadcast_ss(beta_cast);
1119                     // multiply C by beta and accumulate.
1120                     ymm2 = _mm256_loadu_ps(tC);
1121                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1122                     ymm2 = _mm256_loadu_ps(tC + 8);
1123                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1124                     float* ttC = tC + ldc;
1125                     ymm2 = _mm256_loadu_ps(ttC);
1126                     ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1127                     ymm2 = _mm256_loadu_ps(ttC + 8);
1128                     ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1129                     ttC += ldc;
1130                     ymm2 = _mm256_loadu_ps(ttC);
1131                     ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);
1132                     ymm2 = _mm256_loadu_ps(ttC + 8);
1133                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
1134                 }
1135                 _mm256_storeu_ps(tC, ymm4);
1136                 _mm256_storeu_ps(tC + 8, ymm5);
1137 
1138                 // multiply C by beta and accumulate.
1139                 tC += ldc;
1140                 _mm256_storeu_ps(tC, ymm6);
1141                 _mm256_storeu_ps(tC + 8, ymm7);
1142 
1143                 // multiply C by beta and accumulate.
1144                 tC += ldc;
1145                 _mm256_storeu_ps(tC, ymm8);
1146                 _mm256_storeu_ps(tC + 8, ymm9);
1147 
1148             }
1149             n_remainder = N - col_idx;
1150             // if the N is not multiple of 3.
1151             // handling edge case.
1152             if (n_remainder == 2)
1153             {
1154                 //pointer math to point to proper memory
1155                 tC = C + ldc * col_idx + row_idx;
1156                 tB = B + tb_inc_col * col_idx;
1157                 tA = A + row_idx;
1158 
1159                 // clear scratch registers.
1160                 ymm4 = _mm256_setzero_ps();
1161                 ymm5 = _mm256_setzero_ps();
1162                 ymm6 = _mm256_setzero_ps();
1163                 ymm7 = _mm256_setzero_ps();
1164 
1165                 for (k = 0; k < K; ++k)
1166                 {
1167                     // The inner loop broadcasts the B matrix data and
1168                     // multiplies it with the A matrix.
1169                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1170                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1171                     tB += tb_inc_row;
1172 
1173                     //broadcasted matrix B elements are multiplied
1174                     //with matrix A columns.
1175                     ymm3 = _mm256_loadu_ps(tA);
1176                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1177                     ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6);
1178 
1179                     ymm3 = _mm256_loadu_ps(tA + 8);
1180                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1181                     ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1182 
1183                     tA += lda;
1184                 }
1185                 // alpha, beta multiplication.
1186                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1187 
1188                 //multiply A*B by alpha.
1189                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1190                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1191                 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1192                 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1193 
1194                 if(is_beta_non_zero)
1195                 {
1196                     ymm1 = _mm256_broadcast_ss(beta_cast);
1197                     // multiply C by beta and accumulate.
1198                     ymm2 = _mm256_loadu_ps(tC);
1199                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1200                     ymm2 = _mm256_loadu_ps(tC + 8);
1201                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1202                     float* ttC = tC + ldc;
1203                     ymm2 = _mm256_loadu_ps(ttC);
1204                     ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1205                     ymm2 = _mm256_loadu_ps(ttC + 8);
1206                     ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1207                 }
1208                 _mm256_storeu_ps(tC, ymm4);
1209                 _mm256_storeu_ps(tC + 8, ymm5);
1210 
1211                 // multiply C by beta and accumulate.
1212                 tC += ldc;
1213                 _mm256_storeu_ps(tC, ymm6);
1214                 _mm256_storeu_ps(tC + 8, ymm7);
1215 
1216                 col_idx += 2;
1217 
1218             }
1219             // if the N is not multiple of 3.
1220             // handling edge case.
1221             if (n_remainder == 1)
1222             {
1223                 //pointer math to point to proper memory
1224                 tC = C + ldc * col_idx + row_idx;
1225                 tB = B + tb_inc_col * col_idx;
1226                 tA = A + row_idx;
1227 
1228                 ymm4 = _mm256_setzero_ps();
1229                 ymm5 = _mm256_setzero_ps();
1230 
1231                 for (k = 0; k < K; ++k)
1232                 {
1233                     // The inner loop broadcasts the B matrix data and
1234                     // multiplies it with the A matrix.
1235                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1236                     tB += tb_inc_row;
1237 
1238                     //broadcasted matrix B elements are multiplied
1239                     //with matrix A columns.
1240                     ymm3 = _mm256_loadu_ps(tA);
1241                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1242 
1243                     ymm3 = _mm256_loadu_ps(tA + 8);
1244                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1245 
1246                     tA += lda;
1247                 }
1248                 // alpha, beta multiplication.
1249                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1250 
1251                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1252                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1253 
1254                 // multiply C by beta and accumulate.
1255                 if(is_beta_non_zero)
1256                 {
1257                     ymm1 = _mm256_broadcast_ss(beta_cast);
1258                     ymm2 = _mm256_loadu_ps(tC);
1259                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1260                     ymm2 = _mm256_loadu_ps(tC + 8);
1261                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1262                 }
1263                 _mm256_storeu_ps(tC, ymm4);
1264                 _mm256_storeu_ps(tC + 8, ymm5);
1265 
1266             }
1267 
1268             row_idx += 16;
1269         }
1270 
1271         if (m_remainder >= 8)
1272         {
1273             m_remainder -= 8;
1274 
1275             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1276             {
1277                 //pointer math to point to proper memory
1278                 tC = C + ldc * col_idx + row_idx;
1279                 tB = B + tb_inc_col * col_idx;
1280                 tA = A + row_idx;
1281 
1282                 // clear scratch registers.
1283                 ymm4 = _mm256_setzero_ps();
1284                 ymm5 = _mm256_setzero_ps();
1285                 ymm6 = _mm256_setzero_ps();
1286 
1287                 for (k = 0; k < K; ++k)
1288                 {
1289                     // The inner loop broadcasts the B matrix data and
1290                     // multiplies it with the A matrix.
1291                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1292                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1293                     ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1294                     tB += tb_inc_row;
1295 
1296                     //broadcasted matrix B elements are multiplied
1297                     //with matrix A columns.
1298                     ymm3 = _mm256_loadu_ps(tA);
1299                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1300                     ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
1301                     ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
1302 
1303                     tA += lda;
1304                 }
1305                 // alpha, beta multiplication.
1306                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1307 
1308                 //multiply A*B by alpha.
1309                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1310                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1311                 ymm6 = _mm256_mul_ps(ymm6, ymm0);
1312 
1313                 if(is_beta_non_zero)
1314                 {
1315                     ymm1 = _mm256_broadcast_ss(beta_cast);
1316                     ymm2 = _mm256_loadu_ps(tC);
1317                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1318                     ymm2 = _mm256_loadu_ps(tC + ldc);
1319                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1320                     ymm2 = _mm256_loadu_ps(tC + 2*ldc);
1321                     ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);
1322                 }
1323                 _mm256_storeu_ps(tC, ymm4);
1324 
1325                 // multiply C by beta and accumulate.
1326                 tC += ldc;
1327                 _mm256_storeu_ps(tC, ymm5);
1328 
1329                 // multiply C by beta and accumulate.
1330                 tC += ldc;
1331                 _mm256_storeu_ps(tC, ymm6);
1332             }
1333             n_remainder = N - col_idx;
1334             // if the N is not multiple of 3.
1335             // handling edge case.
1336             if (n_remainder == 2)
1337             {
1338                 //pointer math to point to proper memory
1339                 tC = C + ldc * col_idx + row_idx;
1340                 tB = B + tb_inc_col * col_idx;
1341                 tA = A + row_idx;
1342 
1343                 ymm4 = _mm256_setzero_ps();
1344                 ymm5 = _mm256_setzero_ps();
1345 
1346                 for (k = 0; k < K; ++k)
1347                 {
1348                     // The inner loop broadcasts the B matrix data and
1349                     // multiplies it with the A matrix.
1350                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1351                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1352                     tB += tb_inc_row;
1353 
1354                     //broadcasted matrix B elements are multiplied
1355                     //with matrix A columns.
1356                     ymm3 = _mm256_loadu_ps(tA);
1357                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1358                     ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
1359 
1360                     tA += lda;
1361                 }
1362                 // alpha, beta multiplication.
1363                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1364 
1365                 //multiply A*B by alpha.
1366                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1367                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1368 
1369                 if(is_beta_non_zero)
1370                 {
1371                     ymm1 = _mm256_broadcast_ss(beta_cast);
1372                     // multiply C by beta and accumulate.
1373                     ymm2 = _mm256_loadu_ps(tC);
1374                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1375                     ymm2 = _mm256_loadu_ps(tC + ldc);
1376                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1377                 }
1378                 _mm256_storeu_ps(tC, ymm4);
1379                 // multiply C by beta and accumulate.
1380                 tC += ldc;
1381                 _mm256_storeu_ps(tC, ymm5);
1382 
1383                 col_idx += 2;
1384 
1385             }
1386             // if the N is not multiple of 3.
1387             // handling edge case.
1388             if (n_remainder == 1)
1389             {
1390                 //pointer math to point to proper memory
1391                 tC = C + ldc * col_idx + row_idx;
1392                 tB = B + tb_inc_col * col_idx;
1393                 tA = A + row_idx;
1394 
1395                 ymm4 = _mm256_setzero_ps();
1396 
1397                 for (k = 0; k < K; ++k)
1398                 {
1399                     // The inner loop broadcasts the B matrix data and
1400                     // multiplies it with the A matrix.
1401                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1402                     tB += tb_inc_row;
1403 
1404                     //broadcasted matrix B elements are multiplied
1405                     //with matrix A columns.
1406                     ymm3 = _mm256_loadu_ps(tA);
1407                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
1408 
1409                     tA += lda;
1410                 }
1411                 // alpha, beta multiplication.
1412                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1413                 ymm4 = _mm256_mul_ps(ymm4, ymm0);
1414 
1415                 if(is_beta_non_zero)
1416                 {
1417                     ymm1 = _mm256_broadcast_ss(beta_cast);
1418                     // multiply C by beta and accumulate.
1419                     ymm2 = _mm256_loadu_ps(tC);
1420                     ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);
1421                 }
1422                 _mm256_storeu_ps(tC, ymm4);
1423 
1424             }
1425 
1426             row_idx += 8;
1427         }
1428         // M is not a multiple of 32.
1429         // The handling of edge case where the remainder
1430         // dimension is less than 8. The padding takes place
1431         // to handle this case.
1432         if ((m_remainder) && (lda > 7))
1433         {
1434             float f_temp[8] = {0.0};
1435 
1436             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
1437             {
1438                 //pointer math to point to proper memory
1439                 tC = C + ldc * col_idx + row_idx;
1440                 tB = B + tb_inc_col * col_idx;
1441                 tA = A + row_idx;
1442 
1443                 // clear scratch registers.
1444                 ymm5 = _mm256_setzero_ps();
1445                 ymm7 = _mm256_setzero_ps();
1446                 ymm9 = _mm256_setzero_ps();
1447 
1448                 for (k = 0; k < (K - 1); ++k)
1449                 {
1450                     // The inner loop broadcasts the B matrix data and
1451                     // multiplies it with the A matrix.
1452                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1453                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1454                     ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1455                     tB += tb_inc_row;
1456 
1457                     //broadcasted matrix B elements are multiplied
1458                     //with matrix A columns.
1459                     ymm3 = _mm256_loadu_ps(tA);
1460                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1461                     ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1462                     ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1463 
1464                     tA += lda;
1465                 }
1466                 // alpha, beta multiplication.
1467                 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1468                 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1469                 ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2);
1470                 tB += tb_inc_row;
1471 
1472                 for (int i = 0; i < m_remainder; i++)
1473                 {
1474                     f_temp[i] = tA[i];
1475                 }
1476                 ymm3 = _mm256_loadu_ps(f_temp);
1477                 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1478                 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1479                 ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
1480 
1481                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1482                 ymm1 = _mm256_broadcast_ss(beta_cast);
1483 
1484                 //multiply A*B by alpha.
1485                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1486                 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1487                 ymm9 = _mm256_mul_ps(ymm9, ymm0);
1488 
1489 
1490                 for (int i = 0; i < m_remainder; i++)
1491                 {
1492                     f_temp[i] = tC[i];
1493                 }
1494                 ymm2 = _mm256_loadu_ps(f_temp);
1495                 if(is_beta_non_zero){
1496                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1497                 }
1498                 _mm256_storeu_ps(f_temp, ymm5);
1499                 for (int i = 0; i < m_remainder; i++)
1500                 {
1501                     tC[i] = f_temp[i];
1502                 }
1503 
1504                 tC += ldc;
1505                 for (int i = 0; i < m_remainder; i++)
1506                 {
1507                     f_temp[i] = tC[i];
1508                 }
1509                 ymm2 = _mm256_loadu_ps(f_temp);
1510                 if(is_beta_non_zero){
1511                     ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1512                 }
1513                 _mm256_storeu_ps(f_temp, ymm7);
1514                 for (int i = 0; i < m_remainder; i++)
1515                 {
1516                     tC[i] = f_temp[i];
1517                 }
1518 
1519                 tC += ldc;
1520                 for (int i = 0; i < m_remainder; i++)
1521                 {
1522                     f_temp[i] = tC[i];
1523                 }
1524                 ymm2 = _mm256_loadu_ps(f_temp);
1525                 if(is_beta_non_zero){
1526                     ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);
1527                 }
1528                 _mm256_storeu_ps(f_temp, ymm9);
1529                 for (int i = 0; i < m_remainder; i++)
1530                 {
1531                     tC[i] = f_temp[i];
1532                 }
1533             }
1534             n_remainder = N - col_idx;
1535             // if the N is not multiple of 3.
1536             // handling edge case.
1537             if (n_remainder == 2)
1538             {
1539                 //pointer math to point to proper memory
1540                 tC = C + ldc * col_idx + row_idx;
1541                 tB = B + tb_inc_col * col_idx;
1542                 tA = A + row_idx;
1543 
1544                 ymm5 = _mm256_setzero_ps();
1545                 ymm7 = _mm256_setzero_ps();
1546 
1547                 for (k = 0; k < (K - 1); ++k)
1548                 {
1549                     // The inner loop broadcasts the B matrix data and
1550                     // multiplies it with the A matrix.
1551                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1552                     ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1553                     tB += tb_inc_row;
1554 
1555                     ymm3 = _mm256_loadu_ps(tA);
1556                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1557                     ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1558 
1559                     tA += lda;
1560                 }
1561 
1562                 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1563                 ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1);
1564                 tB += tb_inc_row;
1565 
1566                 for (int i = 0; i < m_remainder; i++)
1567                 {
1568                     f_temp[i] = tA[i];
1569                 }
1570                 ymm3 = _mm256_loadu_ps(f_temp);
1571                 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1572                 ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7);
1573 
1574                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1575                 ymm1 = _mm256_broadcast_ss(beta_cast);
1576 
1577                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1578                 ymm7 = _mm256_mul_ps(ymm7, ymm0);
1579 
1580                 for (int i = 0; i < m_remainder; i++)
1581                 {
1582                     f_temp[i] = tC[i];
1583                 }
1584                 ymm2 = _mm256_loadu_ps(f_temp);
1585                 if(is_beta_non_zero){
1586                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1587                 }
1588                 _mm256_storeu_ps(f_temp, ymm5);
1589                 for (int i = 0; i < m_remainder; i++)
1590                 {
1591                     tC[i] = f_temp[i];
1592                 }
1593 
1594                 tC += ldc;
1595                 for (int i = 0; i < m_remainder; i++)
1596                 {
1597                     f_temp[i] = tC[i];
1598                 }
1599                 ymm2 = _mm256_loadu_ps(f_temp);
1600                 if(is_beta_non_zero){
1601                     ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);
1602                 }
1603 	    	_mm256_storeu_ps(f_temp, ymm7);
1604                 for (int i = 0; i < m_remainder; i++)
1605                 {
1606                     tC[i] = f_temp[i];
1607                 }
1608             }
1609             // if the N is not multiple of 3.
1610             // handling edge case.
1611             if (n_remainder == 1)
1612             {
1613                 //pointer math to point to proper memory
1614                 tC = C + ldc * col_idx + row_idx;
1615                 tB = B + tb_inc_col * col_idx;
1616                 tA = A + row_idx;
1617 
1618                 ymm5 = _mm256_setzero_ps();
1619 
1620                 for (k = 0; k < (K - 1); ++k)
1621                 {
1622                     // The inner loop broadcasts the B matrix data and
1623                     // multiplies it with the A matrix.
1624                     ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1625                     tB += tb_inc_row;
1626 
1627                     ymm3 = _mm256_loadu_ps(tA);
1628                     ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1629 
1630                     tA += lda;
1631                 }
1632 
1633                 ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0);
1634                 tB += tb_inc_row;
1635 
1636                 for (int i = 0; i < m_remainder; i++)
1637                 {
1638                     f_temp[i] = tA[i];
1639                 }
1640                 ymm3 = _mm256_loadu_ps(f_temp);
1641                 ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5);
1642 
1643                 ymm0 = _mm256_broadcast_ss(alpha_cast);
1644 
1645                 // multiply C by beta and accumulate.
1646                 ymm5 = _mm256_mul_ps(ymm5, ymm0);
1647 
1648                 for (int i = 0; i < m_remainder; i++)
1649                 {
1650                     f_temp[i] = tC[i];
1651                 }
1652                 ymm2 = _mm256_loadu_ps(f_temp);
1653                 if(is_beta_non_zero){
1654                     ymm1 = _mm256_broadcast_ss(beta_cast);
1655                     ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);
1656                 }
1657                 _mm256_storeu_ps(f_temp, ymm5);
1658                 for (int i = 0; i < m_remainder; i++)
1659                 {
1660                     tC[i] = f_temp[i];
1661                 }
1662             }
1663             m_remainder = 0;
1664         }
1665 
1666         if (m_remainder)
1667         {
1668             float result;
1669             for (; row_idx < M; row_idx += 1)
1670             {
1671                 for (col_idx = 0; col_idx < N; col_idx += 1)
1672                 {
1673                     //pointer math to point to proper memory
1674                     tC = C + ldc * col_idx + row_idx;
1675                     tB = B + tb_inc_col * col_idx;
1676                     tA = A + row_idx;
1677 
1678                     result = 0;
1679                     for (k = 0; k < K; ++k)
1680                     {
1681                         result += (*tA) * (*tB);
1682                         tA += lda;
1683                         tB += tb_inc_row;
1684                     }
1685 
1686                     result *= (*alpha_cast);
1687                     if(is_beta_non_zero){
1688                         (*tC) = (*tC) * (*beta_cast) + result;
1689                     }else{
1690                         (*tC) = result;
1691                     }
1692                 }
1693             }
1694         }
1695 
1696         // Return the buffer to pool
1697         if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s) ) {
1698 
1699 #ifdef BLIS_ENABLE_MEM_TRACING
1700         printf( "bli_sgemm_small(): releasing mem pool block\n" );
1701 #endif
1702             bli_membrk_release(&rntm,
1703                                &local_mem_buf_A_s);
1704         }
1705 
1706 		AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7);
1707         return BLIS_SUCCESS;
1708     }
1709     else
1710 	{
1711 		AOCL_DTL_TRACE_EXIT_ERR(
1712 			AOCL_DTL_LEVEL_INFO,
1713 			"Invalid dimesions for small gemm."
1714 			);
1715 		return BLIS_NONCONFORMAL_DIMENSIONS;
1716 	}
1717 
1718 };
1719 
bli_dgemm_small(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)1720 static err_t bli_dgemm_small
1721      (
1722        obj_t*  alpha,
1723        obj_t*  a,
1724        obj_t*  b,
1725        obj_t*  beta,
1726        obj_t*  c,
1727        cntx_t* cntx,
1728        cntl_t* cntl
1729      )
1730 {
1731 
1732 	AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
1733 
1734     gint_t M = bli_obj_length( c ); // number of rows of Matrix C
1735     gint_t N = bli_obj_width( c );  // number of columns of Matrix C
1736     gint_t K = bli_obj_width( a );  // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
1737     gint_t L = M * N;
1738 
1739     // when N is equal to 1 call GEMV instead of GEMM
1740     if (N == 1)
1741     {
1742         bli_gemv
1743         (
1744             alpha,
1745             a,
1746             b,
1747             beta,
1748             c
1749         );
1750 		AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
1751         return BLIS_SUCCESS;
1752     }
1753 
1754     if (N<3) //Implemenation assumes that N is atleast 3.
1755 	{
1756 		AOCL_DTL_TRACE_EXIT_ERR(
1757 			AOCL_DTL_LEVEL_INFO,
1758 			"N < 3, cannot be processed by small gemm"
1759 			);
1760         return BLIS_NOT_YET_IMPLEMENTED;
1761 	}
1762 
1763 #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
1764     if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME))))
1765 #else
1766     if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
1767         || ((M  < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
1768 #endif
1769     {
1770         guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
1771         guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
1772         guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
1773         guint_t row_idx, col_idx, k;
1774         double *A = bli_obj_buffer_at_off(a); // pointer to elements of Matrix A
1775         double *B = bli_obj_buffer_at_off(b); // pointer to elements of Matrix B
1776         double *C = bli_obj_buffer_at_off(c); // pointer to elements of Matrix C
1777 
1778         double *tA = A, *tB = B, *tC = C;//, *tA_pack;
1779         double *tA_packed; // temprorary pointer to hold packed A memory pointer
1780         guint_t row_idx_packed; //packed A memory row index
1781         guint_t lda_packed; //lda of packed A
1782         guint_t col_idx_start; //starting index after A matrix is packed.
1783         dim_t tb_inc_row = 1; // row stride of matrix B
1784         dim_t tb_inc_col = ldb; // column stride of matrix B
1785         __m256d ymm4, ymm5, ymm6, ymm7;
1786         __m256d ymm8, ymm9, ymm10, ymm11;
1787         __m256d ymm12, ymm13, ymm14, ymm15;
1788         __m256d ymm0, ymm1, ymm2, ymm3;
1789 
1790         gint_t n_remainder; // If the N is non multiple of 3.(N%3)
1791         gint_t m_remainder; // If the M is non multiple of 16.(M%16)
1792 
1793         double *alpha_cast, *beta_cast; // alpha, beta multiples
1794         alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha);
1795         beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta);
1796 
1797         gint_t required_packing_A = 1;
1798         mem_t local_mem_buf_A_s;
1799         double *D_A_pack = NULL;
1800         rntm_t rntm;
1801 
1802         //update the pointer math if matrix B needs to be transposed.
1803         if (bli_obj_has_trans( b ))
1804         {
1805             tb_inc_col = 1; //switch row and column strides
1806             tb_inc_row = ldb;
1807         }
1808 
1809         //checking whether beta value is zero.
1810         //if true, we should perform C=alpha * A*B operation
1811         //instead of C = beta * C + alpha * (A * B)
1812         bool is_beta_non_zero = 0;
1813         if(!bli_obj_equals(beta, &BLIS_ZERO))
1814                 is_beta_non_zero = 1;
1815 
1816         /*
1817          * This function was using global array to pack part of A input when needed.
1818          * However, using this global array make the function non-reentrant.
1819          * Instead of using a global array we should allocate buffer for each invocation.
1820          * Since the buffer size is too big or stack and doing malloc every time will be too expensive,
1821          * better approach is to get the buffer from the pre-allocated pool and return
1822          * it the pool once we are doing.
1823          *
1824          * In order to get the buffer from pool, we need access to memory broker,
1825          * currently this function is not invoked in such a way that it can receive
1826          * the memory broker (via rntm). Following hack will get the global memory
1827          * broker that can be use it to access the pool.
1828          *
1829          * Note there will be memory allocation at least on first innovation
1830          * as there will not be any pool created for this size.
1831          * Subsequent invocations will just reuse the buffer from the pool.
1832          */
1833 
1834         bli_rntm_init_from_global( &rntm );
1835         bli_rntm_set_num_threads_only( 1, &rntm );
1836         bli_membrk_rntm_set_membrk( &rntm );
1837 
1838         // Get the current size of the buffer pool for A block packing.
1839         // We will use the same size to avoid pool re-initliazaton
1840         siz_t buffer_size = bli_pool_block_size(
1841             bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK),
1842                             bli_rntm_membrk(&rntm)));
1843 
1844         //
1845         // This kernel assumes that "A" will be unpackged if N <= 3.
1846         // Usually this range (N <= 3) is handled by SUP, however,
1847         // if SUP is disabled or for any other condition if we do
1848         // enter this kernel with N <= 3, we want to make sure that
1849         // "A" remains unpacked.
1850         //
1851         // If this check is removed it will result in the crash as
1852         // reported in CPUPL-587.
1853         //
1854 
1855         if ((N <= 3) || ((D_MR * K) << 3) > buffer_size)
1856         {
1857             required_packing_A = 0;
1858         }
1859 
1860         if (required_packing_A == 1)
1861         {
1862 #ifdef BLIS_ENABLE_MEM_TRACING
1863             printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", buffer_size);
1864 #endif
1865             // Get the buffer from the pool.
1866             bli_membrk_acquire_m(&rntm,
1867                                  buffer_size,
1868                                  BLIS_BITVAL_BUFFER_FOR_A_BLOCK,
1869                                  &local_mem_buf_A_s);
1870 
1871             D_A_pack = bli_mem_buffer(&local_mem_buf_A_s);
1872         }
1873 
1874         /*
1875         * The computation loop runs for D_MRxN columns of C matrix, thus
1876         * accessing the D_MRxK A matrix data and KxNR B matrix data.
1877         * The computation is organized as inner loops of dimension D_MRxNR.
1878         */
1879         // Process D_MR rows of C matrix at a time.
1880         for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR)
1881         {
1882 
1883             col_idx_start = 0;
1884             tA_packed = A;
1885             row_idx_packed = row_idx;
1886             lda_packed = lda;
1887 
1888             // This is the part of the pack and compute optimization.
1889             // During the first column iteration, we store the accessed A matrix into
1890             // contiguous static memory. This helps to keep te A matrix in Cache and
1891             // aviods the TLB misses.
1892             if (required_packing_A)
1893             {
1894                 col_idx = 0;
1895 
1896                 //pointer math to point to proper memory
1897                 tC = C + ldc * col_idx + row_idx;
1898                 tB = B + tb_inc_col * col_idx;
1899                 tA = A + row_idx;
1900                 tA_packed = D_A_pack;
1901 
1902 #ifdef BLIS_ENABLE_PREFETCH
1903                 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
1904                 _mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
1905                 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
1906                 _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
1907                 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
1908                 _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
1909 #endif
1910                 // clear scratch registers.
1911                 ymm4 = _mm256_setzero_pd();
1912                 ymm5 = _mm256_setzero_pd();
1913                 ymm6 = _mm256_setzero_pd();
1914                 ymm7 = _mm256_setzero_pd();
1915                 ymm8 = _mm256_setzero_pd();
1916                 ymm9 = _mm256_setzero_pd();
1917                 ymm10 = _mm256_setzero_pd();
1918                 ymm11 = _mm256_setzero_pd();
1919                 ymm12 = _mm256_setzero_pd();
1920                 ymm13 = _mm256_setzero_pd();
1921                 ymm14 = _mm256_setzero_pd();
1922                 ymm15 = _mm256_setzero_pd();
1923 
1924                 for (k = 0; k < K; ++k)
1925                 {
1926                     // The inner loop broadcasts the B matrix data and
1927                     // multiplies it with the A matrix.
1928                     // This loop is processing D_MR x K
1929                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
1930                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
1931                     ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
1932                     tB += tb_inc_row;
1933 
1934                     //broadcasted matrix B elements are multiplied
1935                     //with matrix A columns.
1936                     ymm3 = _mm256_loadu_pd(tA);
1937                     _mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A
1938                                                        //                   ymm4 += ymm0 * ymm3;
1939                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
1940                     //                    ymm8 += ymm1 * ymm3;
1941                     ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
1942                     //                    ymm12 += ymm2 * ymm3;
1943                     ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
1944 
1945                     ymm3 = _mm256_loadu_pd(tA + 4);
1946                     _mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A
1947                                                            //                    ymm5 += ymm0 * ymm3;
1948                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
1949                     //                    ymm9 += ymm1 * ymm3;
1950                     ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
1951                     //                    ymm13 += ymm2 * ymm3;
1952                     ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
1953 
1954                     ymm3 = _mm256_loadu_pd(tA + 8);
1955                     _mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A
1956                                                            //                   ymm6 += ymm0 * ymm3;
1957                     ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
1958                     //                    ymm10 += ymm1 * ymm3;
1959                     ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
1960                     //                    ymm14 += ymm2 * ymm3;
1961                     ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
1962 
1963                     ymm3 = _mm256_loadu_pd(tA + 12);
1964                     _mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A
1965                                                             //                    ymm7 += ymm0 * ymm3;
1966                     ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
1967                     //                    ymm11 += ymm1 * ymm3;
1968                     ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
1969                     //                   ymm15 += ymm2 * ymm3;
1970                     ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
1971 
1972                     tA += lda;
1973                     tA_packed += D_MR;
1974                 }
1975                 // alpha, beta multiplication.
1976                 ymm0 = _mm256_broadcast_sd(alpha_cast);
1977                 ymm1 = _mm256_broadcast_sd(beta_cast);
1978 
1979                 //multiply A*B by alpha.
1980                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
1981                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
1982                 ymm6 = _mm256_mul_pd(ymm6, ymm0);
1983                 ymm7 = _mm256_mul_pd(ymm7, ymm0);
1984                 ymm8 = _mm256_mul_pd(ymm8, ymm0);
1985                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
1986                 ymm10 = _mm256_mul_pd(ymm10, ymm0);
1987                 ymm11 = _mm256_mul_pd(ymm11, ymm0);
1988                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
1989                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
1990                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
1991                 ymm15 = _mm256_mul_pd(ymm15, ymm0);
1992 
1993                 if(is_beta_non_zero)
1994                 {
1995                     // multiply C by beta and accumulate col 1.
1996                     ymm2 = _mm256_loadu_pd(tC);
1997                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
1998                     ymm2 = _mm256_loadu_pd(tC + 4);
1999                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2000                     ymm2 = _mm256_loadu_pd(tC + 8);
2001                     ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2002                     ymm2 = _mm256_loadu_pd(tC + 12);
2003                     ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2004 
2005                     double* ttC = tC + ldc;
2006 
2007                     // multiply C by beta and accumulate, col 2.
2008                     ymm2 = _mm256_loadu_pd(ttC);
2009                     ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2010                     ymm2 = _mm256_loadu_pd(ttC + 4);
2011                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2012                     ymm2 = _mm256_loadu_pd(ttC + 8);
2013                     ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2014                     ymm2 = _mm256_loadu_pd(ttC + 12);
2015                     ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
2016 
2017                     ttC += ldc;
2018 
2019                     // multiply C by beta and accumulate, col 3.
2020                     ymm2 = _mm256_loadu_pd(ttC);
2021                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2022                     ymm2 = _mm256_loadu_pd(ttC + 4);
2023                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2024                     ymm2 = _mm256_loadu_pd(ttC + 8);
2025                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2026                     ymm2 = _mm256_loadu_pd(ttC + 12);
2027                     ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2028                 }
2029                 _mm256_storeu_pd(tC, ymm4);
2030                 _mm256_storeu_pd(tC + 4, ymm5);
2031                 _mm256_storeu_pd(tC + 8, ymm6);
2032                 _mm256_storeu_pd(tC + 12, ymm7);
2033 
2034                 tC += ldc;
2035 
2036                 _mm256_storeu_pd(tC, ymm8);
2037                 _mm256_storeu_pd(tC + 4, ymm9);
2038                 _mm256_storeu_pd(tC + 8, ymm10);
2039                 _mm256_storeu_pd(tC + 12, ymm11);
2040 
2041                 tC += ldc;
2042 
2043                 _mm256_storeu_pd(tC, ymm12);
2044                 _mm256_storeu_pd(tC + 4, ymm13);
2045                 _mm256_storeu_pd(tC + 8, ymm14);
2046                 _mm256_storeu_pd(tC + 12, ymm15);
2047 
2048                 // modify the pointer arithematic to use packed A matrix.
2049                 col_idx_start = NR;
2050                 tA_packed = D_A_pack;
2051                 row_idx_packed = 0;
2052                 lda_packed = D_MR;
2053             }
2054             // Process NR columns of C matrix at a time.
2055             for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR)
2056             {
2057                 //pointer math to point to proper memory
2058                 tC = C + ldc * col_idx + row_idx;
2059                 tB = B + tb_inc_col * col_idx;
2060                 tA = tA_packed + row_idx_packed;
2061 
2062 #ifdef BLIS_ENABLE_PREFETCH
2063                 _mm_prefetch((char*)(tC + 0), _MM_HINT_T0);
2064                 _mm_prefetch((char*)(tC + 8), _MM_HINT_T0);
2065                 _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0);
2066                 _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0);
2067                 _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0);
2068                 _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0);
2069 #endif
2070                 // clear scratch registers.
2071                 ymm4 = _mm256_setzero_pd();
2072                 ymm5 = _mm256_setzero_pd();
2073                 ymm6 = _mm256_setzero_pd();
2074                 ymm7 = _mm256_setzero_pd();
2075                 ymm8 = _mm256_setzero_pd();
2076                 ymm9 = _mm256_setzero_pd();
2077                 ymm10 = _mm256_setzero_pd();
2078                 ymm11 = _mm256_setzero_pd();
2079                 ymm12 = _mm256_setzero_pd();
2080                 ymm13 = _mm256_setzero_pd();
2081                 ymm14 = _mm256_setzero_pd();
2082                 ymm15 = _mm256_setzero_pd();
2083 
2084                 for (k = 0; k < K; ++k)
2085                 {
2086                     // The inner loop broadcasts the B matrix data and
2087                     // multiplies it with the A matrix.
2088                     // This loop is processing D_MR x K
2089                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2090                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2091                     ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2092                     tB += tb_inc_row;
2093 
2094                     //broadcasted matrix B elements are multiplied
2095                     //with matrix A columns.
2096                     ymm3 = _mm256_loadu_pd(tA);
2097                     //                   ymm4 += ymm0 * ymm3;
2098                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2099                     //                    ymm8 += ymm1 * ymm3;
2100                     ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
2101                     //                    ymm12 += ymm2 * ymm3;
2102                     ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
2103 
2104                     ymm3 = _mm256_loadu_pd(tA + 4);
2105                     //                    ymm5 += ymm0 * ymm3;
2106                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2107                     //                    ymm9 += ymm1 * ymm3;
2108                     ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
2109                     //                    ymm13 += ymm2 * ymm3;
2110                     ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
2111 
2112                     ymm3 = _mm256_loadu_pd(tA + 8);
2113                     //                   ymm6 += ymm0 * ymm3;
2114                     ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
2115                     //                    ymm10 += ymm1 * ymm3;
2116                     ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
2117                     //                    ymm14 += ymm2 * ymm3;
2118                     ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
2119 
2120                     ymm3 = _mm256_loadu_pd(tA + 12);
2121                     //                    ymm7 += ymm0 * ymm3;
2122                     ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
2123                     //                    ymm11 += ymm1 * ymm3;
2124                     ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
2125                     //                   ymm15 += ymm2 * ymm3;
2126                     ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
2127 
2128                     tA += lda_packed;
2129                 }
2130                 // alpha, beta multiplication.
2131                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2132                 ymm1 = _mm256_broadcast_sd(beta_cast);
2133 
2134                 //multiply A*B by alpha.
2135                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2136                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2137                 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2138                 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2139                 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2140                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2141                 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2142                 ymm11 = _mm256_mul_pd(ymm11, ymm0);
2143                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2144                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2145                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2146                 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2147 
2148                 if(is_beta_non_zero)
2149                 {
2150                     // multiply C by beta and accumulate col 1.
2151                     ymm2 = _mm256_loadu_pd(tC);
2152                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2153                     ymm2 = _mm256_loadu_pd(tC + 4);
2154                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2155                     ymm2 = _mm256_loadu_pd(tC + 8);
2156                     ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2157                     ymm2 = _mm256_loadu_pd(tC + 12);
2158                     ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2159 
2160                     // multiply C by beta and accumulate, col 2.
2161                     double* ttC  = tC + ldc;
2162                     ymm2 = _mm256_loadu_pd(ttC);
2163                     ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2164                     ymm2 = _mm256_loadu_pd(ttC + 4);
2165                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2166                     ymm2 = _mm256_loadu_pd(ttC + 8);
2167                     ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2168                     ymm2 = _mm256_loadu_pd(ttC + 12);
2169                     ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
2170 
2171                     // multiply C by beta and accumulate, col 3.
2172                     ttC += ldc;
2173                     ymm2 = _mm256_loadu_pd(ttC);
2174                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2175                     ymm2 = _mm256_loadu_pd(ttC + 4);
2176                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2177                     ymm2 = _mm256_loadu_pd(ttC + 8);
2178                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2179                     ymm2 = _mm256_loadu_pd(ttC + 12);
2180                     ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2181                 }
2182                 _mm256_storeu_pd(tC, ymm4);
2183                 _mm256_storeu_pd(tC + 4, ymm5);
2184                 _mm256_storeu_pd(tC + 8, ymm6);
2185                 _mm256_storeu_pd(tC + 12, ymm7);
2186 
2187                 tC += ldc;
2188 
2189                 _mm256_storeu_pd(tC, ymm8);
2190                 _mm256_storeu_pd(tC + 4, ymm9);
2191                 _mm256_storeu_pd(tC + 8, ymm10);
2192                 _mm256_storeu_pd(tC + 12, ymm11);
2193 
2194                 tC += ldc;
2195 
2196                 _mm256_storeu_pd(tC, ymm12);
2197                 _mm256_storeu_pd(tC + 4, ymm13);
2198                 _mm256_storeu_pd(tC + 8, ymm14);
2199                 _mm256_storeu_pd(tC + 12, ymm15);
2200 
2201             }
2202             n_remainder = N - col_idx;
2203 
2204             // if the N is not multiple of 3.
2205             // handling edge case.
2206             if (n_remainder == 2)
2207             {
2208                 //pointer math to point to proper memory
2209                 tC = C + ldc * col_idx + row_idx;
2210                 tB = B + tb_inc_col * col_idx;
2211                 tA = A + row_idx;
2212 
2213                 // clear scratch registers.
2214                 ymm8 = _mm256_setzero_pd();
2215                 ymm9 = _mm256_setzero_pd();
2216                 ymm10 = _mm256_setzero_pd();
2217                 ymm11 = _mm256_setzero_pd();
2218                 ymm12 = _mm256_setzero_pd();
2219                 ymm13 = _mm256_setzero_pd();
2220                 ymm14 = _mm256_setzero_pd();
2221                 ymm15 = _mm256_setzero_pd();
2222 
2223                 for (k = 0; k < K; ++k)
2224                 {
2225                     // The inner loop broadcasts the B matrix data and
2226                     // multiplies it with the A matrix.
2227                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2228                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2229                     tB += tb_inc_row;
2230 
2231                     //broadcasted matrix B elements are multiplied
2232                     //with matrix A columns.
2233                     ymm3 = _mm256_loadu_pd(tA);
2234                     ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
2235                     ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
2236 
2237                     ymm3 = _mm256_loadu_pd(tA + 4);
2238                     ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
2239                     ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
2240 
2241                     ymm3 = _mm256_loadu_pd(tA + 8);
2242                     ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
2243                     ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
2244 
2245                     ymm3 = _mm256_loadu_pd(tA + 12);
2246                     ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11);
2247                     ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);
2248 
2249                     tA += lda;
2250 
2251                 }
2252                 // alpha, beta multiplication.
2253                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2254                 ymm1 = _mm256_broadcast_sd(beta_cast);
2255 
2256                 //multiply A*B by alpha.
2257                 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2258                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2259                 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2260                 ymm11 = _mm256_mul_pd(ymm11, ymm0);
2261                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2262                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2263                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2264                 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2265 
2266                 if(is_beta_non_zero)
2267                 {
2268                     // multiply C by beta and accumulate, col 1.
2269                     ymm2 = _mm256_loadu_pd(tC + 0);
2270                     ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2271                     ymm2 = _mm256_loadu_pd(tC + 4);
2272                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2273                     ymm2 = _mm256_loadu_pd(tC + 8);
2274                     ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2275                     ymm2 = _mm256_loadu_pd(tC + 12);
2276                     ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);
2277 
2278                     // multiply C by beta and accumulate, col 2.
2279                     double *ttC = tC + ldc;
2280 
2281                     ymm2 = _mm256_loadu_pd(ttC);
2282                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2283                     ymm2 = _mm256_loadu_pd(ttC + 4);
2284                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2285                     ymm2 = _mm256_loadu_pd(ttC + 8);
2286                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2287                     ymm2 = _mm256_loadu_pd(ttC + 12);
2288                     ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2289                 }
2290 
2291                 _mm256_storeu_pd(tC + 0, ymm8);
2292                 _mm256_storeu_pd(tC + 4, ymm9);
2293                 _mm256_storeu_pd(tC + 8, ymm10);
2294                 _mm256_storeu_pd(tC + 12, ymm11);
2295 
2296                 tC += ldc;
2297 
2298                 _mm256_storeu_pd(tC, ymm12);
2299                 _mm256_storeu_pd(tC + 4, ymm13);
2300                 _mm256_storeu_pd(tC + 8, ymm14);
2301                 _mm256_storeu_pd(tC + 12, ymm15);
2302                 col_idx += 2;
2303             }
2304             // if the N is not multiple of 3.
2305             // handling edge case.
2306             if (n_remainder == 1)
2307             {
2308                 //pointer math to point to proper memory
2309                 tC = C + ldc * col_idx + row_idx;
2310                 tB = B + tb_inc_col * col_idx;
2311                 tA = A + row_idx;
2312 
2313                 // clear scratch registers.
2314                 ymm12 = _mm256_setzero_pd();
2315                 ymm13 = _mm256_setzero_pd();
2316                 ymm14 = _mm256_setzero_pd();
2317                 ymm15 = _mm256_setzero_pd();
2318 
2319                 for (k = 0; k < K; ++k)
2320                 {
2321                     // The inner loop broadcasts the B matrix data and
2322                     // multiplies it with the A matrix.
2323                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2324                     tB += tb_inc_row;
2325 
2326                     //broadcasted matrix B elements are multiplied
2327                     //with matrix A columns.
2328                     ymm3 = _mm256_loadu_pd(tA);
2329                     ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
2330 
2331                     ymm3 = _mm256_loadu_pd(tA + 4);
2332                     ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
2333 
2334                     ymm3 = _mm256_loadu_pd(tA + 8);
2335                     ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
2336 
2337                     ymm3 = _mm256_loadu_pd(tA + 12);
2338                     ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15);
2339 
2340                     tA += lda;
2341 
2342                 }
2343                 // alpha, beta multiplication.
2344                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2345                 ymm1 = _mm256_broadcast_sd(beta_cast);
2346 
2347                 //multiply A*B by alpha.
2348                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2349                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2350                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2351                 ymm15 = _mm256_mul_pd(ymm15, ymm0);
2352 
2353                 if(is_beta_non_zero)
2354                 {
2355                     // multiply C by beta and accumulate.
2356                     ymm2 = _mm256_loadu_pd(tC + 0);
2357                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2358                     ymm2 = _mm256_loadu_pd(tC + 4);
2359                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2360                     ymm2 = _mm256_loadu_pd(tC + 8);
2361                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2362                     ymm2 = _mm256_loadu_pd(tC + 12);
2363                     ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);
2364                 }
2365 
2366                 _mm256_storeu_pd(tC + 0, ymm12);
2367                 _mm256_storeu_pd(tC + 4, ymm13);
2368                 _mm256_storeu_pd(tC + 8, ymm14);
2369                 _mm256_storeu_pd(tC + 12, ymm15);
2370             }
2371         }
2372 
2373         m_remainder = M - row_idx;
2374 
2375         if (m_remainder >= 12)
2376         {
2377             m_remainder -= 12;
2378 
2379             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2380             {
2381                 //pointer math to point to proper memory
2382                 tC = C + ldc * col_idx + row_idx;
2383                 tB = B + tb_inc_col * col_idx;
2384                 tA = A + row_idx;
2385 
2386                 // clear scratch registers.
2387                 ymm4 = _mm256_setzero_pd();
2388                 ymm5 = _mm256_setzero_pd();
2389                 ymm6 = _mm256_setzero_pd();
2390                 ymm8 = _mm256_setzero_pd();
2391                 ymm9 = _mm256_setzero_pd();
2392                 ymm10 = _mm256_setzero_pd();
2393                 ymm12 = _mm256_setzero_pd();
2394                 ymm13 = _mm256_setzero_pd();
2395                 ymm14 = _mm256_setzero_pd();
2396 
2397                 for (k = 0; k < K; ++k)
2398                 {
2399                     // The inner loop broadcasts the B matrix data and
2400                     // multiplies it with the A matrix.
2401                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2402                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2403                     ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2404                     tB += tb_inc_row;
2405 
2406                     //broadcasted matrix B elements are multiplied
2407                     //with matrix A columns.
2408                     ymm3 = _mm256_loadu_pd(tA);
2409                     //                   ymm4 += ymm0 * ymm3;
2410                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2411                     //                    ymm8 += ymm1 * ymm3;
2412                     ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
2413                     //                    ymm12 += ymm2 * ymm3;
2414                     ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
2415 
2416                     ymm3 = _mm256_loadu_pd(tA + 4);
2417                     //                    ymm5 += ymm0 * ymm3;
2418                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2419                     //                    ymm9 += ymm1 * ymm3;
2420                     ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9);
2421                     //                    ymm13 += ymm2 * ymm3;
2422                     ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13);
2423 
2424                     ymm3 = _mm256_loadu_pd(tA + 8);
2425                     //                   ymm6 += ymm0 * ymm3;
2426                     ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);
2427                     //                    ymm10 += ymm1 * ymm3;
2428                     ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10);
2429                     //                    ymm14 += ymm2 * ymm3;
2430                     ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14);
2431 
2432                     tA += lda;
2433                 }
2434                 // alpha, beta multiplication.
2435                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2436                 ymm1 = _mm256_broadcast_sd(beta_cast);
2437 
2438                 //multiply A*B by alpha.
2439                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2440                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2441                 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2442                 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2443                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2444                 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2445                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2446                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2447                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2448 
2449                 if(is_beta_non_zero)
2450                 {
2451                     // multiply C by beta and accumulate.
2452                     ymm2 = _mm256_loadu_pd(tC);
2453                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2454                     ymm2 = _mm256_loadu_pd(tC + 4);
2455                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2456                     ymm2 = _mm256_loadu_pd(tC + 8);
2457                     ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2458 
2459                     // multiply C by beta and accumulate.
2460                     double *ttC = tC +ldc;
2461                     ymm2 = _mm256_loadu_pd(ttC);
2462                     ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2463                     ymm2 = _mm256_loadu_pd(ttC + 4);
2464                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2465                     ymm2 = _mm256_loadu_pd(ttC + 8);
2466                     ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2467 
2468                     // multiply C by beta and accumulate.
2469                     ttC += ldc;
2470                     ymm2 = _mm256_loadu_pd(ttC);
2471                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2472                     ymm2 = _mm256_loadu_pd(ttC + 4);
2473                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2474                     ymm2 = _mm256_loadu_pd(ttC + 8);
2475                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2476 
2477                 }
2478                 _mm256_storeu_pd(tC, ymm4);
2479                 _mm256_storeu_pd(tC + 4, ymm5);
2480                 _mm256_storeu_pd(tC + 8, ymm6);
2481 
2482                 tC += ldc;
2483 
2484                 _mm256_storeu_pd(tC, ymm8);
2485                 _mm256_storeu_pd(tC + 4, ymm9);
2486                 _mm256_storeu_pd(tC + 8, ymm10);
2487 
2488                 tC += ldc;
2489 
2490                 _mm256_storeu_pd(tC, ymm12);
2491                 _mm256_storeu_pd(tC + 4, ymm13);
2492                 _mm256_storeu_pd(tC + 8, ymm14);
2493             }
2494             n_remainder = N - col_idx;
2495             // if the N is not multiple of 3.
2496             // handling edge case.
2497             if (n_remainder == 2)
2498             {
2499                 //pointer math to point to proper memory
2500                 tC = C + ldc * col_idx + row_idx;
2501                 tB = B + tb_inc_col * col_idx;
2502                 tA = A + row_idx;
2503 
2504                 // clear scratch registers.
2505                 ymm8 = _mm256_setzero_pd();
2506                 ymm9 = _mm256_setzero_pd();
2507                 ymm10 = _mm256_setzero_pd();
2508                 ymm12 = _mm256_setzero_pd();
2509                 ymm13 = _mm256_setzero_pd();
2510                 ymm14 = _mm256_setzero_pd();
2511 
2512                 for (k = 0; k < K; ++k)
2513                 {
2514                     // The inner loop broadcasts the B matrix data and
2515                     // multiplies it with the A matrix.
2516                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2517                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2518                     tB += tb_inc_row;
2519 
2520                     //broadcasted matrix B elements are multiplied
2521                     //with matrix A columns.
2522                     ymm3 = _mm256_loadu_pd(tA);
2523                     ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8);
2524                     ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12);
2525 
2526                     ymm3 = _mm256_loadu_pd(tA + 4);
2527                     ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9);
2528                     ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13);
2529 
2530                     ymm3 = _mm256_loadu_pd(tA + 8);
2531                     ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
2532                     ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
2533 
2534                     tA += lda;
2535 
2536                 }
2537                 // alpha, beta multiplication.
2538                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2539                 ymm1 = _mm256_broadcast_sd(beta_cast);
2540 
2541                 //multiply A*B by alpha.
2542                 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2543                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2544                 ymm10 = _mm256_mul_pd(ymm10, ymm0);
2545                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2546                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2547                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2548 
2549 
2550                 if(is_beta_non_zero)
2551                 {
2552                     // multiply C by beta and accumulate.
2553                     ymm2 = _mm256_loadu_pd(tC + 0);
2554                     ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2555                     ymm2 = _mm256_loadu_pd(tC + 4);
2556                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2557                     ymm2 = _mm256_loadu_pd(tC + 8);
2558                     ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);
2559 
2560                     double *ttC = tC + ldc;
2561 
2562                     // multiply C by beta and accumulate.
2563                     ymm2 = _mm256_loadu_pd(ttC);
2564                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2565                     ymm2 = _mm256_loadu_pd(ttC + 4);
2566                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2567                     ymm2 = _mm256_loadu_pd(ttC + 8);
2568                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2569 
2570                 }
2571                 _mm256_storeu_pd(tC + 0, ymm8);
2572                 _mm256_storeu_pd(tC + 4, ymm9);
2573                 _mm256_storeu_pd(tC + 8, ymm10);
2574 
2575                 tC += ldc;
2576 
2577                 _mm256_storeu_pd(tC, ymm12);
2578                 _mm256_storeu_pd(tC + 4, ymm13);
2579                 _mm256_storeu_pd(tC + 8, ymm14);
2580 
2581                 col_idx += 2;
2582             }
2583             // if the N is not multiple of 3.
2584             // handling edge case.
2585             if (n_remainder == 1)
2586             {
2587                 //pointer math to point to proper memory
2588                 tC = C + ldc * col_idx + row_idx;
2589                 tB = B + tb_inc_col * col_idx;
2590                 tA = A + row_idx;
2591 
2592                 // clear scratch registers.
2593                 ymm12 = _mm256_setzero_pd();
2594                 ymm13 = _mm256_setzero_pd();
2595                 ymm14 = _mm256_setzero_pd();
2596 
2597                 for (k = 0; k < K; ++k)
2598                 {
2599                     // The inner loop broadcasts the B matrix data and
2600                     // multiplies it with the A matrix.
2601                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2602                     tB += tb_inc_row;
2603 
2604                     //broadcasted matrix B elements are multiplied
2605                     //with matrix A columns.
2606                     ymm3 = _mm256_loadu_pd(tA);
2607                     ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12);
2608 
2609                     ymm3 = _mm256_loadu_pd(tA + 4);
2610                     ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
2611 
2612                     ymm3 = _mm256_loadu_pd(tA + 8);
2613                     ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);
2614 
2615                     tA += lda;
2616 
2617                 }
2618                 // alpha, beta multiplication.
2619                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2620                 ymm1 = _mm256_broadcast_sd(beta_cast);
2621 
2622                 //multiply A*B by alpha.
2623                 ymm12 = _mm256_mul_pd(ymm12, ymm0);
2624                 ymm13 = _mm256_mul_pd(ymm13, ymm0);
2625                 ymm14 = _mm256_mul_pd(ymm14, ymm0);
2626 
2627 
2628                 if(is_beta_non_zero)
2629                 {
2630                     // multiply C by beta and accumulate.
2631                     ymm2 = _mm256_loadu_pd(tC + 0);
2632                     ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);
2633                     ymm2 = _mm256_loadu_pd(tC + 4);
2634                     ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);
2635                     ymm2 = _mm256_loadu_pd(tC + 8);
2636                     ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);
2637 
2638                 }
2639                 _mm256_storeu_pd(tC + 0, ymm12);
2640                 _mm256_storeu_pd(tC + 4, ymm13);
2641                 _mm256_storeu_pd(tC + 8, ymm14);
2642             }
2643 
2644             row_idx += 12;
2645         }
2646 
2647         if (m_remainder >= 8)
2648         {
2649             m_remainder -= 8;
2650 
2651             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2652             {
2653                 //pointer math to point to proper memory
2654                 tC = C + ldc * col_idx + row_idx;
2655                 tB = B + tb_inc_col * col_idx;
2656                 tA = A + row_idx;
2657 
2658                 // clear scratch registers.
2659                 ymm4 = _mm256_setzero_pd();
2660                 ymm5 = _mm256_setzero_pd();
2661                 ymm6 = _mm256_setzero_pd();
2662                 ymm7 = _mm256_setzero_pd();
2663                 ymm8 = _mm256_setzero_pd();
2664                 ymm9 = _mm256_setzero_pd();
2665 
2666                 for (k = 0; k < K; ++k)
2667                 {
2668                     // The inner loop broadcasts the B matrix data and
2669                     // multiplies it with the A matrix.
2670                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2671                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2672                     ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2673                     tB += tb_inc_row;
2674 
2675                     //broadcasted matrix B elements are multiplied
2676                     //with matrix A columns.
2677                     ymm3 = _mm256_loadu_pd(tA);
2678                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2679                     ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
2680                     ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8);
2681 
2682                     ymm3 = _mm256_loadu_pd(tA + 4);
2683                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2684                     ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2685                     ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
2686 
2687                     tA += lda;
2688                 }
2689                 // alpha, beta multiplication.
2690                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2691                 ymm1 = _mm256_broadcast_sd(beta_cast);
2692 
2693                 //multiply A*B by alpha.
2694                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2695                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2696                 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2697                 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2698                 ymm8 = _mm256_mul_pd(ymm8, ymm0);
2699                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
2700 
2701                 if(is_beta_non_zero)
2702                 {
2703                     // multiply C by beta and accumulate.
2704                     ymm2 = _mm256_loadu_pd(tC);
2705                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2706                     ymm2 = _mm256_loadu_pd(tC + 4);
2707                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2708 
2709                     double* ttC = tC + ldc;
2710 
2711                     // multiply C by beta and accumulate.
2712                     ymm2 = _mm256_loadu_pd(ttC);
2713                     ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2714                     ymm2 = _mm256_loadu_pd(ttC + 4);
2715                     ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2716 
2717                     ttC += ldc;
2718 
2719                     // multiply C by beta and accumulate.
2720                     ymm2 = _mm256_loadu_pd(ttC);
2721                     ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);
2722                     ymm2 = _mm256_loadu_pd(ttC + 4);
2723                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
2724                 }
2725 
2726                 _mm256_storeu_pd(tC, ymm4);
2727                 _mm256_storeu_pd(tC + 4, ymm5);
2728 
2729                 tC += ldc;
2730                 _mm256_storeu_pd(tC, ymm6);
2731                 _mm256_storeu_pd(tC + 4, ymm7);
2732 
2733                 tC += ldc;
2734                 _mm256_storeu_pd(tC, ymm8);
2735                 _mm256_storeu_pd(tC + 4, ymm9);
2736 
2737             }
2738             n_remainder = N - col_idx;
2739             // if the N is not multiple of 3.
2740             // handling edge case.
2741             if (n_remainder == 2)
2742             {
2743                 //pointer math to point to proper memory
2744                 tC = C + ldc * col_idx + row_idx;
2745                 tB = B + tb_inc_col * col_idx;
2746                 tA = A + row_idx;
2747 
2748                 // clear scratch registers.
2749                 ymm4 = _mm256_setzero_pd();
2750                 ymm5 = _mm256_setzero_pd();
2751                 ymm6 = _mm256_setzero_pd();
2752                 ymm7 = _mm256_setzero_pd();
2753 
2754                 for (k = 0; k < K; ++k)
2755                 {
2756                     // The inner loop broadcasts the B matrix data and
2757                     // multiplies it with the A matrix.
2758                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2759                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2760                     tB += tb_inc_row;
2761 
2762                     //broadcasted matrix B elements are multiplied
2763                     //with matrix A columns.
2764                     ymm3 = _mm256_loadu_pd(tA);
2765                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2766                     ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6);
2767 
2768                     ymm3 = _mm256_loadu_pd(tA + 4);
2769                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2770                     ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
2771 
2772                     tA += lda;
2773                 }
2774                 // alpha, beta multiplication.
2775                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2776                 ymm1 = _mm256_broadcast_sd(beta_cast);
2777 
2778                 //multiply A*B by alpha.
2779                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2780                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2781                 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2782                 ymm7 = _mm256_mul_pd(ymm7, ymm0);
2783 
2784                 if(is_beta_non_zero)
2785                 {
2786                 // multiply C by beta and accumulate.
2787                 ymm2 = _mm256_loadu_pd(tC);
2788                 ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2789                 ymm2 = _mm256_loadu_pd(tC + 4);
2790                 ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2791 
2792                 double* ttC = tC + ldc;
2793 
2794                 // multiply C by beta and accumulate.
2795                 ymm2 = _mm256_loadu_pd(ttC);
2796                 ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2797                 ymm2 = _mm256_loadu_pd(ttC + 4);
2798                 ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
2799                 }
2800                 _mm256_storeu_pd(tC, ymm4);
2801                 _mm256_storeu_pd(tC + 4, ymm5);
2802 
2803                 tC += ldc;
2804                 _mm256_storeu_pd(tC, ymm6);
2805                 _mm256_storeu_pd(tC + 4, ymm7);
2806 
2807                 col_idx += 2;
2808 
2809             }
2810             // if the N is not multiple of 3.
2811             // handling edge case.
2812             if (n_remainder == 1)
2813             {
2814                 //pointer math to point to proper memory
2815                 tC = C + ldc * col_idx + row_idx;
2816                 tB = B + tb_inc_col * col_idx;
2817                 tA = A + row_idx;
2818 
2819                 ymm4 = _mm256_setzero_pd();
2820                 ymm5 = _mm256_setzero_pd();
2821 
2822                 for (k = 0; k < K; ++k)
2823                 {
2824                     // The inner loop broadcasts the B matrix data and
2825                     // multiplies it with the A matrix.
2826                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2827                     tB += tb_inc_row;
2828 
2829                     //broadcasted matrix B elements are multiplied
2830                     //with matrix A columns.
2831                     ymm3 = _mm256_loadu_pd(tA);
2832                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2833 
2834                     ymm3 = _mm256_loadu_pd(tA + 4);
2835                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
2836 
2837                     tA += lda;
2838                 }
2839                 // alpha, beta multiplication.
2840                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2841                 ymm1 = _mm256_broadcast_sd(beta_cast);
2842 
2843                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2844                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2845 
2846                 if(is_beta_non_zero)
2847                 {
2848                     // multiply C by beta and accumulate.
2849                     ymm2 = _mm256_loadu_pd(tC);
2850                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2851                     ymm2 = _mm256_loadu_pd(tC + 4);
2852                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2853                 }
2854                 _mm256_storeu_pd(tC, ymm4);
2855                 _mm256_storeu_pd(tC + 4, ymm5);
2856 
2857             }
2858 
2859             row_idx += 8;
2860         }
2861 
2862         if (m_remainder >= 4)
2863         {
2864             //printf("HERE\n");
2865             m_remainder -= 4;
2866 
2867             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
2868             {
2869                 //pointer math to point to proper memory
2870                 tC = C + ldc * col_idx + row_idx;
2871                 tB = B + tb_inc_col * col_idx;
2872                 tA = A + row_idx;
2873 
2874                 // clear scratch registers.
2875                 ymm4 = _mm256_setzero_pd();
2876                 ymm5 = _mm256_setzero_pd();
2877                 ymm6 = _mm256_setzero_pd();
2878 
2879                 for (k = 0; k < K; ++k)
2880                 {
2881                     // The inner loop broadcasts the B matrix data and
2882                     // multiplies it with the A matrix.
2883                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2884                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2885                     ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
2886                     tB += tb_inc_row;
2887 
2888                     //broadcasted matrix B elements are multiplied
2889                     //with matrix A columns.
2890                     ymm3 = _mm256_loadu_pd(tA);
2891                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2892                     ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
2893                     ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
2894 
2895                     tA += lda;
2896                 }
2897                 // alpha, beta multiplication.
2898                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2899                 ymm1 = _mm256_broadcast_sd(beta_cast);
2900 
2901                 //multiply A*B by alpha.
2902                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2903                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2904                 ymm6 = _mm256_mul_pd(ymm6, ymm0);
2905 
2906                 if(is_beta_non_zero)
2907                 {
2908                     // multiply C by beta and accumulate.
2909                     ymm2 = _mm256_loadu_pd(tC);
2910                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2911 
2912                     double* ttC = tC + ldc;
2913 
2914                     // multiply C by beta and accumulate.
2915                     ymm2 = _mm256_loadu_pd(ttC);
2916                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2917 
2918                     ttC += ldc;
2919 
2920                     // multiply C by beta and accumulate.
2921                     ymm2 = _mm256_loadu_pd(ttC);
2922                     ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);
2923                 }
2924                 _mm256_storeu_pd(tC, ymm4);
2925 
2926                 tC += ldc;
2927                 _mm256_storeu_pd(tC, ymm5);
2928 
2929                 tC += ldc;
2930                 _mm256_storeu_pd(tC, ymm6);
2931             }
2932             n_remainder = N - col_idx;
2933             // if the N is not multiple of 3.
2934             // handling edge case.
2935             if (n_remainder == 2)
2936             {
2937                 //pointer math to point to proper memory
2938                 tC = C + ldc * col_idx + row_idx;
2939                 tB = B + tb_inc_col * col_idx;
2940                 tA = A + row_idx;
2941 
2942                 ymm4 = _mm256_setzero_pd();
2943                 ymm5 = _mm256_setzero_pd();
2944 
2945                 for (k = 0; k < K; ++k)
2946                 {
2947                     // The inner loop broadcasts the B matrix data and
2948                     // multiplies it with the A matrix.
2949                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
2950                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
2951                     tB += tb_inc_row;
2952 
2953                     //broadcasted matrix B elements are multiplied
2954                     //with matrix A columns.
2955                     ymm3 = _mm256_loadu_pd(tA);
2956                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
2957                     ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
2958 
2959                     tA += lda;
2960                 }
2961                 // alpha, beta multiplication.
2962                 ymm0 = _mm256_broadcast_sd(alpha_cast);
2963                 ymm1 = _mm256_broadcast_sd(beta_cast);
2964 
2965                 //multiply A*B by alpha.
2966                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
2967                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
2968 
2969                 if(is_beta_non_zero)
2970                 {
2971                     // multiply C by beta and accumulate.
2972                     ymm2 = _mm256_loadu_pd(tC);
2973                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
2974 
2975                     double* ttC = tC + ldc;
2976 
2977                     // multiply C by beta and accumulate.
2978                     ymm2 = _mm256_loadu_pd(ttC);
2979                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
2980                 }
2981                 _mm256_storeu_pd(tC, ymm4);
2982 
2983                 tC += ldc;
2984                 _mm256_storeu_pd(tC, ymm5);
2985 
2986                 col_idx += 2;
2987 
2988             }
2989             // if the N is not multiple of 3.
2990             // handling edge case.
2991             if (n_remainder == 1)
2992             {
2993                 //pointer math to point to proper memory
2994                 tC = C + ldc * col_idx + row_idx;
2995                 tB = B + tb_inc_col * col_idx;
2996                 tA = A + row_idx;
2997 
2998                 ymm4 = _mm256_setzero_pd();
2999 
3000                 for (k = 0; k < K; ++k)
3001                 {
3002                     // The inner loop broadcasts the B matrix data and
3003                     // multiplies it with the A matrix.
3004                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3005                     tB += tb_inc_row;
3006 
3007                     //broadcasted matrix B elements are multiplied
3008                     //with matrix A columns.
3009                     ymm3 = _mm256_loadu_pd(tA);
3010                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3011 
3012                     tA += lda;
3013                 }
3014                 // alpha, beta multiplication.
3015                 ymm0 = _mm256_broadcast_sd(alpha_cast);
3016                 ymm1 = _mm256_broadcast_sd(beta_cast);
3017 
3018                 ymm4 = _mm256_mul_pd(ymm4, ymm0);
3019 
3020                 if(is_beta_non_zero)
3021                 {
3022                     // multiply C by beta and accumulate.
3023                     ymm2 = _mm256_loadu_pd(tC);
3024                     ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);
3025 
3026                 }
3027                 _mm256_storeu_pd(tC, ymm4);
3028 
3029             }
3030 
3031             row_idx += 4;
3032         }
3033         // M is not a multiple of 32.
3034         // The handling of edge case where the remainder
3035         // dimension is less than 8. The padding takes place
3036         // to handle this case.
3037         if ((m_remainder) && (lda > 3))
3038         {
3039             double f_temp[8] = {0.0};
3040 
3041             for (col_idx = 0; (col_idx + 2) < N; col_idx += 3)
3042             {
3043                 //pointer math to point to proper memory
3044                 tC = C + ldc * col_idx + row_idx;
3045                 tB = B + tb_inc_col * col_idx;
3046                 tA = A + row_idx;
3047 
3048                 // clear scratch registers.
3049                 ymm5 = _mm256_setzero_pd();
3050                 ymm7 = _mm256_setzero_pd();
3051                 ymm9 = _mm256_setzero_pd();
3052 
3053                 for (k = 0; k < (K - 1); ++k)
3054                 {
3055                     // The inner loop broadcasts the B matrix data and
3056                     // multiplies it with the A matrix.
3057                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3058                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3059                     ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
3060                     tB += tb_inc_row;
3061 
3062                     //broadcasted matrix B elements are multiplied
3063                     //with matrix A columns.
3064                     ymm3 = _mm256_loadu_pd(tA);
3065                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3066                     ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3067                     ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3068 
3069                     tA += lda;
3070                 }
3071                 // alpha, beta multiplication.
3072                 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3073                 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3074                 ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2);
3075                 tB += tb_inc_row;
3076 
3077                 for (int i = 0; i < m_remainder; i++)
3078                 {
3079                     f_temp[i] = tA[i];
3080                 }
3081                 ymm3 = _mm256_loadu_pd(f_temp);
3082                 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3083                 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3084                 ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3085 
3086                 ymm0 = _mm256_broadcast_sd(alpha_cast);
3087                 ymm1 = _mm256_broadcast_sd(beta_cast);
3088 
3089                 //multiply A*B by alpha.
3090                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3091                 ymm7 = _mm256_mul_pd(ymm7, ymm0);
3092                 ymm9 = _mm256_mul_pd(ymm9, ymm0);
3093 
3094                 if(is_beta_non_zero)
3095                 {
3096                     for (int i = 0; i < m_remainder; i++)
3097                     {
3098                         f_temp[i] = tC[i];
3099                     }
3100                     ymm2 = _mm256_loadu_pd(f_temp);
3101                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
3102 
3103 
3104                     double* ttC = tC + ldc;
3105 
3106                     for (int i = 0; i < m_remainder; i++)
3107                     {
3108                         f_temp[i] = ttC[i];
3109                     }
3110                     ymm2 = _mm256_loadu_pd(f_temp);
3111                     ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
3112 
3113                     ttC += ldc;
3114                     for (int i = 0; i < m_remainder; i++)
3115                     {
3116                         f_temp[i] = ttC[i];
3117                     }
3118                     ymm2 = _mm256_loadu_pd(f_temp);
3119                     ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);
3120                 }
3121                     _mm256_storeu_pd(f_temp, ymm5);
3122                     for (int i = 0; i < m_remainder; i++)
3123                     {
3124                         tC[i] = f_temp[i];
3125                     }
3126 
3127                     tC += ldc;
3128                     _mm256_storeu_pd(f_temp, ymm7);
3129                     for (int i = 0; i < m_remainder; i++)
3130                     {
3131                         tC[i] = f_temp[i];
3132                     }
3133 
3134                     tC += ldc;
3135                     _mm256_storeu_pd(f_temp, ymm9);
3136                     for (int i = 0; i < m_remainder; i++)
3137                     {
3138                         tC[i] = f_temp[i];
3139                     }
3140             }
3141             n_remainder = N - col_idx;
3142             // if the N is not multiple of 3.
3143             // handling edge case.
3144             if (n_remainder == 2)
3145             {
3146                 //pointer math to point to proper memory
3147                 tC = C + ldc * col_idx + row_idx;
3148                 tB = B + tb_inc_col * col_idx;
3149                 tA = A + row_idx;
3150 
3151                 ymm5 = _mm256_setzero_pd();
3152                 ymm7 = _mm256_setzero_pd();
3153 
3154                 for (k = 0; k < (K - 1); ++k)
3155                 {
3156                     // The inner loop broadcasts the B matrix data and
3157                     // multiplies it with the A matrix.
3158                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3159                     ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3160                     tB += tb_inc_row;
3161 
3162                     ymm3 = _mm256_loadu_pd(tA);
3163                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3164                     ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3165 
3166                     tA += lda;
3167                 }
3168 
3169                 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3170                 ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1);
3171                 tB += tb_inc_row;
3172 
3173                 for (int i = 0; i < m_remainder; i++)
3174                 {
3175                     f_temp[i] = tA[i];
3176                 }
3177                 ymm3 = _mm256_loadu_pd(f_temp);
3178                 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3179                 ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);
3180 
3181                 ymm0 = _mm256_broadcast_sd(alpha_cast);
3182                 ymm1 = _mm256_broadcast_sd(beta_cast);
3183 
3184                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3185                 ymm7 = _mm256_mul_pd(ymm7, ymm0);
3186 
3187                 if(is_beta_non_zero)
3188                 {
3189                     for (int i = 0; i < m_remainder; i++)
3190                     {
3191                         f_temp[i] = tC[i];
3192                     }
3193                     ymm2 = _mm256_loadu_pd(f_temp);
3194                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
3195 
3196                     double* ttC = tC + ldc;
3197 
3198                     for (int i = 0; i < m_remainder; i++)
3199                     {
3200                         f_temp[i] = ttC[i];
3201                     }
3202                     ymm2 = _mm256_loadu_pd(f_temp);
3203                     ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);
3204 
3205                 }
3206                 _mm256_storeu_pd(f_temp, ymm5);
3207                 for (int i = 0; i < m_remainder; i++)
3208                 {
3209                     tC[i] = f_temp[i];
3210                 }
3211 
3212                 tC += ldc;
3213                 _mm256_storeu_pd(f_temp, ymm7);
3214                 for (int i = 0; i < m_remainder; i++)
3215                 {
3216                     tC[i] = f_temp[i];
3217                 }
3218             }
3219             // if the N is not multiple of 3.
3220             // handling edge case.
3221             if (n_remainder == 1)
3222             {
3223                 //pointer math to point to proper memory
3224                 tC = C + ldc * col_idx + row_idx;
3225                 tB = B + tb_inc_col * col_idx;
3226                 tA = A + row_idx;
3227 
3228                 ymm5 = _mm256_setzero_pd();
3229 
3230                 for (k = 0; k < (K - 1); ++k)
3231                 {
3232                     // The inner loop broadcasts the B matrix data and
3233                     // multiplies it with the A matrix.
3234                     ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3235                     tB += tb_inc_row;
3236 
3237                     ymm3 = _mm256_loadu_pd(tA);
3238                     ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3239 
3240                     tA += lda;
3241                 }
3242 
3243                 ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0);
3244                 tB += tb_inc_row;
3245 
3246                 for (int i = 0; i < m_remainder; i++)
3247                 {
3248                     f_temp[i] = tA[i];
3249                 }
3250                 ymm3 = _mm256_loadu_pd(f_temp);
3251                 ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5);
3252 
3253                 ymm0 = _mm256_broadcast_sd(alpha_cast);
3254                 ymm1 = _mm256_broadcast_sd(beta_cast);
3255 
3256                 // multiply C by beta and accumulate.
3257                 ymm5 = _mm256_mul_pd(ymm5, ymm0);
3258 
3259                 if(is_beta_non_zero)
3260                 {
3261 
3262                     for (int i = 0; i < m_remainder; i++)
3263                     {
3264                         f_temp[i] = tC[i];
3265                     }
3266                     ymm2 = _mm256_loadu_pd(f_temp);
3267                     ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);
3268                 }
3269                 _mm256_storeu_pd(f_temp, ymm5);
3270                 for (int i = 0; i < m_remainder; i++)
3271                 {
3272                     tC[i] = f_temp[i];
3273                 }
3274             }
3275             m_remainder = 0;
3276         }
3277 
3278         if (m_remainder)
3279         {
3280             double result;
3281             for (; row_idx < M; row_idx += 1)
3282             {
3283                 for (col_idx = 0; col_idx < N; col_idx += 1)
3284                 {
3285                     //pointer math to point to proper memory
3286                     tC = C + ldc * col_idx + row_idx;
3287                     tB = B + tb_inc_col * col_idx;
3288                     tA = A + row_idx;
3289 
3290                     result = 0;
3291                     for (k = 0; k < K; ++k)
3292                     {
3293                         result += (*tA) * (*tB);
3294                         tA += lda;
3295                         tB += tb_inc_row;
3296                     }
3297 
3298                     result *= (*alpha_cast);
3299                     if(is_beta_non_zero)
3300                         (*tC) = (*tC) * (*beta_cast) + result;
3301                     else
3302                     (*tC) = result;
3303                 }
3304             }
3305         }
3306 
3307     // Return the buffer to pool
3308         if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) {
3309 #ifdef BLIS_ENABLE_MEM_TRACING
3310         printf( "bli_dgemm_small(): releasing mem pool block\n" );
3311 #endif
3312         bli_membrk_release(&rntm,
3313                            &local_mem_buf_A_s);
3314         }
3315 		AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
3316         return BLIS_SUCCESS;
3317     }
3318     else
3319 	{
3320 		AOCL_DTL_TRACE_EXIT_ERR(
3321 			AOCL_DTL_LEVEL_INFO,
3322 			"Invalid dimesions for small gemm."
3323 			);
3324         return BLIS_NONCONFORMAL_DIMENSIONS;
3325 	}
3326 };
3327 
bli_sgemm_small_atbn(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)3328 static err_t bli_sgemm_small_atbn
3329      (
3330        obj_t*  alpha,
3331        obj_t*  a,
3332        obj_t*  b,
3333        obj_t*  beta,
3334        obj_t*  c,
3335        cntx_t* cntx,
3336        cntl_t* cntl
3337      )
3338 {
3339 	AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
3340 
3341 	gint_t M = bli_obj_length( c ); // number of rows of Matrix C
3342     gint_t N = bli_obj_width( c );  // number of columns of Matrix C
3343     gint_t K = bli_obj_length( b ); // number of rows of Matrix B
3344 
3345     guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
3346     guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
3347     guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
3348 
3349     int row_idx = 0, col_idx = 0, k;
3350 
3351     float *A = bli_obj_buffer_at_off(a); // pointer to matrix A elements, stored in row major format
3352     float *B = bli_obj_buffer_at_off(b); // pointer to matrix B elements, stored in column major format
3353     float *C = bli_obj_buffer_at_off(c); // pointer to matrix C elements, stored in column major format
3354 
3355     float *tA = A, *tB = B, *tC = C;
3356 
3357     __m256 ymm4, ymm5, ymm6, ymm7;
3358     __m256 ymm8, ymm9, ymm10, ymm11;
3359     __m256 ymm12, ymm13, ymm14, ymm15;
3360     __m256 ymm0, ymm1, ymm2, ymm3;
3361 
3362     float result;
3363     float scratch[8] = {0.0};
3364     const num_t    dt_exec   = bli_obj_dt( c );
3365     float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha );
3366     float* restrict beta_cast  = bli_obj_buffer_for_1x1( dt_exec, beta );
3367 
3368     /*Beta Zero Check*/
3369     bool is_beta_non_zero=0;
3370     if ( !bli_obj_equals( beta, &BLIS_ZERO ) ){
3371         is_beta_non_zero = 1;
3372     }
3373 
3374     // The non-copy version of the A^T GEMM gives better performance for the small M cases.
3375     // The threshold is controlled by BLIS_ATBN_M_THRES
3376     if (M <= BLIS_ATBN_M_THRES)
3377     {
3378         for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
3379         {
3380             for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3381             {
3382                 tA = A + row_idx * lda;
3383                 tB = B + col_idx * ldb;
3384                 tC = C + col_idx * ldc + row_idx;
3385                 // clear scratch registers.
3386                 ymm4 = _mm256_setzero_ps();
3387                 ymm5 = _mm256_setzero_ps();
3388                 ymm6 = _mm256_setzero_ps();
3389                 ymm7 = _mm256_setzero_ps();
3390                 ymm8 = _mm256_setzero_ps();
3391                 ymm9 = _mm256_setzero_ps();
3392                 ymm10 = _mm256_setzero_ps();
3393                 ymm11 = _mm256_setzero_ps();
3394                 ymm12 = _mm256_setzero_ps();
3395                 ymm13 = _mm256_setzero_ps();
3396                 ymm14 = _mm256_setzero_ps();
3397                 ymm15 = _mm256_setzero_ps();
3398 
3399                 //The inner loop computes the 4x3 values of the matrix.
3400                 //The computation pattern is:
3401                 // ymm4  ymm5  ymm6
3402                 // ymm7  ymm8  ymm9
3403                 // ymm10 ymm11 ymm12
3404                 // ymm13 ymm14 ymm15
3405 
3406                 //The Dot operation is performed in the inner loop, 8 float elements fit
3407                 //in the YMM register hence loop count incremented by 8
3408                 for (k = 0; (k + 7) < K; k += 8)
3409                 {
3410                     ymm0 = _mm256_loadu_ps(tB + 0);
3411                     ymm1 = _mm256_loadu_ps(tB + ldb);
3412                     ymm2 = _mm256_loadu_ps(tB + 2 * ldb);
3413 
3414                     ymm3 = _mm256_loadu_ps(tA);
3415                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3416                     ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
3417                     ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
3418 
3419                     ymm3 = _mm256_loadu_ps(tA + lda);
3420                     ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3421                     ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
3422                     ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
3423 
3424                     ymm3 = _mm256_loadu_ps(tA + 2 * lda);
3425                     ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3426                     ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
3427                     ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
3428 
3429                     ymm3 = _mm256_loadu_ps(tA + 3 * lda);
3430                     ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3431                     ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
3432                     ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
3433 
3434                     tA += 8;
3435                     tB += 8;
3436 
3437                 }
3438 
3439                 // if K is not a multiple of 8, padding is done before load using temproary array.
3440                 if (k < K)
3441                 {
3442                     int iter;
3443                     float data_feeder[8] = { 0.0 };
3444 
3445                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3446                     ymm0 = _mm256_loadu_ps(data_feeder);
3447                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
3448                     ymm1 = _mm256_loadu_ps(data_feeder);
3449                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
3450                     ymm2 = _mm256_loadu_ps(data_feeder);
3451 
3452                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3453                     ymm3 = _mm256_loadu_ps(data_feeder);
3454                     ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3455                     ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5);
3456                     ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6);
3457 
3458                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3459                     ymm3 = _mm256_loadu_ps(data_feeder);
3460                     ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3461                     ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8);
3462                     ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9);
3463 
3464                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3465                     ymm3 = _mm256_loadu_ps(data_feeder);
3466                     ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3467                     ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11);
3468                     ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12);
3469 
3470                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3471                     ymm3 = _mm256_loadu_ps(data_feeder);
3472                     ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3473                     ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14);
3474                     ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15);
3475 
3476                 }
3477 
3478                 //horizontal addition and storage of the data.
3479                 //Results for 4x3 blocks of C is stored here
3480                 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3481                 ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3482                 _mm256_storeu_ps(scratch, ymm4);
3483                 result = scratch[0] + scratch[4];
3484                 result *= (*alpha_cast);
3485                 if(is_beta_non_zero){
3486                     tC[0] = result + tC[0] * (*beta_cast);
3487                 }else{
3488                     tC[0] = result;
3489                 }
3490 
3491                 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3492                 ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3493                 _mm256_storeu_ps(scratch, ymm7);
3494                 result = scratch[0] + scratch[4];
3495                 result *= (*alpha_cast);
3496                 if(is_beta_non_zero){
3497                     tC[1] = result + tC[1] * (*beta_cast);
3498                 }else{
3499                     tC[1] = result;
3500                 }
3501 
3502                 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3503                 ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3504                 _mm256_storeu_ps(scratch, ymm10);
3505                 result = scratch[0] + scratch[4];
3506                 result *= (*alpha_cast);
3507                 if(is_beta_non_zero){
3508                     tC[2] = result + tC[2] * (*beta_cast);
3509                 }else{
3510                     tC[2] = result;
3511                 }
3512 
3513                 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3514                 ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3515                 _mm256_storeu_ps(scratch, ymm13);
3516                 result = scratch[0] + scratch[4];
3517                 result *= (*alpha_cast);
3518                 if(is_beta_non_zero){
3519                     tC[3] = result + tC[3] * (*beta_cast);
3520                 }else{
3521                     tC[3] = result;
3522                 }
3523 
3524                 tC += ldc;
3525                 ymm5 = _mm256_hadd_ps(ymm5, ymm5);
3526                 ymm5 = _mm256_hadd_ps(ymm5, ymm5);
3527                 _mm256_storeu_ps(scratch, ymm5);
3528                 result = scratch[0] + scratch[4];
3529                 result *= (*alpha_cast);
3530                 if(is_beta_non_zero){
3531                     tC[0] = result + tC[0] * (*beta_cast);
3532                 }else{
3533                     tC[0] = result;
3534                 }
3535 
3536                 ymm8 = _mm256_hadd_ps(ymm8, ymm8);
3537                 ymm8 = _mm256_hadd_ps(ymm8, ymm8);
3538                 _mm256_storeu_ps(scratch, ymm8);
3539                 result = scratch[0] + scratch[4];
3540                 result *= (*alpha_cast);
3541                 if(is_beta_non_zero){
3542                     tC[1] = result + tC[1] * (*beta_cast);
3543                 }else{
3544                     tC[1] = result;
3545                 }
3546 
3547                 ymm11 = _mm256_hadd_ps(ymm11, ymm11);
3548                 ymm11 = _mm256_hadd_ps(ymm11, ymm11);
3549                 _mm256_storeu_ps(scratch, ymm11);
3550                 result = scratch[0] + scratch[4];
3551                 result *= (*alpha_cast);
3552                 if(is_beta_non_zero){
3553                     tC[2] = result + tC[2] * (*beta_cast);
3554                 }else{
3555                     tC[2] = result;
3556                 }
3557 
3558                 ymm14 = _mm256_hadd_ps(ymm14, ymm14);
3559                 ymm14 = _mm256_hadd_ps(ymm14, ymm14);
3560                 _mm256_storeu_ps(scratch, ymm14);
3561                 result = scratch[0] + scratch[4];
3562                 result *= (*alpha_cast);
3563                 if(is_beta_non_zero){
3564                     tC[3] = result + tC[3] * (*beta_cast);
3565                 }else{
3566                     tC[3] = result;
3567                 }
3568 
3569                 tC += ldc;
3570                 ymm6 = _mm256_hadd_ps(ymm6, ymm6);
3571                 ymm6 = _mm256_hadd_ps(ymm6, ymm6);
3572                 _mm256_storeu_ps(scratch, ymm6);
3573                 result = scratch[0] + scratch[4];
3574                 result *= (*alpha_cast);
3575                 if(is_beta_non_zero){
3576                     tC[0] = result + tC[0] * (*beta_cast);
3577                 }else{
3578                     tC[0] = result;
3579                 }
3580 
3581                 ymm9 = _mm256_hadd_ps(ymm9, ymm9);
3582                 ymm9 = _mm256_hadd_ps(ymm9, ymm9);
3583                 _mm256_storeu_ps(scratch, ymm9);
3584                 result = scratch[0] + scratch[4];
3585                 result *= (*alpha_cast);
3586                 if(is_beta_non_zero){
3587                     tC[1] = result + tC[1] * (*beta_cast);
3588                 }else{
3589                     tC[1] = result;
3590                 }
3591 
3592                 ymm12 = _mm256_hadd_ps(ymm12, ymm12);
3593                 ymm12 = _mm256_hadd_ps(ymm12, ymm12);
3594                 _mm256_storeu_ps(scratch, ymm12);
3595                 result = scratch[0] + scratch[4];
3596                 result *= (*alpha_cast);
3597                 if(is_beta_non_zero){
3598                     tC[2] = result + tC[2] * (*beta_cast);
3599                 }else{
3600                     tC[2] = result;
3601                 }
3602 
3603                 ymm15 = _mm256_hadd_ps(ymm15, ymm15);
3604                 ymm15 = _mm256_hadd_ps(ymm15, ymm15);
3605                 _mm256_storeu_ps(scratch, ymm15);
3606                 result = scratch[0] + scratch[4];
3607                 result *= (*alpha_cast);
3608                 if(is_beta_non_zero){
3609                     tC[3] = result + tC[3] * (*beta_cast);
3610                 }else{
3611                     tC[3] = result;
3612                 }
3613             }
3614         }
3615 
3616         int processed_col = col_idx;
3617         int processed_row = row_idx;
3618 
3619         //The edge case handling where N is not a multiple of 3
3620         if (processed_col < N)
3621         {
3622             for (col_idx = processed_col; col_idx < N; col_idx += 1)
3623             {
3624                 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3625                 {
3626                     tA = A + row_idx * lda;
3627                     tB = B + col_idx * ldb;
3628                     tC = C + col_idx * ldc + row_idx;
3629                     // clear scratch registers.
3630                     ymm4 = _mm256_setzero_ps();
3631                     ymm7 = _mm256_setzero_ps();
3632                     ymm10 = _mm256_setzero_ps();
3633                     ymm13 = _mm256_setzero_ps();
3634 
3635                     //The inner loop computes the 4x1 values of the matrix.
3636                     //The computation pattern is:
3637                     // ymm4
3638                     // ymm7
3639                     // ymm10
3640                     // ymm13
3641 
3642                     for (k = 0; (k + 7) < K; k += 8)
3643                     {
3644                         ymm0 = _mm256_loadu_ps(tB + 0);
3645 
3646                         ymm3 = _mm256_loadu_ps(tA);
3647                         ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3648 
3649                         ymm3 = _mm256_loadu_ps(tA + lda);
3650                         ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3651 
3652                         ymm3 = _mm256_loadu_ps(tA + 2 * lda);
3653                         ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3654 
3655                         ymm3 = _mm256_loadu_ps(tA + 3 * lda);
3656                         ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3657 
3658                         tA += 8;
3659                         tB += 8;
3660                     }
3661 
3662                     // if K is not a multiple of 8, padding is done before load using temproary array.
3663                     if (k < K)
3664                     {
3665                         int iter;
3666                         float data_feeder[8] = { 0.0 };
3667 
3668                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3669                         ymm0 = _mm256_loadu_ps(data_feeder);
3670 
3671                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3672                         ymm3 = _mm256_loadu_ps(data_feeder);
3673                         ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3674 
3675                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3676                         ymm3 = _mm256_loadu_ps(data_feeder);
3677                         ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7);
3678 
3679                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3680                         ymm3 = _mm256_loadu_ps(data_feeder);
3681                         ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10);
3682 
3683                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3684                         ymm3 = _mm256_loadu_ps(data_feeder);
3685                         ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13);
3686 
3687                     }
3688 
3689                     //horizontal addition and storage of the data.
3690                     //Results for 4x1 blocks of C is stored here
3691                     ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3692                     ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3693                     _mm256_storeu_ps(scratch, ymm4);
3694                     result = scratch[0] + scratch[4];
3695                     result *= (*alpha_cast);
3696                     if(is_beta_non_zero){
3697                         tC[0] = result + tC[0] * (*beta_cast);
3698                     }else{
3699                         tC[0] = result;
3700                     }
3701 
3702                     ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3703                     ymm7 = _mm256_hadd_ps(ymm7, ymm7);
3704                     _mm256_storeu_ps(scratch, ymm7);
3705                     result = scratch[0] + scratch[4];
3706                     result *= (*alpha_cast);
3707                     if(is_beta_non_zero){
3708                         tC[1] = result + tC[1] * (*beta_cast);
3709                     }else{
3710                         tC[1] = result;
3711                     }
3712 
3713                     ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3714                     ymm10 = _mm256_hadd_ps(ymm10, ymm10);
3715                     _mm256_storeu_ps(scratch, ymm10);
3716                     result = scratch[0] + scratch[4];
3717                     result *= (*alpha_cast);
3718                     if(is_beta_non_zero){
3719                         tC[2] = result + tC[2] * (*beta_cast);
3720                     }else{
3721                         tC[2] = result;
3722                     }
3723 
3724                     ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3725                     ymm13 = _mm256_hadd_ps(ymm13, ymm13);
3726                     _mm256_storeu_ps(scratch, ymm13);
3727                     result = scratch[0] + scratch[4];
3728                     result *= (*alpha_cast);
3729                     if(is_beta_non_zero){
3730                         tC[3] = result + tC[3] * (*beta_cast);
3731                     }else{
3732                         tC[3] = result;
3733                     }
3734                 }
3735             }
3736             processed_row = row_idx;
3737         }
3738 
3739         //The edge case handling where M is not a multiple of 4
3740         if (processed_row < M)
3741         {
3742             for (row_idx = processed_row; row_idx < M; row_idx += 1)
3743             {
3744                 for (col_idx = 0; col_idx < N; col_idx += 1)
3745                 {
3746                     tA = A + row_idx * lda;
3747                     tB = B + col_idx * ldb;
3748                     tC = C + col_idx * ldc + row_idx;
3749                     // clear scratch registers.
3750                     ymm4 = _mm256_setzero_ps();
3751 
3752                     for (k = 0; (k + 7) < K; k += 8)
3753                     {
3754                         ymm0 = _mm256_loadu_ps(tB + 0);
3755                         ymm3 = _mm256_loadu_ps(tA);
3756                         ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3757 
3758                         tA += 8;
3759                         tB += 8;
3760                     }
3761 
3762                     // if K is not a multiple of 8, padding is done before load using temproary array.
3763                     if (k < K)
3764                     {
3765                         int iter;
3766                         float data_feeder[8] = { 0.0 };
3767 
3768                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3769                         ymm0 = _mm256_loadu_ps(data_feeder);
3770 
3771                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3772                         ymm3 = _mm256_loadu_ps(data_feeder);
3773                         ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);
3774 
3775                     }
3776 
3777                     //horizontal addition and storage of the data.
3778                     ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3779                     ymm4 = _mm256_hadd_ps(ymm4, ymm4);
3780                     _mm256_storeu_ps(scratch, ymm4);
3781                     result = scratch[0] + scratch[4];
3782                     result *= (*alpha_cast);
3783                     if(is_beta_non_zero){
3784                         tC[0] = result + tC[0] * (*beta_cast);
3785                     }else{
3786                         tC[0] = result;
3787                     }
3788 
3789                 }
3790             }
3791         }
3792 		AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
3793         return BLIS_SUCCESS;
3794     }
3795     else
3796 	{
3797 		AOCL_DTL_TRACE_EXIT_ERR(
3798 			AOCL_DTL_LEVEL_INFO,
3799 			"Invalid dimesions for small gemm."
3800 			);
3801         return BLIS_NONCONFORMAL_DIMENSIONS;
3802 	}
3803 }
3804 
bli_dgemm_small_atbn(obj_t * alpha,obj_t * a,obj_t * b,obj_t * beta,obj_t * c,cntx_t * cntx,cntl_t * cntl)3805 static err_t bli_dgemm_small_atbn
3806      (
3807        obj_t*  alpha,
3808        obj_t*  a,
3809        obj_t*  b,
3810        obj_t*  beta,
3811        obj_t*  c,
3812        cntx_t* cntx,
3813        cntl_t* cntl
3814      )
3815 {
3816 	AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
3817 
3818     gint_t M = bli_obj_length( c ); // number of rows of Matrix C
3819     gint_t N = bli_obj_width( c );  // number of columns of Matrix C
3820     gint_t K = bli_obj_length( b ); // number of rows of Matrix B
3821 
3822     // The non-copy version of the A^T GEMM gives better performance for the small M cases.
3823     // The threshold is controlled by BLIS_ATBN_M_THRES
3824     if (M <= BLIS_ATBN_M_THRES)
3825     {
3826     guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
3827     guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.
3828     guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C
3829     guint_t row_idx = 0, col_idx = 0, k;
3830     double *A = bli_obj_buffer_at_off(a); // pointer to matrix A elements, stored in row major format
3831     double *B = bli_obj_buffer_at_off(b); // pointer to matrix B elements, stored in column major format
3832     double *C = bli_obj_buffer_at_off(c); // pointer to matrix C elements, stored in column major format
3833 
3834     double *tA = A, *tB = B, *tC = C;
3835 
3836     __m256d ymm4, ymm5, ymm6, ymm7;
3837     __m256d ymm8, ymm9, ymm10, ymm11;
3838     __m256d ymm12, ymm13, ymm14, ymm15;
3839     __m256d ymm0, ymm1, ymm2, ymm3;
3840 
3841     double result;
3842     double scratch[8] = {0.0};
3843     double *alpha_cast, *beta_cast; // alpha, beta multiples
3844     alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha);
3845     beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta);
3846 
3847     //check if beta is zero
3848     //if true, we need to perform C = alpha * (A * B)
3849     //instead of C = beta * C  + alpha * (A * B)
3850     bool is_beta_non_zero = 0;
3851     if(!bli_obj_equals(beta,&BLIS_ZERO))
3852         is_beta_non_zero = 1;
3853 
3854     for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR)
3855     {
3856         for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
3857         {
3858                 tA = A + row_idx * lda;
3859                 tB = B + col_idx * ldb;
3860                 tC = C + col_idx * ldc + row_idx;
3861                 // clear scratch registers.
3862                 ymm4 = _mm256_setzero_pd();
3863                 ymm5 = _mm256_setzero_pd();
3864                 ymm6 = _mm256_setzero_pd();
3865                 ymm7 = _mm256_setzero_pd();
3866                 ymm8 = _mm256_setzero_pd();
3867                 ymm9 = _mm256_setzero_pd();
3868                 ymm10 = _mm256_setzero_pd();
3869                 ymm11 = _mm256_setzero_pd();
3870                 ymm12 = _mm256_setzero_pd();
3871                 ymm13 = _mm256_setzero_pd();
3872                 ymm14 = _mm256_setzero_pd();
3873                 ymm15 = _mm256_setzero_pd();
3874 
3875                 //The inner loop computes the 4x3 values of the matrix.
3876                 //The computation pattern is:
3877                 // ymm4  ymm5  ymm6
3878                 // ymm7  ymm8  ymm9
3879                 // ymm10 ymm11 ymm12
3880                 // ymm13 ymm14 ymm15
3881 
3882                 //The Dot operation is performed in the inner loop, 4 double elements fit
3883                 //in the YMM register hence loop count incremented by 4
3884                 for (k = 0; (k + 3) < K; k += 4)
3885                 {
3886                     ymm0 = _mm256_loadu_pd(tB + 0);
3887                     ymm1 = _mm256_loadu_pd(tB + ldb);
3888                     ymm2 = _mm256_loadu_pd(tB + 2 * ldb);
3889 
3890                     ymm3 = _mm256_loadu_pd(tA);
3891                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3892                     ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
3893                     ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
3894 
3895                     ymm3 = _mm256_loadu_pd(tA + lda);
3896                     ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
3897                     ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
3898                     ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3899 
3900                     ymm3 = _mm256_loadu_pd(tA + 2 * lda);
3901                     ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
3902                     ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
3903                     ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
3904 
3905                     ymm3 = _mm256_loadu_pd(tA + 3 * lda);
3906                     ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
3907                     ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
3908                     ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
3909 
3910                     tA += 4;
3911                     tB += 4;
3912 
3913                 }
3914 
3915                 // if K is not a multiple of 4, padding is done before load using temproary array.
3916                 if (k < K)
3917                 {
3918                     int iter;
3919                     double data_feeder[4] = { 0.0 };
3920 
3921                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
3922                     ymm0 = _mm256_loadu_pd(data_feeder);
3923                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb];
3924                     ymm1 = _mm256_loadu_pd(data_feeder);
3925                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb];
3926                     ymm2 = _mm256_loadu_pd(data_feeder);
3927 
3928                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
3929                     ymm3 = _mm256_loadu_pd(data_feeder);
3930                     ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
3931                     ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);
3932                     ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6);
3933 
3934                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
3935                     ymm3 = _mm256_loadu_pd(data_feeder);
3936                     ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
3937                     ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8);
3938                     ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9);
3939 
3940                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
3941                     ymm3 = _mm256_loadu_pd(data_feeder);
3942                     ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
3943                     ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11);
3944                     ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12);
3945 
3946                     for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
3947                     ymm3 = _mm256_loadu_pd(data_feeder);
3948                     ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
3949                     ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14);
3950                     ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15);
3951 
3952                 }
3953 
3954                 //horizontal addition and storage of the data.
3955                 //Results for 4x3 blocks of C is stored here
3956                 ymm4 = _mm256_hadd_pd(ymm4, ymm4);
3957                 _mm256_storeu_pd(scratch, ymm4);
3958                 result = scratch[0] + scratch[2];
3959                 result *= (*alpha_cast);
3960                 if(is_beta_non_zero)
3961                     tC[0] = result + tC[0] * (*beta_cast);
3962                 else
3963                     tC[0] = result;
3964 
3965                 ymm7 = _mm256_hadd_pd(ymm7, ymm7);
3966                 _mm256_storeu_pd(scratch, ymm7);
3967                 result = scratch[0] + scratch[2];
3968                 result *= (*alpha_cast);
3969                 if(is_beta_non_zero)
3970                     tC[1] = result + tC[1] * (*beta_cast);
3971                 else
3972                     tC[1] = result;
3973 
3974                 ymm10 = _mm256_hadd_pd(ymm10, ymm10);
3975                 _mm256_storeu_pd(scratch, ymm10);
3976                 result = scratch[0] + scratch[2];
3977                 result *= (*alpha_cast);
3978                 if(is_beta_non_zero)
3979                     tC[2] = result + tC[2] * (*beta_cast);
3980                 else
3981                     tC[2] = result;
3982 
3983                 ymm13 = _mm256_hadd_pd(ymm13, ymm13);
3984                 _mm256_storeu_pd(scratch, ymm13);
3985                 result = scratch[0] + scratch[2];
3986                 result *= (*alpha_cast);
3987                 if(is_beta_non_zero)
3988                     tC[3] = result + tC[3] * (*beta_cast);
3989                 else
3990                     tC[3] = result;
3991 
3992                 tC += ldc;
3993                 ymm5 = _mm256_hadd_pd(ymm5, ymm5);
3994                 _mm256_storeu_pd(scratch, ymm5);
3995                 result = scratch[0] + scratch[2];
3996                 result *= (*alpha_cast);
3997                 if(is_beta_non_zero)
3998                     tC[0] = result + tC[0] * (*beta_cast);
3999                 else
4000                     tC[0] = result;
4001 
4002                 ymm8 = _mm256_hadd_pd(ymm8, ymm8);
4003                 _mm256_storeu_pd(scratch, ymm8);
4004                 result = scratch[0] + scratch[2];
4005                 result *= (*alpha_cast);
4006                 if(is_beta_non_zero)
4007                     tC[1] = result + tC[1] * (*beta_cast);
4008                 else
4009                     tC[1] = result;
4010 
4011                 ymm11 = _mm256_hadd_pd(ymm11, ymm11);
4012                 _mm256_storeu_pd(scratch, ymm11);
4013                 result = scratch[0] + scratch[2];
4014                 result *= (*alpha_cast);
4015                 if(is_beta_non_zero)
4016                     tC[2] = result + tC[2] * (*beta_cast);
4017                 else
4018                     tC[2] = result;
4019 
4020                 ymm14 = _mm256_hadd_pd(ymm14, ymm14);
4021                 _mm256_storeu_pd(scratch, ymm14);
4022                 result = scratch[0] + scratch[2];
4023                 result *= (*alpha_cast);
4024                 if(is_beta_non_zero)
4025                     tC[3] = result + tC[3] * (*beta_cast);
4026                 else
4027                     tC[3] = result;
4028 
4029                 tC += ldc;
4030                 ymm6 = _mm256_hadd_pd(ymm6, ymm6);
4031                 _mm256_storeu_pd(scratch, ymm6);
4032                 result = scratch[0] + scratch[2];
4033                 result *= (*alpha_cast);
4034                 if(is_beta_non_zero)
4035                     tC[0] = result + tC[0] * (*beta_cast);
4036                 else
4037                     tC[0] = result;
4038 
4039                 ymm9 = _mm256_hadd_pd(ymm9, ymm9);
4040                 _mm256_storeu_pd(scratch, ymm9);
4041                 result = scratch[0] + scratch[2];
4042                 result *= (*alpha_cast);
4043                 if(is_beta_non_zero)
4044                     tC[1] = result + tC[1] * (*beta_cast);
4045                 else
4046                     tC[1] = result;
4047 
4048                 ymm12 = _mm256_hadd_pd(ymm12, ymm12);
4049                 _mm256_storeu_pd(scratch, ymm12);
4050                 result = scratch[0] + scratch[2];
4051                 result *= (*alpha_cast);
4052                 if(is_beta_non_zero)
4053                     tC[2] = result + tC[2] * (*beta_cast);
4054                 else
4055                     tC[2] = result;
4056 
4057                 ymm15 = _mm256_hadd_pd(ymm15, ymm15);
4058                 _mm256_storeu_pd(scratch, ymm15);
4059                 result = scratch[0] + scratch[2];
4060                 result *= (*alpha_cast);
4061                 if(is_beta_non_zero)
4062                     tC[3] = result + tC[3] * (*beta_cast);
4063                 else
4064                     tC[3] = result;
4065             }
4066         }
4067 
4068         int processed_col = col_idx;
4069         int processed_row = row_idx;
4070 
4071         //The edge case handling where N is not a multiple of 3
4072         if (processed_col < N)
4073         {
4074             for (col_idx = processed_col; col_idx < N; col_idx += 1)
4075             {
4076                 for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR)
4077                 {
4078                     tA = A + row_idx * lda;
4079                     tB = B + col_idx * ldb;
4080                     tC = C + col_idx * ldc + row_idx;
4081                     // clear scratch registers.
4082                     ymm4 = _mm256_setzero_pd();
4083                     ymm7 = _mm256_setzero_pd();
4084                     ymm10 = _mm256_setzero_pd();
4085                     ymm13 = _mm256_setzero_pd();
4086 
4087                     //The inner loop computes the 4x1 values of the matrix.
4088                     //The computation pattern is:
4089                     // ymm4
4090                     // ymm7
4091                     // ymm10
4092                     // ymm13
4093 
4094                     for (k = 0; (k + 3) < K; k += 4)
4095                     {
4096                         ymm0 = _mm256_loadu_pd(tB + 0);
4097 
4098                         ymm3 = _mm256_loadu_pd(tA);
4099                         ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4100 
4101                         ymm3 = _mm256_loadu_pd(tA + lda);
4102                         ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
4103 
4104                         ymm3 = _mm256_loadu_pd(tA + 2 * lda);
4105                         ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
4106 
4107                         ymm3 = _mm256_loadu_pd(tA + 3 * lda);
4108                         ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
4109 
4110                         tA += 4;
4111                         tB += 4;
4112                     }
4113                     // if K is not a multiple of 4, padding is done before load using temproary array.
4114                     if (k < K)
4115                     {
4116                         int iter;
4117                         double data_feeder[4] = { 0.0 };
4118 
4119                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
4120                         ymm0 = _mm256_loadu_pd(data_feeder);
4121 
4122                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
4123                         ymm3 = _mm256_loadu_pd(data_feeder);
4124                         ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4125 
4126                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter];
4127                         ymm3 = _mm256_loadu_pd(data_feeder);
4128                         ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7);
4129 
4130                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter];
4131                         ymm3 = _mm256_loadu_pd(data_feeder);
4132                         ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10);
4133 
4134                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter];
4135                         ymm3 = _mm256_loadu_pd(data_feeder);
4136                         ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13);
4137 
4138                     }
4139 
4140                     //horizontal addition and storage of the data.
4141                     //Results for 4x1 blocks of C is stored here
4142                     ymm4 = _mm256_hadd_pd(ymm4, ymm4);
4143                     _mm256_storeu_pd(scratch, ymm4);
4144                     result = scratch[0] + scratch[2];
4145                     result *= (*alpha_cast);
4146                     if(is_beta_non_zero)
4147                         tC[0] = result + tC[0] * (*beta_cast);
4148                     else
4149                         tC[0] = result;
4150 
4151                     ymm7 = _mm256_hadd_pd(ymm7, ymm7);
4152                     _mm256_storeu_pd(scratch, ymm7);
4153                     result = scratch[0] + scratch[2];
4154                     result *= (*alpha_cast);
4155                     if(is_beta_non_zero)
4156                         tC[1] = result + tC[1] * (*beta_cast);
4157                     else
4158                         tC[1] = result;
4159 
4160                     ymm10 = _mm256_hadd_pd(ymm10, ymm10);
4161                     _mm256_storeu_pd(scratch, ymm10);
4162                     result = scratch[0] + scratch[2];
4163                     result *= (*alpha_cast);
4164                     if(is_beta_non_zero)
4165                         tC[2] = result + tC[2] * (*beta_cast);
4166                     else
4167                         tC[2] = result;
4168 
4169                     ymm13 = _mm256_hadd_pd(ymm13, ymm13);
4170                     _mm256_storeu_pd(scratch, ymm13);
4171                     result = scratch[0] + scratch[2];
4172                     result *= (*alpha_cast);
4173                     if(is_beta_non_zero)
4174                         tC[3] = result + tC[3] * (*beta_cast);
4175                     else
4176                         tC[3] = result;
4177                 }
4178             }
4179             processed_row = row_idx;
4180         }
4181 
4182         // The edge case handling where M is not a multiple of 4
4183         if (processed_row < M)
4184         {
4185             for (row_idx = processed_row; row_idx < M; row_idx += 1)
4186             {
4187                 for (col_idx = 0; col_idx < N; col_idx += 1)
4188                 {
4189                     tA = A + row_idx * lda;
4190                     tB = B + col_idx * ldb;
4191                     tC = C + col_idx * ldc + row_idx;
4192                     // clear scratch registers.
4193                     ymm4 = _mm256_setzero_pd();
4194 
4195                     for (k = 0; (k + 3) < K; k += 4)
4196                     {
4197                         ymm0 = _mm256_loadu_pd(tB + 0);
4198                         ymm3 = _mm256_loadu_pd(tA);
4199                         ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4200 
4201                         tA += 4;
4202                         tB += 4;
4203                     }
4204 
4205                     // if K is not a multiple of 4, padding is done before load using temproary array.
4206                     if (k < K)
4207                     {
4208                         int iter;
4209                         double data_feeder[4] = { 0.0 };
4210 
4211                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter];
4212                         ymm0 = _mm256_loadu_pd(data_feeder);
4213 
4214                         for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter];
4215                         ymm3 = _mm256_loadu_pd(data_feeder);
4216                         ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);
4217 
4218                     }
4219 
4220                     //horizontal addition and storage of the data.
4221                     ymm4 = _mm256_hadd_pd(ymm4, ymm4);
4222                     _mm256_storeu_pd(scratch, ymm4);
4223                     result = scratch[0] + scratch[2];
4224                     result *= (*alpha_cast);
4225                     if(is_beta_non_zero)
4226                         tC[0] = result + tC[0] * (*beta_cast);
4227                     else
4228                         tC[0] = result;
4229                 }
4230             }
4231         }
4232 		AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
4233         return BLIS_SUCCESS;
4234     }
4235     else
4236 	{
4237 		AOCL_DTL_TRACE_EXIT_ERR(
4238 			AOCL_DTL_LEVEL_INFO,
4239 			"Invalid dimesions for small gemm."
4240 			);
4241 		return BLIS_NONCONFORMAL_DIMENSIONS;
4242 	}
4243 }
4244 #endif
4245 
4246