1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #ifndef X86_USABILITY_H
16 #define X86_USABILITY_H
17
float2int8(float v)18 static inline signed char float2int8(float v)
19 {
20 int int32 = round(v);
21 if (int32 > 127) return 127;
22 if (int32 < -127) return -127;
23 return (signed char)int32;
24 }
25
26 #if __SSE2__
27 #include <emmintrin.h>
28
_mm_reduce_add_ps(__m128 x128)29 static inline float _mm_reduce_add_ps(__m128 x128)
30 {
31 const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
32 const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
33 return _mm_cvtss_f32(x32);
34 }
35
float2int8_sse(const __m128 & _v0,const __m128 & _v1)36 static inline int64_t float2int8_sse(const __m128& _v0, const __m128& _v1)
37 {
38 float v0[4];
39 float v1[4];
40 _mm_storeu_ps(v0, _v0);
41 _mm_storeu_ps(v1, _v1);
42
43 int v0_i[4];
44 int v1_i[4];
45 v0_i[0] = round(v0[0]);
46 v0_i[1] = round(v0[1]);
47 v0_i[2] = round(v0[2]);
48 v0_i[3] = round(v0[3]);
49 v1_i[0] = round(v1[0]);
50 v1_i[1] = round(v1[1]);
51 v1_i[2] = round(v1[2]);
52 v1_i[3] = round(v1[3]);
53
54 __m128i _v0_i = _mm_loadu_si128((const __m128i*)v0_i);
55 __m128i _v1_i = _mm_loadu_si128((const __m128i*)v1_i);
56
57 __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i);
58
59 _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127));
60 _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127));
61
62 __m128i _v8 = _mm_packs_epi16(_v01_s16, _v01_s16);
63
64 // TODO use _mm_cvtsi128_si64 on 64bit target
65 int64_t v8[2];
66 _mm_storeu_si128((__m128i*)v8, _v8);
67 return v8[0];
68 }
69
float2int8_sse(const __m128 & _v0,const __m128 & _v1,const __m128 & _v2,const __m128 & _v3)70 static inline __m128i float2int8_sse(const __m128& _v0, const __m128& _v1, const __m128& _v2, const __m128& _v3)
71 {
72 float v0[4];
73 float v1[4];
74 float v2[4];
75 float v3[4];
76 _mm_storeu_ps(v0, _v0);
77 _mm_storeu_ps(v1, _v1);
78 _mm_storeu_ps(v2, _v2);
79 _mm_storeu_ps(v3, _v3);
80
81 int v0_i[4];
82 int v1_i[4];
83 int v2_i[4];
84 int v3_i[4];
85 v0_i[0] = round(v0[0]);
86 v0_i[1] = round(v0[1]);
87 v0_i[2] = round(v0[2]);
88 v0_i[3] = round(v0[3]);
89 v1_i[0] = round(v1[0]);
90 v1_i[1] = round(v1[1]);
91 v1_i[2] = round(v1[2]);
92 v1_i[3] = round(v1[3]);
93 v2_i[0] = round(v2[0]);
94 v2_i[1] = round(v2[1]);
95 v2_i[2] = round(v2[2]);
96 v2_i[3] = round(v2[3]);
97 v3_i[0] = round(v3[0]);
98 v3_i[1] = round(v3[1]);
99 v3_i[2] = round(v3[2]);
100 v3_i[3] = round(v3[3]);
101
102 __m128i _v0_i = _mm_loadu_si128((const __m128i*)v0_i);
103 __m128i _v1_i = _mm_loadu_si128((const __m128i*)v1_i);
104 __m128i _v2_i = _mm_loadu_si128((const __m128i*)v2_i);
105 __m128i _v3_i = _mm_loadu_si128((const __m128i*)v3_i);
106
107 __m128i _v01_s16 = _mm_packs_epi32(_v0_i, _v1_i);
108 __m128i _v23_s16 = _mm_packs_epi32(_v2_i, _v3_i);
109
110 _v01_s16 = _mm_min_epi16(_v01_s16, _mm_set1_epi16(127));
111 _v23_s16 = _mm_min_epi16(_v23_s16, _mm_set1_epi16(127));
112 _v01_s16 = _mm_max_epi16(_v01_s16, _mm_set1_epi16(-127));
113 _v23_s16 = _mm_max_epi16(_v23_s16, _mm_set1_epi16(-127));
114
115 __m128i _v8 = _mm_packs_epi16(_v01_s16, _v23_s16);
116
117 return _v8;
118 }
119
120 #if __AVX__
121 #include <immintrin.h>
122
loadfp16(const unsigned short * ptr)123 static inline __m256 loadfp16(const unsigned short* ptr)
124 {
125 return _mm256_cvtph_ps(_mm_lddqu_si128((__m128i*)(ptr)));
126 }
_mm256_fmadd_1_ps(__m256 a,__m256 b,float c)127 static inline __m256 _mm256_fmadd_1_ps(__m256 a, __m256 b, float c)
128 {
129 return _mm256_fmadd_ps(b, _mm256_set1_ps(c), a);
130 }
131
_mm256_fmrsub_1_ps(__m256 a,__m256 b,float c)132 static inline __m256 _mm256_fmrsub_1_ps(__m256 a, __m256 b, float c)
133 {
134 return _mm256_sub_ps(a, _mm256_mul_ps(b, _mm256_set1_ps(c)));
135 }
136 // From: https://stackoverflow.com/a/25627536
transpose8_ps(__m256 & row0,__m256 & row1,__m256 & row2,__m256 & row3,__m256 & row4,__m256 & row5,__m256 & row6,__m256 & row7)137 static inline void transpose8_ps(__m256& row0, __m256& row1, __m256& row2, __m256& row3, __m256& row4, __m256& row5, __m256& row6, __m256& row7)
138 {
139 __m256 __t0, __t1, __t2, __t3, __t4, __t5, __t6, __t7;
140 __m256 __tt0, __tt1, __tt2, __tt3, __tt4, __tt5, __tt6, __tt7;
141 __t0 = _mm256_unpacklo_ps(row0, row1);
142 __t1 = _mm256_unpackhi_ps(row0, row1);
143 __t2 = _mm256_unpacklo_ps(row2, row3);
144 __t3 = _mm256_unpackhi_ps(row2, row3);
145 __t4 = _mm256_unpacklo_ps(row4, row5);
146 __t5 = _mm256_unpackhi_ps(row4, row5);
147 __t6 = _mm256_unpacklo_ps(row6, row7);
148 __t7 = _mm256_unpackhi_ps(row6, row7);
149 __tt0 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(1, 0, 1, 0));
150 __tt1 = _mm256_shuffle_ps(__t0, __t2, _MM_SHUFFLE(3, 2, 3, 2));
151 __tt2 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(1, 0, 1, 0));
152 __tt3 = _mm256_shuffle_ps(__t1, __t3, _MM_SHUFFLE(3, 2, 3, 2));
153 __tt4 = _mm256_shuffle_ps(__t4, __t6, _MM_SHUFFLE(1, 0, 1, 0));
154 __tt5 = _mm256_shuffle_ps(__t4, __t6, _MM_SHUFFLE(3, 2, 3, 2));
155 __tt6 = _mm256_shuffle_ps(__t5, __t7, _MM_SHUFFLE(1, 0, 1, 0));
156 __tt7 = _mm256_shuffle_ps(__t5, __t7, _MM_SHUFFLE(3, 2, 3, 2));
157 row0 = _mm256_permute2f128_ps(__tt0, __tt4, 0x20);
158 row1 = _mm256_permute2f128_ps(__tt1, __tt5, 0x20);
159 row2 = _mm256_permute2f128_ps(__tt2, __tt6, 0x20);
160 row3 = _mm256_permute2f128_ps(__tt3, __tt7, 0x20);
161 row4 = _mm256_permute2f128_ps(__tt0, __tt4, 0x31);
162 row5 = _mm256_permute2f128_ps(__tt1, __tt5, 0x31);
163 row6 = _mm256_permute2f128_ps(__tt2, __tt6, 0x31);
164 row7 = _mm256_permute2f128_ps(__tt3, __tt7, 0x31);
165 }
166
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2,__m256 & v3,__m256 & v4,__m256 & v5,__m256 & v6,__m256 & v7)167 static inline __m256 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3, __m256& v4, __m256& v5, __m256& v6, __m256& v7)
168 {
169 const __m256 s01 = _mm256_hadd_ps(v0, v1);
170 const __m256 s23 = _mm256_hadd_ps(v2, v3);
171 const __m256 s45 = _mm256_hadd_ps(v4, v5);
172 const __m256 s67 = _mm256_hadd_ps(v6, v7);
173 const __m256 s0123 = _mm256_hadd_ps(s01, s23);
174 const __m256 s4556 = _mm256_hadd_ps(s45, s67);
175
176 // inter-lane shuffle
177 const __m256 vb0 = _mm256_blend_ps(s0123, s4556, 0xF0);
178 const __m256 vb1 = _mm256_permute2f128_ps(s0123, s4556, 0x21);
179
180 return _mm256_add_ps(vb0, vb1);
181 }
182
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2,__m256 & v3)183 static inline __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2, __m256& v3)
184 {
185 const __m256 s01 = _mm256_hadd_ps(v0, v1);
186 const __m256 s23 = _mm256_hadd_ps(v2, v3);
187 const __m256 s0123 = _mm256_hadd_ps(s01, s23);
188
189 return _mm_add_ps(_mm256_extractf128_ps(s0123, 1),
190 _mm256_castps256_ps128(s0123));
191 }
192
HorizontalSums(__m256 & v0,__m256 & v1,__m256 & v2)193 static inline __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2)
194 {
195 const __m256 v3 = _mm256_set1_ps(0.0f);
196 const __m256 s01 = _mm256_hadd_ps(v0, v1);
197 const __m256 s23 = _mm256_hadd_ps(v2, v3);
198 const __m256 s0123 = _mm256_hadd_ps(s01, s23);
199
200 return _mm_add_ps(_mm256_extractf128_ps(s0123, 1),
201 _mm256_castps256_ps128(s0123));
202 }
203
_mm256_reduce_add_ps(__m256 x)204 static inline float _mm256_reduce_add_ps(__m256 x)
205 {
206 /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
207 const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
208 /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
209 const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
210 /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
211 const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
212 /* Conversion to float is a no-op on x86-64 */
213 return _mm_cvtss_f32(x32);
214 }
215
float2int8_avx(const __m256 & _v0)216 static inline int64_t float2int8_avx(const __m256& _v0)
217 {
218 __m256i _v0_i = _mm256_cvtps_epi32(_mm256_round_ps(_v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
219
220 __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v0_i);
221 _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8);
222
223 __m128i _v01_s16low = _mm256_extractf128_si256(_v01_s16, 0);
224
225 _v01_s16low = _mm_min_epi16(_v01_s16low, _mm_set1_epi16(127));
226 _v01_s16low = _mm_max_epi16(_v01_s16low, _mm_set1_epi16(-127));
227
228 __m128i _v8 = _mm_packs_epi16(_v01_s16low, _v01_s16low);
229
230 // TODO use _mm_cvtsi128_si64 on 64bit target
231 int64_t v8[2];
232 _mm_storeu_si128((__m128i*)v8, _v8);
233 return v8[0];
234 }
235
float2int8_avx(const __m256 & _v0,const __m256 & _v1)236 static inline __m128i float2int8_avx(const __m256& _v0, const __m256& _v1)
237 {
238 __m256i _v0_i = _mm256_cvtps_epi32(_mm256_round_ps(_v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
239 __m256i _v1_i = _mm256_cvtps_epi32(_mm256_round_ps(_v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
240
241 __m256i _v01_s16 = _mm256_packs_epi32(_v0_i, _v1_i);
242 _v01_s16 = _mm256_permute4x64_epi64(_v01_s16, 0xd8);
243
244 _v01_s16 = _mm256_min_epi16(_v01_s16, _mm256_set1_epi16(127));
245 _v01_s16 = _mm256_max_epi16(_v01_s16, _mm256_set1_epi16(-127));
246
247 __m256i _v8 = _mm256_packs_epi16(_v01_s16, _v01_s16);
248 _v8 = _mm256_permute4x64_epi64(_v8, 0xd8);
249
250 return _mm256_extractf128_si256(_v8, 0);
251 }
252
_mm256_fmadd_ps4(__m256 & _sum,const __m256 & _w0,const __m256 & _w1,const __m256 & _w2,const __m256 & _w3,const __m256 & _v0,const __m256 & _v1,const __m256 & _v2,const __m256 & _v3)253 static inline void _mm256_fmadd_ps4(__m256& _sum,
254 const __m256& _w0, const __m256& _w1, const __m256& _w2, const __m256& _w3,
255 const __m256& _v0, const __m256& _v1, const __m256& _v2, const __m256& _v3)
256 {
257 __m256 _mul0 = _mm256_mul_ps(_w0, _v0);
258 __m256 _mul1 = _mm256_mul_ps(_w1, _v1);
259 __m256 _sum01 = _mm256_add_ps(_mul0, _mul1);
260 __m256 _mul2 = _mm256_mul_ps(_w2, _v2);
261 __m256 _mul3 = _mm256_mul_ps(_w3, _v3);
262 __m256 _sum23 = _mm256_add_ps(_mul2, _mul3);
263 __m256 _sum0123 = _mm256_add_ps(_sum01, _sum23);
264 _sum = _mm256_add_ps(_sum, _sum0123);
265 }
266
_mm256_fmadd_ps8(__m256 & _sum,const __m256 & _w0,const __m256 & _w1,const __m256 & _w2,const __m256 & _w3,const __m256 & _w4,const __m256 & _w5,const __m256 & _w6,const __m256 & _w7,const __m256 & _v0,const __m256 & _v1,const __m256 & _v2,const __m256 & _v3,const __m256 & _v4,const __m256 & _v5,const __m256 & _v6,const __m256 & _v7)267 static inline void _mm256_fmadd_ps8(__m256& _sum,
268 const __m256& _w0, const __m256& _w1, const __m256& _w2, const __m256& _w3, const __m256& _w4, const __m256& _w5, const __m256& _w6, const __m256& _w7,
269 const __m256& _v0, const __m256& _v1, const __m256& _v2, const __m256& _v3, const __m256& _v4, const __m256& _v5, const __m256& _v6, const __m256& _v7)
270 {
271 _mm256_fmadd_ps4(_sum, _w0, _w1, _w2, _w3, _v0, _v1, _v2, _v3);
272
273 _mm256_fmadd_ps4(_sum, _w4, _w5, _w6, _w7, _v4, _v5, _v6, _v7);
274 }
275 #endif // __AVX__
276 #endif // __SSE2__
277
278 #endif // X86_USABILITY_H
279