1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "cast_arm.h"
16
17 #ifdef __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20
21 #include "cpu.h"
22
23 namespace ncnn {
24
Cast_arm()25 Cast_arm::Cast_arm()
26 {
27 #if __ARM_NEON
28 support_packing = true;
29 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
30 support_fp16_storage = true;
31 #endif
32 #endif // __ARM_NEON
33
34 support_bf16_storage = true;
35 }
36
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const37 int Cast_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
38 {
39 if (type_from == type_to)
40 {
41 top_blob = bottom_blob;
42 return 0;
43 }
44
45 int w = bottom_blob.w;
46 int h = bottom_blob.h;
47 int channels = bottom_blob.c;
48 int dims = bottom_blob.dims;
49 size_t elemsize = bottom_blob.elemsize;
50 int elempack = bottom_blob.elempack;
51
52 #if __ARM_NEON
53 if (elempack % 4 == 0)
54 {
55 #if (__ARM_FP & 2)
56 if (!cpu_support_arm_vfpv4() && (type_from == 2 || type_to == 2))
57 #else
58 if (type_from == 2 || type_to == 2)
59 #endif // (__ARM_FP & 2)
60 {
61 // no fp16 conversion instruction, fallback
62 return Cast::forward(bottom_blob, top_blob, opt);
63 }
64
65 size_t out_elemsize = elemsize;
66 if (type_to == 1)
67 {
68 // float32
69 out_elemsize = 4 * elempack;
70 }
71 else if (type_to == 2)
72 {
73 // float16
74 out_elemsize = 2 * elempack;
75 }
76 else if (type_to == 3)
77 {
78 // int8
79 out_elemsize = elempack;
80 }
81 else if (type_to == 4)
82 {
83 // bfloat16
84 out_elemsize = 2 * elempack;
85 }
86
87 if (dims == 1)
88 {
89 top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
90 }
91 else if (dims == 2)
92 {
93 top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
94 }
95 else if (dims == 3)
96 {
97 top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
98 }
99 if (top_blob.empty())
100 return -100;
101
102 int size = w * h * elempack;
103
104 #if (__ARM_FP & 2)
105 if (type_from == 1 && type_to == 2)
106 {
107 #pragma omp parallel for num_threads(opt.num_threads)
108 for (int q = 0; q < channels; q++)
109 {
110 const float* ptr = bottom_blob.channel(q);
111 unsigned short* outptr = top_blob.channel(q);
112
113 int nn = size / 4;
114
115 #if __aarch64__
116 asm volatile(
117 "0: \n"
118 "prfm pldl1keep, [%1, #128] \n"
119 "ld1 {v0.4s}, [%1], #16 \n"
120 "fcvtn v1.4h, v0.4s \n"
121 "subs %w0, %w0, #1 \n"
122 "st1 {v1.4h}, [%2], #8 \n"
123 "bne 0b \n"
124 : "=r"(nn), // %0
125 "=r"(ptr), // %1
126 "=r"(outptr) // %2
127 : "0"(nn),
128 "1"(ptr),
129 "2"(outptr)
130 : "cc", "memory", "v0", "v1");
131 #else
132 asm volatile(
133 "0: \n"
134 "pld [%1, #128] \n"
135 "vld1.f32 {d0-d1}, [%1 :128]! \n"
136 "vcvt.f16.f32 d2, q0 \n"
137 "subs %0, #1 \n"
138 "vst1.f32 {d2}, [%2 :64]! \n"
139 "bne 0b \n"
140 : "=r"(nn), // %0
141 "=r"(ptr), // %1
142 "=r"(outptr) // %2
143 : "0"(nn),
144 "1"(ptr),
145 "2"(outptr)
146 : "cc", "memory", "q0", "q1");
147 #endif // __aarch64__
148 }
149 }
150
151 if (type_from == 2 && type_to == 1)
152 {
153 #pragma omp parallel for num_threads(opt.num_threads)
154 for (int q = 0; q < channels; q++)
155 {
156 const unsigned short* ptr = bottom_blob.channel(q);
157 float* outptr = top_blob.channel(q);
158
159 int nn = size / 4;
160
161 #if __aarch64__
162 asm volatile(
163 "0: \n"
164 "prfm pldl1keep, [%1, #64] \n"
165 "ld1 {v0.4h}, [%1], #8 \n"
166 "fcvtl v1.4s, v0.4h \n"
167 "subs %w0, %w0, #1 \n"
168 "st1 {v1.4s}, [%2], #16 \n"
169 "bne 0b \n"
170 : "=r"(nn), // %0
171 "=r"(ptr), // %1
172 "=r"(outptr) // %2
173 : "0"(nn),
174 "1"(ptr),
175 "2"(outptr)
176 : "cc", "memory", "v0", "v1");
177 #else
178 asm volatile(
179 "0: \n"
180 "pld [%1, #64] \n"
181 "vld1.s16 {d0}, [%1 :64]! \n"
182 "vcvt.f32.f16 q1, d0 \n"
183 "subs %0, #1 \n"
184 "vst1.f32 {d2-d3}, [%2 :128]! \n"
185 "bne 0b \n"
186 : "=r"(nn), // %0
187 "=r"(ptr), // %1
188 "=r"(outptr) // %2
189 : "0"(nn),
190 "1"(ptr),
191 "2"(outptr)
192 : "cc", "memory", "q0", "q1");
193 #endif // __aarch64__
194 }
195 }
196 #endif // (__ARM_FP & 2)
197
198 if (type_from == 3 && type_to == 1)
199 {
200 #pragma omp parallel for num_threads(opt.num_threads)
201 for (int q = 0; q < channels; q++)
202 {
203 const signed char* ptr = bottom_blob.channel(q);
204 float* outptr = top_blob.channel(q);
205
206 for (int i = 0; i < size; i++)
207 {
208 outptr[i] = (float)ptr[i];
209 }
210 }
211 }
212
213 if (type_from == 1 && type_to == 4)
214 {
215 #pragma omp parallel for num_threads(opt.num_threads)
216 for (int q = 0; q < channels; q++)
217 {
218 const float* ptr = bottom_blob.channel(q);
219 unsigned short* outptr = top_blob.channel(q);
220
221 int nn = size / 4;
222
223 #if __aarch64__
224 asm volatile(
225 "0: \n"
226 "prfm pldl1keep, [%1, #128] \n"
227 "ld1 {v0.4s}, [%1], #16 \n"
228 "shrn v1.4h, v0.4s, #16 \n"
229 "subs %w0, %w0, #1 \n"
230 "st1 {v1.4h}, [%2], #8 \n"
231 "bne 0b \n"
232 : "=r"(nn), // %0
233 "=r"(ptr), // %1
234 "=r"(outptr) // %2
235 : "0"(nn),
236 "1"(ptr),
237 "2"(outptr)
238 : "cc", "memory", "v0", "v1");
239 #else
240 asm volatile(
241 "0: \n"
242 "pld [%1, #128] \n"
243 "vld1.f32 {d0-d1}, [%1 :128]! \n"
244 "vshrn.u32 d2, q0, #16 \n"
245 "subs %0, #1 \n"
246 "vst1.u16 {d2}, [%2 :64]! \n"
247 "bne 0b \n"
248 : "=r"(nn), // %0
249 "=r"(ptr), // %1
250 "=r"(outptr) // %2
251 : "0"(nn),
252 "1"(ptr),
253 "2"(outptr)
254 : "cc", "memory", "q0", "q1");
255 #endif // __aarch64__
256 }
257 }
258
259 if (type_from == 4 && type_to == 1)
260 {
261 #pragma omp parallel for num_threads(opt.num_threads)
262 for (int q = 0; q < channels; q++)
263 {
264 const unsigned short* ptr = bottom_blob.channel(q);
265 float* outptr = top_blob.channel(q);
266
267 int nn = size / 4;
268
269 #if __aarch64__
270 asm volatile(
271 "0: \n"
272 "prfm pldl1keep, [%1, #64] \n"
273 "ld1 {v0.4h}, [%1], #8 \n"
274 "shll v1.4s, v0.4h, #16 \n"
275 "subs %w0, %w0, #1 \n"
276 "st1 {v1.4s}, [%2], #16 \n"
277 "bne 0b \n"
278 : "=r"(nn), // %0
279 "=r"(ptr), // %1
280 "=r"(outptr) // %2
281 : "0"(nn),
282 "1"(ptr),
283 "2"(outptr)
284 : "cc", "memory", "v0", "v1");
285 #else
286 asm volatile(
287 "0: \n"
288 "pld [%1, #64] \n"
289 "vld1.u16 {d0}, [%1 :64]! \n"
290 "vshll.u16 q1, d0, #16 \n"
291 "subs %0, #1 \n"
292 "vst1.f32 {d2-d3}, [%2 :128]! \n"
293 "bne 0b \n"
294 : "=r"(nn), // %0
295 "=r"(ptr), // %1
296 "=r"(outptr) // %2
297 : "0"(nn),
298 "1"(ptr),
299 "2"(outptr)
300 : "cc", "memory", "q0", "q1");
301 #endif // __aarch64__
302 }
303 }
304
305 // TODO more cast type
306
307 return 0;
308 }
309 #endif // __ARM_NEON
310
311 return Cast::forward(bottom_blob, top_blob, opt);
312 }
313
314 } // namespace ncnn
315