1 /**
2  * Copyright 2014-2016 Andreas Schäfer
3  * Copyright 2015 Kurt Kanzenbach
4  *
5  * Distributed under the Boost Software License, Version 1.0. (See accompanying
6  * file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
7  */
8 
9 #ifndef FLAT_ARRAY_DETAIL_SHORT_VEC_SSE_FLOAT_16_HPP
10 #define FLAT_ARRAY_DETAIL_SHORT_VEC_SSE_FLOAT_16_HPP
11 
12 #if (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_SSE) ||             \
13     (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_SSE2) ||            \
14     (LIBFLATARRAY_WIDEST_VECTOR_ISA == LIBFLATARRAY_SSE4_1)
15 
16 #include <emmintrin.h>
17 #include <libflatarray/detail/sqrt_reference.hpp>
18 #include <libflatarray/detail/short_vec_helpers.hpp>
19 #include <libflatarray/config.h>
20 
21 #ifdef __SSE4_1__
22 #include <smmintrin.h>
23 #endif
24 
25 #ifdef LIBFLATARRAY_WITH_CPP14
26 #include <initializer_list>
27 #endif
28 
29 namespace LibFlatArray {
30 
31 template<typename CARGO, int ARITY>
32 class short_vec;
33 
34 template<typename CARGO, int ARITY>
35 class sqrt_reference;
36 
37 #ifdef __ICC
38 // disabling this warning as implicit type conversion is exactly our goal here:
39 #pragma warning push
40 #pragma warning (disable: 2304)
41 #endif
42 
43 template<>
44 class short_vec<float, 16>
45 {
46 public:
47     static const int ARITY = 16;
48     typedef short_vec<float, 16> mask_type;
49     typedef short_vec_strategy::sse strategy;
50 
51     template<typename _CharT, typename _Traits>
52     friend std::basic_ostream<_CharT, _Traits>& operator<<(
53         std::basic_ostream<_CharT, _Traits>& __os,
54         const short_vec<float, 16>& vec);
55 
56     inline
short_vec(const float data=0)57     short_vec(const float data = 0) :
58         val1(_mm_set1_ps(data)),
59         val2(_mm_set1_ps(data)),
60         val3(_mm_set1_ps(data)),
61         val4(_mm_set1_ps(data))
62     {}
63 
64     inline
short_vec(const float * data)65     short_vec(const float *data)
66     {
67         load(data);
68     }
69 
70     inline
short_vec(const __m128 & val1,const __m128 & val2,const __m128 & val3,const __m128 & val4)71     short_vec(const __m128& val1, const __m128& val2, const __m128& val3, const __m128& val4) :
72         val1(val1),
73         val2(val2),
74         val3(val3),
75         val4(val4)
76     {}
77 
78 #ifdef LIBFLATARRAY_WITH_CPP14
79     inline
short_vec(const std::initializer_list<float> & il)80     short_vec(const std::initializer_list<float>& il)
81     {
82         const float *ptr = static_cast<const float *>(&(*il.begin()));
83         load(ptr);
84     }
85 #endif
86 
87     inline
88     short_vec(const sqrt_reference<float, 16>& other);
89 
90     inline
any() const91     bool any() const
92     {
93         __m128 buf1 = _mm_or_ps(
94             _mm_or_ps(val1, val2),
95             _mm_or_ps(val3, val4));
96         __m128 buf2 = _mm_shuffle_ps(buf1, buf1, (3 << 2) | (2 << 0));
97         buf1 = _mm_or_ps(buf1, buf2);
98         buf2 = _mm_shuffle_ps(buf1, buf1, (1 << 0));
99         return _mm_cvtss_f32(buf1) || _mm_cvtss_f32(buf2);
100     }
101 
102     inline
get(int i) const103     float get(int i) const
104     {
105         __m128 buf;
106         if (i < 8) {
107             if (i < 4) {
108                 buf = val1;
109             } else {
110                 buf = val2;
111             }
112         } else {
113             if (i < 12) {
114                 buf = val3;
115             } else {
116                 buf = val4;
117             }
118         }
119 
120         i &= 3;
121 
122         if (i == 3) {
123             return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 3));
124         }
125         if (i == 2) {
126             return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 2));
127         }
128         if (i == 1) {
129             return _mm_cvtss_f32(_mm_shuffle_ps(buf, buf, 1));
130         }
131 
132         return _mm_cvtss_f32(buf);
133     }
134 
135     inline
operator -=(const short_vec<float,16> & other)136     void operator-=(const short_vec<float, 16>& other)
137     {
138         val1 = _mm_sub_ps(val1, other.val1);
139         val2 = _mm_sub_ps(val2, other.val2);
140         val3 = _mm_sub_ps(val3, other.val3);
141         val4 = _mm_sub_ps(val4, other.val4);
142     }
143 
144     inline
operator -(const short_vec<float,16> & other) const145     short_vec<float, 16> operator-(const short_vec<float, 16>& other) const
146     {
147         return short_vec<float, 16>(
148             _mm_sub_ps(val1, other.val1),
149             _mm_sub_ps(val2, other.val2),
150             _mm_sub_ps(val3, other.val3),
151             _mm_sub_ps(val4, other.val4));
152     }
153 
154     inline
operator +=(const short_vec<float,16> & other)155     void operator+=(const short_vec<float, 16>& other)
156     {
157         val1 = _mm_add_ps(val1, other.val1);
158         val2 = _mm_add_ps(val2, other.val2);
159         val3 = _mm_add_ps(val3, other.val3);
160         val4 = _mm_add_ps(val4, other.val4);
161     }
162 
163     inline
operator +(const short_vec<float,16> & other) const164     short_vec<float, 16> operator+(const short_vec<float, 16>& other) const
165     {
166         return short_vec<float, 16>(
167             _mm_add_ps(val1, other.val1),
168             _mm_add_ps(val2, other.val2),
169             _mm_add_ps(val3, other.val3),
170             _mm_add_ps(val4, other.val4));
171     }
172 
173     inline
operator *=(const short_vec<float,16> & other)174     void operator*=(const short_vec<float, 16>& other)
175     {
176         val1 = _mm_mul_ps(val1, other.val1);
177         val2 = _mm_mul_ps(val2, other.val2);
178         val3 = _mm_mul_ps(val3, other.val3);
179         val4 = _mm_mul_ps(val4, other.val4);
180     }
181 
182     inline
operator *(const short_vec<float,16> & other) const183     short_vec<float, 16> operator*(const short_vec<float, 16>& other) const
184     {
185         return short_vec<float, 16>(
186             _mm_mul_ps(val1, other.val1),
187             _mm_mul_ps(val2, other.val2),
188             _mm_mul_ps(val3, other.val3),
189             _mm_mul_ps(val4, other.val4));
190     }
191 
192     inline
operator /=(const short_vec<float,16> & other)193     void operator/=(const short_vec<float, 16>& other)
194     {
195         val1 = _mm_div_ps(val1, other.val1);
196         val2 = _mm_div_ps(val2, other.val2);
197         val3 = _mm_div_ps(val3, other.val3);
198         val4 = _mm_div_ps(val4, other.val4);
199     }
200 
201     inline
202     void operator/=(const sqrt_reference<float, 16>& other);
203 
204     inline
operator /(const short_vec<float,16> & other) const205     short_vec<float, 16> operator/(const short_vec<float, 16>& other) const
206     {
207         return short_vec<float, 16>(
208             _mm_div_ps(val1, other.val1),
209             _mm_div_ps(val2, other.val2),
210             _mm_div_ps(val3, other.val3),
211             _mm_div_ps(val4, other.val4));
212     }
213 
214     inline
215     short_vec<float, 16> operator/(const sqrt_reference<float, 16>& other) const;
216 
217     inline
operator <(const short_vec<float,16> & other) const218     short_vec<float, 16> operator<(const short_vec<float, 16>& other) const
219     {
220         return short_vec<float, 16>(
221             _mm_cmplt_ps(val1, other.val1),
222             _mm_cmplt_ps(val2, other.val2),
223             _mm_cmplt_ps(val3, other.val3),
224             _mm_cmplt_ps(val4, other.val4));
225     }
226 
227     inline
operator <=(const short_vec<float,16> & other) const228     short_vec<float, 16> operator<=(const short_vec<float, 16>& other) const
229     {
230         return short_vec<float, 16>(
231             _mm_cmple_ps(val1, other.val1),
232             _mm_cmple_ps(val2, other.val2),
233             _mm_cmple_ps(val3, other.val3),
234             _mm_cmple_ps(val4, other.val4));
235     }
236 
237     inline
operator ==(const short_vec<float,16> & other) const238     short_vec<float, 16> operator==(const short_vec<float, 16>& other) const
239     {
240         return short_vec<float, 16>(
241             _mm_cmpeq_ps(val1, other.val1),
242             _mm_cmpeq_ps(val2, other.val2),
243             _mm_cmpeq_ps(val3, other.val3),
244             _mm_cmpeq_ps(val4, other.val4));
245     }
246 
247     inline
operator >(const short_vec<float,16> & other) const248     short_vec<float, 16> operator>(const short_vec<float, 16>& other) const
249     {
250         return short_vec<float, 16>(
251             _mm_cmpgt_ps(val1, other.val1),
252             _mm_cmpgt_ps(val2, other.val2),
253             _mm_cmpgt_ps(val3, other.val3),
254             _mm_cmpgt_ps(val4, other.val4));
255     }
256 
257     inline
operator >=(const short_vec<float,16> & other) const258     short_vec<float, 16> operator>=(const short_vec<float, 16>& other) const
259     {
260         return short_vec<float, 16>(
261             _mm_cmpge_ps(val1, other.val1),
262             _mm_cmpge_ps(val2, other.val2),
263             _mm_cmpge_ps(val3, other.val3),
264             _mm_cmpge_ps(val4, other.val4));
265     }
266 
267     inline
sqrt() const268     short_vec<float, 16> sqrt() const
269     {
270         return short_vec<float, 16>(
271             _mm_sqrt_ps(val1),
272             _mm_sqrt_ps(val2),
273             _mm_sqrt_ps(val3),
274             _mm_sqrt_ps(val4));
275     }
276 
277     inline
load(const float * data)278     void load(const float *data)
279     {
280         val1 = _mm_loadu_ps(data +  0);
281         val2 = _mm_loadu_ps(data +  4);
282         val3 = _mm_loadu_ps(data +  8);
283         val4 = _mm_loadu_ps(data + 12);
284     }
285 
286     inline
load_aligned(const float * data)287     void load_aligned(const float *data)
288     {
289         SHORTVEC_ASSERT_ALIGNED(data, 16);
290         val1 = _mm_load_ps(data +  0);
291         val2 = _mm_load_ps(data +  4);
292         val3 = _mm_load_ps(data +  8);
293         val4 = _mm_load_ps(data + 12);
294     }
295 
296     inline
store(float * data) const297     void store(float *data) const
298     {
299         _mm_storeu_ps(data +  0, val1);
300         _mm_storeu_ps(data +  4, val2);
301         _mm_storeu_ps(data +  8, val3);
302         _mm_storeu_ps(data + 12, val4);
303     }
304 
305     inline
store_aligned(float * data) const306     void store_aligned(float *data) const
307     {
308         SHORTVEC_ASSERT_ALIGNED(data, 16);
309         _mm_store_ps(data +  0, val1);
310         _mm_store_ps(data +  4, val2);
311         _mm_store_ps(data +  8, val3);
312         _mm_store_ps(data + 12, val4);
313     }
314 
315     inline
store_nt(float * data) const316     void store_nt(float *data) const
317     {
318         SHORTVEC_ASSERT_ALIGNED(data, 16);
319         _mm_stream_ps(data +  0, val1);
320         _mm_stream_ps(data +  4, val2);
321         _mm_stream_ps(data +  8, val3);
322         _mm_stream_ps(data + 12, val4);
323     }
324 
325 #ifdef __SSE4_1__
326     inline
gather(const float * ptr,const int * offsets)327     void gather(const float *ptr, const int *offsets)
328     {
329         val1 = _mm_load_ss(ptr + offsets[0]);
330         SHORTVEC_INSERT_PS(val1, ptr, offsets[ 1], _MM_MK_INSERTPS_NDX(0,1,0));
331         SHORTVEC_INSERT_PS(val1, ptr, offsets[ 2], _MM_MK_INSERTPS_NDX(0,2,0));
332         SHORTVEC_INSERT_PS(val1, ptr, offsets[ 3], _MM_MK_INSERTPS_NDX(0,3,0));
333         val2 = _mm_load_ss(ptr + offsets[4]);
334         SHORTVEC_INSERT_PS(val2, ptr, offsets[ 5], _MM_MK_INSERTPS_NDX(0,1,0));
335         SHORTVEC_INSERT_PS(val2, ptr, offsets[ 6], _MM_MK_INSERTPS_NDX(0,2,0));
336         SHORTVEC_INSERT_PS(val2, ptr, offsets[ 7], _MM_MK_INSERTPS_NDX(0,3,0));
337         val3 = _mm_load_ss(ptr + offsets[8]);
338         SHORTVEC_INSERT_PS(val3, ptr, offsets[ 9], _MM_MK_INSERTPS_NDX(0,1,0));
339         SHORTVEC_INSERT_PS(val3, ptr, offsets[10], _MM_MK_INSERTPS_NDX(0,2,0));
340         SHORTVEC_INSERT_PS(val3, ptr, offsets[11], _MM_MK_INSERTPS_NDX(0,3,0));
341         val4 = _mm_load_ss(ptr + offsets[12]);
342         SHORTVEC_INSERT_PS(val4, ptr, offsets[13], _MM_MK_INSERTPS_NDX(0,1,0));
343         SHORTVEC_INSERT_PS(val4, ptr, offsets[14], _MM_MK_INSERTPS_NDX(0,2,0));
344         SHORTVEC_INSERT_PS(val4, ptr, offsets[15], _MM_MK_INSERTPS_NDX(0,3,0));
345     }
346 
347     inline
scatter(float * ptr,const int * offsets) const348     void scatter(float *ptr, const int *offsets) const
349     {
350         ShortVecHelpers::ExtractResult r1, r2, r3, r4;
351         r1.i = _mm_extract_ps(val1, 0);
352         r2.i = _mm_extract_ps(val1, 1);
353         r3.i = _mm_extract_ps(val1, 2);
354         r4.i = _mm_extract_ps(val1, 3);
355         ptr[offsets[0]] = r1.f;
356         ptr[offsets[1]] = r2.f;
357         ptr[offsets[2]] = r3.f;
358         ptr[offsets[3]] = r4.f;
359         r1.i = _mm_extract_ps(val2, 0);
360         r2.i = _mm_extract_ps(val2, 1);
361         r3.i = _mm_extract_ps(val2, 2);
362         r4.i = _mm_extract_ps(val2, 3);
363         ptr[offsets[4]] = r1.f;
364         ptr[offsets[5]] = r2.f;
365         ptr[offsets[6]] = r3.f;
366         ptr[offsets[7]] = r4.f;
367         r1.i = _mm_extract_ps(val3, 0);
368         r2.i = _mm_extract_ps(val3, 1);
369         r3.i = _mm_extract_ps(val3, 2);
370         r4.i = _mm_extract_ps(val3, 3);
371         ptr[offsets[ 8]] = r1.f;
372         ptr[offsets[ 9]] = r2.f;
373         ptr[offsets[10]] = r3.f;
374         ptr[offsets[11]] = r4.f;
375         r1.i = _mm_extract_ps(val4, 0);
376         r2.i = _mm_extract_ps(val4, 1);
377         r3.i = _mm_extract_ps(val4, 2);
378         r4.i = _mm_extract_ps(val4, 3);
379         ptr[offsets[12]] = r1.f;
380         ptr[offsets[13]] = r2.f;
381         ptr[offsets[14]] = r3.f;
382         ptr[offsets[15]] = r4.f;
383     }
384 #else
385     inline
gather(const float * ptr,const int * offsets)386     void gather(const float *ptr, const int *offsets)
387     {
388         __m128 f1, f2, f3, f4;
389         f1   = _mm_load_ss(ptr + offsets[0]);
390         f2   = _mm_load_ss(ptr + offsets[2]);
391         f1   = _mm_unpacklo_ps(f1, f2);
392         f3   = _mm_load_ss(ptr + offsets[1]);
393         f4   = _mm_load_ss(ptr + offsets[3]);
394         f3   = _mm_unpacklo_ps(f3, f4);
395         val1 = _mm_unpacklo_ps(f1, f3);
396         f1   = _mm_load_ss(ptr + offsets[4]);
397         f2   = _mm_load_ss(ptr + offsets[6]);
398         f1   = _mm_unpacklo_ps(f1, f2);
399         f3   = _mm_load_ss(ptr + offsets[5]);
400         f4   = _mm_load_ss(ptr + offsets[7]);
401         f3   = _mm_unpacklo_ps(f3, f4);
402         val2 = _mm_unpacklo_ps(f1, f3);
403         f1   = _mm_load_ss(ptr + offsets[ 8]);
404         f2   = _mm_load_ss(ptr + offsets[10]);
405         f1   = _mm_unpacklo_ps(f1, f2);
406         f3   = _mm_load_ss(ptr + offsets[ 9]);
407         f4   = _mm_load_ss(ptr + offsets[11]);
408         f3   = _mm_unpacklo_ps(f3, f4);
409         val3 = _mm_unpacklo_ps(f1, f3);
410         f1   = _mm_load_ss(ptr + offsets[12]);
411         f2   = _mm_load_ss(ptr + offsets[14]);
412         f1   = _mm_unpacklo_ps(f1, f2);
413         f3   = _mm_load_ss(ptr + offsets[13]);
414         f4   = _mm_load_ss(ptr + offsets[15]);
415         f3   = _mm_unpacklo_ps(f3, f4);
416         val4 = _mm_unpacklo_ps(f1, f3);
417     }
418 
419     inline
scatter(float * ptr,const int * offsets) const420     void scatter(float *ptr, const int *offsets) const
421     {
422         __m128 tmp = val1;
423         _mm_store_ss(ptr + offsets[0], tmp);
424         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
425         _mm_store_ss(ptr + offsets[1], tmp);
426         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
427         _mm_store_ss(ptr + offsets[2], tmp);
428         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
429         _mm_store_ss(ptr + offsets[3], tmp);
430         tmp = val2;
431         _mm_store_ss(ptr + offsets[4], tmp);
432         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
433         _mm_store_ss(ptr + offsets[5], tmp);
434         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
435         _mm_store_ss(ptr + offsets[6], tmp);
436         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
437         _mm_store_ss(ptr + offsets[7], tmp);
438         tmp = val3;
439         _mm_store_ss(ptr + offsets[8], tmp);
440         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
441         _mm_store_ss(ptr + offsets[9], tmp);
442         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
443         _mm_store_ss(ptr + offsets[10], tmp);
444         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
445         _mm_store_ss(ptr + offsets[11], tmp);
446         tmp = val4;
447         _mm_store_ss(ptr + offsets[12], tmp);
448         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
449         _mm_store_ss(ptr + offsets[13], tmp);
450         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
451         _mm_store_ss(ptr + offsets[14], tmp);
452         tmp = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(0,3,2,1));
453         _mm_store_ss(ptr + offsets[15], tmp);
454    }
455 #endif
456 
457 private:
458     __m128 val1;
459     __m128 val2;
460     __m128 val3;
461     __m128 val4;
462 };
463 
464 inline
operator <<(float * data,const short_vec<float,16> & vec)465 void operator<<(float *data, const short_vec<float, 16>& vec)
466 {
467     vec.store(data);
468 }
469 
470 template<>
471 class sqrt_reference<float, 16>
472 {
473 public:
474     template<typename OTHER_CARGO, int OTHER_ARITY>
475     friend class short_vec;
476 
sqrt_reference(const short_vec<float,16> & vec)477     sqrt_reference(const short_vec<float, 16>& vec) :
478         vec(vec)
479     {}
480 
481 private:
482     short_vec<float, 16> vec;
483 };
484 
485 #ifdef __ICC
486 #pragma warning pop
487 #endif
488 
489 inline
short_vec(const sqrt_reference<float,16> & other)490 short_vec<float, 16>::short_vec(const sqrt_reference<float, 16>& other) :
491     val1(_mm_sqrt_ps(other.vec.val1)),
492     val2(_mm_sqrt_ps(other.vec.val2)),
493     val3(_mm_sqrt_ps(other.vec.val3)),
494     val4(_mm_sqrt_ps(other.vec.val4))
495 {}
496 
497 inline
operator /=(const sqrt_reference<float,16> & other)498 void short_vec<float, 16>::operator/=(const sqrt_reference<float, 16>& other)
499 {
500     val1 = _mm_mul_ps(val1, _mm_rsqrt_ps(other.vec.val1));
501     val2 = _mm_mul_ps(val2, _mm_rsqrt_ps(other.vec.val2));
502     val3 = _mm_mul_ps(val3, _mm_rsqrt_ps(other.vec.val3));
503     val4 = _mm_mul_ps(val4, _mm_rsqrt_ps(other.vec.val4));
504 }
505 
506 inline
operator /(const sqrt_reference<float,16> & other) const507 short_vec<float, 16> short_vec<float, 16>::operator/(const sqrt_reference<float, 16>& other) const
508 {
509     return short_vec<float, 16>(
510         _mm_mul_ps(val1, _mm_rsqrt_ps(other.vec.val1)),
511         _mm_mul_ps(val2, _mm_rsqrt_ps(other.vec.val2)),
512         _mm_mul_ps(val3, _mm_rsqrt_ps(other.vec.val3)),
513         _mm_mul_ps(val4, _mm_rsqrt_ps(other.vec.val4)));
514 }
515 
516 inline
sqrt(const short_vec<float,16> & vec)517 sqrt_reference<float, 16> sqrt(const short_vec<float, 16>& vec)
518 {
519     return sqrt_reference<float, 16>(vec);
520 }
521 
522 template<typename _CharT, typename _Traits>
523 std::basic_ostream<_CharT, _Traits>&
operator <<(std::basic_ostream<_CharT,_Traits> & __os,const short_vec<float,16> & vec)524 operator<<(std::basic_ostream<_CharT, _Traits>& __os,
525            const short_vec<float, 16>& vec)
526 {
527     const float *data1 = reinterpret_cast<const float *>(&vec.val1);
528     const float *data2 = reinterpret_cast<const float *>(&vec.val2);
529     const float *data3 = reinterpret_cast<const float *>(&vec.val3);
530     const float *data4 = reinterpret_cast<const float *>(&vec.val4);
531     __os << "["
532          << data1[0] << ", " << data1[1]  << ", " << data1[2] << ", " << data1[3] << ", "
533          << data2[0] << ", " << data2[1]  << ", " << data2[2] << ", " << data2[3] << ", "
534          << data3[0] << ", " << data3[1]  << ", " << data3[2] << ", " << data3[3] << ", "
535          << data4[0] << ", " << data4[1]  << ", " << data4[2] << ", " << data4[3] << "]";
536     return __os;
537 }
538 
539 }
540 
541 #endif
542 
543 #endif
544