1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "flatten_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #endif // __SSE2__
20 
21 #include "x86_usability.h"
22 
23 namespace ncnn {
24 
Flatten_x86()25 Flatten_x86::Flatten_x86()
26 {
27 #if __SSE2__
28     support_packing = true;
29 #endif // __SSE2__
30 }
31 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const32 int Flatten_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33 {
34     int elembits = bottom_blob.elembits();
35 
36     if (elembits == 8)
37         return forward_int8(bottom_blob, top_blob, opt);
38 
39     int dims = bottom_blob.dims;
40 
41     if (dims == 1)
42     {
43         top_blob = bottom_blob;
44         return 0;
45     }
46 
47     int w = bottom_blob.w;
48     int h = bottom_blob.h;
49     int channels = bottom_blob.c;
50     size_t elemsize = bottom_blob.elemsize;
51     int elempack = bottom_blob.elempack;
52     int size = w * h;
53 
54     int total = size * channels * elempack;
55 
56     int out_elempack = 1;
57 #if __SSE2__
58     if (opt.use_packing_layout)
59     {
60 #if __AVX__
61         out_elempack = total % 8 == 0 ? 8 : total % 4 == 0 ? 4 : 1;
62 #else
63         out_elempack = total % 4 == 0 ? 4 : 1;
64 #endif
65     }
66 #endif // __SSE2__
67     size_t out_elemsize = elemsize / elempack * out_elempack;
68 
69     if (out_elempack == 1)
70     {
71         return Flatten::forward(bottom_blob, top_blob, opt);
72     }
73 
74     if (dims == 2 && elempack == 1) // out_elempack == 4 || out_elempack == 8
75     {
76         top_blob = bottom_blob;
77         top_blob.dims = 1;
78         top_blob.w = total / out_elempack;
79         top_blob.h = 1;
80         top_blob.cstep = top_blob.w;
81         top_blob.elemsize = out_elemsize;
82         top_blob.elempack = out_elempack;
83         return 0;
84     }
85 
86     top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
87     if (top_blob.empty())
88         return -100;
89 
90     if (dims == 2)
91     {
92 #if __SSE2__
93 #if __AVX__
94         if (elempack == 8) // out_elempack == 8
95         {
96             #pragma omp parallel for num_threads(opt.num_threads)
97             for (int i = 0; i < h; i++)
98             {
99                 const float* ptr = bottom_blob.row(i);
100                 float* outptr0 = (float*)top_blob + w * i * 8;
101                 float* outptr1 = (float*)top_blob + w * (i * 8 + 1);
102                 float* outptr2 = (float*)top_blob + w * (i * 8 + 2);
103                 float* outptr3 = (float*)top_blob + w * (i * 8 + 3);
104                 float* outptr4 = (float*)top_blob + w * (i * 8 + 4);
105                 float* outptr5 = (float*)top_blob + w * (i * 8 + 5);
106                 float* outptr6 = (float*)top_blob + w * (i * 8 + 6);
107                 float* outptr7 = (float*)top_blob + w * (i * 8 + 7);
108 
109                 int j = 0;
110                 for (; j + 7 < w; j += 8)
111                 {
112                     __m256 _row0 = _mm256_loadu_ps(ptr);
113                     __m256 _row1 = _mm256_loadu_ps(ptr + 8);
114                     __m256 _row2 = _mm256_loadu_ps(ptr + 16);
115                     __m256 _row3 = _mm256_loadu_ps(ptr + 24);
116                     __m256 _row4 = _mm256_loadu_ps(ptr + 32);
117                     __m256 _row5 = _mm256_loadu_ps(ptr + 40);
118                     __m256 _row6 = _mm256_loadu_ps(ptr + 48);
119                     __m256 _row7 = _mm256_loadu_ps(ptr + 56);
120 
121                     transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
122 
123                     _mm256_storeu_ps(outptr0, _row0);
124                     _mm256_storeu_ps(outptr1, _row1);
125                     _mm256_storeu_ps(outptr2, _row2);
126                     _mm256_storeu_ps(outptr3, _row3);
127                     _mm256_storeu_ps(outptr4, _row4);
128                     _mm256_storeu_ps(outptr5, _row5);
129                     _mm256_storeu_ps(outptr6, _row6);
130                     _mm256_storeu_ps(outptr7, _row7);
131 
132                     outptr0 += 8;
133                     outptr1 += 8;
134                     outptr2 += 8;
135                     outptr3 += 8;
136                     outptr4 += 8;
137                     outptr5 += 8;
138                     outptr6 += 8;
139                     outptr7 += 8;
140                     ptr += 64;
141                 }
142                 for (; j < w; j++)
143                 {
144                     *outptr0++ = ptr[0];
145                     *outptr1++ = ptr[1];
146                     *outptr2++ = ptr[2];
147                     *outptr3++ = ptr[3];
148                     *outptr4++ = ptr[4];
149                     *outptr5++ = ptr[5];
150                     *outptr6++ = ptr[6];
151                     *outptr7++ = ptr[7];
152 
153                     ptr += 8;
154                 }
155             }
156         }
157 #endif // __AVX__
158 
159         if (elempack == 4) // out_elempack == 4 || out_elempack == 8
160         {
161             #pragma omp parallel for num_threads(opt.num_threads)
162             for (int i = 0; i < h; i++)
163             {
164                 const float* ptr = bottom_blob.row(i);
165                 float* outptr0 = (float*)top_blob + w * i * 4;
166                 float* outptr1 = (float*)top_blob + w * (i * 4 + 1);
167                 float* outptr2 = (float*)top_blob + w * (i * 4 + 2);
168                 float* outptr3 = (float*)top_blob + w * (i * 4 + 3);
169 
170                 int j = 0;
171                 for (; j + 3 < w; j += 4)
172                 {
173                     __m128 _row0 = _mm_loadu_ps(ptr);
174                     __m128 _row1 = _mm_loadu_ps(ptr + 4);
175                     __m128 _row2 = _mm_loadu_ps(ptr + 8);
176                     __m128 _row3 = _mm_loadu_ps(ptr + 12);
177 
178                     _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
179 
180                     _mm_storeu_ps(outptr0, _row0);
181                     _mm_storeu_ps(outptr1, _row1);
182                     _mm_storeu_ps(outptr2, _row2);
183                     _mm_storeu_ps(outptr3, _row3);
184 
185                     ptr += 16;
186                     outptr0 += 4;
187                     outptr1 += 4;
188                     outptr2 += 4;
189                     outptr3 += 4;
190                 }
191                 for (; j < w; j++)
192                 {
193                     *outptr0++ = ptr[0];
194                     *outptr1++ = ptr[1];
195                     *outptr2++ = ptr[2];
196                     *outptr3++ = ptr[3];
197 
198                     ptr += 4;
199                 }
200             }
201         }
202 #endif // __SSE2__
203     }
204 
205     if (dims == 3)
206     {
207 #if __SSE2__
208 #if __AVX__
209         if (elempack == 8) // out_elempack == 8
210         {
211             #pragma omp parallel for num_threads(opt.num_threads)
212             for (int q = 0; q < channels; q++)
213             {
214                 const float* ptr = bottom_blob.channel(q);
215                 float* outptr0 = (float*)top_blob + size * q * 8;
216                 float* outptr1 = (float*)top_blob + size * (q * 8 + 1);
217                 float* outptr2 = (float*)top_blob + size * (q * 8 + 2);
218                 float* outptr3 = (float*)top_blob + size * (q * 8 + 3);
219                 float* outptr4 = (float*)top_blob + size * (q * 8 + 4);
220                 float* outptr5 = (float*)top_blob + size * (q * 8 + 5);
221                 float* outptr6 = (float*)top_blob + size * (q * 8 + 6);
222                 float* outptr7 = (float*)top_blob + size * (q * 8 + 7);
223 
224                 int i = 0;
225                 for (; i + 7 < size; i += 8)
226                 {
227                     __m256 _row0 = _mm256_loadu_ps(ptr);
228                     __m256 _row1 = _mm256_loadu_ps(ptr + 8);
229                     __m256 _row2 = _mm256_loadu_ps(ptr + 16);
230                     __m256 _row3 = _mm256_loadu_ps(ptr + 24);
231                     __m256 _row4 = _mm256_loadu_ps(ptr + 32);
232                     __m256 _row5 = _mm256_loadu_ps(ptr + 40);
233                     __m256 _row6 = _mm256_loadu_ps(ptr + 48);
234                     __m256 _row7 = _mm256_loadu_ps(ptr + 56);
235 
236                     transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
237 
238                     _mm256_storeu_ps(outptr0, _row0);
239                     _mm256_storeu_ps(outptr1, _row1);
240                     _mm256_storeu_ps(outptr2, _row2);
241                     _mm256_storeu_ps(outptr3, _row3);
242                     _mm256_storeu_ps(outptr4, _row4);
243                     _mm256_storeu_ps(outptr5, _row5);
244                     _mm256_storeu_ps(outptr6, _row6);
245                     _mm256_storeu_ps(outptr7, _row7);
246 
247                     outptr0 += 8;
248                     outptr1 += 8;
249                     outptr2 += 8;
250                     outptr3 += 8;
251                     outptr4 += 8;
252                     outptr5 += 8;
253                     outptr6 += 8;
254                     outptr7 += 8;
255                     ptr += 64;
256                 }
257                 for (; i < size; i++)
258                 {
259                     *outptr0++ = ptr[0];
260                     *outptr1++ = ptr[1];
261                     *outptr2++ = ptr[2];
262                     *outptr3++ = ptr[3];
263                     *outptr4++ = ptr[4];
264                     *outptr5++ = ptr[5];
265                     *outptr6++ = ptr[6];
266                     *outptr7++ = ptr[7];
267 
268                     ptr += 8;
269                 }
270             }
271         }
272 #endif // __AVX__
273 
274         if (elempack == 4) // out_elempack == 4 || out_elempack == 8
275         {
276             #pragma omp parallel for num_threads(opt.num_threads)
277             for (int q = 0; q < channels; q++)
278             {
279                 const float* ptr = bottom_blob.channel(q);
280                 float* outptr0 = (float*)top_blob + size * q * 4;
281                 float* outptr1 = (float*)top_blob + size * (q * 4 + 1);
282                 float* outptr2 = (float*)top_blob + size * (q * 4 + 2);
283                 float* outptr3 = (float*)top_blob + size * (q * 4 + 3);
284 
285                 int i = 0;
286                 for (; i + 3 < size; i += 4)
287                 {
288                     __m128 _row0 = _mm_loadu_ps(ptr);
289                     __m128 _row1 = _mm_loadu_ps(ptr + 4);
290                     __m128 _row2 = _mm_loadu_ps(ptr + 8);
291                     __m128 _row3 = _mm_loadu_ps(ptr + 12);
292 
293                     _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
294 
295                     _mm_storeu_ps(outptr0, _row0);
296                     _mm_storeu_ps(outptr1, _row1);
297                     _mm_storeu_ps(outptr2, _row2);
298                     _mm_storeu_ps(outptr3, _row3);
299 
300                     ptr += 16;
301                     outptr0 += 4;
302                     outptr1 += 4;
303                     outptr2 += 4;
304                     outptr3 += 4;
305                 }
306                 for (; i < size; i++)
307                 {
308                     *outptr0++ = ptr[0];
309                     *outptr1++ = ptr[1];
310                     *outptr2++ = ptr[2];
311                     *outptr3++ = ptr[3];
312 
313                     ptr += 4;
314                 }
315             }
316         }
317 #endif // __SSE2__
318 
319         if (elempack == 1) // out_elempack == 4 || out_elempack == 8
320         {
321             #pragma omp parallel for num_threads(opt.num_threads)
322             for (int q = 0; q < channels; q++)
323             {
324                 const float* ptr = bottom_blob.channel(q);
325                 float* outptr = (float*)top_blob + size * q;
326 
327                 int i = 0;
328 #if __AVX__
329                 for (; i + 7 < size; i += 8)
330                 {
331                     __m256 _v = _mm256_loadu_ps(ptr);
332                     _mm256_storeu_ps(outptr, _v);
333                     ptr += 8;
334                     outptr += 8;
335                 }
336 #endif
337                 for (; i < size; i++)
338                 {
339                     *outptr++ = *ptr++;
340                 }
341             }
342         }
343     }
344 
345     return 0;
346 }
347 
forward_int8(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const348 int Flatten_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
349 {
350     int dims = bottom_blob.dims;
351 
352     if (dims == 1)
353     {
354         top_blob = bottom_blob;
355         return 0;
356     }
357 
358     int w = bottom_blob.w;
359     int h = bottom_blob.h;
360     int channels = bottom_blob.c;
361     size_t elemsize = bottom_blob.elemsize;
362     int elempack = bottom_blob.elempack;
363     int size = w * h;
364 
365     int total = size * channels * elempack;
366 
367     int out_elempack = 1;
368     if (opt.use_packing_layout)
369     {
370         out_elempack = total % 8 == 0 ? 8 : 1;
371     }
372     size_t out_elemsize = elemsize / elempack * out_elempack;
373 
374     if (out_elempack == 1)
375     {
376         return Flatten::forward(bottom_blob, top_blob, opt);
377     }
378 
379     if (dims == 2 && elempack == 1) // out_elempack == 8
380     {
381         top_blob = bottom_blob;
382         top_blob.dims = 1;
383         top_blob.w = total / out_elempack;
384         top_blob.h = 1;
385         top_blob.cstep = top_blob.w;
386         top_blob.elemsize = out_elemsize;
387         top_blob.elempack = out_elempack;
388         return 0;
389     }
390 
391     top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
392     if (top_blob.empty())
393         return -100;
394 
395     if (dims == 2)
396     {
397         if (elempack == 8) // out_elempack == 8
398         {
399             #pragma omp parallel for num_threads(opt.num_threads)
400             for (int i = 0; i < h; i++)
401             {
402                 const signed char* ptr = bottom_blob.row<const signed char>(i);
403                 signed char* outptr0 = (signed char*)top_blob + w * i * 8;
404                 signed char* outptr1 = (signed char*)top_blob + w * (i * 8 + 1);
405                 signed char* outptr2 = (signed char*)top_blob + w * (i * 8 + 2);
406                 signed char* outptr3 = (signed char*)top_blob + w * (i * 8 + 3);
407                 signed char* outptr4 = (signed char*)top_blob + w * (i * 8 + 4);
408                 signed char* outptr5 = (signed char*)top_blob + w * (i * 8 + 5);
409                 signed char* outptr6 = (signed char*)top_blob + w * (i * 8 + 6);
410                 signed char* outptr7 = (signed char*)top_blob + w * (i * 8 + 7);
411 
412                 int j = 0;
413                 for (; j < w; j++)
414                 {
415                     *outptr0++ = ptr[0];
416                     *outptr1++ = ptr[1];
417                     *outptr2++ = ptr[2];
418                     *outptr3++ = ptr[3];
419                     *outptr4++ = ptr[4];
420                     *outptr5++ = ptr[5];
421                     *outptr6++ = ptr[6];
422                     *outptr7++ = ptr[7];
423 
424                     ptr += 8;
425                 }
426             }
427         }
428     }
429 
430     if (dims == 3)
431     {
432         if (elempack == 8) // out_elempack == 8
433         {
434             #pragma omp parallel for num_threads(opt.num_threads)
435             for (int q = 0; q < channels; q++)
436             {
437                 const signed char* ptr = bottom_blob.channel(q);
438                 signed char* outptr0 = (signed char*)top_blob + size * q * 8;
439                 signed char* outptr1 = (signed char*)top_blob + size * (q * 8 + 1);
440                 signed char* outptr2 = (signed char*)top_blob + size * (q * 8 + 2);
441                 signed char* outptr3 = (signed char*)top_blob + size * (q * 8 + 3);
442                 signed char* outptr4 = (signed char*)top_blob + size * (q * 8 + 4);
443                 signed char* outptr5 = (signed char*)top_blob + size * (q * 8 + 5);
444                 signed char* outptr6 = (signed char*)top_blob + size * (q * 8 + 6);
445                 signed char* outptr7 = (signed char*)top_blob + size * (q * 8 + 7);
446 
447                 int i = 0;
448                 for (; i < size; i++)
449                 {
450                     *outptr0++ = ptr[0];
451                     *outptr1++ = ptr[1];
452                     *outptr2++ = ptr[2];
453                     *outptr3++ = ptr[3];
454                     *outptr4++ = ptr[4];
455                     *outptr5++ = ptr[5];
456                     *outptr6++ = ptr[6];
457                     *outptr7++ = ptr[7];
458 
459                     ptr += 8;
460                 }
461             }
462         }
463 
464         if (elempack == 1) // out_elempack == 8
465         {
466             #pragma omp parallel for num_threads(opt.num_threads)
467             for (int q = 0; q < channels; q++)
468             {
469                 const signed char* ptr = bottom_blob.channel(q);
470                 signed char* outptr = (signed char*)top_blob + size * q;
471 
472                 int i = 0;
473                 for (; i < size; i++)
474                 {
475                     *outptr++ = *ptr++;
476                 }
477             }
478         }
479     }
480 
481     return 0;
482 }
483 
484 } // namespace ncnn
485