1 #pragma once
2 
3 #include <immintrin.h>
4 #include <stdint.h>
5 #include <math.h>
6 
7 #include "simint/vectorization/intrinsics_avx.h"
8 
9 #ifdef __cplusplus
10 #include "simint/cpp_restrict.hpp"
11 extern "C" {
12 #endif
13 
14 union simint_double8
15 {
16     __m512d v;
17     double d[8];
18 };
19 
20 
21 // Missing GCC vectorized exp and pow
simint_exp_vec8(__m512d x)22 static inline __m512d simint_exp_vec8(__m512d x)
23 {
24     union simint_double8 u = { x };
25     union simint_double8 res;
26     for(int i = 0; i < 8; i++)
27         res.d[i] = exp(u.d[i]);
28     return res.v;
29 }
30 
simint_pow_vec8(__m512d a,__m512d p)31 static inline __m512d simint_pow_vec8(__m512d a, __m512d p)
32 {
33     union simint_double8 ua = { a };
34     union simint_double8 up = { p };
35     union simint_double8 res;
36     for(int i = 0; i < 8; i++)
37         res.d[i] = pow(ua.d[i], up.d[i]);
38     return res.v;
39 }
40 
41 #if defined SIMINT_AVX512 || defined SIMINT_MICAVX512
42 
43     #define SIMINT_SIMD_LEN 8
44 
45     #define SIMINT_DBLTYPE         __m512d
46     #define SIMINT_DBLLOAD(p,i)    _mm512_load_pd((p) + (i))
47     #define SIMINT_DBLSET1(a)      _mm512_set1_pd((a))
48     #define SIMINT_NEG(a)          (SIMINT_MUL((a), (SIMINT_DBLSET1(-1.0))))
49     #define SIMINT_ADD(a,b)        _mm512_add_pd((a), (b))
50     #define SIMINT_SUB(a,b)        _mm512_sub_pd((a), (b))
51     #define SIMINT_MUL(a,b)        _mm512_mul_pd((a), (b))
52     #define SIMINT_DIV(a,b)        _mm512_div_pd((a), (b))
53     #define SIMINT_SQRT(a)         _mm512_sqrt_pd((a))
54     #define SIMINT_FMADD(a,b,c)    _mm512_fmadd_pd((a), (b), (c))
55     #define SIMINT_FMSUB(a,b,c)    _mm512_fmsub_pd((a), (b), (c))
56 
57     #if defined __INTEL_COMPILER
58         #define SIMINT_EXP(a)       _mm512_exp_pd((a))
59         #define SIMINT_POW(a,p)     _mm512_pow_pd((a), (p))
60     #else
61         #define SIMINT_EXP(a)       simint_exp_vec8((a))
62         #define SIMINT_POW(a,p)     simint_pow_vec8((a), (p))
63     #endif
64 
65 
66 
67     ////////////////////////////////////////
68     // Special functions
69     ////////////////////////////////////////
70 
71     static inline
contract(int ncart,int const * restrict offsets,__m512d const * restrict src,double * restrict dest)72     void contract(int ncart,
73                   int const * restrict offsets,
74                   __m512d const * restrict src,
75                   double * restrict dest)
76     {
77         for(int n = 0; n < SIMINT_SIMD_LEN; ++n)
78         {
79             double const * restrict src_tmp = (double *)src + n;
80             double * restrict dest_tmp = dest + offsets[n]*ncart;
81 
82             for(int np = 0; np < ncart; ++np)
83             {
84                 dest_tmp[np] += *src_tmp;
85                 src_tmp += SIMINT_SIMD_LEN;
86             }
87         }
88     }
89 
90 
91     static inline
contract_all(int ncart,__m512d const * restrict src,double * restrict dest)92     void contract_all(int ncart,
93                       __m512d const * restrict src,
94                       double * restrict dest)
95     {
96         #if defined __clang__ || defined __INTEL_COMPILER
97 
98         for(int np = 0; np < ncart; np++)
99             dest[np] += _mm512_reduce_add_pd(src[np]);
100 
101         #else
102 
103         int offsets[8] = {0, 0, 0, 0, 0, 0, 0, 0};
104         contract(ncart, offsets, src, dest);
105 
106         #endif
107     }
108 
109 
110     static inline
contract_fac(int ncart,const __m512d factor,int const * restrict offsets,__m512d const * restrict src,double * restrict dest)111     void contract_fac(int ncart,
112                       const __m512d factor,
113                       int const * restrict offsets,
114                       __m512d const * restrict src,
115                       double * restrict dest)
116     {
117         for(int np = 0; np < ncart; ++np)
118         {
119             union simint_double8 vtmp = { SIMINT_MUL(src[np], factor) };
120 
121             for(int n = 0; n < SIMINT_SIMD_LEN; ++n)
122                 dest[offsets[n]*ncart+np] += vtmp.d[n];
123         }
124     }
125 
126 
127     static inline
contract_all_fac(int ncart,const __m512d factor,__m512d const * restrict src,double * restrict dest)128     void contract_all_fac(int ncart,
129                           const __m512d factor,
130                           __m512d const * restrict src,
131                           double * restrict dest)
132     {
133         #if defined __clang__ || defined __INTEL_COMPILER
134 
135         for(int np = 0; np < ncart; np++)
136             dest[np] += _mm512_reduce_add_pd(_mm512_mul_pd(factor, src[np]));
137 
138         #else
139 
140         int offsets[8] = {0, 0, 0, 0, 0, 0, 0, 0};
141         contract_fac(ncart, offsets, factor, src, dest);
142 
143         #endif
144     }
145 
146 
147     static inline
vector_min(__m512d v)148     double vector_min(__m512d v)
149     {
150         #if defined __clang__ || defined __INTEL_COMPILER
151             return _mm512_reduce_min_pd(v);
152         #else
153             union simint_double8 u = { v };
154             double min = u.d[0];
155             for(int i = 1; i < 8; i++)
156                 min = (u.d[i] < min ? u.d[i] : min);
157             return min;
158         #endif
159     }
160 
161 
162     static inline
vector_max(__m512d v)163     double vector_max(__m512d v)
164     {
165         #if defined __clang__ || defined __INTEL_COMPILER
166             return _mm512_reduce_max_pd(v);
167         #else
168             union simint_double8 u = { v };
169             double max = u.d[0];
170             for(int i = 1; i < 8; i++)
171                 max = (u.d[i] > max ? u.d[i] : max);
172             return max;
173         #endif
174     }
175 
176     static inline
mask_load(int nlane,double * memaddr)177     __m512d mask_load(int nlane, double * memaddr)
178     {
179         union simint_double8 u = { _mm512_load_pd(memaddr) };
180         for(int n = nlane; n < SIMINT_SIMD_LEN; n++)
181             u.d[n] = 0.0;
182         return u.v;
183     }
184 
185 #endif // defined SIMINT_AVX512 || defined SIMINT_MICAVX512
186 
187 #ifdef __cplusplus
188 }
189 #endif
190 
191