1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "cast_arm.h"
16 
17 #ifdef __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 #include "cpu.h"
22 
23 namespace ncnn {
24 
Cast_arm()25 Cast_arm::Cast_arm()
26 {
27 #if __ARM_NEON
28     support_packing = true;
29 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
30     support_fp16_storage = true;
31 #endif
32 #endif // __ARM_NEON
33 
34     support_bf16_storage = true;
35 }
36 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const37 int Cast_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
38 {
39     if (type_from == type_to)
40     {
41         top_blob = bottom_blob;
42         return 0;
43     }
44 
45     int w = bottom_blob.w;
46     int h = bottom_blob.h;
47     int channels = bottom_blob.c;
48     int dims = bottom_blob.dims;
49     size_t elemsize = bottom_blob.elemsize;
50     int elempack = bottom_blob.elempack;
51 
52 #if __ARM_NEON
53     if (elempack % 4 == 0)
54     {
55 #if (__ARM_FP & 2)
56         if (!cpu_support_arm_vfpv4() && (type_from == 2 || type_to == 2))
57 #else
58         if (type_from == 2 || type_to == 2)
59 #endif // (__ARM_FP & 2)
60         {
61             // no fp16 conversion instruction, fallback
62             return Cast::forward(bottom_blob, top_blob, opt);
63         }
64 
65         size_t out_elemsize = elemsize;
66         if (type_to == 1)
67         {
68             // float32
69             out_elemsize = 4 * elempack;
70         }
71         else if (type_to == 2)
72         {
73             // float16
74             out_elemsize = 2 * elempack;
75         }
76         else if (type_to == 3)
77         {
78             // int8
79             out_elemsize = elempack;
80         }
81         else if (type_to == 4)
82         {
83             // bfloat16
84             out_elemsize = 2 * elempack;
85         }
86 
87         if (dims == 1)
88         {
89             top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
90         }
91         else if (dims == 2)
92         {
93             top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
94         }
95         else if (dims == 3)
96         {
97             top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
98         }
99         if (top_blob.empty())
100             return -100;
101 
102         int size = w * h * elempack;
103 
104 #if (__ARM_FP & 2)
105         if (type_from == 1 && type_to == 2)
106         {
107             #pragma omp parallel for num_threads(opt.num_threads)
108             for (int q = 0; q < channels; q++)
109             {
110                 const float* ptr = bottom_blob.channel(q);
111                 unsigned short* outptr = top_blob.channel(q);
112 
113                 int nn = size / 4;
114 
115 #if __aarch64__
116                 asm volatile(
117                     "0:                             \n"
118                     "prfm   pldl1keep, [%1, #128]   \n"
119                     "ld1    {v0.4s}, [%1], #16      \n"
120                     "fcvtn  v1.4h, v0.4s            \n"
121                     "subs   %w0, %w0, #1            \n"
122                     "st1    {v1.4h}, [%2], #8       \n"
123                     "bne    0b                      \n"
124                     : "=r"(nn),    // %0
125                     "=r"(ptr),   // %1
126                     "=r"(outptr) // %2
127                     : "0"(nn),
128                     "1"(ptr),
129                     "2"(outptr)
130                     : "cc", "memory", "v0", "v1");
131 #else
132                 asm volatile(
133                     "0:                             \n"
134                     "pld        [%1, #128]          \n"
135                     "vld1.f32   {d0-d1}, [%1 :128]! \n"
136                     "vcvt.f16.f32 d2, q0            \n"
137                     "subs       %0, #1              \n"
138                     "vst1.f32   {d2}, [%2 :64]!     \n"
139                     "bne        0b                  \n"
140                     : "=r"(nn),    // %0
141                     "=r"(ptr),   // %1
142                     "=r"(outptr) // %2
143                     : "0"(nn),
144                     "1"(ptr),
145                     "2"(outptr)
146                     : "cc", "memory", "q0", "q1");
147 #endif // __aarch64__
148             }
149         }
150 
151         if (type_from == 2 && type_to == 1)
152         {
153             #pragma omp parallel for num_threads(opt.num_threads)
154             for (int q = 0; q < channels; q++)
155             {
156                 const unsigned short* ptr = bottom_blob.channel(q);
157                 float* outptr = top_blob.channel(q);
158 
159                 int nn = size / 4;
160 
161 #if __aarch64__
162                 asm volatile(
163                     "0:                             \n"
164                     "prfm   pldl1keep, [%1, #64]    \n"
165                     "ld1    {v0.4h}, [%1], #8       \n"
166                     "fcvtl  v1.4s, v0.4h            \n"
167                     "subs   %w0, %w0, #1            \n"
168                     "st1    {v1.4s}, [%2], #16      \n"
169                     "bne    0b                      \n"
170                     : "=r"(nn),    // %0
171                     "=r"(ptr),   // %1
172                     "=r"(outptr) // %2
173                     : "0"(nn),
174                     "1"(ptr),
175                     "2"(outptr)
176                     : "cc", "memory", "v0", "v1");
177 #else
178                 asm volatile(
179                     "0:                             \n"
180                     "pld        [%1, #64]           \n"
181                     "vld1.s16   {d0}, [%1 :64]!     \n"
182                     "vcvt.f32.f16 q1, d0            \n"
183                     "subs       %0, #1              \n"
184                     "vst1.f32   {d2-d3}, [%2 :128]! \n"
185                     "bne        0b                  \n"
186                     : "=r"(nn),    // %0
187                     "=r"(ptr),   // %1
188                     "=r"(outptr) // %2
189                     : "0"(nn),
190                     "1"(ptr),
191                     "2"(outptr)
192                     : "cc", "memory", "q0", "q1");
193 #endif // __aarch64__
194             }
195         }
196 #endif // (__ARM_FP & 2)
197 
198         if (type_from == 3 && type_to == 1)
199         {
200             #pragma omp parallel for num_threads(opt.num_threads)
201             for (int q = 0; q < channels; q++)
202             {
203                 const signed char* ptr = bottom_blob.channel(q);
204                 float* outptr = top_blob.channel(q);
205 
206                 for (int i = 0; i < size; i++)
207                 {
208                     outptr[i] = (float)ptr[i];
209                 }
210             }
211         }
212 
213         if (type_from == 1 && type_to == 4)
214         {
215             #pragma omp parallel for num_threads(opt.num_threads)
216             for (int q = 0; q < channels; q++)
217             {
218                 const float* ptr = bottom_blob.channel(q);
219                 unsigned short* outptr = top_blob.channel(q);
220 
221                 int nn = size / 4;
222 
223 #if __aarch64__
224                 asm volatile(
225                     "0:                             \n"
226                     "prfm   pldl1keep, [%1, #128]   \n"
227                     "ld1    {v0.4s}, [%1], #16      \n"
228                     "shrn   v1.4h, v0.4s, #16       \n"
229                     "subs   %w0, %w0, #1            \n"
230                     "st1    {v1.4h}, [%2], #8       \n"
231                     "bne    0b                      \n"
232                     : "=r"(nn),    // %0
233                     "=r"(ptr),   // %1
234                     "=r"(outptr) // %2
235                     : "0"(nn),
236                     "1"(ptr),
237                     "2"(outptr)
238                     : "cc", "memory", "v0", "v1");
239 #else
240                 asm volatile(
241                     "0:                             \n"
242                     "pld        [%1, #128]          \n"
243                     "vld1.f32   {d0-d1}, [%1 :128]! \n"
244                     "vshrn.u32  d2, q0, #16         \n"
245                     "subs       %0, #1              \n"
246                     "vst1.u16   {d2}, [%2 :64]!     \n"
247                     "bne        0b                  \n"
248                     : "=r"(nn),    // %0
249                     "=r"(ptr),   // %1
250                     "=r"(outptr) // %2
251                     : "0"(nn),
252                     "1"(ptr),
253                     "2"(outptr)
254                     : "cc", "memory", "q0", "q1");
255 #endif // __aarch64__
256             }
257         }
258 
259         if (type_from == 4 && type_to == 1)
260         {
261             #pragma omp parallel for num_threads(opt.num_threads)
262             for (int q = 0; q < channels; q++)
263             {
264                 const unsigned short* ptr = bottom_blob.channel(q);
265                 float* outptr = top_blob.channel(q);
266 
267                 int nn = size / 4;
268 
269 #if __aarch64__
270                 asm volatile(
271                     "0:                             \n"
272                     "prfm   pldl1keep, [%1, #64]    \n"
273                     "ld1    {v0.4h}, [%1], #8       \n"
274                     "shll   v1.4s, v0.4h, #16       \n"
275                     "subs   %w0, %w0, #1            \n"
276                     "st1    {v1.4s}, [%2], #16      \n"
277                     "bne    0b                      \n"
278                     : "=r"(nn),    // %0
279                     "=r"(ptr),   // %1
280                     "=r"(outptr) // %2
281                     : "0"(nn),
282                     "1"(ptr),
283                     "2"(outptr)
284                     : "cc", "memory", "v0", "v1");
285 #else
286                 asm volatile(
287                     "0:                             \n"
288                     "pld        [%1, #64]           \n"
289                     "vld1.u16   {d0}, [%1 :64]!     \n"
290                     "vshll.u16  q1, d0, #16         \n"
291                     "subs       %0, #1              \n"
292                     "vst1.f32   {d2-d3}, [%2 :128]! \n"
293                     "bne        0b                  \n"
294                     : "=r"(nn),    // %0
295                     "=r"(ptr),   // %1
296                     "=r"(outptr) // %2
297                     : "0"(nn),
298                     "1"(ptr),
299                     "2"(outptr)
300                     : "cc", "memory", "q0", "q1");
301 #endif // __aarch64__
302             }
303         }
304 
305         // TODO more cast type
306 
307         return 0;
308     }
309 #endif // __ARM_NEON
310 
311     return Cast::forward(bottom_blob, top_blob, opt);
312 }
313 
314 } // namespace ncnn
315