1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "prelu_mips.h"
16
17 #if __mips_msa
18 #include <msa.h>
19 #endif // __mips_msa
20
21 #include "mips_usability.h"
22
23 namespace ncnn {
24
PReLU_mips()25 PReLU_mips::PReLU_mips()
26 {
27 #if __mips_msa
28 support_packing = true;
29 #endif // __mips_msa
30 }
31
forward_inplace(Mat & bottom_top_blob,const Option & opt) const32 int PReLU_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
33 {
34 int dims = bottom_top_blob.dims;
35 int elempack = bottom_top_blob.elempack;
36
37 #if __mips_msa
38 if (elempack == 4)
39 {
40 v4f32 _zero = (v4f32)__msa_fill_w(0);
41
42 if (dims == 1)
43 {
44 int w = bottom_top_blob.w;
45
46 if (num_slope > 1)
47 {
48 const float* slope = slope_data;
49
50 #pragma omp parallel for num_threads(opt.num_threads)
51 for (int i = 0; i < w; i++)
52 {
53 float* ptr = (float*)bottom_top_blob + i * 4;
54
55 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
56 v4f32 _slope = (v4f32)__msa_ld_w(slope + i * 4, 0);
57 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
58 v4f32 _ps = __msa_fmul_w(_p, _slope);
59 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
60 __msa_st_w((v4i32)_p, ptr, 0);
61 }
62 }
63 else
64 {
65 v4f32 _slope = (v4f32)__msa_fill_w_f32(slope_data[0]);
66
67 #pragma omp parallel for num_threads(opt.num_threads)
68 for (int i = 0; i < w; i++)
69 {
70 float* ptr = (float*)bottom_top_blob + i * 4;
71
72 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
73 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
74 v4f32 _ps = __msa_fmul_w(_p, _slope);
75 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
76 __msa_st_w((v4i32)_p, ptr, 0);
77 }
78 }
79 }
80
81 if (dims == 2)
82 {
83 int w = bottom_top_blob.w;
84 int h = bottom_top_blob.h;
85
86 #pragma omp parallel for num_threads(opt.num_threads)
87 for (int i = 0; i < h; i++)
88 {
89 float* ptr = bottom_top_blob.row(i);
90 v4f32 _slope = num_slope > 1 ? (v4f32)__msa_ld_w((const float*)slope_data + i * 4, 0) : (v4f32)__msa_fill_w_f32(slope_data[0]);
91
92 for (int j = 0; j < w; j++)
93 {
94 __builtin_prefetch(ptr + 32);
95 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
96 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
97 v4f32 _ps = __msa_fmul_w(_p, _slope);
98 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
99 __msa_st_w((v4i32)_p, ptr, 0);
100
101 ptr += 4;
102 }
103 }
104 }
105
106 if (dims == 3)
107 {
108 int w = bottom_top_blob.w;
109 int h = bottom_top_blob.h;
110 int channels = bottom_top_blob.c;
111 int size = w * h;
112
113 #pragma omp parallel for num_threads(opt.num_threads)
114 for (int q = 0; q < channels; q++)
115 {
116 float* ptr = bottom_top_blob.channel(q);
117 v4f32 _slope = num_slope > 1 ? (v4f32)__msa_ld_w((const float*)slope_data + q * 4, 0) : (v4f32)__msa_fill_w_f32(slope_data[0]);
118
119 for (int i = 0; i < size; i++)
120 {
121 __builtin_prefetch(ptr + 32);
122 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
123 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
124 v4f32 _ps = __msa_fmul_w(_p, _slope);
125 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
126 __msa_st_w((v4i32)_p, ptr, 0);
127
128 ptr += 4;
129 }
130 }
131 }
132
133 return 0;
134 }
135 #endif // __mips_msa
136
137 if (dims == 1)
138 {
139 int w = bottom_top_blob.w;
140
141 float* ptr = bottom_top_blob;
142
143 if (num_slope > 1)
144 {
145 const float* slope = slope_data;
146
147 #pragma omp parallel for num_threads(opt.num_threads)
148 for (int i = 0; i < w; i++)
149 {
150 float v = ptr[i];
151 if (v < 0.f)
152 ptr[i] = v * slope[i];
153 }
154 }
155 else
156 {
157 const float slope = slope_data[0];
158
159 #pragma omp parallel for num_threads(opt.num_threads)
160 for (int i = 0; i < w; i++)
161 {
162 float v = ptr[i];
163 if (v < 0.f)
164 ptr[i] = v * slope;
165 }
166 }
167 }
168
169 if (dims == 2)
170 {
171 int w = bottom_top_blob.w;
172 int h = bottom_top_blob.h;
173
174 #pragma omp parallel for num_threads(opt.num_threads)
175 for (int i = 0; i < h; i++)
176 {
177 float* ptr = bottom_top_blob.row(i);
178
179 const float slope = num_slope > 1 ? slope_data[i] : slope_data[0];
180
181 int j = 0;
182 #if __mips_msa
183 v4f32 _zero = (v4f32)__msa_fill_w(0);
184 v4f32 _slope = (v4f32)__msa_fill_w_f32(slope);
185
186 for (; j + 3 < w; j += 4)
187 {
188 __builtin_prefetch(ptr + 32);
189 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
190 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
191 v4f32 _ps = __msa_fmul_w(_p, _slope);
192 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
193 __msa_st_w((v4i32)_p, ptr, 0);
194
195 ptr += 4;
196 }
197 #endif // __mips_msa
198 for (; j < w; j++)
199 {
200 float v = *ptr;
201 if (v < 0.f)
202 *ptr = v * slope;
203
204 ptr++;
205 }
206 }
207 }
208
209 if (dims == 3)
210 {
211 int w = bottom_top_blob.w;
212 int h = bottom_top_blob.h;
213 int channels = bottom_top_blob.c;
214 int size = w * h;
215
216 const float* slope_data_ptr = slope_data;
217
218 #pragma omp parallel for num_threads(opt.num_threads)
219 for (int q = 0; q < channels; q++)
220 {
221 float* ptr = bottom_top_blob.channel(q);
222 float slope = num_slope > 1 ? slope_data_ptr[q] : slope_data_ptr[0];
223
224 int i = 0;
225 #if __mips_msa
226 v4f32 _zero = (v4f32)__msa_fill_w(0);
227 v4f32 _slope = (v4f32)__msa_fill_w_f32(slope);
228
229 for (; i + 3 < size; i += 4)
230 {
231 __builtin_prefetch(ptr + 32);
232 v4f32 _p = (v4f32)__msa_ld_w(ptr, 0);
233 v4i32_w _lemask = __msa_fcle_w(_p, _zero);
234 v4f32 _ps = __msa_fmul_w(_p, _slope);
235 _p = (v4f32)__msa_bsel_v((v16u8)_lemask, (v16u8)_p, (v16u8)_ps);
236 __msa_st_w((v4i32)_p, ptr, 0);
237
238 ptr += 4;
239 }
240 #endif // __mips_msa
241 for (; i < size; i++)
242 {
243 if (*ptr < 0)
244 *ptr *= slope;
245
246 ptr++;
247 }
248 }
249 }
250
251 return 0;
252 }
253
254 } // namespace ncnn
255