1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "scale_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif // __AVX__
22 #endif // __SSE2__
23 
24 namespace ncnn {
25 
Scale_x86()26 Scale_x86::Scale_x86()
27 {
28 #if __SSE2__
29     support_packing = true;
30 #endif // __SSE2__
31 }
32 
forward_inplace(std::vector<Mat> & bottom_top_blobs,const Option & opt) const33 int Scale_x86::forward_inplace(std::vector<Mat>& bottom_top_blobs, const Option& opt) const
34 {
35     Mat& bottom_top_blob = bottom_top_blobs[0];
36     const Mat& scale_blob = bottom_top_blobs[1];
37 
38     int dims = bottom_top_blob.dims;
39 #if __SSE2__
40     int elempack = bottom_top_blob.elempack;
41 
42 #if __AVX__
43     if (elempack == 8)
44     {
45         if (dims == 1)
46         {
47             int w = bottom_top_blob.w;
48 
49             const float* scale = scale_blob;
50             if (bias_term)
51             {
52                 const float* bias = bias_data;
53                 #pragma omp parallel for num_threads(opt.num_threads)
54                 for (int i = 0; i < w; i++)
55                 {
56                     float* ptr = (float*)bottom_top_blob + i * 8;
57 
58                     __m256 _p = _mm256_loadu_ps(ptr);
59                     __m256 _s = _mm256_loadu_ps(scale + i * 8);
60                     __m256 _bias = _mm256_loadu_ps(bias + i * 8);
61                     _p = _mm256_fmadd_ps(_p, _s, _bias);
62                     _mm256_storeu_ps(ptr, _p);
63                 }
64             }
65             else
66             {
67                 #pragma omp parallel for num_threads(opt.num_threads)
68                 for (int i = 0; i < w; i++)
69                 {
70                     float* ptr = (float*)bottom_top_blob + i * 8;
71 
72                     __m256 _p = _mm256_loadu_ps(ptr);
73                     __m256 _s = _mm256_loadu_ps(scale + i * 8);
74                     _p = _mm256_mul_ps(_p, _s);
75                     _mm256_storeu_ps(ptr, _p);
76                 }
77             }
78         }
79 
80         if (dims == 2)
81         {
82             int w = bottom_top_blob.w;
83             int h = bottom_top_blob.h;
84 
85             if (bias_term)
86             {
87                 #pragma omp parallel for num_threads(opt.num_threads)
88                 for (int i = 0; i < h; i++)
89                 {
90                     float* ptr = bottom_top_blob.row(i);
91                     __m256 _s = _mm256_loadu_ps((const float*)scale_blob + i * 8);
92                     __m256 _bias = _mm256_loadu_ps((const float*)bias_data + i * 8);
93 
94                     for (int j = 0; j < w; j++)
95                     {
96                         __m256 _p = _mm256_loadu_ps(ptr);
97                         _p = _mm256_fmadd_ps(_p, _s, _bias);
98                         _mm256_storeu_ps(ptr, _p);
99 
100                         ptr += 8;
101                     }
102                 }
103             }
104             else
105             {
106                 #pragma omp parallel for num_threads(opt.num_threads)
107                 for (int i = 0; i < h; i++)
108                 {
109                     float* ptr = bottom_top_blob.row(i);
110                     __m256 _s = _mm256_loadu_ps((const float*)scale_blob + i * 8);
111 
112                     for (int j = 0; j < w; j++)
113                     {
114                         __m256 _p = _mm256_loadu_ps(ptr);
115                         _p = _mm256_mul_ps(_p, _s);
116                         _mm256_storeu_ps(ptr, _p);
117 
118                         ptr += 8;
119                     }
120                 }
121             }
122         }
123 
124         if (dims == 3)
125         {
126             int w = bottom_top_blob.w;
127             int h = bottom_top_blob.h;
128             int channels = bottom_top_blob.c;
129             int size = w * h;
130 
131             if (bias_term)
132             {
133                 #pragma omp parallel for num_threads(opt.num_threads)
134                 for (int q = 0; q < channels; q++)
135                 {
136                     float* ptr = bottom_top_blob.channel(q);
137                     __m256 _s = _mm256_loadu_ps((const float*)scale_blob + q * 8);
138                     __m256 _bias = _mm256_loadu_ps((const float*)bias_data + q * 8);
139 
140                     for (int i = 0; i < size; i++)
141                     {
142                         __m256 _p = _mm256_loadu_ps(ptr);
143                         _p = _mm256_fmadd_ps(_p, _s, _bias);
144                         _mm256_storeu_ps(ptr, _p);
145 
146                         ptr += 8;
147                     }
148                 }
149             }
150             else
151             {
152                 #pragma omp parallel for num_threads(opt.num_threads)
153                 for (int q = 0; q < channels; q++)
154                 {
155                     float* ptr = bottom_top_blob.channel(q);
156                     __m256 _s = _mm256_loadu_ps((const float*)scale_blob + q * 8);
157 
158                     for (int i = 0; i < size; i++)
159                     {
160                         __m256 _p = _mm256_loadu_ps(ptr);
161                         _p = _mm256_mul_ps(_p, _s);
162                         _mm256_storeu_ps(ptr, _p);
163 
164                         ptr += 8;
165                     }
166                 }
167             }
168         }
169 
170         return 0;
171     }
172 #endif // __AVX__
173 
174     if (elempack == 4)
175     {
176         if (dims == 1)
177         {
178             int w = bottom_top_blob.w;
179 
180             const float* scale = scale_blob;
181             if (bias_term)
182             {
183                 const float* bias = bias_data;
184                 #pragma omp parallel for num_threads(opt.num_threads)
185                 for (int i = 0; i < w; i++)
186                 {
187                     float* ptr = (float*)bottom_top_blob + i * 4;
188 
189                     __m128 _p = _mm_loadu_ps(ptr);
190                     __m128 _s = _mm_loadu_ps(scale + i * 4);
191                     __m128 _bias = _mm_loadu_ps(bias + i * 4);
192                     _p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias);
193                     _mm_storeu_ps(ptr, _p);
194                 }
195             }
196             else
197             {
198                 #pragma omp parallel for num_threads(opt.num_threads)
199                 for (int i = 0; i < w; i++)
200                 {
201                     float* ptr = (float*)bottom_top_blob + i * 4;
202 
203                     __m128 _p = _mm_loadu_ps(ptr);
204                     __m128 _s = _mm_loadu_ps(scale + i * 4);
205                     _p = _mm_mul_ps(_p, _s);
206                     _mm_storeu_ps(ptr, _p);
207                 }
208             }
209         }
210 
211         if (dims == 2)
212         {
213             int w = bottom_top_blob.w;
214             int h = bottom_top_blob.h;
215 
216             if (bias_term)
217             {
218                 #pragma omp parallel for num_threads(opt.num_threads)
219                 for (int i = 0; i < h; i++)
220                 {
221                     float* ptr = bottom_top_blob.row(i);
222                     __m128 _s = _mm_loadu_ps((const float*)scale_blob + i * 4);
223                     __m128 _bias = _mm_loadu_ps((const float*)bias_data + i * 4);
224 
225                     for (int j = 0; j < w; j++)
226                     {
227                         __m128 _p = _mm_loadu_ps(ptr);
228                         _p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias);
229                         _mm_storeu_ps(ptr, _p);
230 
231                         ptr += 4;
232                     }
233                 }
234             }
235             else
236             {
237                 #pragma omp parallel for num_threads(opt.num_threads)
238                 for (int i = 0; i < h; i++)
239                 {
240                     float* ptr = bottom_top_blob.row(i);
241                     __m128 _s = _mm_loadu_ps((const float*)scale_blob + i * 4);
242 
243                     for (int j = 0; j < w; j++)
244                     {
245                         __m128 _p = _mm_loadu_ps(ptr);
246                         _p = _mm_mul_ps(_p, _s);
247                         _mm_storeu_ps(ptr, _p);
248 
249                         ptr += 4;
250                     }
251                 }
252             }
253         }
254 
255         if (dims == 3)
256         {
257             int w = bottom_top_blob.w;
258             int h = bottom_top_blob.h;
259             int channels = bottom_top_blob.c;
260             int size = w * h;
261 
262             if (bias_term)
263             {
264                 #pragma omp parallel for num_threads(opt.num_threads)
265                 for (int q = 0; q < channels; q++)
266                 {
267                     float* ptr = bottom_top_blob.channel(q);
268                     __m128 _s = _mm_loadu_ps((const float*)scale_blob + q * 4);
269                     __m128 _bias = _mm_loadu_ps((const float*)bias_data + q * 4);
270 
271                     for (int i = 0; i < size; i++)
272                     {
273                         __m128 _p = _mm_loadu_ps(ptr);
274                         _p = _mm_add_ps(_mm_mul_ps(_p, _s), _bias);
275                         _mm_storeu_ps(ptr, _p);
276 
277                         ptr += 4;
278                     }
279                 }
280             }
281             else
282             {
283                 #pragma omp parallel for num_threads(opt.num_threads)
284                 for (int q = 0; q < channels; q++)
285                 {
286                     float* ptr = bottom_top_blob.channel(q);
287                     __m128 _s = _mm_loadu_ps((const float*)scale_blob + q * 4);
288 
289                     for (int i = 0; i < size; i++)
290                     {
291                         __m128 _p = _mm_loadu_ps(ptr);
292                         _p = _mm_mul_ps(_p, _s);
293                         _mm_storeu_ps(ptr, _p);
294 
295                         ptr += 4;
296                     }
297                 }
298             }
299         }
300 
301         return 0;
302     }
303 #endif // __SSE2__
304 
305     if (dims != 3)
306         return Scale::forward_inplace(bottom_top_blobs, opt);
307 
308     int w = bottom_top_blob.w;
309     int h = bottom_top_blob.h;
310     int channels = bottom_top_blob.c;
311     int size = w * h;
312 
313     if (bias_term)
314     {
315         const float* scale_ptr = scale_blob;
316         const float* bias_ptr = bias_data;
317         #pragma omp parallel for num_threads(opt.num_threads)
318         for (int q = 0; q < channels; q++)
319         {
320             float* ptr = bottom_top_blob.channel(q);
321 
322             float s = scale_ptr[q];
323             float bias = bias_ptr[q];
324 
325 #if __AVX__
326             int nn = size >> 3;
327             int remain = size & 7;
328 #else
329             int remain = size;
330 #endif // __AVX__
331 
332 #if __AVX__
333             __m256 _s = _mm256_set1_ps(s);
334             __m256 _bias = _mm256_set1_ps(bias);
335             for (; nn > 0; nn--)
336             {
337                 __m256 _p = _mm256_loadu_ps(ptr);
338                 _p = _mm256_fmadd_ps(_p, _s, _bias);
339                 _mm256_storeu_ps(ptr, _p);
340 
341                 ptr += 8;
342             }
343 #endif // __AVX__
344 
345             for (; remain > 0; remain--)
346             {
347                 *ptr = *ptr * s + bias;
348 
349                 ptr++;
350             }
351         }
352     }
353     else
354     {
355         const float* scale_ptr = scale_blob;
356         #pragma omp parallel for num_threads(opt.num_threads)
357         for (int q = 0; q < channels; q++)
358         {
359             float* ptr = bottom_top_blob.channel(q);
360 
361             float s = scale_ptr[q];
362 
363 #if __AVX__
364             int nn = size >> 3;
365             int remain = size & 7;
366 #else
367             int remain = size;
368 #endif // __AVX__
369 
370 #if __AVX__
371             __m256 _s = _mm256_set1_ps(s);
372             for (; nn > 0; nn--)
373             {
374                 __m256 _p = _mm256_loadu_ps(ptr);
375                 _p = _mm256_mul_ps(_p, _s);
376                 _mm256_storeu_ps(ptr, _p);
377 
378                 ptr += 8;
379             }
380 #endif // __AVX__
381 
382             for (; remain > 0; remain--)
383             {
384                 *ptr *= s;
385 
386                 ptr++;
387             }
388         }
389     }
390 
391     return 0;
392 }
393 
394 } // namespace ncnn
395