1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "cast_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif // __AVX__
22 #endif // __SSE2__
23 
24 #if __AVX__
25 #include <stdint.h>
26 typedef union m128i
27 {
28     __m128i vec;
29     uint16_t m128i_u16[8];
30 } m128;
31 
32 typedef union m256i
33 {
34     __m256i vec;
35     uint32_t m256i_u32[8];
36 } m256;
bfloat2float_avx(__m128i v0)37 static inline __m256 bfloat2float_avx(__m128i v0)
38 {
39     __m128i zero = _mm_set1_epi32(0);
40     __m128i a = _mm_slli_epi32(_mm_unpacklo_epi16(v0, zero), 16);
41     __m128i b = _mm_slli_epi32(_mm_unpackhi_epi16(v0, zero), 16);
42     __m256i ab = _mm256_set1_epi32(0);
43     ab = _mm256_insertf128_si256(ab, a, 0); // insert in low 128-bit lane
44     ab = _mm256_insertf128_si256(ab, b, 1); // insert in high 128-bit lane
45     return _mm256_castsi256_ps(ab);
46 }
float2bfloat_avx(__m256 v0,__m256 v1)47 static inline __m256i float2bfloat_avx(__m256 v0, __m256 v1)
48 {
49     __m256i a = _mm256_castps_si256(v0);
50     a = _mm256_srli_epi32(a, 16);
51     __m256i b = _mm256_castps_si256(v1);
52     b = _mm256_srli_epi32(b, 16);
53     __m256i abab = _mm256_packus_epi32(a, b);
54     return _mm256_permutevar8x32_epi32(abab, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7));
55 }
float2bfloat_avx(__m256 v0)56 static inline __m128i float2bfloat_avx(__m256 v0)
57 {
58     __m256i a = _mm256_castps_si256(v0);
59     a = _mm256_srli_epi32(a, 16);
60     __m256i aaaa = _mm256_packus_epi32(a, a);
61     return _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(aaaa, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7)));
62 }
63 #endif // __AVX__
64 
65 namespace ncnn {
66 
Cast_x86()67 Cast_x86::Cast_x86()
68 {
69     support_packing = true;
70 }
71 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const72 int Cast_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
73 {
74 #if __AVX__
75     if (type_from == type_to)
76     {
77         top_blob = bottom_blob;
78         return 0;
79     }
80 
81     int w = bottom_blob.w;
82     int h = bottom_blob.h;
83     int channels = bottom_blob.c;
84     int dims = bottom_blob.dims;
85     size_t elemsize = bottom_blob.elemsize;
86     int elempack = bottom_blob.elempack;
87 
88     size_t out_elemsize = elemsize;
89     if (type_to == 1)
90     {
91         if (type_from == 3)
92         {
93             Cast::forward(bottom_blob, top_blob, opt);
94         }
95 
96         // float32
97         out_elemsize = 4 * elempack;
98     }
99     else if (type_to == 2)
100     {
101         // float16
102         out_elemsize = 2 * elempack;
103     }
104     else if (type_to == 3)
105     {
106         // int8
107         out_elemsize = elempack;
108     }
109     else if (type_to == 4)
110     {
111         // bfloat16
112         out_elemsize = 2 * elempack;
113     }
114 
115     if (dims == 1)
116     {
117         top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
118     }
119     else if (dims == 2)
120     {
121         top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
122     }
123     else if (dims == 3)
124     {
125         top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
126     }
127     if (top_blob.empty())
128         return -100;
129 
130     int size = w * h * elempack;
131 
132     if (type_from == 1 && type_to == 2)
133     {
134         int nn = size >> 3;
135         int remain = size - (nn << 3);
136         m256i mask = {_mm256_setzero_si256()};
137         for (int i = 0; i < remain; i++)
138             mask.m256i_u32[i] = 0x80000000;
139 
140         #pragma omp parallel for num_threads(opt.num_threads)
141         for (int q = 0; q < channels; q++)
142         {
143             const float* ptr = bottom_blob.channel(q);
144             unsigned short* outptr = top_blob.channel(q);
145 
146             for (int i = 0; i < nn; i++)
147             {
148                 __m256 fp32 = _mm256_loadu_ps(ptr);
149                 __m128i fp16 = _mm256_cvtps_ph(fp32, _MM_FROUND_TRUNC);
150                 _mm_store_si128((__m128i*)outptr, fp16);
151                 ptr += 8;
152                 outptr += 8;
153             }
154 
155             if (remain > 0)
156             {
157                 __m256 fp32 = _mm256_maskload_ps(ptr, mask.vec);
158                 m128i fp16 = {_mm256_cvtps_ph(fp32, _MM_FROUND_TRUNC)};
159                 memcpy(outptr, fp16.m128i_u16, remain * sizeof(unsigned short));
160             }
161         }
162     }
163 
164     if (type_from == 2 && type_to == 1)
165     {
166         int nn = size >> 3;
167         int remain = size - (nn << 3);
168         m256i mask = {_mm256_setzero_si256()};
169         for (int i = 0; i < remain; i++)
170             mask.m256i_u32[i] = 0x80000000;
171 
172         #pragma omp parallel for num_threads(opt.num_threads)
173         for (int q = 0; q < channels; q++)
174         {
175             const unsigned short* ptr = bottom_blob.channel(q);
176             float* outptr = top_blob.channel(q);
177 
178             for (int i = 0; i < nn; i++)
179             {
180                 __m128i fp16 = _mm_lddqu_si128((__m128i const*)ptr);
181                 __m256 fp32 = _mm256_cvtph_ps(fp16);
182                 _mm256_storeu_ps(outptr, fp32);
183                 ptr += 8;
184                 outptr += 8;
185             }
186 
187             if (remain > 0)
188             {
189                 m128i fp16 = {_mm_setzero_si128()};
190                 memcpy(fp16.m128i_u16, ptr, remain * sizeof(unsigned short));
191                 __m256 fp32 = _mm256_cvtph_ps(fp16.vec);
192                 _mm256_maskstore_ps(outptr, mask.vec, fp32);
193             }
194         }
195     }
196     if (type_from == 4 && type_to == 1)
197     {
198         #pragma omp parallel for num_threads(opt.num_threads)
199         for (int q = 0; q < channels; q++)
200         {
201             const unsigned short* ptr = bottom_blob.channel(q);
202             float* outptr = top_blob.channel(q);
203 
204             int nn = size >> 3;
205             int remain = size & 7;
206             for (; nn > 0; nn--)
207             {
208                 _mm256_storeu_ps(outptr, bfloat2float_avx(_mm_lddqu_si128((__m128i const*)ptr)));
209                 ptr += 8;
210                 outptr += 8;
211             }
212 
213             for (; remain > 0; remain--)
214             {
215                 *outptr = bfloat16_to_float32(*ptr);
216                 outptr++;
217                 ptr++;
218             }
219         }
220     }
221     if (type_from == 1 && type_to == 4)
222     {
223         #pragma omp parallel for num_threads(opt.num_threads)
224         for (int q = 0; q < channels; q++)
225         {
226             const float* ptr = bottom_blob.channel(q);
227             unsigned short* outptr = top_blob.channel(q);
228             int nn = size >> 4;
229             int remain = size & 15;
230             for (; nn > 0; nn--)
231             {
232                 _mm256_storeu_si256((__m256i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr + 8)));
233                 ptr += 16;
234                 outptr += 16;
235             }
236             if (remain >= 8)
237             {
238                 remain -= 8;
239                 _mm_store_si128((__m128i*)outptr, float2bfloat_avx(_mm256_loadu_ps(ptr)));
240                 ptr += 8;
241                 outptr += 8;
242             }
243             for (; remain > 0; remain--)
244             {
245                 *outptr = float32_to_bfloat16(*ptr);
246                 outptr++;
247                 ptr++;
248             }
249         }
250     }
251 
252     return 0;
253 #else // __AVX__
254 
255     return Cast::forward(bottom_blob, top_blob, opt);
256 
257 #endif // __AVX__
258 }
259 
260 } // namespace ncnn
261