1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "hardswish_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 namespace ncnn {
22 
HardSwish_arm()23 HardSwish_arm::HardSwish_arm()
24 {
25 #if __ARM_NEON
26     support_packing = true;
27 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
28     support_fp16_storage = true;
29 #endif
30 #endif // __ARM_NEON
31 
32     support_bf16_storage = true;
33 }
34 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const35 int HardSwish_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
36 {
37     int elembits = bottom_top_blob.elembits();
38 
39 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
40     if (opt.use_fp16_storage && elembits == 16)
41     {
42         if (opt.use_fp16_arithmetic)
43             return forward_inplace_fp16sa(bottom_top_blob, opt);
44         else
45             return forward_inplace_fp16s(bottom_top_blob, opt);
46     }
47 #endif
48 
49     if (opt.use_bf16_storage && elembits == 16)
50         return forward_inplace_bf16s(bottom_top_blob, opt);
51 
52     int w = bottom_top_blob.w;
53     int h = bottom_top_blob.h;
54     int channels = bottom_top_blob.c;
55     int size = w * h;
56     int elempack = bottom_top_blob.elempack;
57 
58 #if __ARM_NEON
59     if (elempack == 4)
60     {
61         #pragma omp parallel for num_threads(opt.num_threads)
62         for (int q = 0; q < channels; q++)
63         {
64             float* ptr = bottom_top_blob.channel(q);
65 
66             float32x4_t _zero = vdupq_n_f32(0.f);
67             float32x4_t _one = vdupq_n_f32(1.f);
68             for (int i = 0; i < size; i++)
69             {
70                 float32x4_t _p = vld1q_f32(ptr);
71                 float32x4_t _ans = vdupq_n_f32(beta);
72                 _ans = vmlaq_n_f32(_ans, _p, alpha);
73                 _ans = vmaxq_f32(_ans, _zero);
74                 _ans = vminq_f32(_ans, _one);
75                 _ans = vmulq_f32(_ans, _p);
76                 vst1q_f32(ptr, _ans);
77 
78                 ptr += 4;
79             }
80         }
81 
82         return 0;
83     }
84 #endif // __ARM_NEON
85 
86     #pragma omp parallel for num_threads(opt.num_threads)
87     for (int q = 0; q < channels; q++)
88     {
89         float* ptr = bottom_top_blob.channel(q);
90 
91 #if __ARM_NEON
92         int nn = size >> 2;
93         int remain = size - (nn << 2);
94 #else
95         int remain = size;
96 #endif // __ARM_NEON
97 
98 #if __ARM_NEON
99         float32x4_t _zero = vdupq_n_f32(0.f);
100         float32x4_t _one = vdupq_n_f32(1.f);
101         while (nn--)
102         {
103             float32x4_t _p = vld1q_f32(ptr);
104             float32x4_t _ans = vdupq_n_f32(beta);
105             _ans = vmlaq_n_f32(_ans, _p, alpha);
106             _ans = vmaxq_f32(_ans, _zero);
107             _ans = vminq_f32(_ans, _one);
108             _ans = vmulq_f32(_ans, _p);
109             vst1q_f32(ptr, _ans);
110 
111             ptr += 4;
112         }
113 #endif // __ARM_NEON
114         for (; remain > 0; remain--)
115         {
116             if (*ptr < lower)
117                 *ptr = 0.f;
118             else if (*ptr > upper)
119                 ;
120             else
121                 *ptr = *ptr * (*ptr * alpha + beta);
122             ++ptr;
123         }
124     }
125 
126     return 0;
127 }
128 
129 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const130 int HardSwish_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
131 {
132     int w = bottom_top_blob.w;
133     int h = bottom_top_blob.h;
134     int channels = bottom_top_blob.c;
135     int size = w * h;
136     int elempack = bottom_top_blob.elempack;
137 
138     if (elempack == 4)
139     {
140         #pragma omp parallel for num_threads(opt.num_threads)
141         for (int q = 0; q < channels; q++)
142         {
143             __fp16* ptr = bottom_top_blob.channel(q);
144 
145             float32x4_t _zero = vdupq_n_f32(0.f);
146             float32x4_t _one = vdupq_n_f32(1.f);
147             for (int i = 0; i < size; i++)
148             {
149                 float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr));
150                 float32x4_t _ans = vdupq_n_f32(beta);
151                 _ans = vfmaq_n_f32(_ans, _p, alpha);
152                 _ans = vmaxq_f32(_ans, _zero);
153                 _ans = vminq_f32(_ans, _one);
154                 _ans = vmulq_f32(_ans, _p);
155                 vst1_f16(ptr, vcvt_f16_f32(_ans));
156 
157                 ptr += 4;
158             }
159         }
160 
161         return 0;
162     }
163 
164     #pragma omp parallel for num_threads(opt.num_threads)
165     for (int q = 0; q < channels; q++)
166     {
167         __fp16* ptr = bottom_top_blob.channel(q);
168 
169         float32x4_t _zero = vdupq_n_f32(0.f);
170         float32x4_t _one = vdupq_n_f32(1.f);
171 
172         int i = 0;
173         for (; i + 3 < size; i += 4)
174         {
175             float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr));
176             float32x4_t _ans = vdupq_n_f32(beta);
177             _ans = vfmaq_n_f32(_ans, _p, alpha);
178             _ans = vmaxq_f32(_ans, _zero);
179             _ans = vminq_f32(_ans, _one);
180             _ans = vmulq_f32(_ans, _p);
181             vst1_f16(ptr, vcvt_f16_f32(_ans));
182 
183             ptr += 4;
184         }
185         for (; i < size; i++)
186         {
187             float v = (float)*ptr;
188             if (v < lower)
189                 v = 0.f;
190             else if (v > upper)
191                 ;
192             else
193                 v = v * (v * alpha + beta);
194             *ptr = (__fp16)v;
195             ++ptr;
196         }
197     }
198 
199     return 0;
200 }
201 
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const202 int HardSwish_arm::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
203 {
204     int w = bottom_top_blob.w;
205     int h = bottom_top_blob.h;
206     int channels = bottom_top_blob.c;
207     int size = w * h;
208     int elempack = bottom_top_blob.elempack;
209 
210     if (elempack == 8)
211     {
212         #pragma omp parallel for num_threads(opt.num_threads)
213         for (int q = 0; q < channels; q++)
214         {
215             __fp16* ptr = bottom_top_blob.channel(q);
216 
217             __fp16 alpha_fp16 = (__fp16)alpha;
218             __fp16 beta_fp16 = (__fp16)beta;
219 
220             float16x8_t _zero = vdupq_n_f16((__fp16)0.f);
221             float16x8_t _one = vdupq_n_f16((__fp16)1.f);
222             for (int i = 0; i < size; i++)
223             {
224                 float16x8_t _p = vld1q_f16(ptr);
225                 float16x8_t _ans = vdupq_n_f16(beta_fp16);
226                 _ans = vfmaq_n_f16(_ans, _p, alpha_fp16);
227                 _ans = vmaxq_f16(_ans, _zero);
228                 _ans = vminq_f16(_ans, _one);
229                 _ans = vmulq_f16(_ans, _p);
230                 vst1q_f16(ptr, _ans);
231 
232                 ptr += 8;
233             }
234         }
235 
236         return 0;
237     }
238 
239     if (elempack == 4)
240     {
241         #pragma omp parallel for num_threads(opt.num_threads)
242         for (int q = 0; q < channels; q++)
243         {
244             __fp16* ptr = bottom_top_blob.channel(q);
245 
246             __fp16 alpha_fp16 = (__fp16)alpha;
247             __fp16 beta_fp16 = (__fp16)beta;
248 
249             float16x8_t _zero = vdupq_n_f16((__fp16)0.f);
250             float16x8_t _one = vdupq_n_f16((__fp16)1.f);
251 
252             int i = 0;
253             for (; i + 1 < size; i += 2)
254             {
255                 float16x8_t _p = vld1q_f16(ptr);
256                 float16x8_t _ans = vdupq_n_f16(beta_fp16);
257                 _ans = vfmaq_n_f16(_ans, _p, alpha_fp16);
258                 _ans = vmaxq_f16(_ans, _zero);
259                 _ans = vminq_f16(_ans, _one);
260                 _ans = vmulq_f16(_ans, _p);
261                 vst1q_f16(ptr, _ans);
262 
263                 ptr += 8;
264             }
265             for (; i < size; i++)
266             {
267                 float16x4_t _p = vld1_f16(ptr);
268                 float16x4_t _ans = vdup_n_f16(beta_fp16);
269                 _ans = vfma_n_f16(_ans, _p, alpha_fp16);
270                 _ans = vmax_f16(_ans, vget_low_f16(_zero));
271                 _ans = vmin_f16(_ans, vget_low_f16(_one));
272                 _ans = vmul_f16(_ans, _p);
273                 vst1_f16(ptr, _ans);
274 
275                 ptr += 4;
276             }
277         }
278 
279         return 0;
280     }
281 
282     #pragma omp parallel for num_threads(opt.num_threads)
283     for (int q = 0; q < channels; q++)
284     {
285         __fp16* ptr = bottom_top_blob.channel(q);
286 
287         __fp16 alpha_fp16 = (__fp16)alpha;
288         __fp16 beta_fp16 = (__fp16)beta;
289 
290         float16x8_t _zero = vdupq_n_f16((__fp16)0.f);
291         float16x8_t _one = vdupq_n_f16((__fp16)1.f);
292 
293         int i = 0;
294         for (; i + 7 < size; i += 8)
295         {
296             float16x8_t _p = vld1q_f16(ptr);
297             float16x8_t _ans = vdupq_n_f16(beta_fp16);
298             _ans = vfmaq_n_f16(_ans, _p, alpha_fp16);
299             _ans = vmaxq_f16(_ans, _zero);
300             _ans = vminq_f16(_ans, _one);
301             _ans = vmulq_f16(_ans, _p);
302             vst1q_f16(ptr, _ans);
303 
304             ptr += 8;
305         }
306         for (; i + 3 < size; i += 4)
307         {
308             float16x4_t _p = vld1_f16(ptr);
309             float16x4_t _ans = vdup_n_f16(beta_fp16);
310             _ans = vfma_n_f16(_ans, _p, alpha_fp16);
311             _ans = vmax_f16(_ans, vget_low_f16(_zero));
312             _ans = vmin_f16(_ans, vget_low_f16(_one));
313             _ans = vmul_f16(_ans, _p);
314             vst1_f16(ptr, _ans);
315 
316             ptr += 4;
317         }
318         for (; i < size; i++)
319         {
320             __fp16 v = *ptr;
321             if (v < (__fp16)lower)
322                 v = (__fp16)0.f;
323             else if (v > (__fp16)upper)
324                 ;
325             else
326                 v = v * (v * alpha_fp16 + beta_fp16);
327             *ptr = v;
328             ++ptr;
329         }
330     }
331 
332     return 0;
333 }
334 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
335 
forward_inplace_bf16s(Mat & bottom_top_blob,const Option & opt) const336 int HardSwish_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
337 {
338     int w = bottom_top_blob.w;
339     int h = bottom_top_blob.h;
340     int channels = bottom_top_blob.c;
341     int size = w * h;
342     int elempack = bottom_top_blob.elempack;
343 
344 #if __ARM_NEON
345     if (elempack == 4)
346     {
347         #pragma omp parallel for num_threads(opt.num_threads)
348         for (int q = 0; q < channels; q++)
349         {
350             unsigned short* ptr = bottom_top_blob.channel(q);
351 
352             float32x4_t _zero = vdupq_n_f32(0.f);
353             float32x4_t _one = vdupq_n_f32(1.f);
354             for (int i = 0; i < size; i++)
355             {
356                 float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
357                 float32x4_t _ans = vdupq_n_f32(beta);
358                 _ans = vmlaq_n_f32(_ans, _p, alpha);
359                 _ans = vmaxq_f32(_ans, _zero);
360                 _ans = vminq_f32(_ans, _one);
361                 _ans = vmulq_f32(_ans, _p);
362                 vst1_u16(ptr, vcvt_bf16_f32(_ans));
363 
364                 ptr += 4;
365             }
366         }
367 
368         return 0;
369     }
370 #endif // __ARM_NEON
371 
372     #pragma omp parallel for num_threads(opt.num_threads)
373     for (int q = 0; q < channels; q++)
374     {
375         unsigned short* ptr = bottom_top_blob.channel(q);
376 
377 #if __ARM_NEON
378         int nn = size >> 2;
379         int remain = size - (nn << 2);
380 #else
381         int remain = size;
382 #endif // __ARM_NEON
383 
384 #if __ARM_NEON
385         float32x4_t _zero = vdupq_n_f32(0.f);
386         float32x4_t _one = vdupq_n_f32(1.f);
387         while (nn--)
388         {
389             float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
390             float32x4_t _ans = vdupq_n_f32(beta);
391             _ans = vmlaq_n_f32(_ans, _p, alpha);
392             _ans = vmaxq_f32(_ans, _zero);
393             _ans = vminq_f32(_ans, _one);
394             _ans = vmulq_f32(_ans, _p);
395             vst1_u16(ptr, vcvt_bf16_f32(_ans));
396 
397             ptr += 4;
398         }
399 #endif // __ARM_NEON
400         for (; remain > 0; remain--)
401         {
402             float v = bfloat16_to_float32(*ptr);
403             if (v < lower)
404                 v = 0.f;
405             else if (v > upper)
406                 ;
407             else
408                 v = v * (v * alpha + beta);
409             *ptr = float32_to_bfloat16(v);
410             ++ptr;
411         }
412     }
413 
414     return 0;
415 }
416 
417 } // namespace ncnn
418