1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #ifndef X86_USABILITY_H
16 #define X86_USABILITY_H
17 
float2int8(float v)18 static NCNN_FORCEINLINE signed char float2int8(float v)
19 {
20     int int32 = (int)round(v);
21     if (int32 > 127) return 127;
22     if (int32 < -127) return -127;
23     return (signed char)int32;
24 }
25 
26 #if __SSE2__
27 #include <emmintrin.h>
28 
_mm_reduce_add_ps(__m128 x128)29 static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128)
30 {
31     const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
32     const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
33     return _mm_cvtss_f32(x32);
34 }
35 
float2int8_sse(const __m128 & _v0,const __m128 & _v1)36 static NCNN_FORCEINLINE int64_t float2int8_sse(const __m128& _v0, const __m128& _v1)
37 {
38     // _MM_ROUND_NEAREST round to even
39     // simulate round to nearest via +/-0.5 with round to zero
40     __m128 _p5 = _mm_set1_ps(0.5f);
41     __m128 _signmask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31));
42     __m128 _sign0 = _mm_and_ps(_v0, _signmask);
43     __m128 _sign1 = _mm_and_ps(_v1, _signmask);
44     __m128 _v0_p5 = _mm_or_ps(_p5, _sign0);
45     __m128 _v1_p5 = _mm_or_ps(_p5, _sign1);
46     __m128 _v0_adj = _mm_add_ps(_v0, _v0_p5);
47     __m128 _v1_adj = _mm_add_ps(_v1, _v1_p5);
48     __m128i _v0_i = _mm_cvttps_epi32(_v0_adj);
49     __m128i _v1_i = _mm_cvttps_epi32(_v1_adj);
50 
51     __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i);
52 
53     _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127));
54     _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127));
55 
56     __m128i _v8 = _mm_packs_epi16(_v01_s16, _v01_s16);
57 
58     // TODO use _mm_cvtsi128_si64 on 64bit target
59     int64_t v8[2];
60     _mm_storeu_si128((__m128i*)v8, _v8);
61     return v8[0];
62 }
63 
float2int8_sse(const __m128 & _v0,const __m128 & _v1,const __m128 & _v2,const __m128 & _v3)64 static NCNN_FORCEINLINE __m128i float2int8_sse(const __m128& _v0, const __m128& _v1, const __m128& _v2, const __m128& _v3)
65 {
66     // _MM_ROUND_NEAREST round to even
67     // simulate round to nearest via +/-0.5 with round to zero
68     __m128 _p5 = _mm_set1_ps(0.5f);
69     __m128 _signmask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31));
70     __m128 _sign0 = _mm_and_ps(_v0, _signmask);
71     __m128 _sign1 = _mm_and_ps(_v1, _signmask);
72     __m128 _sign2 = _mm_and_ps(_v2, _signmask);
73     __m128 _sign3 = _mm_and_ps(_v3, _signmask);
74     __m128 _v0_p5 = _mm_or_ps(_p5, _sign0);
75     __m128 _v1_p5 = _mm_or_ps(_p5, _sign1);
76     __m128 _v2_p5 = _mm_or_ps(_p5, _sign2);
77     __m128 _v3_p5 = _mm_or_ps(_p5, _sign3);
78     __m128 _v0_adj = _mm_add_ps(_v0, _v0_p5);
79     __m128 _v1_adj = _mm_add_ps(_v1, _v1_p5);
80     __m128 _v2_adj = _mm_add_ps(_v2, _v2_p5);
81     __m128 _v3_adj = _mm_add_ps(_v3, _v3_p5);
82     __m128i _v0_i = _mm_cvttps_epi32(_v0_adj);
83     __m128i _v1_i = _mm_cvttps_epi32(_v1_adj);
84     __m128i _v2_i = _mm_cvttps_epi32(_v2_adj);
85     __m128i _v3_i = _mm_cvttps_epi32(_v3_adj);
86 
87     __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i);
88     __m128i _v23_s16 = _mm_packs_epi32(_v2_i, _v3_i);
89 
90     _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127));
91     _v23_s16 = _mm_min_epi16(_v23_s16, _mm_set1_epi16(127));
92     _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127));
93     _v23_s16 = _mm_max_epi16(_v23_s16, _mm_set1_epi16(-127));
94 
95     __m128i _v8 = _mm_packs_epi16(_v01_s16, _v23_s16);
96 
97     return _v8;
98 }
99 #if __SSE2__
100 #ifndef __AVX2__
101 
_mm_comp_fmadd_ps(__m128 _a,const __m128 _b,const __m128 _c)102 static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(__m128 _a, const __m128 _b, const __m128 _c)
103 {
104     return _mm_add_ps(_mm_mul_ps(_a, _b), _c);
105 }
106 #endif
107 #endif
108 #if __AVX__
109 #include <immintrin.h>
110 #ifndef __AVX2__
_mm256_comp_fmadd_ps(__m256 _a,const __m256 _b,const __m256 _c)111 static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(__m256 _a, const __m256 _b, const __m256 _c)
112 {
113     return _mm256_add_ps(_mm256_mul_ps(_a, _b), _c);
114 }
115 #ifndef __SSE2__
_mm_comp_fmadd_ps(__m128 _a,const __m128 _b,const __m128 _c)116 static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(__m128 _a, const __m128 _b, const __m128 _c)
117 {
118     return _mm_add_ps(_mm_mul_ps(_a, _b), _c);
119 }
120 #endif
121 #else
_mm_comp_fmadd_ps(__m128 _a,const __m128 _b,const __m128 _c)122 static NCNN_FORCEINLINE __m128 _mm_comp_fmadd_ps(__m128 _a, const __m128 _b, const __m128 _c)
123 {
124     return _mm_fmadd_ps(_a, _b, _c);
125 }
_mm256_comp_fmadd_ps(__m256 _a,const __m256 _b,const __m256 _c)126 static NCNN_FORCEINLINE __m256 _mm256_comp_fmadd_ps(__m256 _a, const __m256 _b, const __m256 _c)
127 {
128     return _mm256_fmadd_ps(_a, _b, _c);
129 }
130 #endif
131 #if __AVX2__
132 
loadfp16(const unsigned short * ptr)133 static NCNN_FORCEINLINE __m256 loadfp16(const unsigned short* ptr)
134 {
135     return _mm256_cvtph_ps(_mm_lddqu_si128((__m128i*)(ptr)));
136 }
137 #endif
_mm256_fmadd_1_ps(__m256 a,__m256 b,float c)138 static NCNN_FORCEINLINE __m256 _mm256_fmadd_1_ps(__m256 a, __m256 b, float c)
139 {
140     return _mm256_comp_fmadd_ps(b, _mm256_set1_ps(c), a);
141 }
142 
_mm256_fmrsub_1_ps(__m256 a,__m256 b,float c)143 static NCNN_FORCEINLINE __m256 _mm256_fmrsub_1_ps(__m256 a, __m256 b, float c)
144 {
145     return _mm256_sub_ps(a, _mm256_mul_ps(b, _mm256_set1_ps(c)));
146 }
147 // From: https://stackoverflow.com/a/25627536
transpose8_ps(__m256 & row0,__m256 & row1,__m256 & row2,__m256 & row3,__m256 & row4,__m256 & row5,__m256 & row6,__m256 & row7)148 static NCNN_FORCEINLINE void transpose8_ps(__m256& row0, __m256& row1, __m256& row2, __m256& row3, __m256& row4, __m256& row5, __m256& row6, __m256& row7)
149 {
150     __m256 __t0, __t1, __t2, __t3, __t4, __t5, __t6, __t7;
151     __m256 __tt0, __tt1, __tt2, __tt3, __tt4, __tt5, __tt6, __tt7;
152     __t0 = _mm256_unpacklo_ps(row0, row1);
153     __t1 = _mm256_unpackhi_ps(row0, row1);
154     __t2 = _mm256_unpacklo_ps(row2, row3);
155     __t3 = _mm256_unpackhi_ps(row2, row3);
156     __t4 = _mm256_unpacklo_ps(row4, row5);
157     __t5 = _mm256_unpackhi_ps(row4, row5);
158     __t6 = _mm256_unpacklo_ps(row6, row7);
159     __t7 = _mm256_unpackhi_ps(row6, row7);
160     __tt0 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(1, 0, 1, 0));
161     __tt1 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(3, 2, 3, 2));
162     __tt2 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(1, 0, 1, 0));
163     __tt3 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(3, 2, 3, 2));
164     __tt4 = _mm256_shuffle_ps(__t4, __t6, _MM_SHUFFLE(1, 0, 1, 0));
165     __tt5 = _mm256_shuffle_ps(__t4, __t6, _MM_SHUFFLE(3, 2, 3, 2));
166     __tt6 = _mm256_shuffle_ps(__t5, __t7, _MM_SHUFFLE(1, 0, 1, 0));
167     __tt7 = _mm256_shuffle_ps(__t5, __t7, _MM_SHUFFLE(3, 2, 3, 2));
168     row0 = _mm256_permute2f128_ps(__tt0, __tt4, 0x20);
169     row1 = _mm256_permute2f128_ps(__tt1, __tt5, 0x20);
170     row2 = _mm256_permute2f128_ps(__tt2, __tt6, 0x20);
171     row3 = _mm256_permute2f128_ps(__tt3, __tt7, 0x20);
172     row4 = _mm256_permute2f128_ps(__tt0, __tt4, 0x31);
173     row5 = _mm256_permute2f128_ps(__tt1, __tt5, 0x31);
174     row6 = _mm256_permute2f128_ps(__tt2, __tt6, 0x31);
175     row7 = _mm256_permute2f128_ps(__tt3, __tt7, 0x31);
176 }
177 
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2,__m256 & v3,__m256 & v4,__m256 & v5,__m256 & v6,__m256 & v7)178 static NCNN_FORCEINLINE __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7)
179 {
180     const __m256 s01 = _mm256_hadd_ps(v0, v1);
181     const __m256 s23 = _mm256_hadd_ps(v2, v3);
182     const __m256 s45 = _mm256_hadd_ps(v4, v5);
183     const __m256 s67 = _mm256_hadd_ps(v6, v7);
184     const __m256 s0123 = _mm256_hadd_ps(s01, s23);
185     const __m256 s4556 = _mm256_hadd_ps(s45, s67);
186 
187     // inter-lane shuffle
188     const __m256 vb0 = _mm256_blend_ps(s0123, s4556, 0xF0);
189     const __m256 vb1 = _mm256_permute2f128_ps(s0123, s4556, 0x21);
190 
191     return _mm256_add_ps(vb0, vb1);
192 }
193 
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2,__m256 & v3)194 static NCNN_FORCEINLINE __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3)
195 {
196     const __m256 s01 = _mm256_hadd_ps(v0, v1);
197     const __m256 s23 = _mm256_hadd_ps(v2, v3);
198     const __m256 s0123 = _mm256_hadd_ps(s01, s23);
199 
200     return _mm_add_ps(_mm256_extractf128_ps(s0123, 1),
201                       _mm256_castps256_ps128(s0123));
202 }
203 
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2)204 static NCNN_FORCEINLINE __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2)
205 {
206     const __m256 v3 = _mm256_set1_ps(0.0f);
207     const __m256 s01 = _mm256_hadd_ps(v0, v1);
208     const __m256 s23 = _mm256_hadd_ps(v2, v3);
209     const __m256 s0123 = _mm256_hadd_ps(s01, s23);
210 
211     return _mm_add_ps(_mm256_extractf128_ps(s0123, 1),
212                       _mm256_castps256_ps128(s0123));
213 }
214 
_mm256_reduce_add_ps(__m256 x)215 static NCNN_FORCEINLINE float _mm256_reduce_add_ps(__m256 x)
216 {
217     /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
218     const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
219     /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
220     const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
221     /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
222     const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
223     /* Conversion to float is a no-op on x86-64 */
224     return _mm_cvtss_f32(x32);
225 }
226 
float2int8_avx(const __m256 & _v0)227 static NCNN_FORCEINLINE int64_t float2int8_avx(const __m256& _v0)
228 {
229     // _MM_FROUND_TO_NEAREST_INT round to even
230     // simulate round to nearest via +/-0.5 with round to zero
231     __m256 _p5 = _mm256_set1_ps(0.5f);
232     __m256 _signmask = _mm256_castsi256_ps(_mm256_set1_epi32(1 << 31));
233     __m256 _sign = _mm256_and_ps(_v0, _signmask);
234     __m256 _v0_p5 = _mm256_or_ps(_p5, _sign);
235     __m256 _v0_adj = _mm256_add_ps(_v0, _v0_p5);
236     __m256i _v0_i = _mm256_cvttps_epi32(_v0_adj);
237 
238 #if __AVX2__
239     __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v0_i);
240     _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8);
241 
242     __m128i _v01_s16low = _mm256_extractf128_si256(_v01_s16, 0);
243 #else  // __AVX2__
244     __m128i _v0_i_low = _mm256_extractf128_si256(_v0_i, 0);
245     __m128i _v0_i_high = _mm256_extractf128_si256(_v0_i, 1);
246 
247     __m128i _v01_s16low = _mm_packs_epi32(_v0_i_low, _v0_i_high);
248 #endif // __AVX2__
249 
250     _v01_s16low = _mm_min_epi16(_v01_s16low, _mm_set1_epi16(127));
251     _v01_s16low = _mm_max_epi16(_v01_s16low, _mm_set1_epi16(-127));
252 
253     __m128i _v8 = _mm_packs_epi16(_v01_s16low, _v01_s16low);
254 
255     // TODO use _mm_cvtsi128_si64 on 64bit target
256     int64_t v8[2];
257     _mm_storeu_si128((__m128i*)v8, _v8);
258     return v8[0];
259 }
260 
float2int8_avx(const __m256 & _v0,const __m256 & _v1)261 static NCNN_FORCEINLINE __m128i float2int8_avx(const __m256& _v0, const __m256& _v1)
262 {
263     // _MM_FROUND_TO_NEAREST_INT round to even
264     // simulate round to nearest via +/-0.5 with round to zero
265     __m256 _p5 = _mm256_set1_ps(0.5f);
266     __m256 _signmask = _mm256_castsi256_ps(_mm256_set1_epi32(1 << 31));
267     __m256 _sign0 = _mm256_and_ps(_v0, _signmask);
268     __m256 _sign1 = _mm256_and_ps(_v1, _signmask);
269     __m256 _v0_p5 = _mm256_or_ps(_p5, _sign0);
270     __m256 _v1_p5 = _mm256_or_ps(_p5, _sign1);
271     __m256 _v0_adj = _mm256_add_ps(_v0, _v0_p5);
272     __m256 _v1_adj = _mm256_add_ps(_v1, _v1_p5);
273     __m256i _v0_i = _mm256_cvttps_epi32(_v0_adj);
274     __m256i _v1_i = _mm256_cvttps_epi32(_v1_adj);
275 
276 #if __AVX2__
277     __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v1_i);
278     _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8);
279 
280     _v01_s16 = _mm256_min_epi16(_v01_s16, _mm256_set1_epi16(127));
281     _v01_s16 = _mm256_max_epi16(_v01_s16, _mm256_set1_epi16(-127));
282 
283     __m256i _v8 = _mm256_packs_epi16(_v01_s16, _v01_s16);
284     _v8 = _mm256_permute4x64_epi64(_v8, 0xd8);
285 
286     return _mm256_extractf128_si256(_v8, 0);
287 #else  // __AVX2__
288     __m128i _v0_i_low = _mm256_extractf128_si256(_v0_i, 0);
289     __m128i _v0_i_high = _mm256_extractf128_si256(_v0_i, 1);
290     __m128i _v1_i_low = _mm256_extractf128_si256(_v1_i, 0);
291     __m128i _v1_i_high = _mm256_extractf128_si256(_v1_i, 1);
292 
293     __m128i _v01_s16low = _mm_packs_epi32(_v0_i_low, _v0_i_high);
294     __m128i _v01_s16high = _mm_packs_epi32(_v1_i_low, _v1_i_high);
295 
296     _v01_s16low = _mm_min_epi16(_v01_s16low, _mm_set1_epi16(127));
297     _v01_s16high = _mm_min_epi16(_v01_s16high, _mm_set1_epi16(127));
298     _v01_s16low = _mm_max_epi16(_v01_s16low, _mm_set1_epi16(-127));
299     _v01_s16high = _mm_max_epi16(_v01_s16high, _mm_set1_epi16(-127));
300 
301     __m128i _v8 = _mm_packs_epi16(_v01_s16low, _v01_s16high);
302     return _v8;
303 #endif // __AVX2__
304 }
305 
_mm256_comp_fmadd_ps4(__m256 & _sum,const __m256 & _w0,const __m256 & _w1,const __m256 & _w2,const __m256 & _w3,const __m256 & _v0,const __m256 & _v1,const __m256 & _v2,const __m256 & _v3)306 static NCNN_FORCEINLINE void _mm256_comp_fmadd_ps4(__m256& _sum,
307         const __m256& _w0, const __m256& _w1, const __m256& _w2, const __m256& _w3,
308         const __m256& _v0, const __m256& _v1, const __m256& _v2, const __m256& _v3)
309 {
310     __m256 _mul0 = _mm256_mul_ps(_w0, _v0);
311     __m256 _mul1 = _mm256_mul_ps(_w1, _v1);
312     __m256 _sum01 = _mm256_add_ps(_mul0, _mul1);
313     __m256 _mul2 = _mm256_mul_ps(_w2, _v2);
314     __m256 _mul3 = _mm256_mul_ps(_w3, _v3);
315     __m256 _sum23 = _mm256_add_ps(_mul2, _mul3);
316     __m256 _sum0123 = _mm256_add_ps(_sum01, _sum23);
317     _sum = _mm256_add_ps(_sum, _sum0123);
318 }
319 
_mm256_comp_fmadd_ps8(__m256 & _sum,const __m256 & _w0,const __m256 & _w1,const __m256 & _w2,const __m256 & _w3,const __m256 & _w4,const __m256 & _w5,const __m256 & _w6,const __m256 & _w7,const __m256 & _v0,const __m256 & _v1,const __m256 & _v2,const __m256 & _v3,const __m256 & _v4,const __m256 & _v5,const __m256 & _v6,const __m256 & _v7)320 static NCNN_FORCEINLINE void _mm256_comp_fmadd_ps8(__m256& _sum,
321         const __m256& _w0, const __m256& _w1, const __m256& _w2, const __m256& _w3, const __m256& _w4, const __m256& _w5, const __m256& _w6, const __m256& _w7,
322         const __m256& _v0, const __m256& _v1, const __m256& _v2, const __m256& _v3, const __m256& _v4, const __m256& _v5, const __m256& _v6, const __m256& _v7)
323 {
324     _mm256_comp_fmadd_ps4(_sum, _w0, _w1, _w2, _w3, _v0, _v1, _v2, _v3);
325 
326     _mm256_comp_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7);
327 }
328 
329 #endif // __AVX__
330 #endif // __SSE2__
331 
332 #endif // X86_USABILITY_H
333