1 #ifndef _TBLIS_KERNELS_3M_GEMM_HPP_
2 #define _TBLIS_KERNELS_3M_GEMM_HPP_
3 
4 #include "util/basic_types.h"
5 #include "blis.h"
6 
7 #include <type_traits>
8 
9 namespace tblis
10 {
11 
12 #define EXTERN_GEMM_UKR(T, name) \
13 extern void name(tblis::stride_type k, \
14                  const T* alpha, \
15                  const T* a, const T* b, \
16                  const T* beta, \
17                  T* c, tblis::stride_type rs_c, \
18                        tblis::stride_type cs_c, \
19                        auxinfo_t* aux);
20 
21 template <typename T>
22 using gemm_ukr_t =
23 void (*)(stride_type k,
24         const T* alpha,
25         const T* a, const T* b,
26         const T* beta,
27         T* c, stride_type rs_c, stride_type cs_c,
28         auxinfo_t* aux);
29 
30 template <typename Config, typename T>
gemm_ukr_def(stride_type k,const T * TBLIS_RESTRICT alpha,const T * TBLIS_RESTRICT p_a,const T * TBLIS_RESTRICT p_b,const T * TBLIS_RESTRICT beta,T * TBLIS_RESTRICT p_c,stride_type rs_c,stride_type cs_c,auxinfo_t *)31 void gemm_ukr_def(stride_type k,
32                   const T* TBLIS_RESTRICT alpha,
33                   const T* TBLIS_RESTRICT p_a, const T* TBLIS_RESTRICT p_b,
34                   const T* TBLIS_RESTRICT beta,
35                   T* TBLIS_RESTRICT p_c, stride_type rs_c, stride_type cs_c,
36                   auxinfo_t*)
37 {
38     constexpr len_type MR = Config::template gemm_mr<T>::def;
39     constexpr len_type NR = Config::template gemm_nr<T>::def;
40 
41     T p_ab[MR*NR] __attribute__((aligned(64))) = {};
42 
43     while (k --> 0)
44     {
45         for (int i = 0;i < MR;i++)
46             #pragma omp simd
47             for (int j = 0;j < NR;j++)
48                 p_ab[i*NR + j] += p_a[i] * p_b[j];
49 
50         p_a += MR;
51         p_b += NR;
52     }
53 
54     if (*beta == T(0))
55     {
56         for (len_type i = 0;i < MR;i++)
57             for (len_type j = 0;j < NR;j++)
58                 p_c[i*rs_c + j*cs_c] = (*alpha)*p_ab[i*NR + j];
59     }
60     else
61     {
62         for (len_type i = 0;i < MR;i++)
63             for (len_type j = 0;j < NR;j++)
64                 p_c[i*rs_c + j*cs_c] = (*alpha)*p_ab[i*NR + j] +
65                                        (*beta)*p_c[i*rs_c + j*cs_c];
66     }
67 }
68 
69 }
70 
71 #endif
72