1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "cast_x86.h"
16
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif // __AVX__
22 #endif // __SSE2__
23
24 #if __AVX__
25 #include <stdint.h>
26 typedef union m128i
27 {
28 __m128i vec;
29 uint16_t m128i_u16[8];
30 } m128;
31
32 typedef union m256i
33 {
34 __m256i vec;
35 uint32_t m256i_u32[8];
36 } m256;
bfloat2float_avx(__m128i v0)37 static inline __m256 bfloat2float_avx(__m128i v0)
38 {
39 __m128i zero = _mm_set1_epi32(0);
40 __m128i a = _mm_slli_epi32(_mm_unpacklo_epi16(v0, zero), 16);
41 __m128i b = _mm_slli_epi32(_mm_unpackhi_epi16(v0, zero), 16);
42 __m256i ab = _mm256_set1_epi32(0);
43 ab = _mm256_insertf128_si256(ab, a, 0); // insert in low 128-bit lane
44 ab = _mm256_insertf128_si256(ab, b, 1); // insert in high 128-bit lane
45 return _mm256_castsi256_ps(ab);
46 }
float2bfloat_avx(__m256 v0,__m256 v1)47 static inline __m256i float2bfloat_avx(__m256 v0, __m256 v1)
48 {
49 __m256i a = _mm256_castps_si256(v0);
50 a = _mm256_srli_epi32(a, 16);
51 __m256i b = _mm256_castps_si256(v1);
52 b = _mm256_srli_epi32(b, 16);
53 __m256i abab = _mm256_packus_epi32(a, b);
54 return _mm256_permutevar8x32_epi32(abab, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7));
55 }
float2bfloat_avx(__m256 v0)56 static inline __m128i float2bfloat_avx(__m256 v0)
57 {
58 __m256i a = _mm256_castps_si256(v0);
59 a = _mm256_srli_epi32(a, 16);
60 __m256i aaaa = _mm256_packus_epi32(a, a);
61 return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(aaaa, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7)));
62 }
63 #endif // __AVX__
64
65 namespace ncnn {
66
Cast_x86()67 Cast_x86::Cast_x86()
68 {
69 support_packing = true;
70 }
71
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const72 int Cast_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
73 {
74 #if __AVX__
75 if (type_from == type_to)
76 {
77 top_blob = bottom_blob;
78 return 0;
79 }
80
81 int w = bottom_blob.w;
82 int h = bottom_blob.h;
83 int channels = bottom_blob.c;
84 int dims = bottom_blob.dims;
85 size_t elemsize = bottom_blob.elemsize;
86 int elempack = bottom_blob.elempack;
87
88 size_t out_elemsize = elemsize;
89 if (type_to == 1)
90 {
91 if (type_from == 3)
92 {
93 Cast::forward(bottom_blob, top_blob, opt);
94 }
95
96 // float32
97 out_elemsize = 4 * elempack;
98 }
99 else if (type_to == 2)
100 {
101 // float16
102 out_elemsize = 2 * elempack;
103 }
104 else if (type_to == 3)
105 {
106 // int8
107 out_elemsize = elempack;
108 }
109 else if (type_to == 4)
110 {
111 // bfloat16
112 out_elemsize = 2 * elempack;
113 }
114
115 if (dims == 1)
116 {
117 top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
118 }
119 else if (dims == 2)
120 {
121 top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
122 }
123 else if (dims == 3)
124 {
125 top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
126 }
127 if (top_blob.empty())
128 return -100;
129
130 int size = w * h * elempack;
131
132 if (type_from == 1 && type_to == 2)
133 {
134 int nn = size >> 3;
135 int remain = size - (nn << 3);
136 m256i mask = {_mm256_setzero_si256()};
137 for (int i = 0; i < remain; i++)
138 mask.m256i_u32[i] = 0x80000000;
139
140 #pragma omp parallel for num_threads(opt.num_threads)
141 for (int q = 0; q < channels; q++)
142 {
143 const float* ptr = bottom_blob.channel(q);
144 unsigned short* outptr = top_blob.channel(q);
145
146 for (int i = 0; i < nn; i++)
147 {
148 __m256 fp32 = _mm256_loadu_ps(ptr);
149 __m128i fp16 = _mm256_cvtps_ph(fp32, _MM_FROUND_TRUNC);
150 _mm_store_si128((__m128i*)outptr, fp16);
151 ptr += 8;
152 outptr += 8;
153 }
154
155 if (remain > 0)
156 {
157 __m256 fp32 = _mm256_maskload_ps(ptr, mask.vec);
158 m128i fp16 = {_mm256_cvtps_ph(fp32, _MM_FROUND_TRUNC)};
159 memcpy(outptr, fp16.m128i_u16, remain * sizeof(unsigned short));
160 }
161 }
162 }
163
164 if (type_from == 2 && type_to == 1)
165 {
166 int nn = size >> 3;
167 int remain = size - (nn << 3);
168 m256i mask = {_mm256_setzero_si256()};
169 for (int i = 0; i < remain; i++)
170 mask.m256i_u32[i] = 0x80000000;
171
172 #pragma omp parallel for num_threads(opt.num_threads)
173 for (int q = 0; q < channels; q++)
174 {
175 const unsigned short* ptr = bottom_blob.channel(q);
176 float* outptr = top_blob.channel(q);
177
178 for (int i = 0; i < nn; i++)
179 {
180 __m128i fp16 = _mm_lddqu_si128((__m128i const*)ptr);
181 __m256 fp32 = _mm256_cvtph_ps(fp16);
182 _mm256_storeu_ps(outptr, fp32);
183 ptr += 8;
184 outptr += 8;
185 }
186
187 if (remain > 0)
188 {
189 m128i fp16 = {_mm_setzero_si128()};
190 memcpy(fp16.m128i_u16, ptr, remain * sizeof(unsigned short));
191 __m256 fp32 = _mm256_cvtph_ps(fp16.vec);
192 _mm256_maskstore_ps(outptr, mask.vec, fp32);
193 }
194 }
195 }
196 if (type_from == 4 && type_to == 1)
197 {
198 #pragma omp parallel for num_threads(opt.num_threads)
199 for (int q = 0; q < channels; q++)
200 {
201 const unsigned short* ptr = bottom_blob.channel(q);
202 float* outptr = top_blob.channel(q);
203
204 int nn = size >> 3;
205 int remain = size & 7;
206 for (; nn > 0; nn--)
207 {
208 _mm256_storeu_ps(outptr, bfloat2float_avx(_mm_lddqu_si128((__m128i const*)ptr)));
209 ptr += 8;
210 outptr += 8;
211 }
212
213 for (; remain > 0; remain--)
214 {
215 *outptr = bfloat16_to_float32(*ptr);
216 outptr++;
217 ptr++;
218 }
219 }
220 }
221 if (type_from == 1 && type_to == 4)
222 {
223 #pragma omp parallel for num_threads(opt.num_threads)
224 for (int q = 0; q < channels; q++)
225 {
226 const float* ptr = bottom_blob.channel(q);
227 unsigned short* outptr = top_blob.channel(q);
228 int nn = size >> 4;
229 int remain = size & 15;
230 for (; nn > 0; nn--)
231 {
232 _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr + 8)));
233 ptr += 16;
234 outptr += 16;
235 }
236 if (remain >= 8)
237 {
238 remain -= 8;
239 _mm_store_si128((__m128i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr)));
240 ptr += 8;
241 outptr += 8;
242 }
243 for (; remain > 0; remain--)
244 {
245 *outptr = float32_to_bfloat16(*ptr);
246 outptr++;
247 ptr++;
248 }
249 }
250 }
251
252 return 0;
253 #else // __AVX__
254
255 return Cast::forward(bottom_blob, top_blob, opt);
256
257 #endif // __AVX__
258 }
259
260 } // namespace ncnn
261