1 #pragma once
2
3 #include <immintrin.h>
4 #include <stdint.h>
5 #include <math.h>
6
7 #include "simint/vectorization/intrinsics_avx.h"
8
9 #ifdef __cplusplus
10 #include "simint/cpp_restrict.hpp"
11 extern "C" {
12 #endif
13
14 union simint_double8
15 {
16 __m512d v;
17 double d[8];
18 };
19
20
21 // Missing GCC vectorized exp and pow
simint_exp_vec8(__m512d x)22 static inline __m512d simint_exp_vec8(__m512d x)
23 {
24 union simint_double8 u = { x };
25 union simint_double8 res;
26 for(int i = 0; i < 8; i++)
27 res.d[i] = exp(u.d[i]);
28 return res.v;
29 }
30
simint_pow_vec8(__m512d a,__m512d p)31 static inline __m512d simint_pow_vec8(__m512d a, __m512d p)
32 {
33 union simint_double8 ua = { a };
34 union simint_double8 up = { p };
35 union simint_double8 res;
36 for(int i = 0; i < 8; i++)
37 res.d[i] = pow(ua.d[i], up.d[i]);
38 return res.v;
39 }
40
41 #if defined SIMINT_AVX512 || defined SIMINT_MICAVX512
42
43 #define SIMINT_SIMD_LEN 8
44
45 #define SIMINT_DBLTYPE __m512d
46 #define SIMINT_DBLLOAD(p,i) _mm512_load_pd((p) + (i))
47 #define SIMINT_DBLSET1(a) _mm512_set1_pd((a))
48 #define SIMINT_NEG(a) (SIMINT_MUL((a), (SIMINT_DBLSET1(-1.0))))
49 #define SIMINT_ADD(a,b) _mm512_add_pd((a), (b))
50 #define SIMINT_SUB(a,b) _mm512_sub_pd((a), (b))
51 #define SIMINT_MUL(a,b) _mm512_mul_pd((a), (b))
52 #define SIMINT_DIV(a,b) _mm512_div_pd((a), (b))
53 #define SIMINT_SQRT(a) _mm512_sqrt_pd((a))
54 #define SIMINT_FMADD(a,b,c) _mm512_fmadd_pd((a), (b), (c))
55 #define SIMINT_FMSUB(a,b,c) _mm512_fmsub_pd((a), (b), (c))
56
57 #if defined __INTEL_COMPILER
58 #define SIMINT_EXP(a) _mm512_exp_pd((a))
59 #define SIMINT_POW(a,p) _mm512_pow_pd((a), (p))
60 #else
61 #define SIMINT_EXP(a) simint_exp_vec8((a))
62 #define SIMINT_POW(a,p) simint_pow_vec8((a), (p))
63 #endif
64
65
66
67 ////////////////////////////////////////
68 // Special functions
69 ////////////////////////////////////////
70
71 static inline
contract(int ncart,int const * restrict offsets,__m512d const * restrict src,double * restrict dest)72 void contract(int ncart,
73 int const * restrict offsets,
74 __m512d const * restrict src,
75 double * restrict dest)
76 {
77 for(int n = 0; n < SIMINT_SIMD_LEN; ++n)
78 {
79 double const * restrict src_tmp = (double *)src + n;
80 double * restrict dest_tmp = dest + offsets[n]*ncart;
81
82 for(int np = 0; np < ncart; ++np)
83 {
84 dest_tmp[np] += *src_tmp;
85 src_tmp += SIMINT_SIMD_LEN;
86 }
87 }
88 }
89
90
91 static inline
contract_all(int ncart,__m512d const * restrict src,double * restrict dest)92 void contract_all(int ncart,
93 __m512d const * restrict src,
94 double * restrict dest)
95 {
96 #if defined __clang__ || defined __INTEL_COMPILER
97
98 for(int np = 0; np < ncart; np++)
99 dest[np] += _mm512_reduce_add_pd(src[np]);
100
101 #else
102
103 int offsets[8] = {0, 0, 0, 0, 0, 0, 0, 0};
104 contract(ncart, offsets, src, dest);
105
106 #endif
107 }
108
109
110 static inline
contract_fac(int ncart,const __m512d factor,int const * restrict offsets,__m512d const * restrict src,double * restrict dest)111 void contract_fac(int ncart,
112 const __m512d factor,
113 int const * restrict offsets,
114 __m512d const * restrict src,
115 double * restrict dest)
116 {
117 for(int np = 0; np < ncart; ++np)
118 {
119 union simint_double8 vtmp = { SIMINT_MUL(src[np], factor) };
120
121 for(int n = 0; n < SIMINT_SIMD_LEN; ++n)
122 dest[offsets[n]*ncart+np] += vtmp.d[n];
123 }
124 }
125
126
127 static inline
contract_all_fac(int ncart,const __m512d factor,__m512d const * restrict src,double * restrict dest)128 void contract_all_fac(int ncart,
129 const __m512d factor,
130 __m512d const * restrict src,
131 double * restrict dest)
132 {
133 #if defined __clang__ || defined __INTEL_COMPILER
134
135 for(int np = 0; np < ncart; np++)
136 dest[np] += _mm512_reduce_add_pd(_mm512_mul_pd(factor, src[np]));
137
138 #else
139
140 int offsets[8] = {0, 0, 0, 0, 0, 0, 0, 0};
141 contract_fac(ncart, offsets, factor, src, dest);
142
143 #endif
144 }
145
146
147 static inline
vector_min(__m512d v)148 double vector_min(__m512d v)
149 {
150 #if defined __clang__ || defined __INTEL_COMPILER
151 return _mm512_reduce_min_pd(v);
152 #else
153 union simint_double8 u = { v };
154 double min = u.d[0];
155 for(int i = 1; i < 8; i++)
156 min = (u.d[i] < min ? u.d[i] : min);
157 return min;
158 #endif
159 }
160
161
162 static inline
vector_max(__m512d v)163 double vector_max(__m512d v)
164 {
165 #if defined __clang__ || defined __INTEL_COMPILER
166 return _mm512_reduce_max_pd(v);
167 #else
168 union simint_double8 u = { v };
169 double max = u.d[0];
170 for(int i = 1; i < 8; i++)
171 max = (u.d[i] > max ? u.d[i] : max);
172 return max;
173 #endif
174 }
175
176 static inline
mask_load(int nlane,double * memaddr)177 __m512d mask_load(int nlane, double * memaddr)
178 {
179 union simint_double8 u = { _mm512_load_pd(memaddr) };
180 for(int n = nlane; n < SIMINT_SIMD_LEN; n++)
181 u.d[n] = 0.0;
182 return u.v;
183 }
184
185 #endif // defined SIMINT_AVX512 || defined SIMINT_MICAVX512
186
187 #ifdef __cplusplus
188 }
189 #endif
190
191