1 /**
2 * Copyright 2014-2016 Andreas Schäfer
3 * Copyright 2015 Kurt Kanzenbach
4 *
5 * Distributed under the Boost Software License, Version 1.0. (See accompanying
6 * file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
7 */
8
9 #ifndef FLAT_ARRAY_DETAIL_SHORT_VEC_SSE_FLOAT_16_HPP
10 #define FLAT_ARRAY_DETAIL_SHORT_VEC_SSE_FLOAT_16_HPP
11
12 #if (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_SSE) || \
13 (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_SSE2) || \
14 (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_SSE4_1)
15
16 #include <emmintrin.h>
17 #include <libflatarray/detail/sqrt_reference.hpp>
18 #include <libflatarray/detail/short_vec_helpers.hpp>
19 #include <libflatarray/config.h>
20
21 #ifdef __SSE4_1__
22 #include <smmintrin.h>
23 #endif
24
25 #ifdef LIBFLATARRAY_WITH_CPP14
26 #include <initializer_list>
27 #endif
28
29 namespace LibFlatArray {
30
31 template<typename CARGO, int ARITY>
32 class short_vec;
33
34 template<typename CARGO, int ARITY>
35 class sqrt_reference;
36
37 #ifdef __ICC
38 // disabling this warning as implicit type conversion is exactly our goal here:
39 #pragma warning push
40 #pragma warning (disable: 2304)
41 #endif
42
43 template<>
44 class short_vec<float, 16>
45 {
46 public:
47 static const int ARITY = 16;
48 typedef short_vec<float, 16> mask_type;
49 typedef short_vec_strategy::sse strategy;
50
51 template<typename _CharT, typename _Traits>
52 friend std::basic_ostream<_CharT, _Traits>& operator<<(
53 std::basic_ostream<_CharT, _Traits>& __os,
54 const short_vec<float, 16>& vec);
55
56 inline
short_vec(const float data=0)57 short_vec(const float data = 0) :
58 val1(_mm_set1_ps(data)),
59 val2(_mm_set1_ps(data)),
60 val3(_mm_set1_ps(data)),
61 val4(_mm_set1_ps(data))
62 {}
63
64 inline
short_vec(const float * data)65 short_vec(const float *data)
66 {
67 load(data);
68 }
69
70 inline
short_vec(const __m128 & val1,const __m128 & val2,const __m128 & val3,const __m128 & val4)71 short_vec(const __m128& val1, const __m128& val2, const __m128& val3, const __m128& val4) :
72 val1(val1),
73 val2(val2),
74 val3(val3),
75 val4(val4)
76 {}
77
78 #ifdef LIBFLATARRAY_WITH_CPP14
79 inline
short_vec(const std::initializer_list<float> & il)80 short_vec(const std::initializer_list<float>& il)
81 {
82 const float *ptr = static_cast<const float *>(&(*il.begin()));
83 load(ptr);
84 }
85 #endif
86
87 inline
88 short_vec(const sqrt_reference<float, 16>& other);
89
90 inline
any() const91 bool any() const
92 {
93 __m128 buf1 = _mm_or_ps(
94 _mm_or_ps(val1, val2),
95 _mm_or_ps(val3, val4));
96 __m128 buf2 = _mm_shuffle_ps(buf1, buf1, (3 << 2) | (2 << 0));
97 buf1 = _mm_or_ps(buf1, buf2);
98 buf2 = _mm_shuffle_ps(buf1, buf1, (1 << 0));
99 return _mm_cvtss_f32(buf1) || _mm_cvtss_f32(buf2);
100 }
101
102 inline
get(int i) const103 float get(int i) const
104 {
105 __m128 buf;
106 if (i < 8) {
107 if (i < 4) {
108 buf = val1;
109 } else {
110 buf = val2;
111 }
112 } else {
113 if (i < 12) {
114 buf = val3;
115 } else {
116 buf = val4;
117 }
118 }
119
120 i &= 3;
121
122 if (i == 3) {
123 return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 3));
124 }
125 if (i == 2) {
126 return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 2));
127 }
128 if (i == 1) {
129 return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 1));
130 }
131
132 return _mm_cvtss_f32(buf);
133 }
134
135 inline
operator -=(const short_vec<float,16> & other)136 void operator-=(const short_vec<float, 16>& other)
137 {
138 val1 = _mm_sub_ps(val1, other.val1);
139 val2 = _mm_sub_ps(val2, other.val2);
140 val3 = _mm_sub_ps(val3, other.val3);
141 val4 = _mm_sub_ps(val4, other.val4);
142 }
143
144 inline
operator -(const short_vec<float,16> & other) const145 short_vec<float, 16> operator-(const short_vec<float, 16>& other) const
146 {
147 return short_vec<float, 16>(
148 _mm_sub_ps(val1, other.val1),
149 _mm_sub_ps(val2, other.val2),
150 _mm_sub_ps(val3, other.val3),
151 _mm_sub_ps(val4, other.val4));
152 }
153
154 inline
operator +=(const short_vec<float,16> & other)155 void operator+=(const short_vec<float, 16>& other)
156 {
157 val1 = _mm_add_ps(val1, other.val1);
158 val2 = _mm_add_ps(val2, other.val2);
159 val3 = _mm_add_ps(val3, other.val3);
160 val4 = _mm_add_ps(val4, other.val4);
161 }
162
163 inline
operator +(const short_vec<float,16> & other) const164 short_vec<float, 16> operator+(const short_vec<float, 16>& other) const
165 {
166 return short_vec<float, 16>(
167 _mm_add_ps(val1, other.val1),
168 _mm_add_ps(val2, other.val2),
169 _mm_add_ps(val3, other.val3),
170 _mm_add_ps(val4, other.val4));
171 }
172
173 inline
operator *=(const short_vec<float,16> & other)174 void operator*=(const short_vec<float, 16>& other)
175 {
176 val1 = _mm_mul_ps(val1, other.val1);
177 val2 = _mm_mul_ps(val2, other.val2);
178 val3 = _mm_mul_ps(val3, other.val3);
179 val4 = _mm_mul_ps(val4, other.val4);
180 }
181
182 inline
operator *(const short_vec<float,16> & other) const183 short_vec<float, 16> operator*(const short_vec<float, 16>& other) const
184 {
185 return short_vec<float, 16>(
186 _mm_mul_ps(val1, other.val1),
187 _mm_mul_ps(val2, other.val2),
188 _mm_mul_ps(val3, other.val3),
189 _mm_mul_ps(val4, other.val4));
190 }
191
192 inline
operator /=(const short_vec<float,16> & other)193 void operator/=(const short_vec<float, 16>& other)
194 {
195 val1 = _mm_div_ps(val1, other.val1);
196 val2 = _mm_div_ps(val2, other.val2);
197 val3 = _mm_div_ps(val3, other.val3);
198 val4 = _mm_div_ps(val4, other.val4);
199 }
200
201 inline
202 void operator/=(const sqrt_reference<float, 16>& other);
203
204 inline
operator /(const short_vec<float,16> & other) const205 short_vec<float, 16> operator/(const short_vec<float, 16>& other) const
206 {
207 return short_vec<float, 16>(
208 _mm_div_ps(val1, other.val1),
209 _mm_div_ps(val2, other.val2),
210 _mm_div_ps(val3, other.val3),
211 _mm_div_ps(val4, other.val4));
212 }
213
214 inline
215 short_vec<float, 16> operator/(const sqrt_reference<float, 16>& other) const;
216
217 inline
operator <(const short_vec<float,16> & other) const218 short_vec<float, 16> operator<(const short_vec<float, 16>& other) const
219 {
220 return short_vec<float, 16>(
221 _mm_cmplt_ps(val1, other.val1),
222 _mm_cmplt_ps(val2, other.val2),
223 _mm_cmplt_ps(val3, other.val3),
224 _mm_cmplt_ps(val4, other.val4));
225 }
226
227 inline
operator <=(const short_vec<float,16> & other) const228 short_vec<float, 16> operator<=(const short_vec<float, 16>& other) const
229 {
230 return short_vec<float, 16>(
231 _mm_cmple_ps(val1, other.val1),
232 _mm_cmple_ps(val2, other.val2),
233 _mm_cmple_ps(val3, other.val3),
234 _mm_cmple_ps(val4, other.val4));
235 }
236
237 inline
operator ==(const short_vec<float,16> & other) const238 short_vec<float, 16> operator==(const short_vec<float, 16>& other) const
239 {
240 return short_vec<float, 16>(
241 _mm_cmpeq_ps(val1, other.val1),
242 _mm_cmpeq_ps(val2, other.val2),
243 _mm_cmpeq_ps(val3, other.val3),
244 _mm_cmpeq_ps(val4, other.val4));
245 }
246
247 inline
operator >(const short_vec<float,16> & other) const248 short_vec<float, 16> operator>(const short_vec<float, 16>& other) const
249 {
250 return short_vec<float, 16>(
251 _mm_cmpgt_ps(val1, other.val1),
252 _mm_cmpgt_ps(val2, other.val2),
253 _mm_cmpgt_ps(val3, other.val3),
254 _mm_cmpgt_ps(val4, other.val4));
255 }
256
257 inline
operator >=(const short_vec<float,16> & other) const258 short_vec<float, 16> operator>=(const short_vec<float, 16>& other) const
259 {
260 return short_vec<float, 16>(
261 _mm_cmpge_ps(val1, other.val1),
262 _mm_cmpge_ps(val2, other.val2),
263 _mm_cmpge_ps(val3, other.val3),
264 _mm_cmpge_ps(val4, other.val4));
265 }
266
267 inline
sqrt() const268 short_vec<float, 16> sqrt() const
269 {
270 return short_vec<float, 16>(
271 _mm_sqrt_ps(val1),
272 _mm_sqrt_ps(val2),
273 _mm_sqrt_ps(val3),
274 _mm_sqrt_ps(val4));
275 }
276
277 inline
load(const float * data)278 void load(const float *data)
279 {
280 val1 = _mm_loadu_ps(data + 0);
281 val2 = _mm_loadu_ps(data + 4);
282 val3 = _mm_loadu_ps(data + 8);
283 val4 = _mm_loadu_ps(data + 12);
284 }
285
286 inline
load_aligned(const float * data)287 void load_aligned(const float *data)
288 {
289 SHORTVEC_ASSERT_ALIGNED(data, 16);
290 val1 = _mm_load_ps(data + 0);
291 val2 = _mm_load_ps(data + 4);
292 val3 = _mm_load_ps(data + 8);
293 val4 = _mm_load_ps(data + 12);
294 }
295
296 inline
store(float * data) const297 void store(float *data) const
298 {
299 _mm_storeu_ps(data + 0, val1);
300 _mm_storeu_ps(data + 4, val2);
301 _mm_storeu_ps(data + 8, val3);
302 _mm_storeu_ps(data + 12, val4);
303 }
304
305 inline
store_aligned(float * data) const306 void store_aligned(float *data) const
307 {
308 SHORTVEC_ASSERT_ALIGNED(data, 16);
309 _mm_store_ps(data + 0, val1);
310 _mm_store_ps(data + 4, val2);
311 _mm_store_ps(data + 8, val3);
312 _mm_store_ps(data + 12, val4);
313 }
314
315 inline
store_nt(float * data) const316 void store_nt(float *data) const
317 {
318 SHORTVEC_ASSERT_ALIGNED(data, 16);
319 _mm_stream_ps(data + 0, val1);
320 _mm_stream_ps(data + 4, val2);
321 _mm_stream_ps(data + 8, val3);
322 _mm_stream_ps(data + 12, val4);
323 }
324
325 #ifdef __SSE4_1__
326 inline
gather(const float * ptr,const int * offsets)327 void gather(const float *ptr, const int *offsets)
328 {
329 val1 = _mm_load_ss(ptr + offsets[0]);
330 SHORTVEC_INSERT_PS(val1, ptr, offsets[ 1], _MM_MK_INSERTPS_NDX(0,1,0));
331 SHORTVEC_INSERT_PS(val1, ptr, offsets[ 2], _MM_MK_INSERTPS_NDX(0,2,0));
332 SHORTVEC_INSERT_PS(val1, ptr, offsets[ 3], _MM_MK_INSERTPS_NDX(0,3,0));
333 val2 = _mm_load_ss(ptr + offsets[4]);
334 SHORTVEC_INSERT_PS(val2, ptr, offsets[ 5], _MM_MK_INSERTPS_NDX(0,1,0));
335 SHORTVEC_INSERT_PS(val2, ptr, offsets[ 6], _MM_MK_INSERTPS_NDX(0,2,0));
336 SHORTVEC_INSERT_PS(val2, ptr, offsets[ 7], _MM_MK_INSERTPS_NDX(0,3,0));
337 val3 = _mm_load_ss(ptr + offsets[8]);
338 SHORTVEC_INSERT_PS(val3, ptr, offsets[ 9], _MM_MK_INSERTPS_NDX(0,1,0));
339 SHORTVEC_INSERT_PS(val3, ptr, offsets[10], _MM_MK_INSERTPS_NDX(0,2,0));
340 SHORTVEC_INSERT_PS(val3, ptr, offsets[11], _MM_MK_INSERTPS_NDX(0,3,0));
341 val4 = _mm_load_ss(ptr + offsets[12]);
342 SHORTVEC_INSERT_PS(val4, ptr, offsets[13], _MM_MK_INSERTPS_NDX(0,1,0));
343 SHORTVEC_INSERT_PS(val4, ptr, offsets[14], _MM_MK_INSERTPS_NDX(0,2,0));
344 SHORTVEC_INSERT_PS(val4, ptr, offsets[15], _MM_MK_INSERTPS_NDX(0,3,0));
345 }
346
347 inline
scatter(float * ptr,const int * offsets) const348 void scatter(float *ptr, const int *offsets) const
349 {
350 ShortVecHelpers::ExtractResult r1, r2, r3, r4;
351 r1.i = _mm_extract_ps(val1, 0);
352 r2.i = _mm_extract_ps(val1, 1);
353 r3.i = _mm_extract_ps(val1, 2);
354 r4.i = _mm_extract_ps(val1, 3);
355 ptr[offsets[0]] = r1.f;
356 ptr[offsets[1]] = r2.f;
357 ptr[offsets[2]] = r3.f;
358 ptr[offsets[3]] = r4.f;
359 r1.i = _mm_extract_ps(val2, 0);
360 r2.i = _mm_extract_ps(val2, 1);
361 r3.i = _mm_extract_ps(val2, 2);
362 r4.i = _mm_extract_ps(val2, 3);
363 ptr[offsets[4]] = r1.f;
364 ptr[offsets[5]] = r2.f;
365 ptr[offsets[6]] = r3.f;
366 ptr[offsets[7]] = r4.f;
367 r1.i = _mm_extract_ps(val3, 0);
368 r2.i = _mm_extract_ps(val3, 1);
369 r3.i = _mm_extract_ps(val3, 2);
370 r4.i = _mm_extract_ps(val3, 3);
371 ptr[offsets[ 8]] = r1.f;
372 ptr[offsets[ 9]] = r2.f;
373 ptr[offsets[10]] = r3.f;
374 ptr[offsets[11]] = r4.f;
375 r1.i = _mm_extract_ps(val4, 0);
376 r2.i = _mm_extract_ps(val4, 1);
377 r3.i = _mm_extract_ps(val4, 2);
378 r4.i = _mm_extract_ps(val4, 3);
379 ptr[offsets[12]] = r1.f;
380 ptr[offsets[13]] = r2.f;
381 ptr[offsets[14]] = r3.f;
382 ptr[offsets[15]] = r4.f;
383 }
384 #else
385 inline
gather(const float * ptr,const int * offsets)386 void gather(const float *ptr, const int *offsets)
387 {
388 __m128 f1, f2, f3, f4;
389 f1 = _mm_load_ss(ptr + offsets[0]);
390 f2 = _mm_load_ss(ptr + offsets[2]);
391 f1 = _mm_unpacklo_ps(f1, f2);
392 f3 = _mm_load_ss(ptr + offsets[1]);
393 f4 = _mm_load_ss(ptr + offsets[3]);
394 f3 = _mm_unpacklo_ps(f3, f4);
395 val1 = _mm_unpacklo_ps(f1, f3);
396 f1 = _mm_load_ss(ptr + offsets[4]);
397 f2 = _mm_load_ss(ptr + offsets[6]);
398 f1 = _mm_unpacklo_ps(f1, f2);
399 f3 = _mm_load_ss(ptr + offsets[5]);
400 f4 = _mm_load_ss(ptr + offsets[7]);
401 f3 = _mm_unpacklo_ps(f3, f4);
402 val2 = _mm_unpacklo_ps(f1, f3);
403 f1 = _mm_load_ss(ptr + offsets[ 8]);
404 f2 = _mm_load_ss(ptr + offsets[10]);
405 f1 = _mm_unpacklo_ps(f1, f2);
406 f3 = _mm_load_ss(ptr + offsets[ 9]);
407 f4 = _mm_load_ss(ptr + offsets[11]);
408 f3 = _mm_unpacklo_ps(f3, f4);
409 val3 = _mm_unpacklo_ps(f1, f3);
410 f1 = _mm_load_ss(ptr + offsets[12]);
411 f2 = _mm_load_ss(ptr + offsets[14]);
412 f1 = _mm_unpacklo_ps(f1, f2);
413 f3 = _mm_load_ss(ptr + offsets[13]);
414 f4 = _mm_load_ss(ptr + offsets[15]);
415 f3 = _mm_unpacklo_ps(f3, f4);
416 val4 = _mm_unpacklo_ps(f1, f3);
417 }
418
419 inline
scatter(float * ptr,const int * offsets) const420 void scatter(float *ptr, const int *offsets) const
421 {
422 __m128 tmp = val1;
423 _mm_store_ss(ptr + offsets[0], tmp);
424 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
425 _mm_store_ss(ptr + offsets[1], tmp);
426 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
427 _mm_store_ss(ptr + offsets[2], tmp);
428 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
429 _mm_store_ss(ptr + offsets[3], tmp);
430 tmp = val2;
431 _mm_store_ss(ptr + offsets[4], tmp);
432 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
433 _mm_store_ss(ptr + offsets[5], tmp);
434 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
435 _mm_store_ss(ptr + offsets[6], tmp);
436 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
437 _mm_store_ss(ptr + offsets[7], tmp);
438 tmp = val3;
439 _mm_store_ss(ptr + offsets[8], tmp);
440 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
441 _mm_store_ss(ptr + offsets[9], tmp);
442 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
443 _mm_store_ss(ptr + offsets[10], tmp);
444 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
445 _mm_store_ss(ptr + offsets[11], tmp);
446 tmp = val4;
447 _mm_store_ss(ptr + offsets[12], tmp);
448 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
449 _mm_store_ss(ptr + offsets[13], tmp);
450 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
451 _mm_store_ss(ptr + offsets[14], tmp);
452 tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
453 _mm_store_ss(ptr + offsets[15], tmp);
454 }
455 #endif
456
457 private:
458 __m128 val1;
459 __m128 val2;
460 __m128 val3;
461 __m128 val4;
462 };
463
464 inline
operator <<(float * data,const short_vec<float,16> & vec)465 void operator<<(float *data, const short_vec<float, 16>& vec)
466 {
467 vec.store(data);
468 }
469
470 template<>
471 class sqrt_reference<float, 16>
472 {
473 public:
474 template<typename OTHER_CARGO, int OTHER_ARITY>
475 friend class short_vec;
476
sqrt_reference(const short_vec<float,16> & vec)477 sqrt_reference(const short_vec<float, 16>& vec) :
478 vec(vec)
479 {}
480
481 private:
482 short_vec<float, 16> vec;
483 };
484
485 #ifdef __ICC
486 #pragma warning pop
487 #endif
488
489 inline
short_vec(const sqrt_reference<float,16> & other)490 short_vec<float, 16>::short_vec(const sqrt_reference<float, 16>& other) :
491 val1(_mm_sqrt_ps(other.vec.val1)),
492 val2(_mm_sqrt_ps(other.vec.val2)),
493 val3(_mm_sqrt_ps(other.vec.val3)),
494 val4(_mm_sqrt_ps(other.vec.val4))
495 {}
496
497 inline
operator /=(const sqrt_reference<float,16> & other)498 void short_vec<float, 16>::operator/=(const sqrt_reference<float, 16>& other)
499 {
500 val1 = _mm_mul_ps(val1, _mm_rsqrt_ps(other.vec.val1));
501 val2 = _mm_mul_ps(val2, _mm_rsqrt_ps(other.vec.val2));
502 val3 = _mm_mul_ps(val3, _mm_rsqrt_ps(other.vec.val3));
503 val4 = _mm_mul_ps(val4, _mm_rsqrt_ps(other.vec.val4));
504 }
505
506 inline
operator /(const sqrt_reference<float,16> & other) const507 short_vec<float, 16> short_vec<float, 16>::operator/(const sqrt_reference<float, 16>& other) const
508 {
509 return short_vec<float, 16>(
510 _mm_mul_ps(val1, _mm_rsqrt_ps(other.vec.val1)),
511 _mm_mul_ps(val2, _mm_rsqrt_ps(other.vec.val2)),
512 _mm_mul_ps(val3, _mm_rsqrt_ps(other.vec.val3)),
513 _mm_mul_ps(val4, _mm_rsqrt_ps(other.vec.val4)));
514 }
515
516 inline
sqrt(const short_vec<float,16> & vec)517 sqrt_reference<float, 16> sqrt(const short_vec<float, 16>& vec)
518 {
519 return sqrt_reference<float, 16>(vec);
520 }
521
522 template<typename _CharT, typename _Traits>
523 std::basic_ostream<_CharT, _Traits>&
operator <<(std::basic_ostream<_CharT,_Traits> & __os,const short_vec<float,16> & vec)524 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
525 const short_vec<float, 16>& vec)
526 {
527 const float *data1 = reinterpret_cast<const float *>(&vec.val1);
528 const float *data2 = reinterpret_cast<const float *>(&vec.val2);
529 const float *data3 = reinterpret_cast<const float *>(&vec.val3);
530 const float *data4 = reinterpret_cast<const float *>(&vec.val4);
531 __os << "["
532 << data1[0] << ", " << data1[1] << ", " << data1[2] << ", " << data1[3] << ", "
533 << data2[0] << ", " << data2[1] << ", " << data2[2] << ", " << data2[3] << ", "
534 << data3[0] << ", " << data3[1] << ", " << data3[2] << ", " << data3[3] << ", "
535 << data4[0] << ", " << data4[1] << ", " << data4[2] << ", " << data4[3] << "]";
536 return __os;
537 }
538
539 }
540
541 #endif
542
543 #endif
544