1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "swish_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #include "neon_mathfun.h"
20 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
21 #include "neon_mathfun_fp16s.h"
22 #endif
23 #endif // __ARM_NEON
24 
25 #include <math.h>
26 
27 namespace ncnn {
28 
Swish_arm()29 Swish_arm::Swish_arm()
30 {
31 #if __ARM_NEON
32     support_packing = true;
33 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
34     support_fp16_storage = true;
35 #endif
36 #endif // __ARM_NEON
37 
38 #if NCNN_BF16
39     support_bf16_storage = true;
40 #endif
41 }
42 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const43 int Swish_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
44 {
45     int elembits = bottom_top_blob.elembits();
46 
47 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
48     if (opt.use_fp16_storage && elembits == 16)
49     {
50         if (opt.use_fp16_arithmetic)
51             return forward_inplace_fp16sa(bottom_top_blob, opt);
52         else
53             return forward_inplace_fp16s(bottom_top_blob, opt);
54     }
55 #endif
56 
57 #if NCNN_BF16
58     if (opt.use_bf16_storage && elembits == 16)
59         return forward_inplace_bf16s(bottom_top_blob, opt);
60 #endif
61 
62     int w = bottom_top_blob.w;
63     int h = bottom_top_blob.h;
64     int channels = bottom_top_blob.c;
65     int size = w * h;
66     int elempack = bottom_top_blob.elempack;
67 
68 #if __ARM_NEON
69     if (elempack == 4)
70     {
71         #pragma omp parallel for num_threads(opt.num_threads)
72         for (int q = 0; q < channels; q++)
73         {
74             float* ptr = bottom_top_blob.channel(q);
75 
76             float32x4_t _one = vdupq_n_f32(1.f);
77             for (int i = 0; i < size; i++)
78             {
79                 float32x4_t _p = vld1q_f32(ptr);
80                 _p = div_ps(_p, vaddq_f32(_one, exp_ps(vnegq_f32(_p))));
81                 vst1q_f32(ptr, _p);
82                 ptr += 4;
83             }
84         }
85 
86         return 0;
87     }
88 #endif // __ARM_NEON
89 
90     #pragma omp parallel for num_threads(opt.num_threads)
91     for (int q = 0; q < channels; q++)
92     {
93         float* ptr = bottom_top_blob.channel(q);
94 
95 #if __ARM_NEON
96         int nn = size >> 2;
97         int remain = size - (nn << 2);
98 #else
99         int remain = size;
100 #endif // __ARM_NEON
101 
102 #if __ARM_NEON
103         float32x4_t _one = vdupq_n_f32(1.f);
104         for (; nn > 0; nn--)
105         {
106             float32x4_t _p = vld1q_f32(ptr);
107             _p = div_ps(_p, vaddq_f32(_one, exp_ps(vnegq_f32(_p))));
108             vst1q_f32(ptr, _p);
109             ptr += 4;
110         }
111 #endif // __ARM_NEON
112         for (; remain > 0; remain--)
113         {
114             *ptr = *ptr / (1.f + exp(-*ptr));
115             ptr++;
116         }
117     }
118 
119     return 0;
120 }
121 
122 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const123 int Swish_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
124 {
125     int w = bottom_top_blob.w;
126     int h = bottom_top_blob.h;
127     int channels = bottom_top_blob.c;
128     int size = w * h;
129     int elempack = bottom_top_blob.elempack;
130 
131     if (elempack == 4)
132     {
133         #pragma omp parallel for num_threads(opt.num_threads)
134         for (int q = 0; q < channels; q++)
135         {
136             __fp16* ptr = bottom_top_blob.channel(q);
137 
138             float32x4_t _one = vdupq_n_f32(1.f);
139             for (int i = 0; i < size; i++)
140             {
141                 float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr));
142                 _p = vdivq_f32(_p, vaddq_f32(_one, exp_ps(vnegq_f32(_p))));
143                 vst1_f16(ptr, vcvt_f16_f32(_p));
144 
145                 ptr += 4;
146             }
147         }
148 
149         return 0;
150     }
151 
152     #pragma omp parallel for num_threads(opt.num_threads)
153     for (int q = 0; q < channels; q++)
154     {
155         __fp16* ptr = bottom_top_blob.channel(q);
156 
157         float32x4_t _one = vdupq_n_f32(1.f);
158         int i = 0;
159         for (; i + 3 < size; i += 4)
160         {
161             float32x4_t _p = vcvt_f32_f16(vld1_f16(ptr));
162             _p = vdivq_f32(_p, vaddq_f32(_one, exp_ps(vnegq_f32(_p))));
163             vst1_f16(ptr, vcvt_f16_f32(_p));
164 
165             ptr += 4;
166         }
167         for (; i < size; i++)
168         {
169             float v = (float)*ptr;
170             v = v / (1.f + exp(-v));
171             *ptr = (__fp16)v;
172             ptr++;
173         }
174     }
175 
176     return 0;
177 }
178 
forward_inplace_fp16sa(Mat & bottom_top_blob,const Option & opt) const179 int Swish_arm::forward_inplace_fp16sa(Mat& bottom_top_blob, const Option& opt) const
180 {
181     int w = bottom_top_blob.w;
182     int h = bottom_top_blob.h;
183     int channels = bottom_top_blob.c;
184     int size = w * h;
185     int elempack = bottom_top_blob.elempack;
186 
187     if (elempack == 8)
188     {
189         #pragma omp parallel for num_threads(opt.num_threads)
190         for (int q = 0; q < channels; q++)
191         {
192             __fp16* ptr = bottom_top_blob.channel(q);
193 
194             float16x8_t _one = vdupq_n_f16(1.f);
195             for (int i = 0; i < size; i++)
196             {
197                 float16x8_t _p = vld1q_f16(ptr);
198                 _p = vdivq_f16(_p, vaddq_f16(_one, exp_ps(vnegq_f16(_p))));
199                 vst1q_f16(ptr, _p);
200 
201                 ptr += 8;
202             }
203         }
204 
205         return 0;
206     }
207 
208     if (elempack == 4)
209     {
210         #pragma omp parallel for num_threads(opt.num_threads)
211         for (int q = 0; q < channels; q++)
212         {
213             __fp16* ptr = bottom_top_blob.channel(q);
214 
215             float16x4_t _one = vdup_n_f16(1.f);
216             for (int i = 0; i < size; i++)
217             {
218                 float16x4_t _p = vld1_f16(ptr);
219                 _p = vdiv_f16(_p, vadd_f16(_one, exp_ps(vneg_f16(_p))));
220                 vst1_f16(ptr, _p);
221 
222                 ptr += 4;
223             }
224         }
225 
226         return 0;
227     }
228 
229     #pragma omp parallel for num_threads(opt.num_threads)
230     for (int q = 0; q < channels; q++)
231     {
232         __fp16* ptr = bottom_top_blob.channel(q);
233 
234         float16x4_t _one = vdup_n_f16(1.f);
235         int i = 0;
236         for (; i + 3 < size; i += 4)
237         {
238             float16x4_t _p = vld1_f16(ptr);
239             _p = vdiv_f16(_p, vadd_f16(_one, exp_ps(vneg_f16(_p))));
240             vst1_f16(ptr, _p);
241 
242             ptr += 4;
243         }
244         for (; i < size; i++)
245         {
246             __fp16 v = *ptr;
247             v = v / ((__fp16)1.f + exp(-v));
248             *ptr = v;
249             ptr++;
250         }
251     }
252 
253     return 0;
254 }
255 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
256 
257 #if NCNN_BF16
forward_inplace_bf16s(Mat & bottom_top_blob,const Option & opt) const258 int Swish_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
259 {
260     int w = bottom_top_blob.w;
261     int h = bottom_top_blob.h;
262     int channels = bottom_top_blob.c;
263     int size = w * h;
264     int elempack = bottom_top_blob.elempack;
265 
266 #if __ARM_NEON
267     if (elempack == 4)
268     {
269         #pragma omp parallel for num_threads(opt.num_threads)
270         for (int q = 0; q < channels; q++)
271         {
272             unsigned short* ptr = bottom_top_blob.channel(q);
273 
274             float32x4_t _one = vdupq_n_f32(1.f);
275             for (int i = 0; i < size; i++)
276             {
277                 float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
278                 _p = div_ps(_p, vaddq_f32(_one, exp_ps(vnegq_f32(_p))));
279                 vst1_u16(ptr, vcvt_bf16_f32(_p));
280                 ptr += 4;
281             }
282         }
283 
284         return 0;
285     }
286 #endif // __ARM_NEON
287 
288     #pragma omp parallel for num_threads(opt.num_threads)
289     for (int q = 0; q < channels; q++)
290     {
291         unsigned short* ptr = bottom_top_blob.channel(q);
292 
293 #if __ARM_NEON
294         int nn = size >> 2;
295         int remain = size - (nn << 2);
296 #else
297         int remain = size;
298 #endif // __ARM_NEON
299 
300 #if __ARM_NEON
301         float32x4_t _one = vdupq_n_f32(1.f);
302         for (; nn > 0; nn--)
303         {
304             float32x4_t _p = vcvt_f32_bf16(vld1_u16(ptr));
305             _p = div_ps(_p, vaddq_f32(_one, exp_ps(vnegq_f32(_p))));
306             vst1_u16(ptr, vcvt_bf16_f32(_p));
307             ptr += 4;
308         }
309 #endif // __ARM_NEON
310         for (; remain > 0; remain--)
311         {
312             float v = bfloat16_to_float32(*ptr);
313             v = v / (1.f + exp(-v));
314             *ptr = float32_to_bfloat16(v);
315             ptr++;
316         }
317     }
318 
319     return 0;
320 }
321 #endif // NCNN_BF16
322 
323 } // namespace ncnn
324