1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "prelu_x86.h"
16
17 #include "x86_activation.h"
18
19 namespace ncnn {
20
PReLU_x86()21 PReLU_x86::PReLU_x86()
22 {
23 #if __SSE2__
24 support_packing = true;
25 #endif // __SSE2__
26 }
27
forward_inplace(Mat & bottom_top_blob,const Option & opt) const28 int PReLU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
29 {
30 int dims = bottom_top_blob.dims;
31 #if __SSE2__
32 int elempack = bottom_top_blob.elempack;
33
34 #if __AVX__
35 if (elempack == 8)
36 {
37 if (dims == 1)
38 {
39 int w = bottom_top_blob.w;
40
41 if (num_slope > 1)
42 {
43 const float* slope = slope_data;
44
45 #pragma omp parallel for num_threads(opt.num_threads)
46 for (int i = 0; i < w; i++)
47 {
48 float* ptr = (float*)bottom_top_blob + i * 8;
49 __m256 _p = _mm256_loadu_ps(ptr);
50 __m256 _slope = _mm256_loadu_ps(slope + i * 8);
51 _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
52 }
53 }
54 else
55 {
56 __m256 _slope = _mm256_set1_ps(slope_data[0]);
57
58 #pragma omp parallel for num_threads(opt.num_threads)
59 for (int i = 0; i < w; i++)
60 {
61 float* ptr = (float*)bottom_top_blob + i * 8;
62 __m256 _p = _mm256_loadu_ps(ptr);
63 _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
64 }
65 }
66 }
67
68 if (dims == 2)
69 {
70 int w = bottom_top_blob.w;
71 int h = bottom_top_blob.h;
72
73 #pragma omp parallel for num_threads(opt.num_threads)
74 for (int i = 0; i < h; i++)
75 {
76 float* ptr = bottom_top_blob.row(i);
77 __m256 _slope = num_slope > 1 ? _mm256_loadu_ps((const float*)slope_data + i * 8) : _mm256_set1_ps(slope_data[0]);
78
79 for (int j = 0; j < w; j++)
80 {
81 __m256 _p = _mm256_loadu_ps(ptr);
82 _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
83 ptr += 8;
84 }
85 }
86 }
87
88 if (dims == 3)
89 {
90 int w = bottom_top_blob.w;
91 int h = bottom_top_blob.h;
92 int channels = bottom_top_blob.c;
93 int size = w * h;
94
95 #pragma omp parallel for num_threads(opt.num_threads)
96 for (int q = 0; q < channels; q++)
97 {
98 float* ptr = bottom_top_blob.channel(q);
99 __m256 _slope = num_slope > 1 ? _mm256_loadu_ps((const float*)slope_data + q * 8) : _mm256_set1_ps(slope_data[0]);
100
101 for (int i = 0; i < size; i++)
102 {
103 __m256 _p = _mm256_loadu_ps(ptr);
104 _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
105 ptr += 8;
106 }
107 }
108 }
109
110 return 0;
111 }
112 #endif // __AVX__
113
114 if (elempack == 4)
115 {
116 if (dims == 1)
117 {
118 int w = bottom_top_blob.w;
119
120 if (num_slope > 1)
121 {
122 const float* slope = slope_data;
123
124 #pragma omp parallel for num_threads(opt.num_threads)
125 for (int i = 0; i < w; i++)
126 {
127 float* ptr = (float*)bottom_top_blob + i * 4;
128 __m128 _p = _mm_loadu_ps(ptr);
129 __m128 _slope = _mm_loadu_ps(slope + i * 4);
130 _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
131 }
132 }
133 else
134 {
135 __m128 _slope = _mm_set1_ps(slope_data[0]);
136
137 #pragma omp parallel for num_threads(opt.num_threads)
138 for (int i = 0; i < w; i++)
139 {
140 float* ptr = (float*)bottom_top_blob + i * 4;
141 __m128 _p = _mm_loadu_ps(ptr);
142 _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
143 }
144 }
145 }
146
147 if (dims == 2)
148 {
149 int w = bottom_top_blob.w;
150 int h = bottom_top_blob.h;
151
152 #pragma omp parallel for num_threads(opt.num_threads)
153 for (int i = 0; i < h; i++)
154 {
155 float* ptr = bottom_top_blob.row(i);
156 __m128 _slope = num_slope > 1 ? _mm_loadu_ps((const float*)slope_data + i * 4) : _mm_set1_ps(slope_data[0]);
157
158 for (int j = 0; j < w; j++)
159 {
160 __m128 _p = _mm_loadu_ps(ptr);
161 _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
162 ptr += 4;
163 }
164 }
165 }
166
167 if (dims == 3)
168 {
169 int w = bottom_top_blob.w;
170 int h = bottom_top_blob.h;
171 int channels = bottom_top_blob.c;
172 int size = w * h;
173
174 #pragma omp parallel for num_threads(opt.num_threads)
175 for (int q = 0; q < channels; q++)
176 {
177 float* ptr = bottom_top_blob.channel(q);
178 __m128 _slope = num_slope > 1 ? _mm_loadu_ps((const float*)slope_data + q * 4) : _mm_set1_ps(slope_data[0]);
179
180 for (int i = 0; i < size; i++)
181 {
182 __m128 _p = _mm_loadu_ps(ptr);
183 _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
184 ptr += 4;
185 }
186 }
187 }
188
189 return 0;
190 }
191 #endif // __SSE2__
192
193 if (dims != 3)
194 return PReLU::forward_inplace(bottom_top_blob, opt);
195
196 int w = bottom_top_blob.w;
197 int h = bottom_top_blob.h;
198 int channels = bottom_top_blob.c;
199 int size = w * h;
200
201 const float* slope_data_ptr = slope_data;
202
203 #pragma omp parallel for num_threads(opt.num_threads)
204 for (int q = 0; q < channels; q++)
205 {
206 float* ptr = bottom_top_blob.channel(q);
207 float slope = num_slope > 1 ? slope_data_ptr[q] : slope_data_ptr[0];
208
209 #if __AVX__
210 int nn = size >> 3;
211 int remain = size - (nn << 3);
212 #else
213 int remain = size;
214 #endif // __AVX__
215
216 #if __AVX__
217 for (; nn > 0; nn--)
218 {
219 __m256 _p = _mm256_loadu_ps(ptr);
220 _mm256_storeu_ps(ptr, prelu_avx(_p, _mm256_set1_ps(slope)));
221 ptr += 8;
222 }
223 #endif // __AVX__
224 for (; remain > 0; remain--)
225 {
226 if (*ptr < 0)
227 *ptr *= slope;
228
229 ptr++;
230 }
231 }
232
233 return 0;
234 }
235
236 } // namespace ncnn
237