1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #ifndef X86_USABILITY_H
16 #define X86_USABILITY_H
17 
float2int8(float v)18 static inline signed char float2int8(float v)
19 {
20     int int32 = round(v);
21     if (int32 > 127) return 127;
22     if (int32 < -127) return -127;
23     return (signed char)int32;
24 }
25 
26 #if __SSE2__
27 #include <emmintrin.h>
28 
_mm_reduce_add_ps(__m128 x128)29 static inline float _mm_reduce_add_ps(__m128 x128)
30 {
31     const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
32     const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
33     return _mm_cvtss_f32(x32);
34 }
35 
float2int8_sse(const __m128 & _v0,const __m128 & _v1)36 static inline int64_t float2int8_sse(const __m128& _v0, const __m128& _v1)
37 {
38     float v0[4];
39     float v1[4];
40     _mm_storeu_ps(v0, _v0);
41     _mm_storeu_ps(v1, _v1);
42 
43     int v0_i[4];
44     int v1_i[4];
45     v0_i[0] = round(v0[0]);
46     v0_i[1] = round(v0[1]);
47     v0_i[2] = round(v0[2]);
48     v0_i[3] = round(v0[3]);
49     v1_i[0] = round(v1[0]);
50     v1_i[1] = round(v1[1]);
51     v1_i[2] = round(v1[2]);
52     v1_i[3] = round(v1[3]);
53 
54     __m128i _v0_i = _mm_loadu_si128((const __m128i*)v0_i);
55     __m128i _v1_i = _mm_loadu_si128((const __m128i*)v1_i);
56 
57     __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i);
58 
59     _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127));
60     _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127));
61 
62     __m128i _v8 = _mm_packs_epi16(_v01_s16, _v01_s16);
63 
64     // TODO use _mm_cvtsi128_si64 on 64bit target
65     int64_t v8[2];
66     _mm_storeu_si128((__m128i*)v8, _v8);
67     return v8[0];
68 }
69 
float2int8_sse(const __m128 & _v0,const __m128 & _v1,const __m128 & _v2,const __m128 & _v3)70 static inline __m128i float2int8_sse(const __m128& _v0, const __m128& _v1, const __m128& _v2, const __m128& _v3)
71 {
72     float v0[4];
73     float v1[4];
74     float v2[4];
75     float v3[4];
76     _mm_storeu_ps(v0, _v0);
77     _mm_storeu_ps(v1, _v1);
78     _mm_storeu_ps(v2, _v2);
79     _mm_storeu_ps(v3, _v3);
80 
81     int v0_i[4];
82     int v1_i[4];
83     int v2_i[4];
84     int v3_i[4];
85     v0_i[0] = round(v0[0]);
86     v0_i[1] = round(v0[1]);
87     v0_i[2] = round(v0[2]);
88     v0_i[3] = round(v0[3]);
89     v1_i[0] = round(v1[0]);
90     v1_i[1] = round(v1[1]);
91     v1_i[2] = round(v1[2]);
92     v1_i[3] = round(v1[3]);
93     v2_i[0] = round(v2[0]);
94     v2_i[1] = round(v2[1]);
95     v2_i[2] = round(v2[2]);
96     v2_i[3] = round(v2[3]);
97     v3_i[0] = round(v3[0]);
98     v3_i[1] = round(v3[1]);
99     v3_i[2] = round(v3[2]);
100     v3_i[3] = round(v3[3]);
101 
102     __m128i _v0_i = _mm_loadu_si128((const __m128i*)v0_i);
103     __m128i _v1_i = _mm_loadu_si128((const __m128i*)v1_i);
104     __m128i _v2_i = _mm_loadu_si128((const __m128i*)v2_i);
105     __m128i _v3_i = _mm_loadu_si128((const __m128i*)v3_i);
106 
107     __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i);
108     __m128i _v23_s16 = _mm_packs_epi32(_v2_i, _v3_i);
109 
110     _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127));
111     _v23_s16 = _mm_min_epi16(_v23_s16, _mm_set1_epi16(127));
112     _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127));
113     _v23_s16 = _mm_max_epi16(_v23_s16, _mm_set1_epi16(-127));
114 
115     __m128i _v8 = _mm_packs_epi16(_v01_s16, _v23_s16);
116 
117     return _v8;
118 }
119 
120 #if __AVX__
121 #include <immintrin.h>
122 
loadfp16(const unsigned short * ptr)123 static inline __m256 loadfp16(const unsigned short* ptr)
124 {
125     return _mm256_cvtph_ps(_mm_lddqu_si128((__m128i*)(ptr)));
126 }
_mm256_fmadd_1_ps(__m256 a,__m256 b,float c)127 static inline __m256 _mm256_fmadd_1_ps(__m256 a, __m256 b, float c)
128 {
129     return _mm256_fmadd_ps(b, _mm256_set1_ps(c), a);
130 }
131 
_mm256_fmrsub_1_ps(__m256 a,__m256 b,float c)132 static inline __m256 _mm256_fmrsub_1_ps(__m256 a, __m256 b, float c)
133 {
134     return _mm256_sub_ps(a, _mm256_mul_ps(b, _mm256_set1_ps(c)));
135 }
136 // From: https://stackoverflow.com/a/25627536
transpose8_ps(__m256 & row0,__m256 & row1,__m256 & row2,__m256 & row3,__m256 & row4,__m256 & row5,__m256 & row6,__m256 & row7)137 static inline void transpose8_ps(__m256& row0, __m256& row1, __m256& row2, __m256& row3, __m256& row4, __m256& row5, __m256& row6, __m256& row7)
138 {
139     __m256 __t0, __t1, __t2, __t3, __t4, __t5, __t6, __t7;
140     __m256 __tt0, __tt1, __tt2, __tt3, __tt4, __tt5, __tt6, __tt7;
141     __t0 = _mm256_unpacklo_ps(row0, row1);
142     __t1 = _mm256_unpackhi_ps(row0, row1);
143     __t2 = _mm256_unpacklo_ps(row2, row3);
144     __t3 = _mm256_unpackhi_ps(row2, row3);
145     __t4 = _mm256_unpacklo_ps(row4, row5);
146     __t5 = _mm256_unpackhi_ps(row4, row5);
147     __t6 = _mm256_unpacklo_ps(row6, row7);
148     __t7 = _mm256_unpackhi_ps(row6, row7);
149     __tt0 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(1, 0, 1, 0));
150     __tt1 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(3, 2, 3, 2));
151     __tt2 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(1, 0, 1, 0));
152     __tt3 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(3, 2, 3, 2));
153     __tt4 = _mm256_shuffle_ps(__t4, __t6, _MM_SHUFFLE(1, 0, 1, 0));
154     __tt5 = _mm256_shuffle_ps(__t4, __t6, _MM_SHUFFLE(3, 2, 3, 2));
155     __tt6 = _mm256_shuffle_ps(__t5, __t7, _MM_SHUFFLE(1, 0, 1, 0));
156     __tt7 = _mm256_shuffle_ps(__t5, __t7, _MM_SHUFFLE(3, 2, 3, 2));
157     row0 = _mm256_permute2f128_ps(__tt0, __tt4, 0x20);
158     row1 = _mm256_permute2f128_ps(__tt1, __tt5, 0x20);
159     row2 = _mm256_permute2f128_ps(__tt2, __tt6, 0x20);
160     row3 = _mm256_permute2f128_ps(__tt3, __tt7, 0x20);
161     row4 = _mm256_permute2f128_ps(__tt0, __tt4, 0x31);
162     row5 = _mm256_permute2f128_ps(__tt1, __tt5, 0x31);
163     row6 = _mm256_permute2f128_ps(__tt2, __tt6, 0x31);
164     row7 = _mm256_permute2f128_ps(__tt3, __tt7, 0x31);
165 }
166 
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2,__m256 & v3,__m256 & v4,__m256 & v5,__m256 & v6,__m256 & v7)167 static inline __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7)
168 {
169     const __m256 s01 = _mm256_hadd_ps(v0, v1);
170     const __m256 s23 = _mm256_hadd_ps(v2, v3);
171     const __m256 s45 = _mm256_hadd_ps(v4, v5);
172     const __m256 s67 = _mm256_hadd_ps(v6, v7);
173     const __m256 s0123 = _mm256_hadd_ps(s01, s23);
174     const __m256 s4556 = _mm256_hadd_ps(s45, s67);
175 
176     // inter-lane shuffle
177     const __m256 vb0 = _mm256_blend_ps(s0123, s4556, 0xF0);
178     const __m256 vb1 = _mm256_permute2f128_ps(s0123, s4556, 0x21);
179 
180     return _mm256_add_ps(vb0, vb1);
181 }
182 
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2,__m256 & v3)183 static inline __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3)
184 {
185     const __m256 s01 = _mm256_hadd_ps(v0, v1);
186     const __m256 s23 = _mm256_hadd_ps(v2, v3);
187     const __m256 s0123 = _mm256_hadd_ps(s01, s23);
188 
189     return _mm_add_ps(_mm256_extractf128_ps(s0123, 1),
190                       _mm256_castps256_ps128(s0123));
191 }
192 
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2)193 static inline __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2)
194 {
195     const __m256 v3 = _mm256_set1_ps(0.0f);
196     const __m256 s01 = _mm256_hadd_ps(v0, v1);
197     const __m256 s23 = _mm256_hadd_ps(v2, v3);
198     const __m256 s0123 = _mm256_hadd_ps(s01, s23);
199 
200     return _mm_add_ps(_mm256_extractf128_ps(s0123, 1),
201                       _mm256_castps256_ps128(s0123));
202 }
203 
_mm256_reduce_add_ps(__m256 x)204 static inline float _mm256_reduce_add_ps(__m256 x)
205 {
206     /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
207     const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
208     /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
209     const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
210     /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
211     const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
212     /* Conversion to float is a no-op on x86-64 */
213     return _mm_cvtss_f32(x32);
214 }
215 
float2int8_avx(const __m256 & _v0)216 static inline int64_t float2int8_avx(const __m256& _v0)
217 {
218     __m256i _v0_i = _mm256_cvtps_epi32(_mm256_round_ps(_v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
219 
220     __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v0_i);
221     _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8);
222 
223     __m128i _v01_s16low = _mm256_extractf128_si256(_v01_s16, 0);
224 
225     _v01_s16low = _mm_min_epi16(_v01_s16low, _mm_set1_epi16(127));
226     _v01_s16low = _mm_max_epi16(_v01_s16low, _mm_set1_epi16(-127));
227 
228     __m128i _v8 = _mm_packs_epi16(_v01_s16low, _v01_s16low);
229 
230     // TODO use _mm_cvtsi128_si64 on 64bit target
231     int64_t v8[2];
232     _mm_storeu_si128((__m128i*)v8, _v8);
233     return v8[0];
234 }
235 
float2int8_avx(const __m256 & _v0,const __m256 & _v1)236 static inline __m128i float2int8_avx(const __m256& _v0, const __m256& _v1)
237 {
238     __m256i _v0_i = _mm256_cvtps_epi32(_mm256_round_ps(_v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
239     __m256i _v1_i = _mm256_cvtps_epi32(_mm256_round_ps(_v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
240 
241     __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v1_i);
242     _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8);
243 
244     _v01_s16 = _mm256_min_epi16(_v01_s16, _mm256_set1_epi16(127));
245     _v01_s16 = _mm256_max_epi16(_v01_s16, _mm256_set1_epi16(-127));
246 
247     __m256i _v8 = _mm256_packs_epi16(_v01_s16, _v01_s16);
248     _v8 = _mm256_permute4x64_epi64(_v8, 0xd8);
249 
250     return _mm256_extractf128_si256(_v8, 0);
251 }
252 
_mm256_fmadd_ps4(__m256 & _sum,const __m256 & _w0,const __m256 & _w1,const __m256 & _w2,const __m256 & _w3,const __m256 & _v0,const __m256 & _v1,const __m256 & _v2,const __m256 & _v3)253 static inline void _mm256_fmadd_ps4(__m256& _sum,
254                                     const __m256& _w0, const __m256& _w1, const __m256& _w2, const __m256& _w3,
255                                     const __m256& _v0, const __m256& _v1, const __m256& _v2, const __m256& _v3)
256 {
257     __m256 _mul0 = _mm256_mul_ps(_w0, _v0);
258     __m256 _mul1 = _mm256_mul_ps(_w1, _v1);
259     __m256 _sum01 = _mm256_add_ps(_mul0, _mul1);
260     __m256 _mul2 = _mm256_mul_ps(_w2, _v2);
261     __m256 _mul3 = _mm256_mul_ps(_w3, _v3);
262     __m256 _sum23 = _mm256_add_ps(_mul2, _mul3);
263     __m256 _sum0123 = _mm256_add_ps(_sum01, _sum23);
264     _sum = _mm256_add_ps(_sum, _sum0123);
265 }
266 
_mm256_fmadd_ps8(__m256 & _sum,const __m256 & _w0,const __m256 & _w1,const __m256 & _w2,const __m256 & _w3,const __m256 & _w4,const __m256 & _w5,const __m256 & _w6,const __m256 & _w7,const __m256 & _v0,const __m256 & _v1,const __m256 & _v2,const __m256 & _v3,const __m256 & _v4,const __m256 & _v5,const __m256 & _v6,const __m256 & _v7)267 static inline void _mm256_fmadd_ps8(__m256& _sum,
268                                     const __m256& _w0, const __m256& _w1, const __m256& _w2, const __m256& _w3, const __m256& _w4, const __m256& _w5, const __m256& _w6, const __m256& _w7,
269                                     const __m256& _v0, const __m256& _v1, const __m256& _v2, const __m256& _v3, const __m256& _v4, const __m256& _v5, const __m256& _v6, const __m256& _v7)
270 {
271     _mm256_fmadd_ps4(_sum, _w0, _w1, _w2, _w3, _v0, _v1, _v2, _v3);
272 
273     _mm256_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7);
274 }
275 #endif // __AVX__
276 #endif // __SSE2__
277 
278 #endif // X86_USABILITY_H
279