1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "flatten_x86.h"
16
17 #if __SSE2__
18 #include <emmintrin.h>
19 #endif // __SSE2__
20
21 #include "x86_usability.h"
22
23 namespace ncnn {
24
Flatten_x86()25 Flatten_x86::Flatten_x86()
26 {
27 #if __SSE2__
28 support_packing = true;
29 #endif // __SSE2__
30 }
31
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const32 int Flatten_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33 {
34 int elembits = bottom_blob.elembits();
35
36 if (elembits == 8)
37 return forward_int8(bottom_blob, top_blob, opt);
38
39 int dims = bottom_blob.dims;
40
41 if (dims == 1)
42 {
43 top_blob = bottom_blob;
44 return 0;
45 }
46
47 int w = bottom_blob.w;
48 int h = bottom_blob.h;
49 int channels = bottom_blob.c;
50 size_t elemsize = bottom_blob.elemsize;
51 int elempack = bottom_blob.elempack;
52 int size = w * h;
53
54 int total = size * channels * elempack;
55
56 int out_elempack = 1;
57 #if __SSE2__
58 if (opt.use_packing_layout)
59 {
60 #if __AVX__
61 out_elempack = total % 8 == 0 ? 8 : total % 4 == 0 ? 4 : 1;
62 #else
63 out_elempack = total % 4 == 0 ? 4 : 1;
64 #endif
65 }
66 #endif // __SSE2__
67 size_t out_elemsize = elemsize / elempack * out_elempack;
68
69 if (out_elempack == 1)
70 {
71 return Flatten::forward(bottom_blob, top_blob, opt);
72 }
73
74 if (dims == 2 && elempack == 1) // out_elempack == 4 || out_elempack == 8
75 {
76 top_blob = bottom_blob;
77 top_blob.dims = 1;
78 top_blob.w = total / out_elempack;
79 top_blob.h = 1;
80 top_blob.cstep = top_blob.w;
81 top_blob.elemsize = out_elemsize;
82 top_blob.elempack = out_elempack;
83 return 0;
84 }
85
86 top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
87 if (top_blob.empty())
88 return -100;
89
90 if (dims == 2)
91 {
92 #if __SSE2__
93 #if __AVX__
94 if (elempack == 8) // out_elempack == 8
95 {
96 #pragma omp parallel for num_threads(opt.num_threads)
97 for (int i = 0; i < h; i++)
98 {
99 const float* ptr = bottom_blob.row(i);
100 float* outptr0 = (float*)top_blob + w * i * 8;
101 float* outptr1 = (float*)top_blob + w * (i * 8 + 1);
102 float* outptr2 = (float*)top_blob + w * (i * 8 + 2);
103 float* outptr3 = (float*)top_blob + w * (i * 8 + 3);
104 float* outptr4 = (float*)top_blob + w * (i * 8 + 4);
105 float* outptr5 = (float*)top_blob + w * (i * 8 + 5);
106 float* outptr6 = (float*)top_blob + w * (i * 8 + 6);
107 float* outptr7 = (float*)top_blob + w * (i * 8 + 7);
108
109 int j = 0;
110 for (; j + 7 < w; j += 8)
111 {
112 __m256 _row0 = _mm256_loadu_ps(ptr);
113 __m256 _row1 = _mm256_loadu_ps(ptr + 8);
114 __m256 _row2 = _mm256_loadu_ps(ptr + 16);
115 __m256 _row3 = _mm256_loadu_ps(ptr + 24);
116 __m256 _row4 = _mm256_loadu_ps(ptr + 32);
117 __m256 _row5 = _mm256_loadu_ps(ptr + 40);
118 __m256 _row6 = _mm256_loadu_ps(ptr + 48);
119 __m256 _row7 = _mm256_loadu_ps(ptr + 56);
120
121 transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
122
123 _mm256_storeu_ps(outptr0, _row0);
124 _mm256_storeu_ps(outptr1, _row1);
125 _mm256_storeu_ps(outptr2, _row2);
126 _mm256_storeu_ps(outptr3, _row3);
127 _mm256_storeu_ps(outptr4, _row4);
128 _mm256_storeu_ps(outptr5, _row5);
129 _mm256_storeu_ps(outptr6, _row6);
130 _mm256_storeu_ps(outptr7, _row7);
131
132 outptr0 += 8;
133 outptr1 += 8;
134 outptr2 += 8;
135 outptr3 += 8;
136 outptr4 += 8;
137 outptr5 += 8;
138 outptr6 += 8;
139 outptr7 += 8;
140 ptr += 64;
141 }
142 for (; j < w; j++)
143 {
144 *outptr0++ = ptr[0];
145 *outptr1++ = ptr[1];
146 *outptr2++ = ptr[2];
147 *outptr3++ = ptr[3];
148 *outptr4++ = ptr[4];
149 *outptr5++ = ptr[5];
150 *outptr6++ = ptr[6];
151 *outptr7++ = ptr[7];
152
153 ptr += 8;
154 }
155 }
156 }
157 #endif // __AVX__
158
159 if (elempack == 4) // out_elempack == 4 || out_elempack == 8
160 {
161 #pragma omp parallel for num_threads(opt.num_threads)
162 for (int i = 0; i < h; i++)
163 {
164 const float* ptr = bottom_blob.row(i);
165 float* outptr0 = (float*)top_blob + w * i * 4;
166 float* outptr1 = (float*)top_blob + w * (i * 4 + 1);
167 float* outptr2 = (float*)top_blob + w * (i * 4 + 2);
168 float* outptr3 = (float*)top_blob + w * (i * 4 + 3);
169
170 int j = 0;
171 for (; j + 3 < w; j += 4)
172 {
173 __m128 _row0 = _mm_loadu_ps(ptr);
174 __m128 _row1 = _mm_loadu_ps(ptr + 4);
175 __m128 _row2 = _mm_loadu_ps(ptr + 8);
176 __m128 _row3 = _mm_loadu_ps(ptr + 12);
177
178 _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
179
180 _mm_storeu_ps(outptr0, _row0);
181 _mm_storeu_ps(outptr1, _row1);
182 _mm_storeu_ps(outptr2, _row2);
183 _mm_storeu_ps(outptr3, _row3);
184
185 ptr += 16;
186 outptr0 += 4;
187 outptr1 += 4;
188 outptr2 += 4;
189 outptr3 += 4;
190 }
191 for (; j < w; j++)
192 {
193 *outptr0++ = ptr[0];
194 *outptr1++ = ptr[1];
195 *outptr2++ = ptr[2];
196 *outptr3++ = ptr[3];
197
198 ptr += 4;
199 }
200 }
201 }
202 #endif // __SSE2__
203 }
204
205 if (dims == 3)
206 {
207 #if __SSE2__
208 #if __AVX__
209 if (elempack == 8) // out_elempack == 8
210 {
211 #pragma omp parallel for num_threads(opt.num_threads)
212 for (int q = 0; q < channels; q++)
213 {
214 const float* ptr = bottom_blob.channel(q);
215 float* outptr0 = (float*)top_blob + size * q * 8;
216 float* outptr1 = (float*)top_blob + size * (q * 8 + 1);
217 float* outptr2 = (float*)top_blob + size * (q * 8 + 2);
218 float* outptr3 = (float*)top_blob + size * (q * 8 + 3);
219 float* outptr4 = (float*)top_blob + size * (q * 8 + 4);
220 float* outptr5 = (float*)top_blob + size * (q * 8 + 5);
221 float* outptr6 = (float*)top_blob + size * (q * 8 + 6);
222 float* outptr7 = (float*)top_blob + size * (q * 8 + 7);
223
224 int i = 0;
225 for (; i + 7 < size; i += 8)
226 {
227 __m256 _row0 = _mm256_loadu_ps(ptr);
228 __m256 _row1 = _mm256_loadu_ps(ptr + 8);
229 __m256 _row2 = _mm256_loadu_ps(ptr + 16);
230 __m256 _row3 = _mm256_loadu_ps(ptr + 24);
231 __m256 _row4 = _mm256_loadu_ps(ptr + 32);
232 __m256 _row5 = _mm256_loadu_ps(ptr + 40);
233 __m256 _row6 = _mm256_loadu_ps(ptr + 48);
234 __m256 _row7 = _mm256_loadu_ps(ptr + 56);
235
236 transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
237
238 _mm256_storeu_ps(outptr0, _row0);
239 _mm256_storeu_ps(outptr1, _row1);
240 _mm256_storeu_ps(outptr2, _row2);
241 _mm256_storeu_ps(outptr3, _row3);
242 _mm256_storeu_ps(outptr4, _row4);
243 _mm256_storeu_ps(outptr5, _row5);
244 _mm256_storeu_ps(outptr6, _row6);
245 _mm256_storeu_ps(outptr7, _row7);
246
247 outptr0 += 8;
248 outptr1 += 8;
249 outptr2 += 8;
250 outptr3 += 8;
251 outptr4 += 8;
252 outptr5 += 8;
253 outptr6 += 8;
254 outptr7 += 8;
255 ptr += 64;
256 }
257 for (; i < size; i++)
258 {
259 *outptr0++ = ptr[0];
260 *outptr1++ = ptr[1];
261 *outptr2++ = ptr[2];
262 *outptr3++ = ptr[3];
263 *outptr4++ = ptr[4];
264 *outptr5++ = ptr[5];
265 *outptr6++ = ptr[6];
266 *outptr7++ = ptr[7];
267
268 ptr += 8;
269 }
270 }
271 }
272 #endif // __AVX__
273
274 if (elempack == 4) // out_elempack == 4 || out_elempack == 8
275 {
276 #pragma omp parallel for num_threads(opt.num_threads)
277 for (int q = 0; q < channels; q++)
278 {
279 const float* ptr = bottom_blob.channel(q);
280 float* outptr0 = (float*)top_blob + size * q * 4;
281 float* outptr1 = (float*)top_blob + size * (q * 4 + 1);
282 float* outptr2 = (float*)top_blob + size * (q * 4 + 2);
283 float* outptr3 = (float*)top_blob + size * (q * 4 + 3);
284
285 int i = 0;
286 for (; i + 3 < size; i += 4)
287 {
288 __m128 _row0 = _mm_loadu_ps(ptr);
289 __m128 _row1 = _mm_loadu_ps(ptr + 4);
290 __m128 _row2 = _mm_loadu_ps(ptr + 8);
291 __m128 _row3 = _mm_loadu_ps(ptr + 12);
292
293 _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
294
295 _mm_storeu_ps(outptr0, _row0);
296 _mm_storeu_ps(outptr1, _row1);
297 _mm_storeu_ps(outptr2, _row2);
298 _mm_storeu_ps(outptr3, _row3);
299
300 ptr += 16;
301 outptr0 += 4;
302 outptr1 += 4;
303 outptr2 += 4;
304 outptr3 += 4;
305 }
306 for (; i < size; i++)
307 {
308 *outptr0++ = ptr[0];
309 *outptr1++ = ptr[1];
310 *outptr2++ = ptr[2];
311 *outptr3++ = ptr[3];
312
313 ptr += 4;
314 }
315 }
316 }
317 #endif // __SSE2__
318
319 if (elempack == 1) // out_elempack == 4 || out_elempack == 8
320 {
321 #pragma omp parallel for num_threads(opt.num_threads)
322 for (int q = 0; q < channels; q++)
323 {
324 const float* ptr = bottom_blob.channel(q);
325 float* outptr = (float*)top_blob + size * q;
326
327 int i = 0;
328 #if __AVX__
329 for (; i + 7 < size; i += 8)
330 {
331 __m256 _v = _mm256_loadu_ps(ptr);
332 _mm256_storeu_ps(outptr, _v);
333 ptr += 8;
334 outptr += 8;
335 }
336 #endif
337 for (; i < size; i++)
338 {
339 *outptr++ = *ptr++;
340 }
341 }
342 }
343 }
344
345 return 0;
346 }
347
forward_int8(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const348 int Flatten_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
349 {
350 int dims = bottom_blob.dims;
351
352 if (dims == 1)
353 {
354 top_blob = bottom_blob;
355 return 0;
356 }
357
358 int w = bottom_blob.w;
359 int h = bottom_blob.h;
360 int channels = bottom_blob.c;
361 size_t elemsize = bottom_blob.elemsize;
362 int elempack = bottom_blob.elempack;
363 int size = w * h;
364
365 int total = size * channels * elempack;
366
367 int out_elempack = 1;
368 if (opt.use_packing_layout)
369 {
370 out_elempack = total % 8 == 0 ? 8 : 1;
371 }
372 size_t out_elemsize = elemsize / elempack * out_elempack;
373
374 if (out_elempack == 1)
375 {
376 return Flatten::forward(bottom_blob, top_blob, opt);
377 }
378
379 if (dims == 2 && elempack == 1) // out_elempack == 8
380 {
381 top_blob = bottom_blob;
382 top_blob.dims = 1;
383 top_blob.w = total / out_elempack;
384 top_blob.h = 1;
385 top_blob.cstep = top_blob.w;
386 top_blob.elemsize = out_elemsize;
387 top_blob.elempack = out_elempack;
388 return 0;
389 }
390
391 top_blob.create(total / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
392 if (top_blob.empty())
393 return -100;
394
395 if (dims == 2)
396 {
397 if (elempack == 8) // out_elempack == 8
398 {
399 #pragma omp parallel for num_threads(opt.num_threads)
400 for (int i = 0; i < h; i++)
401 {
402 const signed char* ptr = bottom_blob.row<const signed char>(i);
403 signed char* outptr0 = (signed char*)top_blob + w * i * 8;
404 signed char* outptr1 = (signed char*)top_blob + w * (i * 8 + 1);
405 signed char* outptr2 = (signed char*)top_blob + w * (i * 8 + 2);
406 signed char* outptr3 = (signed char*)top_blob + w * (i * 8 + 3);
407 signed char* outptr4 = (signed char*)top_blob + w * (i * 8 + 4);
408 signed char* outptr5 = (signed char*)top_blob + w * (i * 8 + 5);
409 signed char* outptr6 = (signed char*)top_blob + w * (i * 8 + 6);
410 signed char* outptr7 = (signed char*)top_blob + w * (i * 8 + 7);
411
412 int j = 0;
413 for (; j < w; j++)
414 {
415 *outptr0++ = ptr[0];
416 *outptr1++ = ptr[1];
417 *outptr2++ = ptr[2];
418 *outptr3++ = ptr[3];
419 *outptr4++ = ptr[4];
420 *outptr5++ = ptr[5];
421 *outptr6++ = ptr[6];
422 *outptr7++ = ptr[7];
423
424 ptr += 8;
425 }
426 }
427 }
428 }
429
430 if (dims == 3)
431 {
432 if (elempack == 8) // out_elempack == 8
433 {
434 #pragma omp parallel for num_threads(opt.num_threads)
435 for (int q = 0; q < channels; q++)
436 {
437 const signed char* ptr = bottom_blob.channel(q);
438 signed char* outptr0 = (signed char*)top_blob + size * q * 8;
439 signed char* outptr1 = (signed char*)top_blob + size * (q * 8 + 1);
440 signed char* outptr2 = (signed char*)top_blob + size * (q * 8 + 2);
441 signed char* outptr3 = (signed char*)top_blob + size * (q * 8 + 3);
442 signed char* outptr4 = (signed char*)top_blob + size * (q * 8 + 4);
443 signed char* outptr5 = (signed char*)top_blob + size * (q * 8 + 5);
444 signed char* outptr6 = (signed char*)top_blob + size * (q * 8 + 6);
445 signed char* outptr7 = (signed char*)top_blob + size * (q * 8 + 7);
446
447 int i = 0;
448 for (; i < size; i++)
449 {
450 *outptr0++ = ptr[0];
451 *outptr1++ = ptr[1];
452 *outptr2++ = ptr[2];
453 *outptr3++ = ptr[3];
454 *outptr4++ = ptr[4];
455 *outptr5++ = ptr[5];
456 *outptr6++ = ptr[6];
457 *outptr7++ = ptr[7];
458
459 ptr += 8;
460 }
461 }
462 }
463
464 if (elempack == 1) // out_elempack == 8
465 {
466 #pragma omp parallel for num_threads(opt.num_threads)
467 for (int q = 0; q < channels; q++)
468 {
469 const signed char* ptr = bottom_blob.channel(q);
470 signed char* outptr = (signed char*)top_blob + size * q;
471
472 int i = 0;
473 for (; i < size; i++)
474 {
475 *outptr++ = *ptr++;
476 }
477 }
478 }
479 }
480
481 return 0;
482 }
483
484 } // namespace ncnn
485