1 #include "util/cpuid.hpp"
2 #include "config.hpp"
3 
4 #include "blis.h"
5 
6 template <typename T>
7 using bli_packm_t = void(*)(conj_t conja, len_type n, const T* kappa,
8                             const T* a, stride_type rs_a, stride_type cs_a,
9                                   T* p,                   stride_type cs_p);
10 
11 template <typename T>
12 using bli_packm_func = typename std::remove_pointer<bli_packm_t<T>>::type;
13 
14 extern "C" bli_packm_func<double> bli_dpackm_30xk_opt;
15 extern "C" bli_packm_func<double> bli_dpackm_24xk_opt;
16 extern "C" bli_packm_func<double> bli_dpackm_8xk_opt;
17 extern "C" bli_packm_func<float> bli_spackm_24xk_opt;
18 extern "C" bli_packm_func<float> bli_spackm_16xk_opt;
19 
20 namespace tblis
21 {
22 
knl_dpackm_30xk(len_type m,len_type k,const double * p_a,stride_type rs_a,stride_type cs_a,double * p_ap)23 void knl_dpackm_30xk(len_type m, len_type k,
24                      const double* p_a, stride_type rs_a, stride_type cs_a,
25                      double* p_ap)
26 {
27     constexpr double one = 1.0;
28 
29     if (m == 30)
30     {
31         bli_dpackm_30xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 32);
32     }
33     else
34     {
35         pack_nn_ukr_def<knl_d30x8_knc_config, double, matrix_constants::MAT_A>
36             (m, k, p_a, rs_a, cs_a, p_ap);
37     }
38 }
39 
knl_dpackm_24xk(len_type m,len_type k,const double * p_a,stride_type rs_a,stride_type cs_a,double * p_ap)40 void knl_dpackm_24xk(len_type m, len_type k,
41                      const double* p_a, stride_type rs_a, stride_type cs_a,
42                      double* p_ap)
43 {
44     constexpr double one = 1.0;
45 
46     if (m == 24)
47     {
48         bli_dpackm_24xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 24);
49     }
50     else
51     {
52         pack_nn_ukr_def<knl_d24x8_config, double, matrix_constants::MAT_A>
53             (m, k, p_a, rs_a, cs_a, p_ap);
54     }
55 }
56 
knl_dpackm_8xk(len_type m,len_type k,const double * p_a,stride_type rs_a,stride_type cs_a,double * p_ap)57 void knl_dpackm_8xk(len_type m, len_type k,
58                     const double* p_a, stride_type rs_a, stride_type cs_a,
59                     double* p_ap)
60 {
61     constexpr double one = 1.0;
62 
63     if (m == 8)
64     {
65         bli_dpackm_8xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 8);
66     }
67     else
68     {
69         pack_nn_ukr_def<knl_d24x8_config, double, matrix_constants::MAT_B>
70             (m, k, p_a, rs_a, cs_a, p_ap);
71     }
72 }
73 
knl_spackm_24xk(len_type m,len_type k,const float * p_a,stride_type rs_a,stride_type cs_a,float * p_ap)74 void knl_spackm_24xk(len_type m, len_type k,
75                      const float* p_a, stride_type rs_a, stride_type cs_a,
76                      float* p_ap)
77 {
78     constexpr float one = 1.0;
79 
80     if (m == 24)
81     {
82         bli_spackm_24xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 24);
83     }
84     else
85     {
86         pack_nn_ukr_def<knl_d24x8_config, float, matrix_constants::MAT_A>
87             (m, k, p_a, rs_a, cs_a, p_ap);
88     }
89 }
90 
knl_spackm_16xk(len_type m,len_type k,const float * p_a,stride_type rs_a,stride_type cs_a,float * p_ap)91 void knl_spackm_16xk(len_type m, len_type k,
92                      const float* p_a, stride_type rs_a, stride_type cs_a,
93                      float* p_ap)
94 {
95     constexpr float one = 1.0;
96 
97     if (m == 16)
98     {
99         bli_spackm_16xk_opt(BLIS_NO_CONJUGATE, k, &one, p_a, rs_a, cs_a, p_ap, 16);
100     }
101     else
102     {
103         pack_nn_ukr_def<knl_d24x8_config, float, matrix_constants::MAT_B>
104             (m, k, p_a, rs_a, cs_a, p_ap);
105     }
106 }
107 
knl_check()108 int knl_check()
109 {
110     int family, model, features;
111     int vendor = get_cpu_type(family, model, features);
112 
113     if (vendor != VENDOR_INTEL)
114     {
115         if (get_verbose() >= 1) printf("tblis: knl: Wrong vendor.\n");
116         return -1;
117     }
118 
119     if (!check_features(features, FEATURE_AVX))
120     {
121         if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX.\n");
122         return -1;
123     }
124 
125     if (!check_features(features, FEATURE_FMA3))
126     {
127         if (get_verbose() >= 1) printf("tblis: knl: Doesn't support FMA3.\n");
128         return -1;
129     }
130 
131     if (!check_features(features, FEATURE_AVX2))
132     {
133         if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX2.\n");
134         return -1;
135     }
136 
137     if (!check_features(features, FEATURE_AVX512F))
138     {
139         if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX512F.\n");
140         return -1;
141     }
142 
143     if (!check_features(features, FEATURE_AVX512PF))
144     {
145         if (get_verbose() >= 1) printf("tblis: knl: Doesn't support AVX512PF.\n");
146         return -1;
147     }
148 
149     return 4;
150 }
151 
152 //TBLIS_CONFIG_INSTANTIATE(knl_d30x8_knc);
153 //TBLIS_CONFIG_INSTANTIATE(knl_d30x8);
154 TBLIS_CONFIG_INSTANTIATE(knl_d24x8);
155 //TBLIS_CONFIG_INSTANTIATE(knl_d8x24);
156 
157 }
158