1 #ifndef _TBLIS_KERNELS_3M_GEMM_HPP_
2 #define _TBLIS_KERNELS_3M_GEMM_HPP_
3
4 #include "util/basic_types.h"
5 #include "blis.h"
6
7 #include <type_traits>
8
9 namespace tblis
10 {
11
12 #define EXTERN_GEMM_UKR(T, name) \
13 extern void name(tblis::stride_type k, \
14 const T* alpha, \
15 const T* a, const T* b, \
16 const T* beta, \
17 T* c, tblis::stride_type rs_c, \
18 tblis::stride_type cs_c, \
19 auxinfo_t* aux);
20
21 template <typename T>
22 using gemm_ukr_t =
23 void (*)(stride_type k,
24 const T* alpha,
25 const T* a, const T* b,
26 const T* beta,
27 T* c, stride_type rs_c, stride_type cs_c,
28 auxinfo_t* aux);
29
30 template <typename Config, typename T>
gemm_ukr_def(stride_type k,const T * TBLIS_RESTRICT alpha,const T * TBLIS_RESTRICT p_a,const T * TBLIS_RESTRICT p_b,const T * TBLIS_RESTRICT beta,T * TBLIS_RESTRICT p_c,stride_type rs_c,stride_type cs_c,auxinfo_t *)31 void gemm_ukr_def(stride_type k,
32 const T* TBLIS_RESTRICT alpha,
33 const T* TBLIS_RESTRICT p_a, const T* TBLIS_RESTRICT p_b,
34 const T* TBLIS_RESTRICT beta,
35 T* TBLIS_RESTRICT p_c, stride_type rs_c, stride_type cs_c,
36 auxinfo_t*)
37 {
38 constexpr len_type MR = Config::template gemm_mr<T>::def;
39 constexpr len_type NR = Config::template gemm_nr<T>::def;
40
41 T p_ab[MR*NR] __attribute__((aligned(64))) = {};
42
43 while (k --> 0)
44 {
45 for (int i = 0;i < MR;i++)
46 #pragma omp simd
47 for (int j = 0;j < NR;j++)
48 p_ab[i*NR + j] += p_a[i] * p_b[j];
49
50 p_a += MR;
51 p_b += NR;
52 }
53
54 if (*beta == T(0))
55 {
56 for (len_type i = 0;i < MR;i++)
57 for (len_type j = 0;j < NR;j++)
58 p_c[i*rs_c + j*cs_c] = (*alpha)*p_ab[i*NR + j];
59 }
60 else
61 {
62 for (len_type i = 0;i < MR;i++)
63 for (len_type j = 0;j < NR;j++)
64 p_c[i*rs_c + j*cs_c] = (*alpha)*p_ab[i*NR + j] +
65 (*beta)*p_c[i*rs_c + j*cs_c];
66 }
67 }
68
69 }
70
71 #endif
72