1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "clip_arm.h"
16 
17 #ifdef __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 namespace ncnn {
22 
Clip_arm()23 Clip_arm::Clip_arm()
24 {
25 #if __ARM_NEON
26     support_packing = true;
27 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
28     support_fp16_storage = true;
29 #endif
30 #endif // __ARM_NEON
31 
32     support_bf16_storage = true;
33 }
34 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const35 int Clip_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
36 {
37     int elembits = bottom_top_blob.elembits();
38 
39 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
40     if (opt.use_fp16_storage && elembits == 16)
41         return forward_inplace_fp16s(bottom_top_blob, opt);
42 #endif
43 
44     if (opt.use_bf16_storage && elembits == 16)
45         return forward_inplace_bf16s(bottom_top_blob, opt);
46 
47     int w = bottom_top_blob.w;
48     int h = bottom_top_blob.h;
49     int channels = bottom_top_blob.c;
50     int size = w * h;
51     int elempack = bottom_top_blob.elempack;
52 
53 #if __ARM_NEON
54     if (elempack == 4)
55     {
56         #pragma omp parallel for num_threads(opt.num_threads)
57         for (int q = 0; q < channels; q++)
58         {
59             float* ptr = bottom_top_blob.channel(q);
60 
61             float32x4_t _max = vdupq_n_f32(max);
62             float32x4_t _min = vdupq_n_f32(min);
63 
64             for (int i = 0; i < size; i++)
65             {
66                 float32x4_t _ptr = vld1q_f32(ptr);
67                 _ptr = vmaxq_f32(_ptr, _min);
68                 _ptr = vminq_f32(_ptr, _max);
69                 vst1q_f32(ptr, _ptr);
70 
71                 ptr += 4;
72             }
73         }
74 
75         return 0;
76     }
77 #endif // __ARM_NEON
78 
79     #pragma omp parallel for num_threads(opt.num_threads)
80     for (int q = 0; q < channels; q++)
81     {
82         float* ptr = bottom_top_blob.channel(q);
83 
84 #if __ARM_NEON
85         int nn = size >> 2;
86         int remain = size & 3;
87 #else
88         int remain = size;
89 #endif
90 
91 #if __ARM_NEON
92         float32x4_t _max = vdupq_n_f32(max);
93         float32x4_t _min = vdupq_n_f32(min);
94 #if __aarch64__
95         for (; nn > 0; nn--)
96         {
97             float32x4_t _ptr = vld1q_f32(ptr);
98             _ptr = vmaxq_f32(_ptr, _min);
99             _ptr = vminq_f32(_ptr, _max);
100             vst1q_f32(ptr, _ptr);
101             ptr += 4;
102         }
103 #else
104         if (nn > 0)
105         {
106             asm volatile(
107                 "0:                             \n"
108                 "pld        [%1, #128]          \n"
109                 "vld1.f32   {d0-d1}, [%1: 128]  \n"
110 
111                 "vmax.f32   q0, q0, %q4         \n"
112                 "vmin.f32   q0, q0, %q5         \n"
113 
114                 "subs       %0, #1              \n"
115                 "vst1.f32   {d0-d1}, [%1: 128]! \n"
116 
117                 "bne        0b                  \n"
118 
119                 : "=r"(nn), // %0
120                 "=r"(ptr) // %1
121                 : "0"(nn),
122                 "1"(ptr),
123                 "w"(_min), // %q4
124                 "w"(_max)  // %q5
125                 : "cc", "memory", "q0");
126         }
127 #endif // __aarch64__
128 #endif // __ARM_NEON
129 
130         for (; remain > 0; remain--)
131         {
132             if (*ptr < min)
133                 *ptr = min;
134 
135             if (*ptr > max)
136                 *ptr = max;
137 
138             ptr++;
139         }
140     }
141 
142     return 0;
143 }
144 
145 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
forward_inplace_fp16s(Mat & bottom_top_blob,const Option & opt) const146 int Clip_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
147 {
148     int w = bottom_top_blob.w;
149     int h = bottom_top_blob.h;
150     int channels = bottom_top_blob.c;
151     int size = w * h;
152     int elempack = bottom_top_blob.elempack;
153 
154     if (elempack == 8)
155     {
156         #pragma omp parallel for num_threads(opt.num_threads)
157         for (int q = 0; q < channels; q++)
158         {
159             __fp16* ptr = bottom_top_blob.channel(q);
160 
161             float16x8_t _max = vdupq_n_f16(max);
162             float16x8_t _min = vdupq_n_f16(min);
163 
164             for (int i = 0; i < size; i++)
165             {
166                 float16x8_t _ptr = vld1q_f16(ptr);
167                 _ptr = vmaxq_f16(_ptr, _min);
168                 _ptr = vminq_f16(_ptr, _max);
169                 vst1q_f16(ptr, _ptr);
170 
171                 ptr += 8;
172             }
173         }
174 
175         return 0;
176     }
177 
178     if (elempack == 4)
179     {
180         #pragma omp parallel for num_threads(opt.num_threads)
181         for (int q = 0; q < channels; q++)
182         {
183             __fp16* ptr = bottom_top_blob.channel(q);
184 
185             float16x8_t _max = vdupq_n_f16(max);
186             float16x8_t _min = vdupq_n_f16(min);
187 
188             int i = 0;
189             for (; i + 1 < size; i += 2)
190             {
191                 float16x8_t _ptr = vld1q_f16(ptr);
192                 _ptr = vmaxq_f16(_ptr, _min);
193                 _ptr = vminq_f16(_ptr, _max);
194                 vst1q_f16(ptr, _ptr);
195 
196                 ptr += 8;
197             }
198             for (; i < size; i++)
199             {
200                 float16x4_t _ptr = vld1_f16(ptr);
201                 _ptr = vmax_f16(_ptr, vget_low_f16(_min));
202                 _ptr = vmin_f16(_ptr, vget_low_f16(_max));
203                 vst1_f16(ptr, _ptr);
204 
205                 ptr += 4;
206             }
207         }
208 
209         return 0;
210     }
211 
212     #pragma omp parallel for num_threads(opt.num_threads)
213     for (int q = 0; q < channels; q++)
214     {
215         __fp16* ptr = bottom_top_blob.channel(q);
216 
217         int i = 0;
218 
219         float16x8_t _max = vdupq_n_f16(max);
220         float16x8_t _min = vdupq_n_f16(min);
221 
222         for (; i + 7 < size; i += 8)
223         {
224             float16x8_t _ptr = vld1q_f16(ptr);
225             _ptr = vmaxq_f16(_ptr, _min);
226             _ptr = vminq_f16(_ptr, _max);
227             vst1q_f16(ptr, _ptr);
228 
229             ptr += 8;
230         }
231         for (; i + 3 < size; i += 4)
232         {
233             float16x4_t _ptr = vld1_f16(ptr);
234             _ptr = vmax_f16(_ptr, vget_low_f16(_min));
235             _ptr = vmin_f16(_ptr, vget_low_f16(_max));
236             vst1_f16(ptr, _ptr);
237 
238             ptr += 4;
239         }
240 
241         __fp16 min_fp16 = min;
242         __fp16 max_fp16 = max;
243 
244         for (; i < size; i++)
245         {
246             __fp16 v = *ptr;
247             if (v < min_fp16)
248                 v = min_fp16;
249 
250             if (v > max_fp16)
251                 v = max_fp16;
252 
253             *ptr = v;
254             ptr++;
255         }
256     }
257 
258     return 0;
259 }
260 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
261 
forward_inplace_bf16s(Mat & bottom_top_blob,const Option & opt) const262 int Clip_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) const
263 {
264     int w = bottom_top_blob.w;
265     int h = bottom_top_blob.h;
266     int channels = bottom_top_blob.c;
267     int size = w * h;
268     int elempack = bottom_top_blob.elempack;
269 
270 #if __ARM_NEON
271     if (elempack == 4)
272     {
273         #pragma omp parallel for num_threads(opt.num_threads)
274         for (int q = 0; q < channels; q++)
275         {
276             unsigned short* ptr = bottom_top_blob.channel(q);
277 
278             float32x4_t _max = vdupq_n_f32(max);
279             float32x4_t _min = vdupq_n_f32(min);
280 
281             for (int i = 0; i < size; i++)
282             {
283                 float32x4_t _ptr = vcvt_f32_bf16(vld1_u16(ptr));
284                 _ptr = vmaxq_f32(_ptr, _min);
285                 _ptr = vminq_f32(_ptr, _max);
286                 vst1_u16(ptr, vcvt_bf16_f32(_ptr));
287 
288                 ptr += 4;
289             }
290         }
291 
292         return 0;
293     }
294 #endif // __ARM_NEON
295 
296     #pragma omp parallel for num_threads(opt.num_threads)
297     for (int q = 0; q < channels; q++)
298     {
299         unsigned short* ptr = bottom_top_blob.channel(q);
300 
301 #if __ARM_NEON
302         int nn = size >> 2;
303         int remain = size & 3;
304 #else
305         int remain = size;
306 #endif
307 
308 #if __ARM_NEON
309         float32x4_t _max = vdupq_n_f32(max);
310         float32x4_t _min = vdupq_n_f32(min);
311         for (; nn > 0; nn--)
312         {
313             float32x4_t _ptr = vcvt_f32_bf16(vld1_u16(ptr));
314             _ptr = vmaxq_f32(_ptr, _min);
315             _ptr = vminq_f32(_ptr, _max);
316             vst1_u16(ptr, vcvt_bf16_f32(_ptr));
317             ptr += 4;
318         }
319 #endif // __ARM_NEON
320 
321         for (; remain > 0; remain--)
322         {
323             float v = bfloat16_to_float32(*ptr);
324             if (v < min)
325                 v = min;
326 
327             if (v > max)
328                 v = max;
329 
330             *ptr = float32_to_bfloat16(v);
331             ptr++;
332         }
333     }
334 
335     return 0;
336 }
337 
338 } // namespace ncnn
339