1 #include "util/cpuid.hpp"
2 #include "config.hpp"
3
4 #include "blis.h"
5
6 template <typename T>
7 using bli_packm_t = void(*)(conj_t conja, len_type n, const T* kappa,
8 const T* a, stride_type rs_a, stride_type cs_a,
9 T* p, stride_type cs_p);
10
11 template <typename T>
12 using bli_packm_func = typename std::remove_pointer<bli_packm_t<T>>::type;
13
14 extern "C" bli_packm_func<double> bli_dpackm_30xk_opt;
15 extern "C" bli_packm_func<double> bli_dpackm_24xk_opt;
16 extern "C" bli_packm_func<double> bli_dpackm_8xk_opt;
17 extern "C" bli_packm_func<float> bli_spackm_24xk_opt;
18 extern "C" bli_packm_func<float> bli_spackm_16xk_opt;
19
20 namespace tblis
21 {
22
knl_dpackm_30xk(len_type m,len_type k,const double * p_a,stride_type rs_a,stride_type cs_a,double * p_ap)23 void knl_dpackm_30xk(len_type m, len_type k,
24 const double* p_a, stride_type rs_a, stride_type cs_a,
25 double* p_ap)
26 {
27 constexpr double one = 1.0;
28
29 if (m == 30)
30 {
31 bli_dpackm_30xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 32);
32 }
33 else
34 {
35 pack_nn_ukr_def<knl_d30x8_knc_config, double, matrix_constants::MAT_A>
36 (m, k, p_a, rs_a, cs_a, p_ap);
37 }
38 }
39
knl_dpackm_24xk(len_type m,len_type k,const double * p_a,stride_type rs_a,stride_type cs_a,double * p_ap)40 void knl_dpackm_24xk(len_type m, len_type k,
41 const double* p_a, stride_type rs_a, stride_type cs_a,
42 double* p_ap)
43 {
44 constexpr double one = 1.0;
45
46 if (m == 24)
47 {
48 bli_dpackm_24xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 24);
49 }
50 else
51 {
52 pack_nn_ukr_def<knl_d24x8_config, double, matrix_constants::MAT_A>
53 (m, k, p_a, rs_a, cs_a, p_ap);
54 }
55 }
56
knl_dpackm_8xk(len_type m,len_type k,const double * p_a,stride_type rs_a,stride_type cs_a,double * p_ap)57 void knl_dpackm_8xk(len_type m, len_type k,
58 const double* p_a, stride_type rs_a, stride_type cs_a,
59 double* p_ap)
60 {
61 constexpr double one = 1.0;
62
63 if (m == 8)
64 {
65 bli_dpackm_8xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 8);
66 }
67 else
68 {
69 pack_nn_ukr_def<knl_d24x8_config, double, matrix_constants::MAT_B>
70 (m, k, p_a, rs_a, cs_a, p_ap);
71 }
72 }
73
knl_spackm_24xk(len_type m,len_type k,const float * p_a,stride_type rs_a,stride_type cs_a,float * p_ap)74 void knl_spackm_24xk(len_type m, len_type k,
75 const float* p_a, stride_type rs_a, stride_type cs_a,
76 float* p_ap)
77 {
78 constexpr float one = 1.0;
79
80 if (m == 24)
81 {
82 bli_spackm_24xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 24);
83 }
84 else
85 {
86 pack_nn_ukr_def<knl_d24x8_config, float, matrix_constants::MAT_A>
87 (m, k, p_a, rs_a, cs_a, p_ap);
88 }
89 }
90
knl_spackm_16xk(len_type m,len_type k,const float * p_a,stride_type rs_a,stride_type cs_a,float * p_ap)91 void knl_spackm_16xk(len_type m, len_type k,
92 const float* p_a, stride_type rs_a, stride_type cs_a,
93 float* p_ap)
94 {
95 constexpr float one = 1.0;
96
97 if (m == 16)
98 {
99 bli_spackm_16xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 16);
100 }
101 else
102 {
103 pack_nn_ukr_def<knl_d24x8_config, float, matrix_constants::MAT_B>
104 (m, k, p_a, rs_a, cs_a, p_ap);
105 }
106 }
107
knl_check()108 int knl_check()
109 {
110 int family, model, features;
111 int vendor = get_cpu_type(family, model, features);
112
113 if (vendor != VENDOR_INTEL)
114 {
115 if (get_verbose() >= 1) printf("tblis: knl: Wrong vendor.\n");
116 return -1;
117 }
118
119 if (!check_features(features, FEATURE_AVX))
120 {
121 if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX.\n");
122 return -1;
123 }
124
125 if (!check_features(features, FEATURE_FMA3))
126 {
127 if (get_verbose() >= 1) printf("tblis: knl: Doesn't support FMA3.\n");
128 return -1;
129 }
130
131 if (!check_features(features, FEATURE_AVX2))
132 {
133 if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX2.\n");
134 return -1;
135 }
136
137 if (!check_features(features, FEATURE_AVX512F))
138 {
139 if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX512F.\n");
140 return -1;
141 }
142
143 if (!check_features(features, FEATURE_AVX512PF))
144 {
145 if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX512PF.\n");
146 return -1;
147 }
148
149 return 4;
150 }
151
152 //TBLIS_CONFIG_INSTANTIATE(knl_d30x8_knc);
153 //TBLIS_CONFIG_INSTANTIATE(knl_d30x8);
154 TBLIS_CONFIG_INSTANTIATE(knl_d24x8);
155 //TBLIS_CONFIG_INSTANTIATE(knl_d8x24);
156
157 }
158