1 /**
2  * Copyright 2015 Kurt Kanzenbach
3  * Copyright 2016 Andreas Schäfer
4  *
5  * Distributed under the Boost Software License, Version 1.0. (See accompanying
6  * file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
7  */
8 
9 #ifndef FLAT_ARRAY_DETAIL_SHORT_VEC_AVX512_FLOAT_32_HPP
10 #define FLAT_ARRAY_DETAIL_SHORT_VEC_AVX512_FLOAT_32_HPP
11 
12 #if LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX512F
13 
14 #include <immintrin.h>
15 #include <libflatarray/detail/sqrt_reference.hpp>
16 #include <libflatarray/detail/short_vec_helpers.hpp>
17 #include <libflatarray/config.h>
18 
19 #ifdef LIBFLATARRAY_WITH_CPP14
20 #include <initializer_list>
21 #endif
22 
23 namespace LibFlatArray {
24 
25 template<typename CARGO, int ARITY>
26 class short_vec;
27 
28 template<typename CARGO, int ARITY>
29 class sqrt_reference;
30 
31 #ifdef __ICC
32 // disabling this warning as implicit type conversion is exactly our goal here:
33 #pragma warning push
34 #pragma warning (disable: 2304)
35 #endif
36 
37 template<>
38 class short_vec<float, 32>
39 {
40 public:
41     static const int ARITY = 32;
42     typedef unsigned mask_type;
43     typedef short_vec_strategy::avx512f strategy;
44 
45     template<typename _CharT, typename _Traits>
46     friend std::basic_ostream<_CharT, _Traits>& operator<<(
47         std::basic_ostream<_CharT, _Traits>& __os,
48         const short_vec<float, 32>& vec);
49 
50     inline
short_vec(const float data=0)51     short_vec(const float data = 0) :
52         val1(_mm512_set1_ps(data)),
53         val2(_mm512_set1_ps(data))
54     {}
55 
56     inline
short_vec(const float * data)57     short_vec(const float *data)
58     {
59         load(data);
60     }
61 
62     inline
short_vec(const __m512 & val1,const __m512 & val2)63     short_vec(const __m512& val1, const __m512& val2) :
64         val1(val1), val2(val2)
65     {}
66 
67 #ifdef LIBFLATARRAY_WITH_CPP14
68     inline
short_vec(const std::initializer_list<float> & il)69     short_vec(const std::initializer_list<float>& il)
70     {
71         const float *ptr = static_cast<const float *>(&(*il.begin()));
72         load(ptr);
73     }
74 #endif
75 
76     inline
77     short_vec(const sqrt_reference<float, 32>& other);
78 
79     inline
any() const80     bool any() const
81     {
82         __m512 buf0 = _mm512_or_ps(val1, val2);
83         __m128 buf1 = _mm_or_ps(
84             _mm_or_ps(
85                 _mm512_extractf32x4_ps(buf0, 0),
86                 _mm512_extractf32x4_ps(buf0, 1)),
87             _mm_or_ps(
88                 _mm512_extractf32x4_ps(buf0, 2),
89                 _mm512_extractf32x4_ps(buf0, 3)));
90         // shuffle upper 64-bit half down to first 64 bits so we can
91         // "or" both together:
92         __m128 buf2 = _mm_shuffle_ps(buf1, buf1, (3 << 2) | (2 << 0));
93         buf2 = _mm_or_ps(buf1, buf2);
94         // another shuffle to extract 2nd least significant float
95         // member and or it together with least significant float
96         // member:
97         buf1 = _mm_shuffle_ps(buf2, buf2, (1 << 0));
98         return _mm_cvtss_f32(buf1) || _mm_cvtss_f32(buf2);
99     }
100 
101     inline
get(int i) const102     float get(int i) const
103     {
104         __m512 buf0;
105         if (i < 16) {
106             buf0 = val1;
107         } else {
108             buf0 = val2;
109         }
110 
111         i &= 15;
112 
113         __m128 buf1;
114         if (i < 8) {
115             if (i < 4) {
116                 buf1 =  _mm512_extractf32x4_ps(buf0, 0);
117             } else {
118                 buf1 =  _mm512_extractf32x4_ps(buf0, 1);
119             }
120         } else {
121             if (i < 12)  {
122                 buf1 =  _mm512_extractf32x4_ps(buf0, 2);
123             } else {
124                 buf1 =  _mm512_extractf32x4_ps(buf0, 3);
125             }
126         }
127 
128         i &= 3;
129 
130         if (i == 3) {
131             return _mm_cvtss_f32(_mm_shuffle_ps(buf1, buf1, 3));
132         }
133         if (i == 2) {
134             return _mm_cvtss_f32(_mm_shuffle_ps(buf1, buf1, 2));
135         }
136         if (i == 1) {
137             return _mm_cvtss_f32(_mm_shuffle_ps(buf1, buf1, 1));
138         }
139 
140         return _mm_cvtss_f32(buf1);
141     }
142 
143     inline
operator -=(const short_vec<float,32> & other)144     void operator-=(const short_vec<float, 32>& other)
145     {
146         val1 = _mm512_sub_ps(val1, other.val1);
147         val2 = _mm512_sub_ps(val2, other.val2);
148     }
149 
150     inline
operator -(const short_vec<float,32> & other) const151     short_vec<float, 32> operator-(const short_vec<float, 32>& other) const
152     {
153         return short_vec<float, 32>(
154             _mm512_sub_ps(val1, other.val1),
155             _mm512_sub_ps(val2, other.val2));
156     }
157 
158     inline
operator +=(const short_vec<float,32> & other)159     void operator+=(const short_vec<float, 32>& other)
160     {
161         val1 = _mm512_add_ps(val1, other.val1);
162         val2 = _mm512_add_ps(val2, other.val2);
163     }
164 
165     inline
operator +(const short_vec<float,32> & other) const166     short_vec<float, 32> operator+(const short_vec<float, 32>& other) const
167     {
168         return short_vec<float, 32>(
169             _mm512_add_ps(val1, other.val1),
170             _mm512_add_ps(val2, other.val2));
171     }
172 
173     inline
operator *=(const short_vec<float,32> & other)174     void operator*=(const short_vec<float, 32>& other)
175     {
176         val1 = _mm512_mul_ps(val1, other.val1);
177         val2 = _mm512_mul_ps(val2, other.val2);
178     }
179 
180     inline
operator *(const short_vec<float,32> & other) const181     short_vec<float, 32> operator*(const short_vec<float, 32>& other) const
182     {
183         return short_vec<float, 32>(
184             _mm512_mul_ps(val1, other.val1),
185             _mm512_mul_ps(val2, other.val2));
186     }
187 
188     inline
operator /=(const short_vec<float,32> & other)189     void operator/=(const short_vec<float, 32>& other)
190     {
191         val1 = _mm512_mul_ps(val1, _mm512_rcp14_ps(other.val1));
192         val2 = _mm512_mul_ps(val2, _mm512_rcp14_ps(other.val2));
193     }
194 
195     inline
196     void operator/=(const sqrt_reference<float, 32>& other);
197 
198     inline
operator /(const short_vec<float,32> & other) const199     short_vec<float, 32> operator/(const short_vec<float, 32>& other) const
200     {
201         return short_vec<float, 32>(
202             _mm512_mul_ps(val1, _mm512_rcp14_ps(other.val1)),
203             _mm512_mul_ps(val2, _mm512_rcp14_ps(other.val2)));
204     }
205 
206     inline
207     short_vec<float, 32> operator/(const sqrt_reference<float, 32>& other) const;
208 
209     inline
operator <(const short_vec<float,32> & other) const210     mask_type operator<(const short_vec<float, 32>& other) const
211     {
212         return
213             (_mm512_cmp_ps_mask(val1, other.val1, _CMP_LT_OS) <<  0) +
214             (_mm512_cmp_ps_mask(val2, other.val2, _CMP_LT_OS) << 16);
215     }
216 
217     inline
operator <=(const short_vec<float,32> & other) const218     mask_type operator<=(const short_vec<float, 32>& other) const
219     {
220         return
221             (_mm512_cmp_ps_mask(val1, other.val1, _CMP_LE_OS) <<  0) +
222             (_mm512_cmp_ps_mask(val2, other.val2, _CMP_LE_OS) << 16);
223     }
224 
225     inline
operator ==(const short_vec<float,32> & other) const226     mask_type operator==(const short_vec<float, 32>& other) const
227     {
228         return
229             (_mm512_cmp_ps_mask(val1, other.val1, _CMP_EQ_OQ) <<  0) +
230             (_mm512_cmp_ps_mask(val2, other.val2, _CMP_EQ_OQ) << 16);
231     }
232 
233     inline
operator >(const short_vec<float,32> & other) const234     mask_type operator>(const short_vec<float, 32>& other) const
235     {
236         return
237             (_mm512_cmp_ps_mask(val1, other.val1, _CMP_GT_OS) <<  0) +
238             (_mm512_cmp_ps_mask(val2, other.val2, _CMP_GT_OS) << 16);
239     }
240 
241     inline
operator >=(const short_vec<float,32> & other) const242     mask_type operator>=(const short_vec<float, 32>& other) const
243     {
244         return
245             (_mm512_cmp_ps_mask(val1, other.val1, _CMP_GE_OS) <<  0) +
246             (_mm512_cmp_ps_mask(val2, other.val2, _CMP_GE_OS) << 16);
247     }
248 
249     inline
sqrt() const250     short_vec<float, 32> sqrt() const
251     {
252         return short_vec<float, 32>(
253             _mm512_sqrt_ps(val1),
254             _mm512_sqrt_ps(val2));
255     }
256 
257     inline
load(const float * data)258     void load(const float *data)
259     {
260         val1 = _mm512_loadu_ps(data +  0);
261         val2 = _mm512_loadu_ps(data + 16);
262     }
263 
264     inline
load_aligned(const float * data)265     void load_aligned(const float *data)
266     {
267         SHORTVEC_ASSERT_ALIGNED(data, 64);
268         val1 = _mm512_load_ps(data +  0);
269         val2 = _mm512_load_ps(data + 16);
270     }
271 
272     inline
store(float * data) const273     void store(float *data) const
274     {
275         _mm512_storeu_ps(data +  0, val1);
276         _mm512_storeu_ps(data + 16, val2);
277     }
278 
279     inline
store_aligned(float * data) const280     void store_aligned(float *data) const
281     {
282         SHORTVEC_ASSERT_ALIGNED(data, 64);
283         _mm512_store_ps(data +  0, val1);
284         _mm512_store_ps(data + 16, val2);
285     }
286 
287     inline
store_nt(float * data) const288     void store_nt(float *data) const
289     {
290         SHORTVEC_ASSERT_ALIGNED(data, 64);
291         _mm512_stream_ps(data +  0, val1);
292         _mm512_stream_ps(data + 16, val2);
293     }
294 
295     inline
gather(const float * ptr,const int * offsets)296     void gather(const float *ptr, const int *offsets)
297     {
298         __m512i indices;
299         SHORTVEC_ASSERT_ALIGNED(offsets, 64);
300         indices = _mm512_load_epi32(offsets);
301         val1    = _mm512_i32gather_ps(indices, ptr, 4);
302         indices = _mm512_load_epi32(offsets + 16);
303         val2    = _mm512_i32gather_ps(indices, ptr, 4);
304     }
305 
306     inline
scatter(float * ptr,const int * offsets) const307     void scatter(float *ptr, const int *offsets) const
308     {
309         __m512i indices;
310         SHORTVEC_ASSERT_ALIGNED(offsets, 64);
311         indices = _mm512_load_epi32(offsets);
312         _mm512_i32scatter_ps(ptr, indices, val1, 4);
313         indices = _mm512_load_epi32(offsets + 16);
314         _mm512_i32scatter_ps(ptr, indices, val2, 4);
315     }
316 
317 private:
318     __m512 val1;
319     __m512 val2;
320 };
321 
322 inline
operator <<(float * data,const short_vec<float,32> & vec)323 void operator<<(float *data, const short_vec<float, 32>& vec)
324 {
325     vec.store(data);
326 }
327 
328 template<>
329 class sqrt_reference<float, 32>
330 {
331 public:
332     template<typename OTHER_CARGO, int OTHER_ARITY>
333     friend class short_vec;
334 
sqrt_reference(const short_vec<float,32> & vec)335     sqrt_reference(const short_vec<float, 32>& vec) :
336         vec(vec)
337     {}
338 
339 private:
340     short_vec<float, 32> vec;
341 };
342 
343 #ifdef __ICC
344 #pragma warning pop
345 #endif
346 
347 inline
short_vec(const sqrt_reference<float,32> & other)348 short_vec<float, 32>::short_vec(const sqrt_reference<float, 32>& other) :
349     val1(_mm512_sqrt_ps(other.vec.val1)),
350     val2(_mm512_sqrt_ps(other.vec.val2))
351 {}
352 
353 inline
operator /=(const sqrt_reference<float,32> & other)354 void short_vec<float, 32>::operator/=(const sqrt_reference<float, 32>& other)
355 {
356     val1 = _mm512_mul_ps(val1, _mm512_rsqrt14_ps(other.vec.val1));
357     val2 = _mm512_mul_ps(val2, _mm512_rsqrt14_ps(other.vec.val2));
358 }
359 
360 inline
operator /(const sqrt_reference<float,32> & other) const361 short_vec<float, 32> short_vec<float, 32>::operator/(const sqrt_reference<float, 32>& other) const
362 {
363     return short_vec<float, 32>(
364         _mm512_mul_ps(val1, _mm512_rsqrt14_ps(other.vec.val1)),
365         _mm512_mul_ps(val2, _mm512_rsqrt14_ps(other.vec.val2)));
366 }
367 
368 inline
sqrt(const short_vec<float,32> & vec)369 sqrt_reference<float, 32> sqrt(const short_vec<float, 32>& vec)
370 {
371     return sqrt_reference<float, 32>(vec);
372 }
373 
374 template<typename _CharT, typename _Traits>
375 std::basic_ostream<_CharT, _Traits>&
operator <<(std::basic_ostream<_CharT,_Traits> & __os,const short_vec<float,32> & vec)376 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
377            const short_vec<float, 32>& vec)
378 {
379     const float *data1 = reinterpret_cast<const float *>(&vec.val1);
380     const float *data2 = reinterpret_cast<const float *>(&vec.val2);
381     __os << "["  << data1[ 0] << ", " << data1[ 1] << ", " << data1[ 2] << ", " << data1[ 3]
382          << ", " << data1[ 4] << ", " << data1[ 5] << ", " << data1[ 6] << ", " << data1[ 7]
383          << ", " << data1[ 8] << ", " << data1[ 9] << ", " << data1[10] << ", " << data1[11]
384          << ", " << data1[12] << ", " << data1[13] << ", " << data1[14] << ", " << data1[15]
385          << ", " << data2[ 0] << ", " << data2[ 1] << ", " << data2[ 2] << ", " << data2[ 3]
386          << ", " << data2[ 4] << ", " << data2[ 5] << ", " << data2[ 6] << ", " << data2[ 7]
387          << ", " << data2[ 8] << ", " << data2[ 9] << ", " << data2[10] << ", " << data2[11]
388          << ", " << data2[12] << ", " << data2[13] << ", " << data2[14] << ", " << data2[15]
389          << "]";
390     return __os;
391 }
392 
393 }
394 
395 #endif
396 
397 #endif
398