1 /*
2 * This file is part of libsharp2.
3 *
4 * libsharp2 is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation; either version 2 of the License, or
7 * (at your option) any later version.
8 *
9 * libsharp2 is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with libsharp2; if not, write to the Free Software
16 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
17 */
18
19 /* libsharp2 is being developed at the Max-Planck-Institut fuer Astrophysik */
20
21 /* \file sharp_vecsupport.h
22 * Convenience functions for vector arithmetics
23 *
24 * Copyright (C) 2012-2019 Max-Planck-Society
25 * Author: Martin Reinecke
26 */
27
28 #ifndef SHARP2_VECSUPPORT_H
29 #define SHARP2_VECSUPPORT_H
30
31 #include <math.h>
32
33 #ifndef VLEN
34
35 #if (defined(__AVX512F__))
36 #define VLEN 8
37 #elif (defined (__AVX__))
38 #define VLEN 4
39 #elif (defined (__SSE2__))
40 #define VLEN 2
41 #else
42 #define VLEN 1
43 #endif
44
45 #endif
46
47 typedef double Ts;
48
49 #if (VLEN==1)
50
51 typedef double Tv;
52 typedef int Tm;
53
54 #define vload(a) (a)
55 #define vzero 0.
56 #define vone 1.
57
58 #define vaddeq_mask(mask,a,b) if (mask) (a)+=(b);
59 #define vsubeq_mask(mask,a,b) if (mask) (a)-=(b);
60 #define vmuleq_mask(mask,a,b) if (mask) (a)*=(b);
61 #define vneg(a) (-(a))
62 #define vabs(a) fabs(a)
63 #define vsqrt(a) sqrt(a)
64 #define vlt(a,b) ((a)<(b))
65 #define vgt(a,b) ((a)>(b))
66 #define vge(a,b) ((a)>=(b))
67 #define vne(a,b) ((a)!=(b))
68 #define vand_mask(a,b) ((a)&&(b))
69 #define vor_mask(a,b) ((a)||(b))
vmin(Tv a,Tv b)70 static inline Tv vmin (Tv a, Tv b) { return (a<b) ? a : b; }
vmax(Tv a,Tv b)71 static inline Tv vmax (Tv a, Tv b) { return (a>b) ? a : b; }
72 #define vanyTrue(a) (a)
73 #define vallTrue(a) (a)
74
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)75 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c, Tv d,
76 _Complex double * restrict cc)
77 { cc[0] += a+_Complex_I*b; cc[1] += c+_Complex_I*d; }
78
79
80 #endif
81
82 #if (VLEN==2)
83
84 #include <emmintrin.h>
85
86 #if defined (__SSE3__)
87 #include <pmmintrin.h>
88 #endif
89 #if defined (__SSE4_1__)
90 #include <smmintrin.h>
91 #endif
92
93 typedef __m128d Tv;
94 typedef __m128d Tm;
95
96 #if defined(__SSE4_1__)
97 #define vblend__(m,a,b) _mm_blendv_pd(b,a,m)
98 #else
vblend__(Tv m,Tv a,Tv b)99 static inline Tv vblend__(Tv m, Tv a, Tv b)
100 { return _mm_or_pd(_mm_and_pd(a,m),_mm_andnot_pd(m,b)); }
101 #endif
102 #define vload(a) _mm_set1_pd(a)
103 #define vzero _mm_setzero_pd()
104 #define vone vload(1.)
105
106 #define vaddeq_mask(mask,a,b) a+=vblend__(mask,b,vzero)
107 #define vsubeq_mask(mask,a,b) a-=vblend__(mask,b,vzero)
108 #define vmuleq_mask(mask,a,b) a*=vblend__(mask,b,vone)
109 #define vneg(a) _mm_xor_pd(vload(-0.),a)
110 #define vabs(a) _mm_andnot_pd(vload(-0.),a)
111 #define vsqrt(a) _mm_sqrt_pd(a)
112 #define vlt(a,b) _mm_cmplt_pd(a,b)
113 #define vgt(a,b) _mm_cmpgt_pd(a,b)
114 #define vge(a,b) _mm_cmpge_pd(a,b)
115 #define vne(a,b) _mm_cmpneq_pd(a,b)
116 #define vand_mask(a,b) _mm_and_pd(a,b)
117 #define vor_mask(a,b) _mm_or_pd(a,b)
118 #define vmin(a,b) _mm_min_pd(a,b)
119 #define vmax(a,b) _mm_max_pd(a,b);
120 #define vanyTrue(a) (_mm_movemask_pd(a)!=0)
121 #define vallTrue(a) (_mm_movemask_pd(a)==3)
122
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)123 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c,
124 Tv d, _Complex double * restrict cc)
125 {
126 union {Tv v; _Complex double c; } u1, u2;
127 #if defined(__SSE3__)
128 u1.v = _mm_hadd_pd(a,b); u2.v=_mm_hadd_pd(c,d);
129 #else
130 u1.v = _mm_shuffle_pd(a,b,_MM_SHUFFLE2(0,1)) +
131 _mm_shuffle_pd(a,b,_MM_SHUFFLE2(1,0));
132 u2.v = _mm_shuffle_pd(c,d,_MM_SHUFFLE2(0,1)) +
133 _mm_shuffle_pd(c,d,_MM_SHUFFLE2(1,0));
134 #endif
135 cc[0]+=u1.c; cc[1]+=u2.c;
136 }
137
138 #endif
139
140 #if (VLEN==4)
141
142 #include <immintrin.h>
143
144 typedef __m256d Tv;
145 typedef __m256d Tm;
146
147 #define vblend__(m,a,b) _mm256_blendv_pd(b,a,m)
148 #define vload(a) _mm256_set1_pd(a)
149 #define vzero _mm256_setzero_pd()
150 #define vone vload(1.)
151
152 #define vaddeq_mask(mask,a,b) a+=vblend__(mask,b,vzero)
153 #define vsubeq_mask(mask,a,b) a-=vblend__(mask,b,vzero)
154 #define vmuleq_mask(mask,a,b) a*=vblend__(mask,b,vone)
155 #define vneg(a) _mm256_xor_pd(vload(-0.),a)
156 #define vabs(a) _mm256_andnot_pd(vload(-0.),a)
157 #define vsqrt(a) _mm256_sqrt_pd(a)
158 #define vlt(a,b) _mm256_cmp_pd(a,b,_CMP_LT_OQ)
159 #define vgt(a,b) _mm256_cmp_pd(a,b,_CMP_GT_OQ)
160 #define vge(a,b) _mm256_cmp_pd(a,b,_CMP_GE_OQ)
161 #define vne(a,b) _mm256_cmp_pd(a,b,_CMP_NEQ_OQ)
162 #define vand_mask(a,b) _mm256_and_pd(a,b)
163 #define vor_mask(a,b) _mm256_or_pd(a,b)
164 #define vmin(a,b) _mm256_min_pd(a,b)
165 #define vmax(a,b) _mm256_max_pd(a,b)
166 #define vanyTrue(a) (_mm256_movemask_pd(a)!=0)
167 #define vallTrue(a) (_mm256_movemask_pd(a)==15)
168
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)169 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c, Tv d,
170 _Complex double * restrict cc)
171 {
172 Tv tmp1=_mm256_hadd_pd(a,b), tmp2=_mm256_hadd_pd(c,d);
173 Tv tmp3=_mm256_permute2f128_pd(tmp1,tmp2,49),
174 tmp4=_mm256_permute2f128_pd(tmp1,tmp2,32);
175 tmp1=tmp3+tmp4;
176 union {Tv v; _Complex double c[2]; } u;
177 u.v=tmp1;
178 cc[0]+=u.c[0]; cc[1]+=u.c[1];
179 }
180
181 #endif
182
183 #if (VLEN==8)
184
185 #include <immintrin.h>
186
187 typedef __m512d Tv;
188 typedef __mmask8 Tm;
189
190 #define vload(a) _mm512_set1_pd(a)
191 #define vzero _mm512_setzero_pd()
192 #define vone vload(1.)
193
194 #define vaddeq_mask(mask,a,b) a=_mm512_mask_add_pd(a,mask,a,b);
195 #define vsubeq_mask(mask,a,b) a=_mm512_mask_sub_pd(a,mask,a,b);
196 #define vmuleq_mask(mask,a,b) a=_mm512_mask_mul_pd(a,mask,a,b);
197 #define vneg(a) _mm512_mul_pd(a,vload(-1.))
198 #define vabs(a) (__m512d)_mm512_andnot_epi64((__m512i)vload(-0.),(__m512i)a)
199 #define vsqrt(a) _mm512_sqrt_pd(a)
200 #define vlt(a,b) _mm512_cmp_pd_mask(a,b,_CMP_LT_OQ)
201 #define vgt(a,b) _mm512_cmp_pd_mask(a,b,_CMP_GT_OQ)
202 #define vge(a,b) _mm512_cmp_pd_mask(a,b,_CMP_GE_OQ)
203 #define vne(a,b) _mm512_cmp_pd_mask(a,b,_CMP_NEQ_OQ)
204 #define vand_mask(a,b) ((a)&(b))
205 #define vor_mask(a,b) ((a)|(b))
206 #define vmin(a,b) _mm512_min_pd(a,b)
207 #define vmax(a,b) _mm512_max_pd(a,b)
208 #define vanyTrue(a) (a!=0)
209 #define vallTrue(a) (a==255)
210
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)211 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c, Tv d,
212 _Complex double * restrict cc)
213 {
214 cc[0] += _mm512_reduce_add_pd(a)+_Complex_I*_mm512_reduce_add_pd(b);
215 cc[1] += _mm512_reduce_add_pd(c)+_Complex_I*_mm512_reduce_add_pd(d);
216 }
217
218 #endif
219
220 #endif
221