1 /****************************************************************************
2  * Copyright (C) 2017 Intel Corporation.   All Rights Reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  ****************************************************************************/
23 #pragma once
24 
25 #if !defined(__cplusplus)
26 #error C++ compilation required
27 #endif
28 
29 #include <immintrin.h>
30 #include <inttypes.h>
31 #include <stdint.h>
32 
33 #define SIMD_ARCH_AVX 0
34 #define SIMD_ARCH_AVX2 1
35 #define SIMD_ARCH_AVX512 2
36 
37 #if !defined(SIMD_ARCH)
38 #define SIMD_ARCH SIMD_ARCH_AVX
39 #endif
40 
41 #if defined(_MSC_VER)
42 #define SIMDCALL __vectorcall
43 #define SIMDINLINE __forceinline
44 #define SIMDALIGN(type_, align_) __declspec(align(align_)) type_
45 #else
46 #define SIMDCALL
47 #define SIMDINLINE inline
48 #define SIMDALIGN(type_, align_) type_ __attribute__((aligned(align_)))
49 #endif
50 
51 // For documentation, please see the following include...
52 // #include "simdlib_interface.hpp"
53 
54 namespace SIMDImpl
55 {
56     enum class CompareType
57     {
58         EQ_OQ    = 0x00, // Equal (ordered, nonsignaling)
59         LT_OS    = 0x01, // Less-than (ordered, signaling)
60         LE_OS    = 0x02, // Less-than-or-equal (ordered, signaling)
61         UNORD_Q  = 0x03, // Unordered (nonsignaling)
62         NEQ_UQ   = 0x04, // Not-equal (unordered, nonsignaling)
63         NLT_US   = 0x05, // Not-less-than (unordered, signaling)
64         NLE_US   = 0x06, // Not-less-than-or-equal (unordered, signaling)
65         ORD_Q    = 0x07, // Ordered (nonsignaling)
66         EQ_UQ    = 0x08, // Equal (unordered, non-signaling)
67         NGE_US   = 0x09, // Not-greater-than-or-equal (unordered, signaling)
68         NGT_US   = 0x0A, // Not-greater-than (unordered, signaling)
69         FALSE_OQ = 0x0B, // False (ordered, nonsignaling)
70         NEQ_OQ   = 0x0C, // Not-equal (ordered, non-signaling)
71         GE_OS    = 0x0D, // Greater-than-or-equal (ordered, signaling)
72         GT_OS    = 0x0E, // Greater-than (ordered, signaling)
73         TRUE_UQ  = 0x0F, // True (unordered, non-signaling)
74         EQ_OS    = 0x10, // Equal (ordered, signaling)
75         LT_OQ    = 0x11, // Less-than (ordered, nonsignaling)
76         LE_OQ    = 0x12, // Less-than-or-equal (ordered, nonsignaling)
77         UNORD_S  = 0x13, // Unordered (signaling)
78         NEQ_US   = 0x14, // Not-equal (unordered, signaling)
79         NLT_UQ   = 0x15, // Not-less-than (unordered, nonsignaling)
80         NLE_UQ   = 0x16, // Not-less-than-or-equal (unordered, nonsignaling)
81         ORD_S    = 0x17, // Ordered (signaling)
82         EQ_US    = 0x18, // Equal (unordered, signaling)
83         NGE_UQ   = 0x19, // Not-greater-than-or-equal (unordered, nonsignaling)
84         NGT_UQ   = 0x1A, // Not-greater-than (unordered, nonsignaling)
85         FALSE_OS = 0x1B, // False (ordered, signaling)
86         NEQ_OS   = 0x1C, // Not-equal (ordered, signaling)
87         GE_OQ    = 0x1D, // Greater-than-or-equal (ordered, nonsignaling)
88         GT_OQ    = 0x1E, // Greater-than (ordered, nonsignaling)
89         TRUE_US  = 0x1F, // True (unordered, signaling)
90     };
91 
92 #if SIMD_ARCH >= SIMD_ARCH_AVX512
93     enum class CompareTypeInt
94     {
95         EQ = _MM_CMPINT_EQ, // Equal
96         LT = _MM_CMPINT_LT, // Less than
97         LE = _MM_CMPINT_LE, // Less than or Equal
98         NE = _MM_CMPINT_NE, // Not Equal
99         GE = _MM_CMPINT_GE, // Greater than or Equal
100         GT = _MM_CMPINT_GT, // Greater than
101     };
102 #endif // SIMD_ARCH >= SIMD_ARCH_AVX512
103 
104     enum class ScaleFactor
105     {
106         SF_1 = 1, // No scaling
107         SF_2 = 2, // Scale offset by 2
108         SF_4 = 4, // Scale offset by 4
109         SF_8 = 8, // Scale offset by 8
110     };
111 
112     enum class RoundMode
113     {
114         TO_NEAREST_INT = 0x00, // Round to nearest integer == TRUNCATE(value + 0.5)
115         TO_NEG_INF     = 0x01, // Round to negative infinity
116         TO_POS_INF     = 0x02, // Round to positive infinity
117         TO_ZERO        = 0x03, // Round to 0 a.k.a. truncate
118         CUR_DIRECTION  = 0x04, // Round in direction set in MXCSR register
119 
120         RAISE_EXC = 0x00, // Raise exception on overflow
121         NO_EXC    = 0x08, // Suppress exceptions
122 
123         NINT        = static_cast<int>(TO_NEAREST_INT) | static_cast<int>(RAISE_EXC),
124         NINT_NOEXC  = static_cast<int>(TO_NEAREST_INT) | static_cast<int>(NO_EXC),
125         FLOOR       = static_cast<int>(TO_NEG_INF) | static_cast<int>(RAISE_EXC),
126         FLOOR_NOEXC = static_cast<int>(TO_NEG_INF) | static_cast<int>(NO_EXC),
127         CEIL        = static_cast<int>(TO_POS_INF) | static_cast<int>(RAISE_EXC),
128         CEIL_NOEXC  = static_cast<int>(TO_POS_INF) | static_cast<int>(NO_EXC),
129         TRUNC       = static_cast<int>(TO_ZERO) | static_cast<int>(RAISE_EXC),
130         TRUNC_NOEXC = static_cast<int>(TO_ZERO) | static_cast<int>(NO_EXC),
131         RINT        = static_cast<int>(CUR_DIRECTION) | static_cast<int>(RAISE_EXC),
132         NEARBYINT   = static_cast<int>(CUR_DIRECTION) | static_cast<int>(NO_EXC),
133     };
134 
135     struct Traits
136     {
137         using CompareType = SIMDImpl::CompareType;
138         using ScaleFactor = SIMDImpl::ScaleFactor;
139         using RoundMode   = SIMDImpl::RoundMode;
140     };
141 
142     // Attribute, 4-dimensional attribute in SIMD SOA layout
143     template <typename Float, typename Integer, typename Double>
144     union Vec4
145     {
146         Float   v[4];
147         Integer vi[4];
148         Double  vd[4];
149         struct
150         {
151             Float x;
152             Float y;
153             Float z;
154             Float w;
155         };
operator [](const int i)156         SIMDINLINE Float& SIMDCALL operator[](const int i) { return v[i]; }
operator [](const int i) const157         SIMDINLINE Float const& SIMDCALL operator[](const int i) const { return v[i]; }
operator =(Vec4 const & in)158         SIMDINLINE Vec4& SIMDCALL operator=(Vec4 const& in)
159         {
160             v[0] = in.v[0];
161             v[1] = in.v[1];
162             v[2] = in.v[2];
163             v[3] = in.v[3];
164             return *this;
165         }
166     };
167 
168     namespace SIMD128Impl
169     {
170         union Float
171         {
172             SIMDINLINE Float() = default;
Float(__m128 in)173             SIMDINLINE Float(__m128 in) : v(in) {}
operator =(__m128 in)174             SIMDINLINE Float& SIMDCALL operator=(__m128 in)
175             {
176                 v = in;
177                 return *this;
178             }
operator =(Float const & in)179             SIMDINLINE Float& SIMDCALL operator=(Float const& in)
180             {
181                 v = in.v;
182                 return *this;
183             }
operator __m128() const184             SIMDINLINE SIMDCALL operator __m128() const { return v; }
185 
186             SIMDALIGN(__m128, 16) v;
187         };
188 
189         union Integer
190         {
191             SIMDINLINE Integer() = default;
Integer(__m128i in)192             SIMDINLINE Integer(__m128i in) : v(in) {}
operator =(__m128i in)193             SIMDINLINE Integer& SIMDCALL operator=(__m128i in)
194             {
195                 v = in;
196                 return *this;
197             }
operator =(Integer const & in)198             SIMDINLINE Integer& SIMDCALL operator=(Integer const& in)
199             {
200                 v = in.v;
201                 return *this;
202             }
operator __m128i() const203             SIMDINLINE SIMDCALL operator __m128i() const { return v; }
204 
205             SIMDALIGN(__m128i, 16) v;
206         };
207 
208         union Double
209         {
210             SIMDINLINE Double() = default;
Double(__m128d in)211             SIMDINLINE Double(__m128d in) : v(in) {}
operator =(__m128d in)212             SIMDINLINE Double& SIMDCALL operator=(__m128d in)
213             {
214                 v = in;
215                 return *this;
216             }
operator =(Double const & in)217             SIMDINLINE Double& SIMDCALL operator=(Double const& in)
218             {
219                 v = in.v;
220                 return *this;
221             }
operator __m128d() const222             SIMDINLINE SIMDCALL operator __m128d() const { return v; }
223 
224             SIMDALIGN(__m128d, 16) v;
225         };
226 
227         using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>;
228         using Mask = uint8_t;
229 
230         static const uint32_t SIMD_WIDTH = 4;
231     } // namespace SIMD128Impl
232 
233     namespace SIMD256Impl
234     {
235         union Float
236         {
237             SIMDINLINE Float() = default;
Float(__m256 in)238             SIMDINLINE Float(__m256 in) : v(in) {}
Float(SIMD128Impl::Float const & in_lo,SIMD128Impl::Float const & in_hi=_mm_setzero_ps ())239             SIMDINLINE Float(SIMD128Impl::Float const& in_lo,
240                              SIMD128Impl::Float const& in_hi = _mm_setzero_ps())
241             {
242                 v = _mm256_insertf128_ps(_mm256_castps128_ps256(in_lo), in_hi, 0x1);
243             }
operator =(__m256 in)244             SIMDINLINE Float& SIMDCALL operator=(__m256 in)
245             {
246                 v = in;
247                 return *this;
248             }
operator =(Float const & in)249             SIMDINLINE Float& SIMDCALL operator=(Float const& in)
250             {
251                 v = in.v;
252                 return *this;
253             }
operator __m256() const254             SIMDINLINE SIMDCALL operator __m256() const { return v; }
255 
256             SIMDALIGN(__m256, 32) v;
257             SIMD128Impl::Float v4[2];
258         };
259 
260         union Integer
261         {
262             SIMDINLINE Integer() = default;
Integer(__m256i in)263             SIMDINLINE Integer(__m256i in) : v(in) {}
Integer(SIMD128Impl::Integer const & in_lo,SIMD128Impl::Integer const & in_hi=_mm_setzero_si128 ())264             SIMDINLINE Integer(SIMD128Impl::Integer const& in_lo,
265                                SIMD128Impl::Integer const& in_hi = _mm_setzero_si128())
266             {
267                 v = _mm256_insertf128_si256(_mm256_castsi128_si256(in_lo), in_hi, 0x1);
268             }
operator =(__m256i in)269             SIMDINLINE Integer& SIMDCALL operator=(__m256i in)
270             {
271                 v = in;
272                 return *this;
273             }
operator =(Integer const & in)274             SIMDINLINE Integer& SIMDCALL operator=(Integer const& in)
275             {
276                 v = in.v;
277                 return *this;
278             }
operator __m256i() const279             SIMDINLINE SIMDCALL operator __m256i() const { return v; }
280 
281             SIMDALIGN(__m256i, 32) v;
282             SIMD128Impl::Integer v4[2];
283         };
284 
285         union Double
286         {
287             SIMDINLINE Double() = default;
Double(__m256d const & in)288             SIMDINLINE Double(__m256d const& in) : v(in) {}
Double(SIMD128Impl::Double const & in_lo,SIMD128Impl::Double const & in_hi=_mm_setzero_pd ())289             SIMDINLINE Double(SIMD128Impl::Double const& in_lo,
290                               SIMD128Impl::Double const& in_hi = _mm_setzero_pd())
291             {
292                 v = _mm256_insertf128_pd(_mm256_castpd128_pd256(in_lo), in_hi, 0x1);
293             }
operator =(__m256d in)294             SIMDINLINE Double& SIMDCALL operator=(__m256d in)
295             {
296                 v = in;
297                 return *this;
298             }
operator =(Double const & in)299             SIMDINLINE Double& SIMDCALL operator=(Double const& in)
300             {
301                 v = in.v;
302                 return *this;
303             }
operator __m256d() const304             SIMDINLINE SIMDCALL operator __m256d() const { return v; }
305 
306             SIMDALIGN(__m256d, 32) v;
307             SIMD128Impl::Double v4[2];
308         };
309 
310         using Vec4 = SIMDImpl::Vec4<Float, Integer, Double>;
311         using Mask = uint8_t;
312 
313         static const uint32_t SIMD_WIDTH = 8;
314     } // namespace SIMD256Impl
315 
316     namespace SIMD512Impl
317     {
318 #if !(defined(__AVX512F__) || defined(_ZMMINTRIN_H_INCLUDED))
319         // Define AVX512 types if not included via immintrin.h.
320         // All data members of these types are ONLY to viewed
321         // in a debugger.  Do NOT access them via code!
322         union __m512
323         {
324         private:
325             float m512_f32[16];
326         };
327         struct __m512d
328         {
329         private:
330             double m512d_f64[8];
331         };
332 
333         union __m512i
334         {
335         private:
336             int8_t   m512i_i8[64];
337             int16_t  m512i_i16[32];
338             int32_t  m512i_i32[16];
339             int64_t  m512i_i64[8];
340             uint8_t  m512i_u8[64];
341             uint16_t m512i_u16[32];
342             uint32_t m512i_u32[16];
343             uint64_t m512i_u64[8];
344         };
345 
346         using __mmask16 = uint16_t;
347 #endif
348 
349 #if defined(__INTEL_COMPILER) || (SIMD_ARCH >= SIMD_ARCH_AVX512)
350 #define SIMD_ALIGNMENT_BYTES 64
351 #else
352 #define SIMD_ALIGNMENT_BYTES 32
353 #endif
354 
355         union Float
356         {
357             SIMDINLINE Float() = default;
Float(__m512 in)358             SIMDINLINE Float(__m512 in) : v(in) {}
Float(SIMD256Impl::Float const & in_lo,SIMD256Impl::Float const & in_hi=_mm256_setzero_ps ())359             SIMDINLINE Float(SIMD256Impl::Float const& in_lo,
360                              SIMD256Impl::Float const& in_hi = _mm256_setzero_ps())
361             {
362                 v8[0] = in_lo;
363                 v8[1] = in_hi;
364             }
operator =(__m512 in)365             SIMDINLINE Float& SIMDCALL operator=(__m512 in)
366             {
367                 v = in;
368                 return *this;
369             }
operator =(Float const & in)370             SIMDINLINE Float& SIMDCALL operator=(Float const& in)
371             {
372 #if SIMD_ARCH >= SIMD_ARCH_AVX512
373                 v = in.v;
374 #else
375                 v8[0] = in.v8[0];
376                 v8[1] = in.v8[1];
377 #endif
378                 return *this;
379             }
operator __m512() const380             SIMDINLINE SIMDCALL operator __m512() const { return v; }
381 
382             SIMDALIGN(__m512, SIMD_ALIGNMENT_BYTES) v;
383             SIMD256Impl::Float v8[2];
384         };
385 
386         union Integer
387         {
388             SIMDINLINE Integer() = default;
Integer(__m512i in)389             SIMDINLINE Integer(__m512i in) : v(in) {}
Integer(SIMD256Impl::Integer const & in_lo,SIMD256Impl::Integer const & in_hi=_mm256_setzero_si256 ())390             SIMDINLINE Integer(SIMD256Impl::Integer const& in_lo,
391                                SIMD256Impl::Integer const& in_hi = _mm256_setzero_si256())
392             {
393                 v8[0] = in_lo;
394                 v8[1] = in_hi;
395             }
operator =(__m512i in)396             SIMDINLINE Integer& SIMDCALL operator=(__m512i in)
397             {
398                 v = in;
399                 return *this;
400             }
operator =(Integer const & in)401             SIMDINLINE Integer& SIMDCALL operator=(Integer const& in)
402             {
403 #if SIMD_ARCH >= SIMD_ARCH_AVX512
404                 v = in.v;
405 #else
406                 v8[0] = in.v8[0];
407                 v8[1] = in.v8[1];
408 #endif
409                 return *this;
410             }
411 
operator __m512i() const412             SIMDINLINE SIMDCALL operator __m512i() const { return v; }
413 
414             SIMDALIGN(__m512i, SIMD_ALIGNMENT_BYTES) v;
415             SIMD256Impl::Integer v8[2];
416         };
417 
418         union Double
419         {
420             SIMDINLINE Double() = default;
Double(__m512d in)421             SIMDINLINE Double(__m512d in) : v(in) {}
Double(SIMD256Impl::Double const & in_lo,SIMD256Impl::Double const & in_hi=_mm256_setzero_pd ())422             SIMDINLINE Double(SIMD256Impl::Double const& in_lo,
423                               SIMD256Impl::Double const& in_hi = _mm256_setzero_pd())
424             {
425                 v8[0] = in_lo;
426                 v8[1] = in_hi;
427             }
operator =(__m512d in)428             SIMDINLINE Double& SIMDCALL operator=(__m512d in)
429             {
430                 v = in;
431                 return *this;
432             }
operator =(Double const & in)433             SIMDINLINE Double& SIMDCALL operator=(Double const& in)
434             {
435 #if SIMD_ARCH >= SIMD_ARCH_AVX512
436                 v = in.v;
437 #else
438                 v8[0] = in.v8[0];
439                 v8[1] = in.v8[1];
440 #endif
441                 return *this;
442             }
443 
operator __m512d() const444             SIMDINLINE SIMDCALL operator __m512d() const { return v; }
445 
446             SIMDALIGN(__m512d, SIMD_ALIGNMENT_BYTES) v;
447             SIMD256Impl::Double v8[2];
448         };
449 
450         typedef SIMDImpl::Vec4<Float, Integer, Double> SIMDALIGN(Vec4, 64);
451         using Mask = __mmask16;
452 
453         static const uint32_t SIMD_WIDTH = 16;
454 
455 #undef SIMD_ALIGNMENT_BYTES
456     } // namespace SIMD512Impl
457 } // namespace SIMDImpl
458