1 /*
2  *  This file is part of libsharp2.
3  *
4  *  libsharp2 is free software; you can redistribute it and/or modify
5  *  it under the terms of the GNU General Public License as published by
6  *  the Free Software Foundation; either version 2 of the License, or
7  *  (at your option) any later version.
8  *
9  *  libsharp2 is distributed in the hope that it will be useful,
10  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *  GNU General Public License for more details.
13  *
14  *  You should have received a copy of the GNU General Public License
15  *  along with libsharp2; if not, write to the Free Software
16  *  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17  */
18 
19 /* libsharp2 is being developed at the Max-Planck-Institut fuer Astrophysik */
20 
21 /*  \file sharp_vecsupport.h
22  *  Convenience functions for vector arithmetics
23  *
24  *  Copyright (C) 2012-2019 Max-Planck-Society
25  *  Author: Martin Reinecke
26  */
27 
28 #ifndef SHARP2_VECSUPPORT_H
29 #define SHARP2_VECSUPPORT_H
30 
31 #include <math.h>
32 
33 #ifndef VLEN
34 
35 #if (defined(__AVX512F__))
36 #define VLEN 8
37 #elif (defined (__AVX__))
38 #define VLEN 4
39 #elif (defined (__SSE2__))
40 #define VLEN 2
41 #else
42 #define VLEN 1
43 #endif
44 
45 #endif
46 
47 typedef double Ts;
48 
49 #if (VLEN==1)
50 
51 typedef double Tv;
52 typedef int Tm;
53 
54 #define vload(a) (a)
55 #define vzero 0.
56 #define vone 1.
57 
58 #define vaddeq_mask(mask,a,b) if (mask) (a)+=(b);
59 #define vsubeq_mask(mask,a,b) if (mask) (a)-=(b);
60 #define vmuleq_mask(mask,a,b) if (mask) (a)*=(b);
61 #define vneg(a) (-(a))
62 #define vabs(a) fabs(a)
63 #define vsqrt(a) sqrt(a)
64 #define vlt(a,b) ((a)<(b))
65 #define vgt(a,b) ((a)>(b))
66 #define vge(a,b) ((a)>=(b))
67 #define vne(a,b) ((a)!=(b))
68 #define vand_mask(a,b) ((a)&&(b))
69 #define vor_mask(a,b) ((a)||(b))
vmin(Tv a,Tv b)70 static inline Tv vmin (Tv a, Tv b) { return (a<b) ? a : b; }
vmax(Tv a,Tv b)71 static inline Tv vmax (Tv a, Tv b) { return (a>b) ? a : b; }
72 #define vanyTrue(a) (a)
73 #define vallTrue(a) (a)
74 
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)75 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c, Tv d,
76   _Complex double * restrict cc)
77   { cc[0] += a+_Complex_I*b; cc[1] += c+_Complex_I*d; }
78 
79 
80 #endif
81 
82 #if (VLEN==2)
83 
84 #include <emmintrin.h>
85 
86 #if defined (__SSE3__)
87 #include <pmmintrin.h>
88 #endif
89 #if defined (__SSE4_1__)
90 #include <smmintrin.h>
91 #endif
92 
93 typedef __m128d Tv;
94 typedef __m128d Tm;
95 
96 #if defined(__SSE4_1__)
97 #define vblend__(m,a,b) _mm_blendv_pd(b,a,m)
98 #else
vblend__(Tv m,Tv a,Tv b)99 static inline Tv vblend__(Tv m, Tv a, Tv b)
100   { return _mm_or_pd(_mm_and_pd(a,m),_mm_andnot_pd(m,b)); }
101 #endif
102 #define vload(a) _mm_set1_pd(a)
103 #define vzero _mm_setzero_pd()
104 #define vone vload(1.)
105 
106 #define vaddeq_mask(mask,a,b) a+=vblend__(mask,b,vzero)
107 #define vsubeq_mask(mask,a,b) a-=vblend__(mask,b,vzero)
108 #define vmuleq_mask(mask,a,b) a*=vblend__(mask,b,vone)
109 #define vneg(a) _mm_xor_pd(vload(-0.),a)
110 #define vabs(a) _mm_andnot_pd(vload(-0.),a)
111 #define vsqrt(a) _mm_sqrt_pd(a)
112 #define vlt(a,b) _mm_cmplt_pd(a,b)
113 #define vgt(a,b) _mm_cmpgt_pd(a,b)
114 #define vge(a,b) _mm_cmpge_pd(a,b)
115 #define vne(a,b) _mm_cmpneq_pd(a,b)
116 #define vand_mask(a,b) _mm_and_pd(a,b)
117 #define vor_mask(a,b) _mm_or_pd(a,b)
118 #define vmin(a,b) _mm_min_pd(a,b)
119 #define vmax(a,b) _mm_max_pd(a,b);
120 #define vanyTrue(a) (_mm_movemask_pd(a)!=0)
121 #define vallTrue(a) (_mm_movemask_pd(a)==3)
122 
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)123 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c,
124   Tv d, _Complex double * restrict cc)
125   {
126   union {Tv v; _Complex double c; } u1, u2;
127 #if defined(__SSE3__)
128   u1.v = _mm_hadd_pd(a,b); u2.v=_mm_hadd_pd(c,d);
129 #else
130   u1.v = _mm_shuffle_pd(a,b,_MM_SHUFFLE2(0,1)) +
131          _mm_shuffle_pd(a,b,_MM_SHUFFLE2(1,0));
132   u2.v = _mm_shuffle_pd(c,d,_MM_SHUFFLE2(0,1)) +
133          _mm_shuffle_pd(c,d,_MM_SHUFFLE2(1,0));
134 #endif
135   cc[0]+=u1.c; cc[1]+=u2.c;
136   }
137 
138 #endif
139 
140 #if (VLEN==4)
141 
142 #include <immintrin.h>
143 
144 typedef __m256d Tv;
145 typedef __m256d Tm;
146 
147 #define vblend__(m,a,b) _mm256_blendv_pd(b,a,m)
148 #define vload(a) _mm256_set1_pd(a)
149 #define vzero _mm256_setzero_pd()
150 #define vone vload(1.)
151 
152 #define vaddeq_mask(mask,a,b) a+=vblend__(mask,b,vzero)
153 #define vsubeq_mask(mask,a,b) a-=vblend__(mask,b,vzero)
154 #define vmuleq_mask(mask,a,b) a*=vblend__(mask,b,vone)
155 #define vneg(a) _mm256_xor_pd(vload(-0.),a)
156 #define vabs(a) _mm256_andnot_pd(vload(-0.),a)
157 #define vsqrt(a) _mm256_sqrt_pd(a)
158 #define vlt(a,b) _mm256_cmp_pd(a,b,_CMP_LT_OQ)
159 #define vgt(a,b) _mm256_cmp_pd(a,b,_CMP_GT_OQ)
160 #define vge(a,b) _mm256_cmp_pd(a,b,_CMP_GE_OQ)
161 #define vne(a,b) _mm256_cmp_pd(a,b,_CMP_NEQ_OQ)
162 #define vand_mask(a,b) _mm256_and_pd(a,b)
163 #define vor_mask(a,b) _mm256_or_pd(a,b)
164 #define vmin(a,b) _mm256_min_pd(a,b)
165 #define vmax(a,b) _mm256_max_pd(a,b)
166 #define vanyTrue(a) (_mm256_movemask_pd(a)!=0)
167 #define vallTrue(a) (_mm256_movemask_pd(a)==15)
168 
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)169 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c, Tv d,
170   _Complex double * restrict cc)
171   {
172   Tv tmp1=_mm256_hadd_pd(a,b), tmp2=_mm256_hadd_pd(c,d);
173   Tv tmp3=_mm256_permute2f128_pd(tmp1,tmp2,49),
174      tmp4=_mm256_permute2f128_pd(tmp1,tmp2,32);
175   tmp1=tmp3+tmp4;
176   union {Tv v; _Complex double c[2]; } u;
177   u.v=tmp1;
178   cc[0]+=u.c[0]; cc[1]+=u.c[1];
179   }
180 
181 #endif
182 
183 #if (VLEN==8)
184 
185 #include <immintrin.h>
186 
187 typedef __m512d Tv;
188 typedef __mmask8 Tm;
189 
190 #define vload(a) _mm512_set1_pd(a)
191 #define vzero _mm512_setzero_pd()
192 #define vone vload(1.)
193 
194 #define vaddeq_mask(mask,a,b) a=_mm512_mask_add_pd(a,mask,a,b);
195 #define vsubeq_mask(mask,a,b) a=_mm512_mask_sub_pd(a,mask,a,b);
196 #define vmuleq_mask(mask,a,b) a=_mm512_mask_mul_pd(a,mask,a,b);
197 #define vneg(a) _mm512_mul_pd(a,vload(-1.))
198 #define vabs(a) (__m512d)_mm512_andnot_epi64((__m512i)vload(-0.),(__m512i)a)
199 #define vsqrt(a) _mm512_sqrt_pd(a)
200 #define vlt(a,b) _mm512_cmp_pd_mask(a,b,_CMP_LT_OQ)
201 #define vgt(a,b) _mm512_cmp_pd_mask(a,b,_CMP_GT_OQ)
202 #define vge(a,b) _mm512_cmp_pd_mask(a,b,_CMP_GE_OQ)
203 #define vne(a,b) _mm512_cmp_pd_mask(a,b,_CMP_NEQ_OQ)
204 #define vand_mask(a,b) ((a)&(b))
205 #define vor_mask(a,b) ((a)|(b))
206 #define vmin(a,b) _mm512_min_pd(a,b)
207 #define vmax(a,b) _mm512_max_pd(a,b)
208 #define vanyTrue(a) (a!=0)
209 #define vallTrue(a) (a==255)
210 
vhsum_cmplx_special(Tv a,Tv b,Tv c,Tv d,_Complex double * restrict cc)211 static inline void vhsum_cmplx_special (Tv a, Tv b, Tv c, Tv d,
212   _Complex double * restrict cc)
213   {
214   cc[0] += _mm512_reduce_add_pd(a)+_Complex_I*_mm512_reduce_add_pd(b);
215   cc[1] += _mm512_reduce_add_pd(c)+_Complex_I*_mm512_reduce_add_pd(d);
216   }
217 
218 #endif
219 
220 #endif
221