1 /* quick_e.h -- Fast exponential function for Intel and ARM intrinsics
2 
3    Copyright (C) 2020 European Centre for Medium-Range Weather Forecasts
4 
5    Author: Robin Hogan <r.j.hogan@ecmwf.int>
6 
7    This file is part of the Adept library, although can be used
8    stand-alone.
9 
10    The exponential function for real arguments is used in many areas
11    of physics, yet is not vectorized by many compilers.  This C++
12    header file provides a fast exponential function (quick_e::exp) for
13    single and double precision floating point numbers, Intel
14    intrinsics representing packets of 2, 4, 8 and 16 such numbers, and
15    ARM NEON intrinsics representing 2 doubles or 4 floats.  The
16    algorithm has been taken from Agner Fog's Vector Class Library. It
17    is designed to be used in other libraries that make use of Intel or
18    ARM intrinsics.  Since such libraries often define their own
19    classes for representing vectors of numbers, this file does not
20    define any such classes itself.
21 
22    Also in the namespace quick_e, this file defines the following
23    inline functions that work on intrinsics of type "Vec" and the
24    corresponding scalar type "Sca":
25 
26      Vec add(Vec x, Vec y)   Add the elements of x and y
27      Vec sub(Vec x, Vec y)   Subtract the elements of x and y
28      Vec mul(Vec x, Vec y)   Multiply the elements of x and y
29      Vec div(Vec x, Vec y)   Divide the elements of x and y
30      Vec set0<Vec>()         Returns zero in all elements
31      Vec set1<Vec>(Sca a)    Returns all elements set to a
32      Vec sqrt(Vec x)         Square root of all elements
33      Vec fmin(Vec x, Vec y)  Minimum of elements of x and y
34      Vec fmax(Vec x, Vec y)  Maximum of elements of x and y
35      Vec load(const Sca* d)  Aligned load from memory location d
36      Vec loadu(const Sca* d) Unaligned load from memory location d
37      void store(Sca* d, Vec x)  Aligned store of x to d
38      void storeu(Sca* d, Vec x) Unaligned store of x to d
39      Sca hsum(Vec x)         Horizontal sum of elements of x
40      Sca hmul(Vec x)         Horizontal product of elements of x
41      Sca hmin(Vec x)         Horizontal minimum of elements of x
42      Sca hmax(Vec x)         Horizontal maximum of elements of x
43      Vec fma(Vec x, Vec y, Vec z)  Fused multiply-add: (x*y)+z
44      Vec fnma(Vec x, Vec y, Vec z) Returns z-(x*y)
45      Vec pow2n(Vec x)        Returns 2 to the power of x
46      Vec exp(Vec x)          Returns exponential of x
47 
48  */
49 
50 #ifndef QuickE_H
51 #define QuickE_H 1
52 
53 #include <cmath>
54 
55 // Microsoft compiler doesn't define __SSE2__ even if __AVX__ is
56 // defined
57 #ifdef __AVX__
58 #ifndef __SSE2__
59 #define __SSE2__ 1
60 #endif
61 #endif
62 
63 // Headers needed for x86 vector intrinsics
64 #ifdef __SSE2__
65   #include <xmmintrin.h> // SSE
66   #include <emmintrin.h> // SSE2
67   // Numerous platforms don't define _mm_undefined_ps in xmmintrin.h,
68   // so we assume none do, except GCC >= 4.9.1 and CLANG >= 3.8.0.
69   // Those that don't use an equivalent function that sets the
70   // elements to zero.
71   #define QE_MM_UNDEFINED_PS _mm_setzero_ps
72   #ifdef __clang__
73     #if __has_builtin(__builtin_ia32_undef128)
74       #undef QE_MM_UNDEFINED_PS
75       #define QE_MM_UNDEFINED_PS _mm_undefined_ps
76     #endif
77   #elif defined(__GNUC__)
78     #define GCC_VERSION (__GNUC__ * 10000 \
79 			 + __GNUC_MINOR__ * 100	\
80 			 + __GNUC_PATCHLEVEL__)
81     #if GCC_VERSION >= 40901
82       #undef QE_MM_UNDEFINED_PS
83       #define QE_MM_UNDEFINED_PS _mm_undefined_ps
84     #endif
85     #undef GCC_VERSION
86   #endif // __clang__/__GNUC__
87 #endif // __SSE2__
88 
89 #ifdef __SSE4_1__
90 #include <smmintrin.h>
91 #endif
92 
93 #ifdef __AVX__
94   #include <tmmintrin.h> // SSE3
95   #include <immintrin.h> // AVX
96 #endif
97 
98 #ifdef __AVX512F__
99   #include <immintrin.h>
100 #endif
101 
102 #ifdef __ARM_NEON
103   #include "arm_neon.h"
104 #endif
105 
106 namespace quick_e {
107 
108   // -------------------------------------------------------------------
109   // Traits
110   // -------------------------------------------------------------------
111 
112   template <typename Type, int Size> struct packet {
113     static const bool is_available = false;
114     static const int  size         = 1;
115     typedef Type type;
116   };
117   template <typename Type> struct longest_packet {
118     typedef Type type;
119     static const int size = 1;
120   };
121 
122   // g++ issues ugly warnings if VEC is an Intel intrinsic, disabled
123   // with -Wno-ignored-attributes
124 #define QE_DEFINE_TRAITS(TYPE, SIZE, VEC, HALF_TYPE)   \
125   template <> struct packet<TYPE,SIZE> {	       \
126     static const bool is_available = true;	       \
127     static const int  size = SIZE;		       \
128     typedef VEC type;				       \
129     typedef HALF_TYPE half_type;		       \
130   };
131 
132 #define QE_DEFINE_LONGEST(VECS, VECD)			\
133   template <> struct longest_packet<float> {		\
134     typedef VECS type;					\
135     static const int size = sizeof(VECS)/sizeof(float);	\
136   };							\
137   template <> struct longest_packet<double> {		\
138     typedef VECD type;					\
139     static const int size = sizeof(VECD)/sizeof(double);\
140   };
141 
142 #ifdef __SSE2__
143   #define QE_HAVE_FAST_EXP 1
144   QE_DEFINE_TRAITS(float, 4, __m128, __m128)
145   QE_DEFINE_TRAITS(double, 2, __m128d, double)
146   #ifdef __AVX__
147     QE_DEFINE_TRAITS(float, 8, __m256, __m128)
148     QE_DEFINE_TRAITS(double, 4, __m256d, __m128d)
149     #ifdef __AVX512F__
150       QE_DEFINE_TRAITS(float, 16, __m512, __m256)
151       QE_DEFINE_TRAITS(double, 8, __m512d, __m256d)
QE_DEFINE_LONGEST(__m512,__m512d)152       QE_DEFINE_LONGEST(__m512, __m512d)
153       #define QE_LONGEST_FLOAT_PACKET 16
154       #define QE_LONGEST_DOUBLE_PACKET 8
155     #else
156       QE_DEFINE_LONGEST(__m256, __m256d)
157       #define QE_LONGEST_FLOAT_PACKET 8
158       #define QE_LONGEST_DOUBLE_PACKET 4
159     #endif
160   #else
161     QE_DEFINE_LONGEST(__m128, __m128d)
162     #define QE_LONGEST_FLOAT_PACKET 4
163     #define QE_LONGEST_DOUBLE_PACKET 2
164   #endif
165   // If QE_AVAILABLE is defined then we can use the fast exponential
166   #define QE_AVAILABLE
167 #elif defined(__ARM_NEON)
168   #define QE_HAVE_FAST_EXP 1
169   QE_DEFINE_TRAITS(float, 4, float32x4_t, float32x4_t)
170   QE_DEFINE_TRAITS(double, 2, float64x2_t, double)
171   QE_DEFINE_LONGEST(float32x4_t, float64x2_t)
172   #define QE_LONGEST_FLOAT_PACKET 4
173   #define QE_LONGEST_DOUBLE_PACKET 2
174 #else
175   // No vectorization available: longest packet is of size 1
176   QE_DEFINE_LONGEST(float, double);
177 #define QE_LONGEST_FLOAT_PACKET 1
178 #define QE_LONGEST_DOUBLE_PACKET 1
179 #endif
180 
181 
182   // -------------------------------------------------------------------
183   // Scalars
184   // -------------------------------------------------------------------
185 
186   // Define a few functions for scalars in order that the same
187   // implementation of "exp" can be used for both scalars and SIMD
188   // vectors
189   template <typename T> T add(T x, T y) { return x+y; }
sub(T x,T y)190   template <typename T> T sub(T x, T y) { return x-y; }
mul(T x,T y)191   template <typename T> T mul(T x, T y) { return x*y; }
div(T x,T y)192   template <typename T> T div(T x, T y) { return x/y; }
neg(T x)193   template <typename T> T neg(T x)      { return -x;  }
store(T * d,V x)194   template <typename T, typename V> void store(T* d, V x) { *d = x;  }
storeu(T * d,V x)195   template <typename T, typename V> void storeu(T* d, V x){ *d = x;  }
load(const T * d)196   template <typename V, typename T> V load(const T* d) { return *d;  }
loadu(const T * d)197   template <typename V, typename T> V loadu(const T* d){ return *d;  }
set1(T x)198   template <typename V, typename T> V set1(T x) { return x;   }
set0()199   template <typename V> inline V set0() { return 0.0; };
sqrt(T x)200   template <typename T> T sqrt(T x) { return std::sqrt(x); }
201 
hsum(T x)202   template <typename T> T hsum(T x) { return x; }
hmul(T x)203   template <typename T> T hmul(T x) { return x; }
hmin(T x)204   template <typename T> T hmin(T x) { return x; }
hmax(T x)205   template <typename T> T hmax(T x) { return x; }
206 
fma(T x,T y,T z)207   template <typename T> T fma(T x, T y, T z)  { return (x*y)+z; }
fnma(T x,T y,T z)208   template <typename T> T fnma(T x, T y, T z) { return z-(x*y); }
fmin(T x,T y)209   template <typename T> T fmin(T x, T y)  { return std::min(x,y); }
fmax(T x,T y)210   template <typename T> T fmax(T x, T y)  { return std::min(x,y); }
211 
212 #if __cplusplus > 199711L
fmin(float x,float y)213   template <> inline float  fmin(float x, float y)   { return std::fmin(x,y); }
fmin(double x,double y)214   template <> inline double fmin(double x, double y) { return std::fmin(x,y); }
fmax(float x,float y)215   template <> inline float  fmax(float x, float y)   { return std::fmax(x,y); }
fmax(double x,double y)216   template <> inline double fmax(double x, double y) { return std::fmax(x,y); }
217 #endif
218 
select_gt(float x1,float x2,float y1,float y2)219   inline float select_gt(float x1, float x2, float y1, float y2) {
220     if (x1 > x2) { return y1; } else { return y2; }
221   }
select_gt(double x1,double x2,double y1,double y2)222   inline double select_gt(double x1, double x2, double y1, double y2) {
223     if (x1 > x2) { return y1; } else { return y2; }
224   }
225 
all_in_range(float x,float low_bound,float high_bound)226   inline bool all_in_range(float x, float low_bound, float high_bound) {
227     return x >= low_bound && x <= high_bound;
228   }
all_in_range(double x,double low_bound,double high_bound)229   inline bool all_in_range(double x, double low_bound, double high_bound) {
230     return x >= low_bound && x <= high_bound;
231   }
232 
233   // -------------------------------------------------------------------
234   // Macros to define mathematical operations
235   // -------------------------------------------------------------------
236 
237   // Basic load store, arithmetic, sqrt, min and max
238 #define QE_DEFINE_BASIC(TYPE, VEC, LOAD, LOADU, SET0, SET1,	\
239 			STORE, STOREU, ADD, SUB, MUL, DIV,	\
240 			SQRT, FMIN, FMAX)			\
241   inline VEC add(VEC x, VEC y)       { return ADD(x, y); }	\
242   inline VEC sub(VEC x, VEC y)       { return SUB(x, y); }	\
243   inline VEC mul(VEC x, VEC y)       { return MUL(x, y); }	\
244   inline VEC div(VEC x, VEC y)       { return DIV(x, y); }	\
245   inline VEC neg(VEC x)              { return SUB(SET0(), x); }	\
246   template <> inline VEC set0<VEC>()        { return SET0();  }	\
247   template <> inline VEC set1<VEC>(TYPE x)  { return SET1(x); }	\
248   inline VEC sqrt(VEC x)             { return SQRT(x);   }	\
249   inline VEC fmin(VEC x, VEC y)      { return FMIN(x,y); }	\
250   inline VEC fmax(VEC x, VEC y)      { return FMAX(x,y); }	\
251   template <> inline VEC load<VEC,TYPE>(const TYPE* d)		\
252   { return LOAD(d);  }						\
253   template <> inline VEC loadu<VEC,TYPE>(const TYPE* d)         \
254   { return LOADU(d); }						\
255   inline void store(TYPE* d, VEC x)  { STORE(d, x);      }	\
256   inline void storeu(TYPE* d, VEC x) { STOREU(d, x);     }	\
257   inline std::ostream& operator<<(std::ostream& os, VEC x) {	\
258     static const int size = sizeof(VEC)/sizeof(TYPE);		\
259     union { VEC v; TYPE d[size]; };				\
260     v = x; os << "{";						\
261     for (int i = 0; i < size; ++i)				\
262       { os << " " << d[i]; }					\
263     os << "}"; return os;					\
264   }
265 
266 #define QE_DEFINE_CHOP(VEC, HALF_TYPE, LOW, HIGH, PACK)		\
267   inline HALF_TYPE low(VEC x)   { return LOW;       }		\
268   inline HALF_TYPE high(VEC x)  { return HIGH;      }		\
269   inline VEC pack(HALF_TYPE x, HALF_TYPE y) { return PACK; }
270 
271   // Reduction operations: horizontal sum, product, min and max
272 #define QE_DEFINE_HORIZ(TYPE, VEC, HSUM, HMUL, HMIN, HMAX)	\
273   inline TYPE hsum(VEC x)            { return HSUM(x);   }	\
274   inline TYPE hmul(VEC x)            { return HMUL(x);   }	\
275   inline TYPE hmin(VEC x)            { return HMIN(x);   }	\
276   inline TYPE hmax(VEC x)            { return HMAX(x);   }
277 
278   // Define fused multiply-add functions
279 #define QE_DEFINE_FMA(TYPE, VEC, FMA, FNMA)			\
280   inline VEC fma(VEC x,VEC y,VEC z)  { return FMA(x,y,z); }	\
281   inline VEC fma(VEC x,TYPE y,VEC z)				\
282   { return FMA(x,set1<VEC>(y),z); }				\
283   inline VEC fma(TYPE x, VEC y, TYPE z)				\
284   { return FMA(set1<VEC>(x),y,set1<VEC>(z)); }			\
285   inline VEC fma(VEC x, VEC y, TYPE z)				\
286   { return FMA(x,y,set1<VEC>(z)); }				\
287   inline VEC fnma(VEC x,VEC y,VEC z) { return FNMA(x,y,z);}
288 
289   // Alternative order of arguments for ARM NEON
290 #define QE_DEFINE_FMA_ALT(TYPE, VEC, FMA, FNMA)			\
291   inline VEC fma(VEC x,VEC y,VEC z)  { return FMA(z,x,y); }	\
292   inline VEC fma(VEC x,TYPE y,VEC z)				\
293   { return FMA(z,x,set1<VEC>(y)); }				\
294   inline VEC fma(TYPE x, VEC y, TYPE z)				\
295   { return FMA(set1<VEC>(z),set1<VEC>(x),y); }			\
296   inline VEC fma(VEC x, VEC y, TYPE z)				\
297   { return FMA(set1<VEC>(z),x,y); }				\
298   inline VEC fnma(VEC x,VEC y,VEC z) { return FNMA(z,x,y);}
299 
300   // Emulate fused multiply-add if instruction not available
301 #define QE_EMULATE_FMA(TYPE, VEC)				\
302   inline VEC fma(VEC x,VEC y,VEC z)  { return add(mul(x,y),z);}	\
303   inline VEC fma(VEC x,TYPE y,VEC z)				\
304   { return add(mul(x,set1<VEC>(y)),z); }			\
305   inline VEC fma(TYPE x, VEC y, TYPE z)				\
306   { return add(mul(set1<VEC>(x),y),set1<VEC>(z)); }		\
307   inline VEC fma(VEC x, VEC y, TYPE z)				\
308   { return add(mul(x,y),set1<VEC>(z)); }			\
309   inline VEC fnma(VEC x,VEC y,VEC z) { return sub(z,mul(x,y));}
310 
311 #define QE_DEFINE_POW2N_S(VEC, VECI, CASTTO, CASTBACK, SHIFTL,  \
312 			  SETELEM)				\
313   inline VEC pow2n(VEC n) {					\
314     const float pow2_23 = 8388608.0;				\
315     const float bias = 127.0;					\
316     VEC  a = add(n, set1<VEC>(bias+pow2_23));			\
317     VECI b = CASTTO(a);						\
318     VECI c = SHIFTL(b, SETELEM(23));				\
319     VEC  d = CASTBACK(c);					\
320     return d;							\
321   }
322 #define QE_DEFINE_POW2N_D(VEC, VECI, CASTTO, CASTBACK, SHIFTL,  \
323 			  SETELEM)				\
324   inline VEC pow2n(VEC n) {					\
325     const double pow2_52 = 4503599627370496.0;			\
326     const double bias = 1023.0;					\
327     VEC  a = add(n, set1<VEC>(bias+pow2_52));			\
328     VECI b = CASTTO(a);						\
329     VECI c = SHIFTL(b, SETELEM(52));				\
330     VEC  d = CASTBACK(c);					\
331     return d;							\
332   }
333 
334   // -------------------------------------------------------------------
335   // Define operations for SSE2: vector of 4 floats or 2 doubles
336   // -------------------------------------------------------------------
337 
338 
339 #ifdef __SSE2__
QE_DEFINE_BASIC(float,__m128,_mm_load_ps,_mm_loadu_ps,_mm_setzero_ps,_mm_set1_ps,_mm_store_ps,_mm_storeu_ps,_mm_add_ps,_mm_sub_ps,_mm_mul_ps,_mm_div_ps,_mm_sqrt_ps,_mm_min_ps,_mm_max_ps)340   QE_DEFINE_BASIC(float, __m128, _mm_load_ps, _mm_loadu_ps,
341 		  _mm_setzero_ps, _mm_set1_ps, _mm_store_ps, _mm_storeu_ps,
342 		  _mm_add_ps, _mm_sub_ps, _mm_mul_ps, _mm_div_ps,
343 		  _mm_sqrt_ps, _mm_min_ps, _mm_max_ps)
344   QE_DEFINE_BASIC(double, __m128d, _mm_load_pd, _mm_loadu_pd,
345 		  _mm_setzero_pd, _mm_set1_pd, _mm_store_pd, _mm_storeu_pd,
346 		  _mm_add_pd, _mm_sub_pd, _mm_mul_pd, _mm_div_pd,
347 		  _mm_sqrt_pd, _mm_min_pd, _mm_max_pd)
348   // Don't define chop operations for __m128 because we don't have a
349   // container for two floats
350   QE_DEFINE_CHOP(__m128d, double, _mm_cvtsd_f64(x),
351 		 _mm_cvtsd_f64(_mm_unpackhi_pd(x,x)),
352 		 _mm_set_pd(y,x))
353 
354   // No built-in horizontal operations for SSE2, so need to implement
355   // by hand
356 #define QE_DEFINE_HORIZ_SSE2(FUNC, OP_PS, OP_SS, OP_PD)			\
357   inline float FUNC(__m128 x) {						\
358     __m128 shuf = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));	\
359     __m128 sums = OP_PS(x, shuf);					\
360     shuf        = _mm_movehl_ps(shuf, sums);				\
361     return _mm_cvtss_f32(OP_SS(sums, shuf));				\
362   }									\
363   inline double FUNC(__m128d x) {					\
364     __m128 shuftmp= _mm_movehl_ps(QE_MM_UNDEFINED_PS(),			\
365 				  _mm_castpd_ps(x));			\
366     __m128d shuf  = _mm_castps_pd(shuftmp);				\
367     return  _mm_cvtsd_f64(OP_PD(x, shuf));				\
368   }
369   QE_DEFINE_HORIZ_SSE2(hsum, _mm_add_ps, _mm_add_ss, _mm_add_pd)
370   QE_DEFINE_HORIZ_SSE2(hmul, _mm_mul_ps, _mm_mul_ss, _mm_mul_pd)
371   QE_DEFINE_HORIZ_SSE2(hmin, _mm_min_ps, _mm_min_ss, _mm_min_pd)
372   QE_DEFINE_HORIZ_SSE2(hmax, _mm_max_ps, _mm_max_ss, _mm_max_pd)
373 
374 #undef QE_MM_UNDEFINED_PS
375 #undef QE_DEFINE_HORIZ_SSE2
376 
377 #ifdef __FMA__
378   QE_DEFINE_FMA(float, __m128, _mm_fmadd_ps, _mm_fnmadd_ps)
379   QE_DEFINE_FMA(double, __m128d, _mm_fmadd_pd, _mm_fnmadd_pd)
380 #else
381   QE_EMULATE_FMA(float, __m128)
382   QE_EMULATE_FMA(double, __m128d)
383 #endif
384 #ifdef __SSE4_1__
385   inline __m128 unchecked_round(__m128 x)
386   { return _mm_round_ps(x, (_MM_FROUND_TO_NEAREST_INT
387 			      |_MM_FROUND_NO_EXC)); }
unchecked_round(__m128d x)388   inline __m128d unchecked_round(__m128d x)
389   { return _mm_round_pd(x, (_MM_FROUND_TO_NEAREST_INT
390 			      |_MM_FROUND_NO_EXC)); }
391 #else
392   // No native function available, but since the arguments are limited
393   // to +/- 700, we don't need to check for going out of bounds
394   inline __m128 unchecked_round(__m128 x)
395   { return _mm_cvtepi32_ps(_mm_cvtps_epi32(x)); }
396   inline __m128d unchecked_round(__m128d x)
397   { return _mm_cvtepi32_pd(_mm_cvtpd_epi32(x)); }
398 
399 #endif
unchecked_round(float x)400   inline float unchecked_round(float x)
401   { return _mm_cvtss_f32(unchecked_round(_mm_set_ss(x))); }
unchecked_round(double x)402   inline double unchecked_round(double x)
403   { return low(unchecked_round(_mm_set_sd(x))); }
404 
QE_DEFINE_POW2N_S(__m128,__m128i,_mm_castps_si128,_mm_castsi128_ps,_mm_sll_epi32,_mm_cvtsi32_si128)405   QE_DEFINE_POW2N_S(__m128, __m128i, _mm_castps_si128,
406 		    _mm_castsi128_ps, _mm_sll_epi32, _mm_cvtsi32_si128)
407   QE_DEFINE_POW2N_D(__m128d, __m128i, _mm_castpd_si128,
408 		    _mm_castsi128_pd, _mm_sll_epi64, _mm_cvtsi32_si128)
409   inline float pow2n(float x)
410   { return _mm_cvtss_f32(pow2n(quick_e::set1<__m128>(x))); }
pow2n(double x)411   inline double pow2n(double x)
412   { return low(pow2n(quick_e::set1<__m128d>(x))); }
413 
414 
horiz_and(__m128i a)415   inline bool horiz_and(__m128i a) {
416 #ifdef __SSE4_1__
417     return _mm_testc_si128(a, _mm_set1_epi32(-1)) != 0;
418 #else
419     __m128i t1 = _mm_unpackhi_epi64(a, a); // get 64 bits down
420     __m128i t2 = _mm_and_si128(a, t1);     // and 64 bits
421 #ifdef __x86_64__
422     int64_t t5 = _mm_cvtsi128_si64(t2);    // transfer 64 bits to integer
423     return  t5 == int64_t(-1);
424 #else
425     __m128i t3 = _mm_srli_epi64(t2, 32);   // get 32 bits down
426     __m128i t4 = _mm_and_si128(t2, t3);    // and 32 bits
427     int     t5 = _mm_cvtsi128_si32(t4);    // transfer 32 bits to integer
428     return  t5 == -1;
429 #endif  // __x86_64__
430 #endif  // SSE 4.1
431   }
all_in_range(__m128 x,float low_bound,float high_bound)432   inline bool all_in_range(__m128 x, float low_bound, float high_bound) {
433     return horiz_and(_mm_castps_si128(_mm_and_ps(
434 			 _mm_cmpge_ps(x,set1<__m128>(low_bound)),
435 			 _mm_cmple_ps(x,set1<__m128>(high_bound)))));
436   }
all_in_range(__m128d x,double low_bound,double high_bound)437   inline bool all_in_range(__m128d x, double low_bound, double high_bound) {
438     return horiz_and(_mm_castpd_si128(_mm_and_pd(
439 			 _mm_cmpge_pd(x,set1<__m128d>(low_bound)),
440 			 _mm_cmple_pd(x,set1<__m128d>(high_bound)))));
441   }
442 
443   // If x1 > x2, select y1, or select y2 otherwise
select_gt(__m128 x1,__m128 x2,__m128 y1,__m128 y2)444   inline __m128 select_gt(__m128 x1, __m128 x2,
445 			  __m128 y1, __m128 y2) {
446     __m128 mask = _mm_cmpgt_ps(x1,x2);
447 #ifdef __SSE4_1__
448     return _mm_blendv_ps(y2, y1, mask);
449 #else
450     return _mm_or_ps(_mm_and_ps(mask, y1),
451 		     _mm_andnot_ps(mask, y2));
452 #endif
453   }
select_gt(__m128d x1,__m128d x2,__m128d y1,__m128d y2)454   inline __m128d select_gt(__m128d x1, __m128d x2,
455 			   __m128d y1, __m128d y2) {
456     __m128d mask = _mm_cmpgt_pd(x1,x2);
457 #ifdef __SSE4_1__
458     return _mm_blendv_pd(y2, y1, mask);
459 #else
460     return _mm_or_pd(_mm_and_pd(mask, y1),
461 		     _mm_andnot_pd(mask, y2));
462 #endif
463   }
464 #endif
465 
466   // -------------------------------------------------------------------
467   // Define operations for AVX: vector of 8 floats or 4 doubles
468   // -------------------------------------------------------------------
469 #ifdef __AVX__
470   QE_DEFINE_BASIC(float, __m256, _mm256_load_ps, _mm256_loadu_ps,
471 		  _mm256_setzero_ps, _mm256_set1_ps,
472 		  _mm256_store_ps, _mm256_storeu_ps,
473 		  _mm256_add_ps, _mm256_sub_ps,
474 		  _mm256_mul_ps, _mm256_div_ps, _mm256_sqrt_ps,
475 		  _mm256_min_ps, _mm256_max_ps)
476   QE_DEFINE_BASIC(double, __m256d, _mm256_load_pd, _mm256_loadu_pd,
477 		  _mm256_setzero_pd, _mm256_set1_pd,
478 		  _mm256_store_pd, _mm256_storeu_pd,
479 		  _mm256_add_pd, _mm256_sub_pd,
480 		  _mm256_mul_pd, _mm256_div_pd, _mm256_sqrt_pd,
481 		  _mm256_min_pd, _mm256_max_pd)
482   QE_DEFINE_CHOP(__m256, __m128,
483 		 _mm256_castps256_ps128(x), _mm256_extractf128_ps(x,1),
484 		 _mm256_permute2f128_ps(_mm256_castps128_ps256(x),
485 					_mm256_castps128_ps256(y), 0x20))
486   QE_DEFINE_CHOP(__m256d, __m128d, _mm256_castpd256_pd128(x),
487 		 _mm256_extractf128_pd(x,1),
488 		 _mm256_permute2f128_pd(_mm256_castpd128_pd256(x),
489 					_mm256_castpd128_pd256(y), 0x20));
490 
491   // Implement by calling SSE2 h* functions
hsum(__m256 x)492   inline float  hsum(__m256 x)  { return hsum(add(low(x), high(x))); }
hmul(__m256 x)493   inline float  hmul(__m256 x)  { return hmul(mul(low(x), high(x))); }
hmin(__m256 x)494   inline float  hmin(__m256 x)  { return hmin(fmin(low(x), high(x))); }
hmax(__m256 x)495   inline float  hmax(__m256 x)  { return hmax(fmax(low(x), high(x))); }
hsum(__m256d x)496   inline double hsum(__m256d x) { return hsum(add(low(x),  high(x))); } // Alternative would be to use _mm_hadd_pd
hmul(__m256d x)497   inline double hmul(__m256d x) { return hmul(mul(low(x),  high(x))); }
hmin(__m256d x)498   inline double hmin(__m256d x) { return hmin(fmin(low(x), high(x))); }
hmax(__m256d x)499   inline double hmax(__m256d x) { return hmax(fmax(low(x), high(x))); }
500 
501   // Define extras
502 #ifdef __FMA__
QE_DEFINE_FMA(float,__m256,_mm256_fmadd_ps,_mm256_fnmadd_ps)503   QE_DEFINE_FMA(float, __m256,  _mm256_fmadd_ps, _mm256_fnmadd_ps)
504   QE_DEFINE_FMA(double, __m256d, _mm256_fmadd_pd, _mm256_fnmadd_pd)
505 #else
506   QE_EMULATE_FMA(float, __m256)
507   QE_EMULATE_FMA(double, __m256d)
508 #endif
509 
510   inline __m256 unchecked_round(__m256 x)
511   { return _mm256_round_ps(x, (_MM_FROUND_TO_NEAREST_INT
512 			       |_MM_FROUND_NO_EXC)); }
unchecked_round(__m256d x)513   inline __m256d unchecked_round(__m256d x)
514   { return _mm256_round_pd(x, (_MM_FROUND_TO_NEAREST_INT
515 			       |_MM_FROUND_NO_EXC)); }
516   #ifdef __AVX2__
QE_DEFINE_POW2N_S(__m256,__m256i,_mm256_castps_si256,_mm256_castsi256_ps,_mm256_sll_epi32,_mm_cvtsi32_si128)517     QE_DEFINE_POW2N_S(__m256, __m256i, _mm256_castps_si256,
518 		      _mm256_castsi256_ps, _mm256_sll_epi32, _mm_cvtsi32_si128)
519     QE_DEFINE_POW2N_D(__m256d, __m256i, _mm256_castpd_si256,
520 		      _mm256_castsi256_pd, _mm256_sll_epi64, _mm_cvtsi32_si128)
521   #else
522     // Suboptimized versions call the SSE2 functions on the upper and
523     // lower parts
524     inline __m256 pow2n(__m256 n) {
525       return pack(pow2n(low(n)), pow2n(high(n)));
526     }
527     inline __m256d pow2n(__m256d n) {
528       return pack(pow2n(low(n)), pow2n(high(n)));
529     }
530   #endif
531 
532   // Return true if all elements of x are in the range (inclusive) of
533   // low_bound to high_bound.  If so the exp call can exit before the
534   // more costly case of working out what to do with inputs out of
535   // bounds.  Note that _CMP_GE_OS means compare
536   // greater-than-or-equal-to, ordered, signaling, where "ordered"
537   // means that if either operand is NaN, the result is false.
538   inline bool all_in_range(__m256 x, float low_bound, float high_bound) {
539     return _mm256_testc_si256(_mm256_castps_si256(_mm256_and_ps(
540 		 _mm256_cmp_ps(x,set1<__m256>(low_bound), _CMP_GE_OS),
541 		 _mm256_cmp_ps(x,set1<__m256>(high_bound), _CMP_LE_OS))),
542 			      _mm256_set1_epi32(-1)) != 0;
543   }
all_in_range(__m256d x,double low_bound,double high_bound)544   inline bool all_in_range(__m256d x, double low_bound, double high_bound) {
545     return _mm256_testc_si256(_mm256_castpd_si256(_mm256_and_pd(
546 		 _mm256_cmp_pd(x,set1<__m256d>(low_bound), _CMP_GE_OS),
547 		 _mm256_cmp_pd(x,set1<__m256d>(high_bound), _CMP_LE_OS))),
548 			      _mm256_set1_epi32(-1)) != 0;
549   }
select_gt(__m256 x1,__m256 x2,__m256 y1,__m256 y2)550   inline __m256 select_gt(__m256 x1, __m256 x2,
551 			  __m256 y1, __m256 y2) {
552     return _mm256_blendv_ps(y2, y1, _mm256_cmp_ps(x1,x2,_CMP_GT_OS));
553   }
select_gt(__m256d x1,__m256d x2,__m256d y1,__m256d y2)554   inline __m256d select_gt(__m256d x1, __m256d x2,
555 			   __m256d y1, __m256d y2) {
556     return _mm256_blendv_pd(y2, y1, _mm256_cmp_pd(x1,x2,_CMP_GT_OS));
557   }
558 
559 #endif
560 
561 
562   // -------------------------------------------------------------------
563   // Define operations for AVX512: vector of 16 floats or 8 doubles
564   // -------------------------------------------------------------------
565 #ifdef __AVX512F__
QE_DEFINE_BASIC(float,__m512,_mm512_load_ps,_mm512_loadu_ps,_mm512_setzero_ps,_mm512_set1_ps,_mm512_store_ps,_mm512_storeu_ps,_mm512_add_ps,_mm512_sub_ps,_mm512_mul_ps,_mm512_div_ps,_mm512_sqrt_ps,_mm512_min_ps,_mm512_max_ps)566   QE_DEFINE_BASIC(float, __m512, _mm512_load_ps, _mm512_loadu_ps,
567 		  _mm512_setzero_ps, _mm512_set1_ps,
568 		  _mm512_store_ps, _mm512_storeu_ps,
569 		  _mm512_add_ps, _mm512_sub_ps,
570 		  _mm512_mul_ps, _mm512_div_ps, _mm512_sqrt_ps,
571 		  _mm512_min_ps, _mm512_max_ps)
572   QE_DEFINE_HORIZ(float, __m512,
573 		  _mm512_reduce_add_ps, _mm512_reduce_mul_ps,
574 		  _mm512_reduce_min_ps, _mm512_reduce_max_ps)
575   QE_DEFINE_BASIC(double, __m512d, _mm512_load_pd, _mm512_loadu_pd,
576 		  _mm512_setzero_pd, _mm512_set1_pd,
577 		  _mm512_store_pd, _mm512_storeu_pd,
578 		  _mm512_add_pd, _mm512_sub_pd,
579 		  _mm512_mul_pd, _mm512_div_pd, _mm512_sqrt_pd,
580 		  _mm512_min_pd, _mm512_max_pd)
581   QE_DEFINE_HORIZ(double, __m512d,
582 		  _mm512_reduce_add_pd, _mm512_reduce_mul_pd,
583 		  _mm512_reduce_min_pd, _mm512_reduce_max_pd)
584 
585   inline __m512 unchecked_round(__m512 x)   { return _mm512_roundscale_ps(x, 0); }
unchecked_round(__m512d x)586   inline __m512d unchecked_round(__m512d x) { return _mm512_roundscale_pd(x, 0); }
587 
QE_DEFINE_FMA(float,__m512,_mm512_fmadd_ps,_mm512_fnmadd_ps)588   QE_DEFINE_FMA(float, __m512,  _mm512_fmadd_ps, _mm512_fnmadd_ps)
589   QE_DEFINE_FMA(double, __m512d, _mm512_fmadd_pd, _mm512_fnmadd_pd)
590 
591   QE_DEFINE_POW2N_S(__m512, __m512i, _mm512_castps_si512,
592 		    _mm512_castsi512_ps, _mm512_sll_epi32, _mm_cvtsi32_si128)
593   QE_DEFINE_POW2N_D(__m512d, __m512i, _mm512_castpd_si512,
594 		    _mm512_castsi512_pd, _mm512_sll_epi64, _mm_cvtsi32_si128)
595 
596   inline bool all_in_range(__m512 x, float low_bound, float high_bound) {
597     return static_cast<unsigned short int>(_mm512_kand(
598 	      _mm512_cmp_ps_mask(x,set1<__m512>(low_bound),_CMP_GE_OS),
599 	      _mm512_cmp_ps_mask(x,set1<__m512>(high_bound),_CMP_LE_OS)))
600       == static_cast<unsigned short int>(65535);
601   }
all_in_range(__m512d x,double low_bound,double high_bound)602   inline bool all_in_range(__m512d x, double low_bound, double high_bound) {
603     return static_cast<unsigned short int>(_mm512_kand(
604 	      _mm512_cmp_pd_mask(x,set1<__m512d>(low_bound),_CMP_GE_OS),
605 	      _mm512_cmp_pd_mask(x,set1<__m512d>(high_bound),_CMP_LE_OS)))
606       == static_cast<unsigned short int>(255);
607   }
select_gt(__m512 x1,__m512 x2,__m512 y1,__m512 y2)608   inline __m512 select_gt(__m512 x1, __m512 x2,
609 			  __m512 y1, __m512 y2) {
610     return _mm512_mask_mov_ps(y2, _mm512_cmp_ps_mask(x1,x2,_CMP_GT_OS), y1);
611   }
select_gt(__m512d x1,__m512d x2,__m512d y1,__m512d y2)612   inline __m512d select_gt(__m512d x1, __m512d x2,
613 			   __m512d y1, __m512d y2) {
614     return _mm512_mask_mov_pd(y2, _mm512_cmp_pd_mask(x1,x2,_CMP_GT_OS), y1);
615   }
616 
617 #endif
618 
619 
620 #ifdef __ARM_NEON
621 
622   // Implement ARM version of x86 setzero
vzeroq_f32()623   inline float32x4_t vzeroq_f32() { return vdupq_n_f32(0.0); }
vzeroq_f64()624   inline float64x2_t vzeroq_f64() { return vdupq_n_f64(0.0); }
625   // Horizontal multiply across vector
vmulvq_f32(float32x4_t x)626   inline float vmulvq_f32(float32x4_t x) {
627     union {
628       float32x2_t v;
629       float data[2];
630     };
631     v = vmul_f32(vget_low_f32(x), vget_high_f32(x));
632     return data[0] * data[1];
633   }
vmulvq_f64(float64x2_t x)634   inline double vmulvq_f64(float64x2_t x) {
635     union {
636       float64x2_t v;
637       double data[2];
638     };
639     v = x;
640     return data[0] * data[1];
641   }
642 
QE_DEFINE_BASIC(float,float32x4_t,vld1q_f32,vld1q_f32,vzeroq_f32,vdupq_n_f32,vst1q_f32,vst1q_f32,vaddq_f32,vsubq_f32,vmulq_f32,vdivq_f32,vsqrtq_f32,vminq_f32,vmaxq_f32)643   QE_DEFINE_BASIC(float, float32x4_t, vld1q_f32, vld1q_f32,
644 		  vzeroq_f32, vdupq_n_f32, vst1q_f32, vst1q_f32,
645 		  vaddq_f32, vsubq_f32, vmulq_f32, vdivq_f32,
646 		  vsqrtq_f32, vminq_f32, vmaxq_f32)
647   QE_DEFINE_HORIZ(float, float32x4_t,
648 		  vaddvq_f32, vmulvq_f32,
649 		  vminvq_f32, vmaxvq_f32)
650   QE_DEFINE_BASIC(double, float64x2_t, vld1q_f64, vld1q_f64,
651 		  vzeroq_f64, vdupq_n_f64, vst1q_f64, vst1q_f64,
652 		  vaddq_f64, vsubq_f64, vmulq_f64, vdivq_f64,
653 		  vsqrtq_f64, vminq_f64, vmaxq_f64)
654   QE_DEFINE_HORIZ(double, float64x2_t,
655 		  vaddvq_f64, vmulvq_f64,
656 		  vminvq_f64, vmaxvq_f64)
657   QE_DEFINE_POW2N_S(float32x4_t, int32x4_t, vreinterpretq_s32_f32,
658 		    vreinterpretq_f32_s32, vshlq_s32, vdupq_n_s32)
659   QE_DEFINE_POW2N_D(float64x2_t, int64x2_t, vreinterpretq_s64_f64,
660 		    vreinterpretq_f64_s64, vshlq_s64, vdupq_n_s64)
661   QE_DEFINE_FMA_ALT(float, float32x4_t, vfmaq_f32, vfmsq_f32)
662   QE_DEFINE_FMA_ALT(double, float64x2_t, vfmaq_f64, vfmsq_f64)
663   inline bool all_in_range(float32x4_t x, double low_bound, double high_bound) {
664     union {
665       uint32x2_t v;
666       uint32_t data[2];
667     };
668     uint32x4_t tmp = vandq_u32(vcgeq_f32(x,vdupq_n_f32(low_bound)),
669 			       vcleq_f32(x,vdupq_n_f32(high_bound)));
670     v = vand_u32(vget_low_u32(tmp), vget_high_u32(tmp));
671     return data[0] && data[1];
672   }
all_in_range(float64x2_t x,double low_bound,double high_bound)673   inline bool all_in_range(float64x2_t x, double low_bound, double high_bound) {
674     union {
675       uint64x2_t v;
676       uint64_t data[2];
677     };
678     v = vandq_u64(vcgeq_f64(x,vdupq_n_f64(low_bound)),
679 		  vcleq_f64(x,vdupq_n_f64(high_bound)));
680     return data[0] && data[1];
681   }
682 
unchecked_round(float32x4_t x)683   inline float32x4_t unchecked_round(float32x4_t x) {
684     return vcvtq_f32_s32(vcvtaq_s32_f32(x));
685   }
unchecked_round(float64x2_t x)686   inline float64x2_t unchecked_round(float64x2_t x) {
687     return vcvtq_f64_s64(vcvtaq_s64_f64(x));
688   }
select_gt(float32x4_t x1,float32x4_t x2,float32x4_t y1,float32x4_t y2)689   inline float32x4_t select_gt(float32x4_t x1, float32x4_t x2,
690 			       float32x4_t y1, float32x4_t y2) {
691     return vbslq_f32(vcgtq_f32(x1,x2), y1, y2);
692   }
select_gt(float64x2_t x1,float64x2_t x2,float64x2_t y1,float64x2_t y2)693   inline float64x2_t select_gt(float64x2_t x1, float64x2_t x2,
694 			       float64x2_t y1, float64x2_t y2) {
695     return vbslq_f64(vcgtq_f64(x1,x2), y1, y2);
696   }
697 
unchecked_round(float x)698   inline float unchecked_round(float x)
699   { return vgetq_lane_f32(unchecked_round(vdupq_n_f32(x)), 0); }
unchecked_round(double x)700   inline double unchecked_round(double x)
701   { return vgetq_lane_f64(unchecked_round(vdupq_n_f64(x)), 0); }
702 
pow2n(float x)703   inline float pow2n(float x) {
704     return vgetq_lane_f32(pow2n(vdupq_n_f32(x)),0);
705   }
pow2n(double x)706   inline double pow2n(double x) {
707     return vgetq_lane_f64(pow2n(vdupq_n_f64(x)),0);
708   }
709 
710 #endif
711 
712 
713 #ifdef QE_HAVE_FAST_EXP
714 
715   // -------------------------------------------------------------------
716   // Implementation of fast exponential
717   // -------------------------------------------------------------------
718 
719   template<typename Type, typename Vec>
720   static inline
polynomial_5(Vec const x,Type c0,Type c1,Type c2,Type c3,Type c4,Type c5)721   Vec polynomial_5(Vec const x, Type c0, Type c1, Type c2, Type c3, Type c4, Type c5) {
722     // calculates polynomial c5*x^5 + c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
723     using quick_e::fma;
724     Vec x2 = mul(x, x);
725     Vec x4 = mul(x2, x2);
726     return fma(fma(c3, x, c2), x2, fma(fma(c5, x, c4), x4, fma(c1, x, c0)));
727   }
728 
729   template<typename Vec>
730   inline
fastexp_float(Vec const initial_x)731   Vec fastexp_float(Vec const initial_x) {
732     using namespace quick_e;
733     using quick_e::unchecked_round;
734     using quick_e::fma;
735 
736     // Taylor coefficients
737     const float P0expf   =  1.f/2.f;
738     const float P1expf   =  1.f/6.f;
739     const float P2expf   =  1.f/24.f;
740     const float P3expf   =  1.f/120.f;
741     const float P4expf   =  1.f/720.f;
742     const float P5expf   =  1.f/5040.f;
743     const float VM_LOG2E = 1.44269504088896340736;  // 1/log(2)
744     const float ln2f_hi  =  0.693359375f;
745     const float ln2f_lo  = -2.12194440e-4f;
746 #ifndef __FAST_MATH__
747     const float min_x    = -87.3f;
748     const float max_x    = +89.0f;
749 #endif
750 
751     Vec r = unchecked_round(mul(initial_x,set1<Vec>(VM_LOG2E)));
752     Vec x = fnma(r, set1<Vec>(ln2f_hi), initial_x); //  x -= r * ln2f_hi;
753     x = fnma(r, set1<Vec>(ln2f_lo), x);             //  x -= r * ln2f_lo;
754 
755     Vec z = polynomial_5(x,P0expf,P1expf,P2expf,P3expf,P4expf,P5expf);
756 
757     Vec x2 = mul(x, x);
758     z = fma(z, x2, x);                       // z *= x2;  z += x;
759 
760     // multiply by power of 2
761     Vec n2 = pow2n(r);
762 
763     z = fma(z,n2,n2);
764 
765 #ifdef __FAST_MATH__
766     return z;
767 #else
768     if (all_in_range(initial_x, min_x, max_x)) {
769       return z;
770     }
771     else {
772       // When initial_x<-87.3, set exp(x) to -Inf
773       z = select_gt(set1<Vec>(min_x), initial_x, set0<Vec>(), z);
774       // When initial_x>+89.0, set exp(x) to +Inf
775       z = select_gt(initial_x, set1<Vec>(max_x),
776 		    set1<Vec>(std::numeric_limits<float>::infinity()),
777 		    z);
778       return z;
779     }
780 #endif
781   }
782 
783 
784   template <typename Type, typename Vec>
polynomial_13m(Vec const x,Type c2,Type c3,Type c4,Type c5,Type c6,Type c7,Type c8,Type c9,Type c10,Type c11,Type c12,Type c13)785   Vec polynomial_13m(Vec const x,
786 		     Type c2, Type c3, Type c4, Type c5, Type c6, Type c7,
787 		     Type c8, Type c9, Type c10, Type c11, Type c12, Type c13) {
788     // calculates polynomial c13*x^13 + c12*x^12 + ... + x + 0
789     using quick_e::fma;
790 
791     Vec x2 = mul(x, x);
792     Vec x4 = mul(x2, x2);
793     //    Vec x8 = mul(x4, x4);
794     return fma(fma(fma(c13, x, c12), x4,
795 		   fma(fma(c11, x, c10), x2, fma(c9, x, c8))), mul(x4, x4),
796 	       fma(fma(fma(c7, x, c6), x2, fma(c5, x, c4)), x4,
797 		   fma(fma(c3, x, c2), x2, x)));
798     //return fma(fma(fma(fma(fma(fma(fma(fma(fma(fma(fma(fma(c13, x, c12), x, c11), x, c10), x, c9), x, c8), x, c7), x, c6), x, c5), x, c4), x, c3), x, c2), mul(x,x), x);
799 
800   }
801 
802 
803   // Template function implementing the fast exponential, where Vec
804   // can be double, __m128d, __m256d or __m512d
805   template <typename Vec>
806   inline
fastexp_double(Vec const initial_x)807   Vec fastexp_double(Vec const initial_x) {
808     using namespace quick_e;
809     using quick_e::unchecked_round;
810     using quick_e::fma;
811 
812     const double p2  = 1./2.;
813     const double p3  = 1./6.;
814     const double p4  = 1./24.;
815     const double p5  = 1./120.;
816     const double p6  = 1./720.;
817     const double p7  = 1./5040.;
818     const double p8  = 1./40320.;
819     const double p9  = 1./362880.;
820     const double p10 = 1./3628800.;
821     const double p11 = 1./39916800.;
822     const double p12 = 1./479001600.;
823     const double p13 = 1./6227020800.;
824     const double VM_LOG2E = 1.44269504088896340736;  // 1/log(2)
825     const double ln2d_hi = 0.693145751953125;
826     const double ln2d_lo = 1.42860682030941723212E-6;
827 #ifndef __FAST_MATH__
828     const double min_x = -708.39;
829     const double max_x = +709.70;
830 #endif
831 
832     Vec r = unchecked_round(mul(initial_x,set1<Vec>(VM_LOG2E)));
833     // subtraction in two steps for higher precision
834     Vec x = fnma(r, set1<Vec>(ln2d_hi), initial_x);   //  x -= r * ln2d_hi;
835     x = fnma(r, set1<Vec>(ln2d_lo), x);               //  x -= r * ln2d_lo;
836 
837     // multiply by power of 2
838     Vec n2 = pow2n(r);
839 
840     Vec z = polynomial_13m(x, p2, p3, p4, p5, p6, p7,
841 			   p8, p9, p10, p11, p12, p13);
842     z = fma(z,n2,n2);
843 #ifdef __FAST_MATH__
844     return z;
845 #else
846     if (all_in_range(initial_x, min_x, max_x)) {
847       // Fast normal path
848       return z;
849     }
850     else {
851       // When initial_x<-708.39, set exp(x) to 0.0
852       z = select_gt(set1<Vec>(min_x), initial_x, set0<Vec>(), z);
853       // When initial_x>+709.70.0, set exp(x) to +Inf
854       z = select_gt(initial_x, set1<Vec>(max_x),
855 		    set1<Vec>(std::numeric_limits<double>::infinity()),
856 		    z);
857       return z;
858     }
859 #endif
860   }
861 #endif
862 
863 
864   // Define the various overloads for the quick_e::exp function taking
865   // Intel intrinsics as an argument
866 
867 #ifdef __SSE2__
exp(__m128 x)868   inline __m128  exp(__m128 x)  { return fastexp_float(x);  }
exp(__m128d x)869   inline __m128d exp(__m128d x) { return fastexp_double(x); }
870 #endif
871 
872 #ifdef __AVX__
exp(__m256 x)873   inline __m256  exp(__m256 x)  { return fastexp_float(x);  }
exp(__m256d x)874   inline __m256d exp(__m256d x) { return fastexp_double(x); }
875 #endif
876 
877 #ifdef __AVX512F__
exp(__m512 x)878   inline __m512  exp(__m512 x)  { return fastexp_float(x);  }
exp(__m512d x)879   inline __m512d exp(__m512d x) { return fastexp_double(x); }
880 #endif
881 
882 #ifdef __ARM_NEON
exp(float32x4_t x)883   inline float32x4_t exp(float32x4_t x) { return fastexp_float(x);  }
exp(float64x2_t x)884   inline float64x2_t exp(float64x2_t x) { return fastexp_double(x); }
885 #endif
886 
887   // Define the quick_e::exp function for scalar arguments
888 #ifdef QE_HAVE_FAST_EXP
exp(float x)889   inline float  exp(float x)  { return quick_e::fastexp_float(x); }
exp(double x)890   inline double exp(double x) { return quick_e::fastexp_double(x); }
891 #else
892   // If no vectorization available then we fall back to the standard
893   // library scalar version
exp(float x)894   inline float  exp(float x)  { return std::exp(x); }
exp(double x)895   inline double exp(double x) { return std::exp(x); }
896 #endif
897 
898 #undef QE_DEFINE_TRAITS
899 #undef QE_DEFINE_LONGEST
900 #undef QE_DEFINE_BASIC
901 #undef QE_DEFINE_CHOP
902 #undef QE_DEFINE_HORIZ
903 #undef QE_DEFINE_FMA
904 #undef QE_DEFINE_FMA_ALT
905 #undef QE_EMULATE_FMA
906 #undef QE_DEFINE_POW2N_S
907 #undef QE_DEFINE_POW2N_D
908 #undef QE_HAVE_FAST_EXP
909 }
910 
911 #endif
912