1 /**
2  * Copyright 2014-2016 Andreas Schäfer
3  * Copyright 2015 Kurt Kanzenbach
4  *
5  * Distributed under the Boost Software License, Version 1.0. (See accompanying
6  * file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
7  */
8 
9 #ifndef FLAT_ARRAY_DETAIL_SHORT_VEC_AVX_FLOAT_8_HPP
10 #define FLAT_ARRAY_DETAIL_SHORT_VEC_AVX_FLOAT_8_HPP
11 
12 #if (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX) ||     \
13     (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX2) ||    \
14     (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX512F)
15 
16 #include <immintrin.h>
17 #include <libflatarray/detail/sqrt_reference.hpp>
18 #include <libflatarray/detail/short_vec_helpers.hpp>
19 #include <libflatarray/config.h>
20 
21 #ifdef LIBFLATARRAY_WITH_CPP14
22 #include <initializer_list>
23 #endif
24 
25 namespace LibFlatArray {
26 
27 template<typename CARGO, int ARITY>
28 class short_vec;
29 
30 template<typename CARGO, int ARITY>
31 class sqrt_reference;
32 
33 #ifdef __ICC
34 // disabling this warning as implicit type conversion is exactly our goal here:
35 #pragma warning push
36 #pragma warning (disable: 2304)
37 #endif
38 
39 template<>
40 class short_vec<float, 8>
41 {
42 public:
43     static const int ARITY = 8;
44     typedef short_vec<float, 8> mask_type;
45     typedef short_vec_strategy::avx strategy;
46 
47     template<typename _CharT, typename _Traits>
48     friend std::basic_ostream<_CharT, _Traits>& operator<<(
49         std::basic_ostream<_CharT, _Traits>& __os,
50         const short_vec<float, 8>& vec);
51 
52     inline
short_vec(const float data=0)53     short_vec(const float data = 0) :
54         val1(_mm256_broadcast_ss(&data))
55     {}
56 
57     inline
short_vec(const float * data)58     short_vec(const float *data)
59     {
60         load(data);
61     }
62 
63     inline
short_vec(const __m256 & val1)64     short_vec(const __m256& val1) :
65         val1(val1)
66     {}
67 
68 #ifdef LIBFLATARRAY_WITH_CPP14
69     inline
short_vec(const std::initializer_list<float> & il)70     short_vec(const std::initializer_list<float>& il)
71     {
72         const float *ptr = static_cast<const float *>(&(*il.begin()));
73         load(ptr);
74     }
75 #endif
76 
77     inline
78     short_vec(const sqrt_reference<float, 8>& other);
79 
80     inline
any() const81     bool any() const
82     {
83         // merge both 128-bit lanes of AVX register:
84         __m128 buf0 = _mm_or_ps(
85             _mm256_extractf128_ps(val1, 0),
86             _mm256_extractf128_ps(val1, 1));
87         // shuffle upper 64-bit half down to first 64 bits so we can
88         // "or" both together:
89         __m128 buf1 = _mm_shuffle_ps(buf0, buf0, (3 << 2) | (2 << 0));
90         buf1 = _mm_or_ps(buf0, buf1);
91         // another shuffle to extract 2nd least significant float
92         // member and or it together with least significant float
93         // member:
94         __m128 buf2 = _mm_shuffle_ps(buf1, buf1, (1 << 0));
95         return _mm_cvtss_f32(buf1) || _mm_cvtss_f32(buf2);
96     }
97 
98     inline
get(int i) const99     float get(int i) const
100     {
101         __m128 buf;
102         if (i < 4) {
103             buf = _mm256_extractf128_ps(val1, 0);
104         } else {
105             buf = _mm256_extractf128_ps(val1, 1);
106         }
107 
108         i &= 3;
109 
110         if (i == 3) {
111             return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 3));
112         }
113         if (i == 2) {
114             return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 2));
115         }
116         if (i == 1) {
117             return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 1));
118         }
119 
120         return _mm_cvtss_f32(buf);
121     }
122 
123     inline
operator -=(const short_vec<float,8> & other)124     void operator-=(const short_vec<float, 8>& other)
125     {
126         val1 = _mm256_sub_ps(val1, other.val1);
127     }
128 
129     inline
operator -(const short_vec<float,8> & other) const130     short_vec<float, 8> operator-(const short_vec<float, 8>& other) const
131     {
132         return _mm256_sub_ps(val1, other.val1);
133     }
134 
135     inline
operator +=(const short_vec<float,8> & other)136     void operator+=(const short_vec<float, 8>& other)
137     {
138         val1 = _mm256_add_ps(val1, other.val1);
139     }
140 
141     inline
operator +(const short_vec<float,8> & other) const142     short_vec<float, 8> operator+(const short_vec<float, 8>& other) const
143     {
144         return _mm256_add_ps(val1, other.val1);
145     }
146 
147     inline
operator *=(const short_vec<float,8> & other)148     void operator*=(const short_vec<float, 8>& other)
149     {
150         val1 = _mm256_mul_ps(val1, other.val1);
151     }
152 
153     inline
operator *(const short_vec<float,8> & other) const154     short_vec<float, 8> operator*(const short_vec<float, 8>& other) const
155     {
156         return _mm256_mul_ps(val1, other.val1);
157     }
158 
159     inline
operator /=(const short_vec<float,8> & other)160     void operator/=(const short_vec<float, 8>& other)
161     {
162         val1 = _mm256_mul_ps(val1, _mm256_rcp_ps(other.val1));
163     }
164 
165     inline
166     void operator/=(const sqrt_reference<float, 8>& other);
167 
168     inline
operator /(const short_vec<float,8> & other) const169     short_vec<float, 8> operator/(const short_vec<float, 8>& other) const
170     {
171         return _mm256_mul_ps(val1, _mm256_rcp_ps(other.val1));
172     }
173 
174     inline
175     short_vec<float, 8> operator/(const sqrt_reference<float, 8>& other) const;
176 
177     inline
operator <(const short_vec<float,8> & other) const178     short_vec<float, 8> operator<(const short_vec<float, 8>& other) const
179     {
180         return short_vec<float, 8>(
181             _mm256_cmp_ps(val1, other.val1, _CMP_LT_OS));
182     }
183 
184     inline
operator <=(const short_vec<float,8> & other) const185     short_vec<float, 8> operator<=(const short_vec<float, 8>& other) const
186     {
187         return short_vec<float, 8>(
188             _mm256_cmp_ps(val1, other.val1, _CMP_LE_OS));
189     }
190 
191     inline
operator ==(const short_vec<float,8> & other) const192     short_vec<float, 8> operator==(const short_vec<float, 8>& other) const
193     {
194         return short_vec<float, 8>(
195             _mm256_cmp_ps(val1, other.val1, _CMP_EQ_OQ));
196     }
197 
198     inline
operator >(const short_vec<float,8> & other) const199     short_vec<float, 8> operator>(const short_vec<float, 8>& other) const
200     {
201         return short_vec<float, 8>(
202             _mm256_cmp_ps(val1, other.val1, _CMP_GT_OS));
203     }
204 
205     inline
operator >=(const short_vec<float,8> & other) const206     short_vec<float, 8> operator>=(const short_vec<float, 8>& other) const
207     {
208         return short_vec<float, 8>(
209             _mm256_cmp_ps(val1, other.val1, _CMP_GE_OS));
210     }
211 
212     inline
sqrt() const213     short_vec<float, 8> sqrt() const
214     {
215         return _mm256_sqrt_ps(val1);
216     }
217 
218     inline
load(const float * data)219     void load(const float *data)
220     {
221         val1 = _mm256_loadu_ps(data);
222     }
223 
224     inline
load_aligned(const float * data)225     void load_aligned(const float *data)
226     {
227         SHORTVEC_ASSERT_ALIGNED(data, 32);
228         val1 = _mm256_load_ps(data);
229     }
230 
231     inline
store(float * data) const232     void store(float *data) const
233     {
234         _mm256_storeu_ps(data, val1);
235     }
236 
237     inline
store_aligned(float * data) const238     void store_aligned(float *data) const
239     {
240         SHORTVEC_ASSERT_ALIGNED(data, 32);
241         _mm256_store_ps(data, val1);
242     }
243 
244     inline
store_nt(float * data) const245     void store_nt(float *data) const
246     {
247         SHORTVEC_ASSERT_ALIGNED(data, 32);
248         _mm256_stream_ps(data, val1);
249     }
250 
251 #ifdef __AVX2__
252     inline
gather(const float * ptr,const int * offsets)253     void gather(const float *ptr, const int *offsets)
254     {
255         __m256i indices;
256         indices = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(offsets));
257         val1    = _mm256_i32gather_ps(ptr, indices, 4);
258     }
259 #else
260     inline
gather(const float * ptr,const int * offsets)261     void gather(const float *ptr, const int *offsets)
262     {
263         __m128 tmp;
264         tmp  = _mm_load_ss(ptr + offsets[0]);
265         SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[1], _MM_MK_INSERTPS_NDX(0,1,0));
266         SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[2], _MM_MK_INSERTPS_NDX(0,2,0));
267         SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[3], _MM_MK_INSERTPS_NDX(0,3,0));
268         val1 = _mm256_insertf128_ps(val1, tmp, 0);
269         tmp  = _mm_load_ss(ptr + offsets[4]);
270         SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[5], _MM_MK_INSERTPS_NDX(0,1,0));
271         SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[6], _MM_MK_INSERTPS_NDX(0,2,0));
272         SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[7], _MM_MK_INSERTPS_NDX(0,3,0));
273         val1 = _mm256_insertf128_ps(val1, tmp, 1);
274     }
275 #endif
276 
277     inline
scatter(float * ptr,const int * offsets) const278     void scatter(float *ptr, const int *offsets) const
279     {
280         __m128 tmp;
281         tmp = _mm256_extractf128_ps(val1, 0);
282         _MM_EXTRACT_FLOAT(ptr[offsets[0]], tmp, 0);
283         _MM_EXTRACT_FLOAT(ptr[offsets[1]], tmp, 1);
284         _MM_EXTRACT_FLOAT(ptr[offsets[2]], tmp, 2);
285         _MM_EXTRACT_FLOAT(ptr[offsets[3]], tmp, 3);
286         tmp = _mm256_extractf128_ps(val1, 1);
287         _MM_EXTRACT_FLOAT(ptr[offsets[4]], tmp, 0);
288         _MM_EXTRACT_FLOAT(ptr[offsets[5]], tmp, 1);
289         _MM_EXTRACT_FLOAT(ptr[offsets[6]], tmp, 2);
290         _MM_EXTRACT_FLOAT(ptr[offsets[7]], tmp, 3);
291     }
292 
293 private:
294     __m256 val1;
295 };
296 
297 inline
operator <<(float * data,const short_vec<float,8> & vec)298 void operator<<(float *data, const short_vec<float, 8>& vec)
299 {
300     vec.store(data);
301 }
302 
303 template<>
304 class sqrt_reference<float, 8>
305 {
306 public:
307     template<typename OTHER_CARGO, int OTHER_ARITY>
308     friend class short_vec;
309 
sqrt_reference(const short_vec<float,8> & vec)310     sqrt_reference(const short_vec<float, 8>& vec) :
311         vec(vec)
312     {}
313 
314 private:
315     short_vec<float, 8> vec;
316 };
317 
318 #ifdef __ICC
319 #pragma warning pop
320 #endif
321 
322 inline
short_vec(const sqrt_reference<float,8> & other)323 short_vec<float, 8>::short_vec(const sqrt_reference<float, 8>& other) :
324     val1(_mm256_sqrt_ps(other.vec.val1))
325 {}
326 
327 inline
operator /=(const sqrt_reference<float,8> & other)328 void short_vec<float, 8>::operator/=(const sqrt_reference<float, 8>& other)
329 {
330     val1 = _mm256_mul_ps(val1, _mm256_rsqrt_ps(other.vec.val1));
331 }
332 
333 inline
operator /(const sqrt_reference<float,8> & other) const334 short_vec<float, 8> short_vec<float, 8>::operator/(const sqrt_reference<float, 8>& other) const
335 {
336     return _mm256_mul_ps(val1, _mm256_rsqrt_ps(other.vec.val1));
337 }
338 
339 inline
sqrt(const short_vec<float,8> & vec)340 sqrt_reference<float, 8> sqrt(const short_vec<float, 8>& vec)
341 {
342     return sqrt_reference<float, 8>(vec);
343 }
344 
345 template<typename _CharT, typename _Traits>
346 std::basic_ostream<_CharT, _Traits>&
operator <<(std::basic_ostream<_CharT,_Traits> & __os,const short_vec<float,8> & vec)347 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
348            const short_vec<float, 8>& vec)
349 {
350     const float *data1 = reinterpret_cast<const float *>(&vec.val1);
351     __os << "[" << data1[0] << ", " << data1[1]  << ", " << data1[2]  << ", " << data1[3]  << ", " << data1[4]  << ", " << data1[5]  << ", " << data1[6]  << ", " << data1[7] << "]";
352     return __os;
353 }
354 
355 }
356 
357 #endif
358 
359 #endif
360