1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "prelu_x86.h"
16 
17 #include "x86_activation.h"
18 
19 namespace ncnn {
20 
PReLU_x86()21 PReLU_x86::PReLU_x86()
22 {
23 #if __SSE2__
24     support_packing = true;
25 #endif // __SSE2__
26 }
27 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const28 int PReLU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
29 {
30     int dims = bottom_top_blob.dims;
31 #if __SSE2__
32     int elempack = bottom_top_blob.elempack;
33 
34 #if __AVX__
35     if (elempack == 8)
36     {
37         if (dims == 1)
38         {
39             int w = bottom_top_blob.w;
40 
41             if (num_slope > 1)
42             {
43                 const float* slope = slope_data;
44 
45                 #pragma omp parallel for num_threads(opt.num_threads)
46                 for (int i = 0; i < w; i++)
47                 {
48                     float* ptr = (float*)bottom_top_blob + i * 8;
49                     __m256 _p = _mm256_loadu_ps(ptr);
50                     __m256 _slope = _mm256_loadu_ps(slope + i * 8);
51                     _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
52                 }
53             }
54             else
55             {
56                 __m256 _slope = _mm256_set1_ps(slope_data[0]);
57 
58                 #pragma omp parallel for num_threads(opt.num_threads)
59                 for (int i = 0; i < w; i++)
60                 {
61                     float* ptr = (float*)bottom_top_blob + i * 8;
62                     __m256 _p = _mm256_loadu_ps(ptr);
63                     _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
64                 }
65             }
66         }
67 
68         if (dims == 2)
69         {
70             int w = bottom_top_blob.w;
71             int h = bottom_top_blob.h;
72 
73             #pragma omp parallel for num_threads(opt.num_threads)
74             for (int i = 0; i < h; i++)
75             {
76                 float* ptr = bottom_top_blob.row(i);
77                 __m256 _slope = num_slope > 1 ? _mm256_loadu_ps((const float*)slope_data + i * 8) : _mm256_set1_ps(slope_data[0]);
78 
79                 for (int j = 0; j < w; j++)
80                 {
81                     __m256 _p = _mm256_loadu_ps(ptr);
82                     _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
83                     ptr += 8;
84                 }
85             }
86         }
87 
88         if (dims == 3)
89         {
90             int w = bottom_top_blob.w;
91             int h = bottom_top_blob.h;
92             int channels = bottom_top_blob.c;
93             int size = w * h;
94 
95             #pragma omp parallel for num_threads(opt.num_threads)
96             for (int q = 0; q < channels; q++)
97             {
98                 float* ptr = bottom_top_blob.channel(q);
99                 __m256 _slope = num_slope > 1 ? _mm256_loadu_ps((const float*)slope_data + q * 8) : _mm256_set1_ps(slope_data[0]);
100 
101                 for (int i = 0; i < size; i++)
102                 {
103                     __m256 _p = _mm256_loadu_ps(ptr);
104                     _mm256_storeu_ps(ptr, prelu_avx(_p, _slope));
105                     ptr += 8;
106                 }
107             }
108         }
109 
110         return 0;
111     }
112 #endif // __AVX__
113 
114     if (elempack == 4)
115     {
116         if (dims == 1)
117         {
118             int w = bottom_top_blob.w;
119 
120             if (num_slope > 1)
121             {
122                 const float* slope = slope_data;
123 
124                 #pragma omp parallel for num_threads(opt.num_threads)
125                 for (int i = 0; i < w; i++)
126                 {
127                     float* ptr = (float*)bottom_top_blob + i * 4;
128                     __m128 _p = _mm_loadu_ps(ptr);
129                     __m128 _slope = _mm_loadu_ps(slope + i * 4);
130                     _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
131                 }
132             }
133             else
134             {
135                 __m128 _slope = _mm_set1_ps(slope_data[0]);
136 
137                 #pragma omp parallel for num_threads(opt.num_threads)
138                 for (int i = 0; i < w; i++)
139                 {
140                     float* ptr = (float*)bottom_top_blob + i * 4;
141                     __m128 _p = _mm_loadu_ps(ptr);
142                     _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
143                 }
144             }
145         }
146 
147         if (dims == 2)
148         {
149             int w = bottom_top_blob.w;
150             int h = bottom_top_blob.h;
151 
152             #pragma omp parallel for num_threads(opt.num_threads)
153             for (int i = 0; i < h; i++)
154             {
155                 float* ptr = bottom_top_blob.row(i);
156                 __m128 _slope = num_slope > 1 ? _mm_loadu_ps((const float*)slope_data + i * 4) : _mm_set1_ps(slope_data[0]);
157 
158                 for (int j = 0; j < w; j++)
159                 {
160                     __m128 _p = _mm_loadu_ps(ptr);
161                     _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
162                     ptr += 4;
163                 }
164             }
165         }
166 
167         if (dims == 3)
168         {
169             int w = bottom_top_blob.w;
170             int h = bottom_top_blob.h;
171             int channels = bottom_top_blob.c;
172             int size = w * h;
173 
174             #pragma omp parallel for num_threads(opt.num_threads)
175             for (int q = 0; q < channels; q++)
176             {
177                 float* ptr = bottom_top_blob.channel(q);
178                 __m128 _slope = num_slope > 1 ? _mm_loadu_ps((const float*)slope_data + q * 4) : _mm_set1_ps(slope_data[0]);
179 
180                 for (int i = 0; i < size; i++)
181                 {
182                     __m128 _p = _mm_loadu_ps(ptr);
183                     _mm_storeu_ps(ptr, prelu_sse(_p, _slope));
184                     ptr += 4;
185                 }
186             }
187         }
188 
189         return 0;
190     }
191 #endif // __SSE2__
192 
193     if (dims != 3)
194         return PReLU::forward_inplace(bottom_top_blob, opt);
195 
196     int w = bottom_top_blob.w;
197     int h = bottom_top_blob.h;
198     int channels = bottom_top_blob.c;
199     int size = w * h;
200 
201     const float* slope_data_ptr = slope_data;
202 
203     #pragma omp parallel for num_threads(opt.num_threads)
204     for (int q = 0; q < channels; q++)
205     {
206         float* ptr = bottom_top_blob.channel(q);
207         float slope = num_slope > 1 ? slope_data_ptr[q] : slope_data_ptr[0];
208 
209 #if __AVX__
210         int nn = size >> 3;
211         int remain = size - (nn << 3);
212 #else
213         int remain = size;
214 #endif // __AVX__
215 
216 #if __AVX__
217         for (; nn > 0; nn--)
218         {
219             __m256 _p = _mm256_loadu_ps(ptr);
220             _mm256_storeu_ps(ptr, prelu_avx(_p, _mm256_set1_ps(slope)));
221             ptr += 8;
222         }
223 #endif // __AVX__
224         for (; remain > 0; remain--)
225         {
226             if (*ptr < 0)
227                 *ptr *= slope;
228 
229             ptr++;
230         }
231     }
232 
233     return 0;
234 }
235 
236 } // namespace ncnn
237