1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "prelu_mips.h"
16 
17 #if __mips_msa
18 #include <msa.h>
19 #endif // __mips_msa
20 
21 #include "mips_usability.h"
22 
23 namespace ncnn {
24 
PReLU_mips()25 PReLU_mips::PReLU_mips()
26 {
27 #if __mips_msa
28     support_packing = true;
29 #endif // __mips_msa
30 }
31 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const32 int PReLU_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
33 {
34     int dims = bottom_top_blob.dims;
35     int elempack = bottom_top_blob.elempack;
36 
37 #if __mips_msa
38     if (elempack == 4)
39     {
40         v4f32 _zero = (v4f32)__msa_fill_w(0);
41 
42         if (dims == 1)
43         {
44             int w = bottom_top_blob.w;
45 
46             if (num_slope > 1)
47             {
48                 const float* slope = slope_data;
49 
50                 #pragma omp parallel for num_threads(opt.num_threads)
51                 for (int i = 0; i < w; i++)
52                 {
53                     float* ptr = (float*)bottom_top_blob + i * 4;
54 
55                     v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
56                     v4f32 _slope = (v4f32)__msa_ld_w(slope + i * 4, 0);
57                     v4i32_w _lemask = __msa_fcle_w(_p, _zero);
58                     v4f32 _ps = __msa_fmul_w(_p, _slope);
59                     _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
60                     __msa_st_w((v4i32)_p, ptr, 0);
61                 }
62             }
63             else
64             {
65                 v4f32 _slope = (v4f32)__msa_fill_w_f32(slope_data[0]);
66 
67                 #pragma omp parallel for num_threads(opt.num_threads)
68                 for (int i = 0; i < w; i++)
69                 {
70                     float* ptr = (float*)bottom_top_blob + i * 4;
71 
72                     v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
73                     v4i32_w _lemask = __msa_fcle_w(_p, _zero);
74                     v4f32 _ps = __msa_fmul_w(_p, _slope);
75                     _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
76                     __msa_st_w((v4i32)_p, ptr, 0);
77                 }
78             }
79         }
80 
81         if (dims == 2)
82         {
83             int w = bottom_top_blob.w;
84             int h = bottom_top_blob.h;
85 
86             #pragma omp parallel for num_threads(opt.num_threads)
87             for (int i = 0; i < h; i++)
88             {
89                 float* ptr = bottom_top_blob.row(i);
90                 v4f32 _slope = num_slope > 1 ? (v4f32)__msa_ld_w((const float*)slope_data + i * 4, 0) : (v4f32)__msa_fill_w_f32(slope_data[0]);
91 
92                 for (int j = 0; j < w; j++)
93                 {
94                     __builtin_prefetch(ptr + 32);
95                     v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
96                     v4i32_w _lemask = __msa_fcle_w(_p, _zero);
97                     v4f32 _ps = __msa_fmul_w(_p, _slope);
98                     _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
99                     __msa_st_w((v4i32)_p, ptr, 0);
100 
101                     ptr += 4;
102                 }
103             }
104         }
105 
106         if (dims == 3)
107         {
108             int w = bottom_top_blob.w;
109             int h = bottom_top_blob.h;
110             int channels = bottom_top_blob.c;
111             int size = w * h;
112 
113             #pragma omp parallel for num_threads(opt.num_threads)
114             for (int q = 0; q < channels; q++)
115             {
116                 float* ptr = bottom_top_blob.channel(q);
117                 v4f32 _slope = num_slope > 1 ? (v4f32)__msa_ld_w((const float*)slope_data + q * 4, 0) : (v4f32)__msa_fill_w_f32(slope_data[0]);
118 
119                 for (int i = 0; i < size; i++)
120                 {
121                     __builtin_prefetch(ptr + 32);
122                     v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
123                     v4i32_w _lemask = __msa_fcle_w(_p, _zero);
124                     v4f32 _ps = __msa_fmul_w(_p, _slope);
125                     _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
126                     __msa_st_w((v4i32)_p, ptr, 0);
127 
128                     ptr += 4;
129                 }
130             }
131         }
132 
133         return 0;
134     }
135 #endif // __mips_msa
136 
137     if (dims == 1)
138     {
139         int w = bottom_top_blob.w;
140 
141         float* ptr = bottom_top_blob;
142 
143         if (num_slope > 1)
144         {
145             const float* slope = slope_data;
146 
147             #pragma omp parallel for num_threads(opt.num_threads)
148             for (int i = 0; i < w; i++)
149             {
150                 float v = ptr[i];
151                 if (v < 0.f)
152                     ptr[i] = v * slope[i];
153             }
154         }
155         else
156         {
157             const float slope = slope_data[0];
158 
159             #pragma omp parallel for num_threads(opt.num_threads)
160             for (int i = 0; i < w; i++)
161             {
162                 float v = ptr[i];
163                 if (v < 0.f)
164                     ptr[i] = v * slope;
165             }
166         }
167     }
168 
169     if (dims == 2)
170     {
171         int w = bottom_top_blob.w;
172         int h = bottom_top_blob.h;
173 
174         #pragma omp parallel for num_threads(opt.num_threads)
175         for (int i = 0; i < h; i++)
176         {
177             float* ptr = bottom_top_blob.row(i);
178 
179             const float slope = num_slope > 1 ? slope_data[i] : slope_data[0];
180 
181             int j = 0;
182 #if __mips_msa
183             v4f32 _zero = (v4f32)__msa_fill_w(0);
184             v4f32 _slope = (v4f32)__msa_fill_w_f32(slope);
185 
186             for (; j + 3 < w; j += 4)
187             {
188                 __builtin_prefetch(ptr + 32);
189                 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
190                 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
191                 v4f32 _ps = __msa_fmul_w(_p, _slope);
192                 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
193                 __msa_st_w((v4i32)_p, ptr, 0);
194 
195                 ptr += 4;
196             }
197 #endif // __mips_msa
198             for (; j < w; j++)
199             {
200                 float v = *ptr;
201                 if (v < 0.f)
202                     *ptr = v * slope;
203 
204                 ptr++;
205             }
206         }
207     }
208 
209     if (dims == 3)
210     {
211         int w = bottom_top_blob.w;
212         int h = bottom_top_blob.h;
213         int channels = bottom_top_blob.c;
214         int size = w * h;
215 
216         const float* slope_data_ptr = slope_data;
217 
218         #pragma omp parallel for num_threads(opt.num_threads)
219         for (int q = 0; q < channels; q++)
220         {
221             float* ptr = bottom_top_blob.channel(q);
222             float slope = num_slope > 1 ? slope_data_ptr[q] : slope_data_ptr[0];
223 
224             int i = 0;
225 #if __mips_msa
226             v4f32 _zero = (v4f32)__msa_fill_w(0);
227             v4f32 _slope = (v4f32)__msa_fill_w_f32(slope);
228 
229             for (; i + 3 < size; i += 4)
230             {
231                 __builtin_prefetch(ptr + 32);
232                 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
233                 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
234                 v4f32 _ps = __msa_fmul_w(_p, _slope);
235                 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
236                 __msa_st_w((v4i32)_p, ptr, 0);
237 
238                 ptr += 4;
239             }
240 #endif // __mips_msa
241             for (; i < size; i++)
242             {
243                 if (*ptr < 0)
244                     *ptr *= slope;
245 
246                 ptr++;
247             }
248         }
249     }
250 
251     return 0;
252 }
253 
254 } // namespace ncnn
255