1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "hardswish_arm.h"
16
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20
21 namespace ncnn {
22
HardSwish_arm()23 HardSwish_arm::HardSwish_arm()
24 {
25 #if __ARM_NEON
26 support_packing = true;
27 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
28 support_fp16_storage = true;
29 #endif
30 #endif // __ARM_NEON
31
32 support_bf16_storage = true;
33 }
34
forward_inplace(Mat & bottom_top_blob,const Option & opt) const35 int HardSwish_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
36 {
37 int elembits = bottom_top_blob.elembits();
38
39 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
40 if (opt.use_fp16_storage && elembits == 16)
41 {
42 if (opt.use_fp16_arithmetic)
43 return forward_inplace_fp16sa(bottom_top_blob, opt);
44 else
45 return forward_inplace_fp16s(bottom_top_blob, opt);
46 }
47 #endif
48
49 if (opt.use_bf16_storage && elembits == 16)
50 return forward_inplace_bf16s(bottom_top_blob, opt);
51
52 int w = bottom_top_blob.w;
53 int h = bottom_top_blob.h;
54 int channels = bottom_top_blob.c;
55 int size = w * h;
56 int elempack = bottom_top_blob.elempack;
57
58 #if __ARM_NEON
59 if (elempack == 4)
60 {
61 #pragma omp parallel for num_threads(opt.num_threads)
62 for (int q = 0; q < channels; q++)
63 {
64 float* ptr = bottom_top_blob.channel(q);
65
66 float32x4_t _zero = vdupq_n_f32(0.f);
67 float32x4_t _one = vdupq_n_f32(1.f);
68 for (int i = 0; i < size; i++)
69 {
70 float32x4_t _p = vld1q_f32(ptr);
71 float32x4_t _ans = vdupq_n_f32(beta);
72 _ans = vmlaq_n_f32(_ans, _p, alpha);
73 _ans = vmaxq_f32(_ans, _zero);
74 _ans = vminq_f32(_ans, _one);
75 _ans = vmulq_f32(_ans, _p);
76 vst1q_f32(ptr, _ans);
77
78 ptr += 4;
79 }
80 }
81
82 return 0;
83 }
84 #endif // __ARM_NEON
85
86 #pragma omp parallel for num_threads(opt.num_threads)
87 for (int q = 0; q < channels; q++)
88 {
89 float* ptr = bottom_top_blob.channel(q);
90
91 #if __ARM_NEON
92 int nn = size >> 2;
93 int remain = size - (nn << 2);
94 #else
95 int remain = size;
96 #endif // __ARM_NEON
97
98 #if __ARM_NEON
99 float32x4_t _zero = vdupq_n_f32(0.f);
100 float32x4_t _one = vdupq_n_f32(1.f);
101 while (nn--)
102 {
103 float32x4_t _p = vld1q_f32(ptr);
104 float32x4_t _ans = vdupq_n_f32(beta);
105 _ans = vmlaq_n_f32(_ans, _p, alpha);
106 _ans = vmaxq_f32(_ans, _zero);
107 _ans = vminq_f32(_ans, _one);
108 _ans = vmulq_f32(_ans, _p);
109 vst1q_f32(ptr, _ans);
110
111 ptr += 4;
112 }
113 #endif // __ARM_NEON
114 for (; remain > 0; remain--)
115 {
116 if (*ptr < lower)
117 *ptr = 0.f;
118 else if (*ptr > upper)
119 ;
120 else
121 *ptr = *ptr * (*ptr * alpha + beta);
122 ++ptr;
123 }
124 }
125
126 return 0;
127 }
128
129 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const130 int HardSwish_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
131 {
132 int w = bottom_top_blob.w;
133 int h = bottom_top_blob.h;
134 int channels = bottom_top_blob.c;
135 int size = w * h;
136 int elempack = bottom_top_blob.elempack;
137
138 if (elempack == 4)
139 {
140 #pragma omp parallel for num_threads(opt.num_threads)
141 for (int q = 0; q < channels; q++)
142 {
143 __fp16* ptr = bottom_top_blob.channel(q);
144
145 float32x4_t _zero = vdupq_n_f32(0.f);
146 float32x4_t _one = vdupq_n_f32(1.f);
147 for (int i = 0; i < size; i++)
148 {
149 float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr));
150 float32x4_t _ans = vdupq_n_f32(beta);
151 _ans = vfmaq_n_f32(_ans, _p, alpha);
152 _ans = vmaxq_f32(_ans, _zero);
153 _ans = vminq_f32(_ans, _one);
154 _ans = vmulq_f32(_ans, _p);
155 vst1_f16(ptr, vcvt_f16_f32(_ans));
156
157 ptr += 4;
158 }
159 }
160
161 return 0;
162 }
163
164 #pragma omp parallel for num_threads(opt.num_threads)
165 for (int q = 0; q < channels; q++)
166 {
167 __fp16* ptr = bottom_top_blob.channel(q);
168
169 float32x4_t _zero = vdupq_n_f32(0.f);
170 float32x4_t _one = vdupq_n_f32(1.f);
171
172 int i = 0;
173 for (; i + 3 < size; i += 4)
174 {
175 float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr));
176 float32x4_t _ans = vdupq_n_f32(beta);
177 _ans = vfmaq_n_f32(_ans, _p, alpha);
178 _ans = vmaxq_f32(_ans, _zero);
179 _ans = vminq_f32(_ans, _one);
180 _ans = vmulq_f32(_ans, _p);
181 vst1_f16(ptr, vcvt_f16_f32(_ans));
182
183 ptr += 4;
184 }
185 for (; i < size; i++)
186 {
187 float v = (float)*ptr;
188 if (v < lower)
189 v = 0.f;
190 else if (v > upper)
191 ;
192 else
193 v = v * (v * alpha + beta);
194 *ptr = (__fp16)v;
195 ++ptr;
196 }
197 }
198
199 return 0;
200 }
201
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const202 int HardSwish_arm::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
203 {
204 int w = bottom_top_blob.w;
205 int h = bottom_top_blob.h;
206 int channels = bottom_top_blob.c;
207 int size = w * h;
208 int elempack = bottom_top_blob.elempack;
209
210 if (elempack == 8)
211 {
212 #pragma omp parallel for num_threads(opt.num_threads)
213 for (int q = 0; q < channels; q++)
214 {
215 __fp16* ptr = bottom_top_blob.channel(q);
216
217 __fp16 alpha_fp16 = (__fp16)alpha;
218 __fp16 beta_fp16 = (__fp16)beta;
219
220 float16x8_t _zero = vdupq_n_f16((__fp16)0.f);
221 float16x8_t _one = vdupq_n_f16((__fp16)1.f);
222 for (int i = 0; i < size; i++)
223 {
224 float16x8_t _p = vld1q_f16(ptr);
225 float16x8_t _ans = vdupq_n_f16(beta_fp16);
226 _ans = vfmaq_n_f16(_ans, _p, alpha_fp16);
227 _ans = vmaxq_f16(_ans, _zero);
228 _ans = vminq_f16(_ans, _one);
229 _ans = vmulq_f16(_ans, _p);
230 vst1q_f16(ptr, _ans);
231
232 ptr += 8;
233 }
234 }
235
236 return 0;
237 }
238
239 if (elempack == 4)
240 {
241 #pragma omp parallel for num_threads(opt.num_threads)
242 for (int q = 0; q < channels; q++)
243 {
244 __fp16* ptr = bottom_top_blob.channel(q);
245
246 __fp16 alpha_fp16 = (__fp16)alpha;
247 __fp16 beta_fp16 = (__fp16)beta;
248
249 float16x8_t _zero = vdupq_n_f16((__fp16)0.f);
250 float16x8_t _one = vdupq_n_f16((__fp16)1.f);
251
252 int i = 0;
253 for (; i + 1 < size; i += 2)
254 {
255 float16x8_t _p = vld1q_f16(ptr);
256 float16x8_t _ans = vdupq_n_f16(beta_fp16);
257 _ans = vfmaq_n_f16(_ans, _p, alpha_fp16);
258 _ans = vmaxq_f16(_ans, _zero);
259 _ans = vminq_f16(_ans, _one);
260 _ans = vmulq_f16(_ans, _p);
261 vst1q_f16(ptr, _ans);
262
263 ptr += 8;
264 }
265 for (; i < size; i++)
266 {
267 float16x4_t _p = vld1_f16(ptr);
268 float16x4_t _ans = vdup_n_f16(beta_fp16);
269 _ans = vfma_n_f16(_ans, _p, alpha_fp16);
270 _ans = vmax_f16(_ans, vget_low_f16(_zero));
271 _ans = vmin_f16(_ans, vget_low_f16(_one));
272 _ans = vmul_f16(_ans, _p);
273 vst1_f16(ptr, _ans);
274
275 ptr += 4;
276 }
277 }
278
279 return 0;
280 }
281
282 #pragma omp parallel for num_threads(opt.num_threads)
283 for (int q = 0; q < channels; q++)
284 {
285 __fp16* ptr = bottom_top_blob.channel(q);
286
287 __fp16 alpha_fp16 = (__fp16)alpha;
288 __fp16 beta_fp16 = (__fp16)beta;
289
290 float16x8_t _zero = vdupq_n_f16((__fp16)0.f);
291 float16x8_t _one = vdupq_n_f16((__fp16)1.f);
292
293 int i = 0;
294 for (; i + 7 < size; i += 8)
295 {
296 float16x8_t _p = vld1q_f16(ptr);
297 float16x8_t _ans = vdupq_n_f16(beta_fp16);
298 _ans = vfmaq_n_f16(_ans, _p, alpha_fp16);
299 _ans = vmaxq_f16(_ans, _zero);
300 _ans = vminq_f16(_ans, _one);
301 _ans = vmulq_f16(_ans, _p);
302 vst1q_f16(ptr, _ans);
303
304 ptr += 8;
305 }
306 for (; i + 3 < size; i += 4)
307 {
308 float16x4_t _p = vld1_f16(ptr);
309 float16x4_t _ans = vdup_n_f16(beta_fp16);
310 _ans = vfma_n_f16(_ans, _p, alpha_fp16);
311 _ans = vmax_f16(_ans, vget_low_f16(_zero));
312 _ans = vmin_f16(_ans, vget_low_f16(_one));
313 _ans = vmul_f16(_ans, _p);
314 vst1_f16(ptr, _ans);
315
316 ptr += 4;
317 }
318 for (; i < size; i++)
319 {
320 __fp16 v = *ptr;
321 if (v < (__fp16)lower)
322 v = (__fp16)0.f;
323 else if (v > (__fp16)upper)
324 ;
325 else
326 v = v * (v * alpha_fp16 + beta_fp16);
327 *ptr = v;
328 ++ptr;
329 }
330 }
331
332 return 0;
333 }
334 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
335
forward_inplace_bf16s(Mat & bottom_top_blob,const Option & opt) const336 int HardSwish_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
337 {
338 int w = bottom_top_blob.w;
339 int h = bottom_top_blob.h;
340 int channels = bottom_top_blob.c;
341 int size = w * h;
342 int elempack = bottom_top_blob.elempack;
343
344 #if __ARM_NEON
345 if (elempack == 4)
346 {
347 #pragma omp parallel for num_threads(opt.num_threads)
348 for (int q = 0; q < channels; q++)
349 {
350 unsigned short* ptr = bottom_top_blob.channel(q);
351
352 float32x4_t _zero = vdupq_n_f32(0.f);
353 float32x4_t _one = vdupq_n_f32(1.f);
354 for (int i = 0; i < size; i++)
355 {
356 float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
357 float32x4_t _ans = vdupq_n_f32(beta);
358 _ans = vmlaq_n_f32(_ans, _p, alpha);
359 _ans = vmaxq_f32(_ans, _zero);
360 _ans = vminq_f32(_ans, _one);
361 _ans = vmulq_f32(_ans, _p);
362 vst1_u16(ptr, vcvt_bf16_f32(_ans));
363
364 ptr += 4;
365 }
366 }
367
368 return 0;
369 }
370 #endif // __ARM_NEON
371
372 #pragma omp parallel for num_threads(opt.num_threads)
373 for (int q = 0; q < channels; q++)
374 {
375 unsigned short* ptr = bottom_top_blob.channel(q);
376
377 #if __ARM_NEON
378 int nn = size >> 2;
379 int remain = size - (nn << 2);
380 #else
381 int remain = size;
382 #endif // __ARM_NEON
383
384 #if __ARM_NEON
385 float32x4_t _zero = vdupq_n_f32(0.f);
386 float32x4_t _one = vdupq_n_f32(1.f);
387 while (nn--)
388 {
389 float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
390 float32x4_t _ans = vdupq_n_f32(beta);
391 _ans = vmlaq_n_f32(_ans, _p, alpha);
392 _ans = vmaxq_f32(_ans, _zero);
393 _ans = vminq_f32(_ans, _one);
394 _ans = vmulq_f32(_ans, _p);
395 vst1_u16(ptr, vcvt_bf16_f32(_ans));
396
397 ptr += 4;
398 }
399 #endif // __ARM_NEON
400 for (; remain > 0; remain--)
401 {
402 float v = bfloat16_to_float32(*ptr);
403 if (v < lower)
404 v = 0.f;
405 else if (v > upper)
406 ;
407 else
408 v = v * (v * alpha + beta);
409 *ptr = float32_to_bfloat16(v);
410 ++ptr;
411 }
412 }
413
414 return 0;
415 }
416
417 } // namespace ncnn
418