1 /**
2 * Copyright 2014-2016 Andreas Schäfer
3 * Copyright 2015 Kurt Kanzenbach
4 *
5 * Distributed under the Boost Software License, Version 1.0. (See accompanying
6 * file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
7 */
8
9 #ifndef FLAT_ARRAY_DETAIL_SHORT_VEC_AVX_FLOAT_8_HPP
10 #define FLAT_ARRAY_DETAIL_SHORT_VEC_AVX_FLOAT_8_HPP
11
12 #if (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX) || \
13 (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX2) || \
14 (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_AVX512F)
15
16 #include <immintrin.h>
17 #include <libflatarray/detail/sqrt_reference.hpp>
18 #include <libflatarray/detail/short_vec_helpers.hpp>
19 #include <libflatarray/config.h>
20
21 #ifdef LIBFLATARRAY_WITH_CPP14
22 #include <initializer_list>
23 #endif
24
25 namespace LibFlatArray {
26
27 template<typename CARGO, int ARITY>
28 class short_vec;
29
30 template<typename CARGO, int ARITY>
31 class sqrt_reference;
32
33 #ifdef __ICC
34 // disabling this warning as implicit type conversion is exactly our goal here:
35 #pragma warning push
36 #pragma warning (disable: 2304)
37 #endif
38
39 template<>
40 class short_vec<float, 8>
41 {
42 public:
43 static const int ARITY = 8;
44 typedef short_vec<float, 8> mask_type;
45 typedef short_vec_strategy::avx strategy;
46
47 template<typename _CharT, typename _Traits>
48 friend std::basic_ostream<_CharT, _Traits>& operator<<(
49 std::basic_ostream<_CharT, _Traits>& __os,
50 const short_vec<float, 8>& vec);
51
52 inline
short_vec(const float data=0)53 short_vec(const float data = 0) :
54 val1(_mm256_broadcast_ss(&data))
55 {}
56
57 inline
short_vec(const float * data)58 short_vec(const float *data)
59 {
60 load(data);
61 }
62
63 inline
short_vec(const __m256 & val1)64 short_vec(const __m256& val1) :
65 val1(val1)
66 {}
67
68 #ifdef LIBFLATARRAY_WITH_CPP14
69 inline
short_vec(const std::initializer_list<float> & il)70 short_vec(const std::initializer_list<float>& il)
71 {
72 const float *ptr = static_cast<const float *>(&(*il.begin()));
73 load(ptr);
74 }
75 #endif
76
77 inline
78 short_vec(const sqrt_reference<float, 8>& other);
79
80 inline
any() const81 bool any() const
82 {
83 // merge both 128-bit lanes of AVX register:
84 __m128 buf0 = _mm_or_ps(
85 _mm256_extractf128_ps(val1, 0),
86 _mm256_extractf128_ps(val1, 1));
87 // shuffle upper 64-bit half down to first 64 bits so we can
88 // "or" both together:
89 __m128 buf1 = _mm_shuffle_ps(buf0, buf0, (3 << 2) | (2 << 0));
90 buf1 = _mm_or_ps(buf0, buf1);
91 // another shuffle to extract 2nd least significant float
92 // member and or it together with least significant float
93 // member:
94 __m128 buf2 = _mm_shuffle_ps(buf1, buf1, (1 << 0));
95 return _mm_cvtss_f32(buf1) || _mm_cvtss_f32(buf2);
96 }
97
98 inline
get(int i) const99 float get(int i) const
100 {
101 __m128 buf;
102 if (i < 4) {
103 buf = _mm256_extractf128_ps(val1, 0);
104 } else {
105 buf = _mm256_extractf128_ps(val1, 1);
106 }
107
108 i &= 3;
109
110 if (i == 3) {
111 return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 3));
112 }
113 if (i == 2) {
114 return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 2));
115 }
116 if (i == 1) {
117 return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 1));
118 }
119
120 return _mm_cvtss_f32(buf);
121 }
122
123 inline
operator -=(const short_vec<float,8> & other)124 void operator-=(const short_vec<float, 8>& other)
125 {
126 val1 = _mm256_sub_ps(val1, other.val1);
127 }
128
129 inline
operator -(const short_vec<float,8> & other) const130 short_vec<float, 8> operator-(const short_vec<float, 8>& other) const
131 {
132 return _mm256_sub_ps(val1, other.val1);
133 }
134
135 inline
operator +=(const short_vec<float,8> & other)136 void operator+=(const short_vec<float, 8>& other)
137 {
138 val1 = _mm256_add_ps(val1, other.val1);
139 }
140
141 inline
operator +(const short_vec<float,8> & other) const142 short_vec<float, 8> operator+(const short_vec<float, 8>& other) const
143 {
144 return _mm256_add_ps(val1, other.val1);
145 }
146
147 inline
operator *=(const short_vec<float,8> & other)148 void operator*=(const short_vec<float, 8>& other)
149 {
150 val1 = _mm256_mul_ps(val1, other.val1);
151 }
152
153 inline
operator *(const short_vec<float,8> & other) const154 short_vec<float, 8> operator*(const short_vec<float, 8>& other) const
155 {
156 return _mm256_mul_ps(val1, other.val1);
157 }
158
159 inline
operator /=(const short_vec<float,8> & other)160 void operator/=(const short_vec<float, 8>& other)
161 {
162 val1 = _mm256_mul_ps(val1, _mm256_rcp_ps(other.val1));
163 }
164
165 inline
166 void operator/=(const sqrt_reference<float, 8>& other);
167
168 inline
operator /(const short_vec<float,8> & other) const169 short_vec<float, 8> operator/(const short_vec<float, 8>& other) const
170 {
171 return _mm256_mul_ps(val1, _mm256_rcp_ps(other.val1));
172 }
173
174 inline
175 short_vec<float, 8> operator/(const sqrt_reference<float, 8>& other) const;
176
177 inline
operator <(const short_vec<float,8> & other) const178 short_vec<float, 8> operator<(const short_vec<float, 8>& other) const
179 {
180 return short_vec<float, 8>(
181 _mm256_cmp_ps(val1, other.val1, _CMP_LT_OS));
182 }
183
184 inline
operator <=(const short_vec<float,8> & other) const185 short_vec<float, 8> operator<=(const short_vec<float, 8>& other) const
186 {
187 return short_vec<float, 8>(
188 _mm256_cmp_ps(val1, other.val1, _CMP_LE_OS));
189 }
190
191 inline
operator ==(const short_vec<float,8> & other) const192 short_vec<float, 8> operator==(const short_vec<float, 8>& other) const
193 {
194 return short_vec<float, 8>(
195 _mm256_cmp_ps(val1, other.val1, _CMP_EQ_OQ));
196 }
197
198 inline
operator >(const short_vec<float,8> & other) const199 short_vec<float, 8> operator>(const short_vec<float, 8>& other) const
200 {
201 return short_vec<float, 8>(
202 _mm256_cmp_ps(val1, other.val1, _CMP_GT_OS));
203 }
204
205 inline
operator >=(const short_vec<float,8> & other) const206 short_vec<float, 8> operator>=(const short_vec<float, 8>& other) const
207 {
208 return short_vec<float, 8>(
209 _mm256_cmp_ps(val1, other.val1, _CMP_GE_OS));
210 }
211
212 inline
sqrt() const213 short_vec<float, 8> sqrt() const
214 {
215 return _mm256_sqrt_ps(val1);
216 }
217
218 inline
load(const float * data)219 void load(const float *data)
220 {
221 val1 = _mm256_loadu_ps(data);
222 }
223
224 inline
load_aligned(const float * data)225 void load_aligned(const float *data)
226 {
227 SHORTVEC_ASSERT_ALIGNED(data, 32);
228 val1 = _mm256_load_ps(data);
229 }
230
231 inline
store(float * data) const232 void store(float *data) const
233 {
234 _mm256_storeu_ps(data, val1);
235 }
236
237 inline
store_aligned(float * data) const238 void store_aligned(float *data) const
239 {
240 SHORTVEC_ASSERT_ALIGNED(data, 32);
241 _mm256_store_ps(data, val1);
242 }
243
244 inline
store_nt(float * data) const245 void store_nt(float *data) const
246 {
247 SHORTVEC_ASSERT_ALIGNED(data, 32);
248 _mm256_stream_ps(data, val1);
249 }
250
251 #ifdef __AVX2__
252 inline
gather(const float * ptr,const int * offsets)253 void gather(const float *ptr, const int *offsets)
254 {
255 __m256i indices;
256 indices = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(offsets));
257 val1 = _mm256_i32gather_ps(ptr, indices, 4);
258 }
259 #else
260 inline
gather(const float * ptr,const int * offsets)261 void gather(const float *ptr, const int *offsets)
262 {
263 __m128 tmp;
264 tmp = _mm_load_ss(ptr + offsets[0]);
265 SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[1], _MM_MK_INSERTPS_NDX(0,1,0));
266 SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[2], _MM_MK_INSERTPS_NDX(0,2,0));
267 SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[3], _MM_MK_INSERTPS_NDX(0,3,0));
268 val1 = _mm256_insertf128_ps(val1, tmp, 0);
269 tmp = _mm_load_ss(ptr + offsets[4]);
270 SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[5], _MM_MK_INSERTPS_NDX(0,1,0));
271 SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[6], _MM_MK_INSERTPS_NDX(0,2,0));
272 SHORTVEC_INSERT_PS_AVX(tmp, ptr, offsets[7], _MM_MK_INSERTPS_NDX(0,3,0));
273 val1 = _mm256_insertf128_ps(val1, tmp, 1);
274 }
275 #endif
276
277 inline
scatter(float * ptr,const int * offsets) const278 void scatter(float *ptr, const int *offsets) const
279 {
280 __m128 tmp;
281 tmp = _mm256_extractf128_ps(val1, 0);
282 _MM_EXTRACT_FLOAT(ptr[offsets[0]], tmp, 0);
283 _MM_EXTRACT_FLOAT(ptr[offsets[1]], tmp, 1);
284 _MM_EXTRACT_FLOAT(ptr[offsets[2]], tmp, 2);
285 _MM_EXTRACT_FLOAT(ptr[offsets[3]], tmp, 3);
286 tmp = _mm256_extractf128_ps(val1, 1);
287 _MM_EXTRACT_FLOAT(ptr[offsets[4]], tmp, 0);
288 _MM_EXTRACT_FLOAT(ptr[offsets[5]], tmp, 1);
289 _MM_EXTRACT_FLOAT(ptr[offsets[6]], tmp, 2);
290 _MM_EXTRACT_FLOAT(ptr[offsets[7]], tmp, 3);
291 }
292
293 private:
294 __m256 val1;
295 };
296
297 inline
operator <<(float * data,const short_vec<float,8> & vec)298 void operator<<(float *data, const short_vec<float, 8>& vec)
299 {
300 vec.store(data);
301 }
302
303 template<>
304 class sqrt_reference<float, 8>
305 {
306 public:
307 template<typename OTHER_CARGO, int OTHER_ARITY>
308 friend class short_vec;
309
sqrt_reference(const short_vec<float,8> & vec)310 sqrt_reference(const short_vec<float, 8>& vec) :
311 vec(vec)
312 {}
313
314 private:
315 short_vec<float, 8> vec;
316 };
317
318 #ifdef __ICC
319 #pragma warning pop
320 #endif
321
322 inline
short_vec(const sqrt_reference<float,8> & other)323 short_vec<float, 8>::short_vec(const sqrt_reference<float, 8>& other) :
324 val1(_mm256_sqrt_ps(other.vec.val1))
325 {}
326
327 inline
operator /=(const sqrt_reference<float,8> & other)328 void short_vec<float, 8>::operator/=(const sqrt_reference<float, 8>& other)
329 {
330 val1 = _mm256_mul_ps(val1, _mm256_rsqrt_ps(other.vec.val1));
331 }
332
333 inline
operator /(const sqrt_reference<float,8> & other) const334 short_vec<float, 8> short_vec<float, 8>::operator/(const sqrt_reference<float, 8>& other) const
335 {
336 return _mm256_mul_ps(val1, _mm256_rsqrt_ps(other.vec.val1));
337 }
338
339 inline
sqrt(const short_vec<float,8> & vec)340 sqrt_reference<float, 8> sqrt(const short_vec<float, 8>& vec)
341 {
342 return sqrt_reference<float, 8>(vec);
343 }
344
345 template<typename _CharT, typename _Traits>
346 std::basic_ostream<_CharT, _Traits>&
operator <<(std::basic_ostream<_CharT,_Traits> & __os,const short_vec<float,8> & vec)347 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
348 const short_vec<float, 8>& vec)
349 {
350 const float *data1 = reinterpret_cast<const float *>(&vec.val1);
351 __os << "[" << data1[0] << ", " << data1[1] << ", " << data1[2] << ", " << data1[3] << ", " << data1[4] << ", " << data1[5] << ", " << data1[6] << ", " << data1[7] << "]";
352 return __os;
353 }
354
355 }
356
357 #endif
358
359 #endif
360