1 /**
2 * Copyright 2015 Kurt Kanzenbach
3 * Copyright 2016 Andreas Schäfer
4 *
5 * Distributed under the Boost Software License, Version 1.0. (See accompanying
6 * file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
7 */
8
9 #ifndef FLAT_ARRAY_DETAIL_SHORT_VEC_AVX512_FLOAT_32_HPP
10 #define FLAT_ARRAY_DETAIL_SHORT_VEC_AVX512_FLOAT_32_HPP
11
12 #if LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX512F
13
14 #include <immintrin.h>
15 #include <libflatarray/detail/sqrt_reference.hpp>
16 #include <libflatarray/detail/short_vec_helpers.hpp>
17 #include <libflatarray/config.h>
18
19 #ifdef LIBFLATARRAY_WITH_CPP14
20 #include <initializer_list>
21 #endif
22
23 namespace LibFlatArray {
24
25 template<typename CARGO, int ARITY>
26 class short_vec;
27
28 template<typename CARGO, int ARITY>
29 class sqrt_reference;
30
31 #ifdef __ICC
32 // disabling this warning as implicit type conversion is exactly our goal here:
33 #pragma warning push
34 #pragma warning (disable: 2304)
35 #endif
36
37 template<>
38 class short_vec<float, 32>
39 {
40 public:
41 static const int ARITY = 32;
42 typedef unsigned mask_type;
43 typedef short_vec_strategy::avx512f strategy;
44
45 template<typename _CharT, typename _Traits>
46 friend std::basic_ostream<_CharT, _Traits>& operator<<(
47 std::basic_ostream<_CharT, _Traits>& __os,
48 const short_vec<float, 32>& vec);
49
50 inline
short_vec(const float data=0)51 short_vec(const float data = 0) :
52 val1(_mm512_set1_ps(data)),
53 val2(_mm512_set1_ps(data))
54 {}
55
56 inline
short_vec(const float * data)57 short_vec(const float *data)
58 {
59 load(data);
60 }
61
62 inline
short_vec(const __m512 & val1,const __m512 & val2)63 short_vec(const __m512& val1, const __m512& val2) :
64 val1(val1), val2(val2)
65 {}
66
67 #ifdef LIBFLATARRAY_WITH_CPP14
68 inline
short_vec(const std::initializer_list<float> & il)69 short_vec(const std::initializer_list<float>& il)
70 {
71 const float *ptr = static_cast<const float *>(&(*il.begin()));
72 load(ptr);
73 }
74 #endif
75
76 inline
77 short_vec(const sqrt_reference<float, 32>& other);
78
79 inline
any() const80 bool any() const
81 {
82 __m512 buf0 = _mm512_or_ps(val1, val2);
83 __m128 buf1 = _mm_or_ps(
84 _mm_or_ps(
85 _mm512_extractf32x4_ps(buf0, 0),
86 _mm512_extractf32x4_ps(buf0, 1)),
87 _mm_or_ps(
88 _mm512_extractf32x4_ps(buf0, 2),
89 _mm512_extractf32x4_ps(buf0, 3)));
90 // shuffle upper 64-bit half down to first 64 bits so we can
91 // "or" both together:
92 __m128 buf2 = _mm_shuffle_ps(buf1, buf1, (3 << 2) | (2 << 0));
93 buf2 = _mm_or_ps(buf1, buf2);
94 // another shuffle to extract 2nd least significant float
95 // member and or it together with least significant float
96 // member:
97 buf1 = _mm_shuffle_ps(buf2, buf2, (1 << 0));
98 return _mm_cvtss_f32(buf1) || _mm_cvtss_f32(buf2);
99 }
100
101 inline
get(int i) const102 float get(int i) const
103 {
104 __m512 buf0;
105 if (i < 16) {
106 buf0 = val1;
107 } else {
108 buf0 = val2;
109 }
110
111 i &= 15;
112
113 __m128 buf1;
114 if (i < 8) {
115 if (i < 4) {
116 buf1 = _mm512_extractf32x4_ps(buf0, 0);
117 } else {
118 buf1 = _mm512_extractf32x4_ps(buf0, 1);
119 }
120 } else {
121 if (i < 12) {
122 buf1 = _mm512_extractf32x4_ps(buf0, 2);
123 } else {
124 buf1 = _mm512_extractf32x4_ps(buf0, 3);
125 }
126 }
127
128 i &= 3;
129
130 if (i == 3) {
131 return _mm_cvtss_f32(_mm_shuffle_ps(buf1, buf1, 3));
132 }
133 if (i == 2) {
134 return _mm_cvtss_f32(_mm_shuffle_ps(buf1, buf1, 2));
135 }
136 if (i == 1) {
137 return _mm_cvtss_f32(_mm_shuffle_ps(buf1, buf1, 1));
138 }
139
140 return _mm_cvtss_f32(buf1);
141 }
142
143 inline
operator -=(const short_vec<float,32> & other)144 void operator-=(const short_vec<float, 32>& other)
145 {
146 val1 = _mm512_sub_ps(val1, other.val1);
147 val2 = _mm512_sub_ps(val2, other.val2);
148 }
149
150 inline
operator -(const short_vec<float,32> & other) const151 short_vec<float, 32> operator-(const short_vec<float, 32>& other) const
152 {
153 return short_vec<float, 32>(
154 _mm512_sub_ps(val1, other.val1),
155 _mm512_sub_ps(val2, other.val2));
156 }
157
158 inline
operator +=(const short_vec<float,32> & other)159 void operator+=(const short_vec<float, 32>& other)
160 {
161 val1 = _mm512_add_ps(val1, other.val1);
162 val2 = _mm512_add_ps(val2, other.val2);
163 }
164
165 inline
operator +(const short_vec<float,32> & other) const166 short_vec<float, 32> operator+(const short_vec<float, 32>& other) const
167 {
168 return short_vec<float, 32>(
169 _mm512_add_ps(val1, other.val1),
170 _mm512_add_ps(val2, other.val2));
171 }
172
173 inline
operator *=(const short_vec<float,32> & other)174 void operator*=(const short_vec<float, 32>& other)
175 {
176 val1 = _mm512_mul_ps(val1, other.val1);
177 val2 = _mm512_mul_ps(val2, other.val2);
178 }
179
180 inline
operator *(const short_vec<float,32> & other) const181 short_vec<float, 32> operator*(const short_vec<float, 32>& other) const
182 {
183 return short_vec<float, 32>(
184 _mm512_mul_ps(val1, other.val1),
185 _mm512_mul_ps(val2, other.val2));
186 }
187
188 inline
operator /=(const short_vec<float,32> & other)189 void operator/=(const short_vec<float, 32>& other)
190 {
191 val1 = _mm512_mul_ps(val1, _mm512_rcp14_ps(other.val1));
192 val2 = _mm512_mul_ps(val2, _mm512_rcp14_ps(other.val2));
193 }
194
195 inline
196 void operator/=(const sqrt_reference<float, 32>& other);
197
198 inline
operator /(const short_vec<float,32> & other) const199 short_vec<float, 32> operator/(const short_vec<float, 32>& other) const
200 {
201 return short_vec<float, 32>(
202 _mm512_mul_ps(val1, _mm512_rcp14_ps(other.val1)),
203 _mm512_mul_ps(val2, _mm512_rcp14_ps(other.val2)));
204 }
205
206 inline
207 short_vec<float, 32> operator/(const sqrt_reference<float, 32>& other) const;
208
209 inline
operator <(const short_vec<float,32> & other) const210 mask_type operator<(const short_vec<float, 32>& other) const
211 {
212 return
213 (_mm512_cmp_ps_mask(val1, other.val1, _CMP_LT_OS) << 0) +
214 (_mm512_cmp_ps_mask(val2, other.val2, _CMP_LT_OS) << 16);
215 }
216
217 inline
operator <=(const short_vec<float,32> & other) const218 mask_type operator<=(const short_vec<float, 32>& other) const
219 {
220 return
221 (_mm512_cmp_ps_mask(val1, other.val1, _CMP_LE_OS) << 0) +
222 (_mm512_cmp_ps_mask(val2, other.val2, _CMP_LE_OS) << 16);
223 }
224
225 inline
operator ==(const short_vec<float,32> & other) const226 mask_type operator==(const short_vec<float, 32>& other) const
227 {
228 return
229 (_mm512_cmp_ps_mask(val1, other.val1, _CMP_EQ_OQ) << 0) +
230 (_mm512_cmp_ps_mask(val2, other.val2, _CMP_EQ_OQ) << 16);
231 }
232
233 inline
operator >(const short_vec<float,32> & other) const234 mask_type operator>(const short_vec<float, 32>& other) const
235 {
236 return
237 (_mm512_cmp_ps_mask(val1, other.val1, _CMP_GT_OS) << 0) +
238 (_mm512_cmp_ps_mask(val2, other.val2, _CMP_GT_OS) << 16);
239 }
240
241 inline
operator >=(const short_vec<float,32> & other) const242 mask_type operator>=(const short_vec<float, 32>& other) const
243 {
244 return
245 (_mm512_cmp_ps_mask(val1, other.val1, _CMP_GE_OS) << 0) +
246 (_mm512_cmp_ps_mask(val2, other.val2, _CMP_GE_OS) << 16);
247 }
248
249 inline
sqrt() const250 short_vec<float, 32> sqrt() const
251 {
252 return short_vec<float, 32>(
253 _mm512_sqrt_ps(val1),
254 _mm512_sqrt_ps(val2));
255 }
256
257 inline
load(const float * data)258 void load(const float *data)
259 {
260 val1 = _mm512_loadu_ps(data + 0);
261 val2 = _mm512_loadu_ps(data + 16);
262 }
263
264 inline
load_aligned(const float * data)265 void load_aligned(const float *data)
266 {
267 SHORTVEC_ASSERT_ALIGNED(data, 64);
268 val1 = _mm512_load_ps(data + 0);
269 val2 = _mm512_load_ps(data + 16);
270 }
271
272 inline
store(float * data) const273 void store(float *data) const
274 {
275 _mm512_storeu_ps(data + 0, val1);
276 _mm512_storeu_ps(data + 16, val2);
277 }
278
279 inline
store_aligned(float * data) const280 void store_aligned(float *data) const
281 {
282 SHORTVEC_ASSERT_ALIGNED(data, 64);
283 _mm512_store_ps(data + 0, val1);
284 _mm512_store_ps(data + 16, val2);
285 }
286
287 inline
store_nt(float * data) const288 void store_nt(float *data) const
289 {
290 SHORTVEC_ASSERT_ALIGNED(data, 64);
291 _mm512_stream_ps(data + 0, val1);
292 _mm512_stream_ps(data + 16, val2);
293 }
294
295 inline
gather(const float * ptr,const int * offsets)296 void gather(const float *ptr, const int *offsets)
297 {
298 __m512i indices;
299 SHORTVEC_ASSERT_ALIGNED(offsets, 64);
300 indices = _mm512_load_epi32(offsets);
301 val1 = _mm512_i32gather_ps(indices, ptr, 4);
302 indices = _mm512_load_epi32(offsets + 16);
303 val2 = _mm512_i32gather_ps(indices, ptr, 4);
304 }
305
306 inline
scatter(float * ptr,const int * offsets) const307 void scatter(float *ptr, const int *offsets) const
308 {
309 __m512i indices;
310 SHORTVEC_ASSERT_ALIGNED(offsets, 64);
311 indices = _mm512_load_epi32(offsets);
312 _mm512_i32scatter_ps(ptr, indices, val1, 4);
313 indices = _mm512_load_epi32(offsets + 16);
314 _mm512_i32scatter_ps(ptr, indices, val2, 4);
315 }
316
317 private:
318 __m512 val1;
319 __m512 val2;
320 };
321
322 inline
operator <<(float * data,const short_vec<float,32> & vec)323 void operator<<(float *data, const short_vec<float, 32>& vec)
324 {
325 vec.store(data);
326 }
327
328 template<>
329 class sqrt_reference<float, 32>
330 {
331 public:
332 template<typename OTHER_CARGO, int OTHER_ARITY>
333 friend class short_vec;
334
sqrt_reference(const short_vec<float,32> & vec)335 sqrt_reference(const short_vec<float, 32>& vec) :
336 vec(vec)
337 {}
338
339 private:
340 short_vec<float, 32> vec;
341 };
342
343 #ifdef __ICC
344 #pragma warning pop
345 #endif
346
347 inline
short_vec(const sqrt_reference<float,32> & other)348 short_vec<float, 32>::short_vec(const sqrt_reference<float, 32>& other) :
349 val1(_mm512_sqrt_ps(other.vec.val1)),
350 val2(_mm512_sqrt_ps(other.vec.val2))
351 {}
352
353 inline
operator /=(const sqrt_reference<float,32> & other)354 void short_vec<float, 32>::operator/=(const sqrt_reference<float, 32>& other)
355 {
356 val1 = _mm512_mul_ps(val1, _mm512_rsqrt14_ps(other.vec.val1));
357 val2 = _mm512_mul_ps(val2, _mm512_rsqrt14_ps(other.vec.val2));
358 }
359
360 inline
operator /(const sqrt_reference<float,32> & other) const361 short_vec<float, 32> short_vec<float, 32>::operator/(const sqrt_reference<float, 32>& other) const
362 {
363 return short_vec<float, 32>(
364 _mm512_mul_ps(val1, _mm512_rsqrt14_ps(other.vec.val1)),
365 _mm512_mul_ps(val2, _mm512_rsqrt14_ps(other.vec.val2)));
366 }
367
368 inline
sqrt(const short_vec<float,32> & vec)369 sqrt_reference<float, 32> sqrt(const short_vec<float, 32>& vec)
370 {
371 return sqrt_reference<float, 32>(vec);
372 }
373
374 template<typename _CharT, typename _Traits>
375 std::basic_ostream<_CharT, _Traits>&
operator <<(std::basic_ostream<_CharT,_Traits> & __os,const short_vec<float,32> & vec)376 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
377 const short_vec<float, 32>& vec)
378 {
379 const float *data1 = reinterpret_cast<const float *>(&vec.val1);
380 const float *data2 = reinterpret_cast<const float *>(&vec.val2);
381 __os << "[" << data1[ 0] << ", " << data1[ 1] << ", " << data1[ 2] << ", " << data1[ 3]
382 << ", " << data1[ 4] << ", " << data1[ 5] << ", " << data1[ 6] << ", " << data1[ 7]
383 << ", " << data1[ 8] << ", " << data1[ 9] << ", " << data1[10] << ", " << data1[11]
384 << ", " << data1[12] << ", " << data1[13] << ", " << data1[14] << ", " << data1[15]
385 << ", " << data2[ 0] << ", " << data2[ 1] << ", " << data2[ 2] << ", " << data2[ 3]
386 << ", " << data2[ 4] << ", " << data2[ 5] << ", " << data2[ 6] << ", " << data2[ 7]
387 << ", " << data2[ 8] << ", " << data2[ 9] << ", " << data2[10] << ", " << data2[11]
388 << ", " << data2[12] << ", " << data2[13] << ", " << data2[14] << ", " << data2[15]
389 << "]";
390 return __os;
391 }
392
393 }
394
395 #endif
396
397 #endif
398