1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "reshape_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #endif // __SSE2__
20 
21 #include "x86_usability.h"
22 
23 namespace ncnn {
24 
Reshape_x86()25 Reshape_x86::Reshape_x86()
26 {
27 #if __SSE2__
28     support_packing = true;
29 #endif // __SSE2__
30 }
31 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const32 int Reshape_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33 {
34     int elempack = bottom_blob.elempack;
35 
36     if (permute == 1)
37     {
38         // TODO implement permute on-the-fly
39         Option opt_pack = opt;
40         opt_pack.blob_allocator = opt.workspace_allocator;
41 
42         Mat bottom_blob_unpacked;
43         convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
44 
45         Mat top_blob_unpacked;
46         int ret = Reshape::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
47         if (ret != 0)
48             return ret;
49 
50         int out_elempack = 1;
51 #if __SSE2__
52         if (opt.use_packing_layout)
53         {
54             // resolve dst_elempack
55             int dims = top_blob_unpacked.dims;
56 #if __AVX__
57             if (dims == 1) out_elempack = top_blob_unpacked.w % 8 == 0 ? 8 : top_blob_unpacked.w % 4 == 0 ? 4 : 1;
58             if (dims == 2) out_elempack = top_blob_unpacked.h % 8 == 0 ? 8 : top_blob_unpacked.h % 4 == 0 ? 4 : 1;
59             if (dims == 3) out_elempack = top_blob_unpacked.c % 8 == 0 ? 8 : top_blob_unpacked.c % 4 == 0 ? 4 : 1;
60 #else
61             if (dims == 1) out_elempack = top_blob_unpacked.w % 4 == 0 ? 4 : 1;
62             if (dims == 2) out_elempack = top_blob_unpacked.h % 4 == 0 ? 4 : 1;
63             if (dims == 3) out_elempack = top_blob_unpacked.c % 4 == 0 ? 4 : 1;
64 #endif
65         }
66 #endif // __SSE2__
67         convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
68 
69         return 0;
70     }
71 
72     if (ndim == 1)
73     {
74         // flatten
75         flatten(bottom_blob, top_blob, opt);
76         if (top_blob.empty())
77             return -100;
78 
79         return 0;
80     }
81 
82     int dims = bottom_blob.dims;
83     size_t elemsize = bottom_blob.elemsize;
84 
85     int total = bottom_blob.w * bottom_blob.h * bottom_blob.c * elempack;
86 
87     if (ndim == 2)
88     {
89         int _w = w;
90         int _h = h;
91 
92         if (_w == 0)
93             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
94         if (_h == 0)
95             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
96 
97         if (_w == -1)
98             _w = total / _h;
99         if (_h == -1)
100             _h = total / _w;
101 
102         int out_elempack = 1;
103 #if __SSE2__
104         if (opt.use_packing_layout)
105         {
106 #if __AVX__
107             out_elempack = _h % 8 == 0 ? 8 : _h % 4 == 0 ? 4 : 1;
108 #else
109             out_elempack = _h % 4 == 0 ? 4 : 1;
110 #endif
111         }
112 #endif // __SSE2__
113         size_t out_elemsize = elemsize / elempack * out_elempack;
114 
115         if (dims == 2 && bottom_blob.h == _h && elempack == out_elempack)
116         {
117             top_blob = bottom_blob;
118             return 0;
119         }
120 
121         if (out_elempack == 1)
122         {
123             // flatten
124             flatten(bottom_blob, top_blob, opt);
125             if (top_blob.empty())
126                 return -100;
127 
128             top_blob.dims = 2;
129             top_blob.w = _w;
130             top_blob.h = _h;
131             top_blob.cstep = (size_t)_w * _h;
132             top_blob.elemsize = out_elemsize;
133             top_blob.elempack = out_elempack;
134 
135             return 0;
136         }
137 
138         // flatten
139         Mat bottom_blob_flattened = bottom_blob;
140         {
141             Option opt_flatten = opt;
142             opt_flatten.blob_allocator = opt.workspace_allocator;
143 
144             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
145             if (bottom_blob_flattened.empty())
146                 return -100;
147         }
148 
149         top_blob.create(_w, _h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
150         if (top_blob.empty())
151             return -100;
152 
153         int outw = top_blob.w;
154         int outh = top_blob.h;
155 
156 #if __SSE2__
157 #if __AVX__
158         if (out_elempack == 8)
159         {
160             #pragma omp parallel for num_threads(opt.num_threads)
161             for (int i = 0; i < outh; i++)
162             {
163                 const float* ptr0 = (const float*)bottom_blob_flattened + outw * i * 8;
164                 const float* ptr1 = (const float*)bottom_blob_flattened + outw * (i * 8 + 1);
165                 const float* ptr2 = (const float*)bottom_blob_flattened + outw * (i * 8 + 2);
166                 const float* ptr3 = (const float*)bottom_blob_flattened + outw * (i * 8 + 3);
167                 const float* ptr4 = (const float*)bottom_blob_flattened + outw * (i * 8 + 4);
168                 const float* ptr5 = (const float*)bottom_blob_flattened + outw * (i * 8 + 5);
169                 const float* ptr6 = (const float*)bottom_blob_flattened + outw * (i * 8 + 6);
170                 const float* ptr7 = (const float*)bottom_blob_flattened + outw * (i * 8 + 7);
171                 float* outptr = top_blob.row(i);
172 
173                 int j = 0;
174                 for (; j + 7 < outw; j += 8)
175                 {
176                     __m256 _row0 = _mm256_loadu_ps(ptr0);
177                     __m256 _row1 = _mm256_loadu_ps(ptr1);
178                     __m256 _row2 = _mm256_loadu_ps(ptr2);
179                     __m256 _row3 = _mm256_loadu_ps(ptr3);
180                     __m256 _row4 = _mm256_loadu_ps(ptr4);
181                     __m256 _row5 = _mm256_loadu_ps(ptr5);
182                     __m256 _row6 = _mm256_loadu_ps(ptr6);
183                     __m256 _row7 = _mm256_loadu_ps(ptr7);
184 
185                     transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
186 
187                     _mm256_storeu_ps(outptr, _row0);
188                     _mm256_storeu_ps(outptr + 8, _row1);
189                     _mm256_storeu_ps(outptr + 16, _row2);
190                     _mm256_storeu_ps(outptr + 24, _row3);
191                     _mm256_storeu_ps(outptr + 32, _row4);
192                     _mm256_storeu_ps(outptr + 40, _row5);
193                     _mm256_storeu_ps(outptr + 48, _row6);
194                     _mm256_storeu_ps(outptr + 56, _row7);
195 
196                     ptr0 += 8;
197                     ptr1 += 8;
198                     ptr2 += 8;
199                     ptr3 += 8;
200                     ptr4 += 8;
201                     ptr5 += 8;
202                     ptr6 += 8;
203                     ptr7 += 8;
204                     outptr += 64;
205                 }
206                 for (; j < outw; j++)
207                 {
208                     outptr[0] = *ptr0++;
209                     outptr[1] = *ptr1++;
210                     outptr[2] = *ptr2++;
211                     outptr[3] = *ptr3++;
212                     outptr[4] = *ptr4++;
213                     outptr[5] = *ptr5++;
214                     outptr[6] = *ptr6++;
215                     outptr[7] = *ptr7++;
216 
217                     outptr += 8;
218                 }
219             }
220         }
221 #endif // __AVX__
222 
223         if (out_elempack == 4)
224         {
225             #pragma omp parallel for num_threads(opt.num_threads)
226             for (int i = 0; i < outh; i++)
227             {
228                 const float* ptr0 = (const float*)bottom_blob_flattened + outw * i * 4;
229                 const float* ptr1 = (const float*)bottom_blob_flattened + outw * (i * 4 + 1);
230                 const float* ptr2 = (const float*)bottom_blob_flattened + outw * (i * 4 + 2);
231                 const float* ptr3 = (const float*)bottom_blob_flattened + outw * (i * 4 + 3);
232                 float* outptr = top_blob.row(i);
233 
234                 int j = 0;
235                 for (; j + 3 < outw; j += 4)
236                 {
237                     __m128 _row0 = _mm_loadu_ps(ptr0);
238                     __m128 _row1 = _mm_loadu_ps(ptr1);
239                     __m128 _row2 = _mm_loadu_ps(ptr2);
240                     __m128 _row3 = _mm_loadu_ps(ptr3);
241 
242                     _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
243 
244                     _mm_storeu_ps(outptr, _row0);
245                     _mm_storeu_ps(outptr + 4, _row1);
246                     _mm_storeu_ps(outptr + 8, _row2);
247                     _mm_storeu_ps(outptr + 12, _row3);
248 
249                     ptr0 += 4;
250                     ptr1 += 4;
251                     ptr2 += 4;
252                     ptr3 += 4;
253                     outptr += 16;
254                 }
255                 for (; j < outw; j++)
256                 {
257                     outptr[0] = *ptr0++;
258                     outptr[1] = *ptr1++;
259                     outptr[2] = *ptr2++;
260                     outptr[3] = *ptr3++;
261 
262                     outptr += 4;
263                 }
264             }
265         }
266 #endif // __SSE2__
267     }
268 
269     if (ndim == 3)
270     {
271         int _w = w;
272         int _h = h;
273         int _c = c;
274 
275         if (_w == 0)
276             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
277         if (_h == 0)
278             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
279         if (_c == 0)
280             _c = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c;
281 
282         if (_w == -1)
283             _w = total / _c / _h;
284         if (_h == -1)
285             _h = total / _c / _w;
286         if (_c == -1)
287             _c = total / _h / _w;
288 
289         int out_elempack = 1;
290 #if __SSE2__
291         if (opt.use_packing_layout)
292         {
293 #if __AVX__
294             out_elempack = _c % 8 == 0 ? 8 : _c % 4 == 0 ? 4 : 1;
295 #else
296             out_elempack = _c % 4 == 0 ? 4 : 1;
297 #endif
298         }
299 #endif // __SSE2__
300         size_t out_elemsize = elemsize / elempack * out_elempack;
301 
302         if (dims == 3 && bottom_blob.c == _c && elempack == out_elempack)
303         {
304             top_blob = bottom_blob;
305             top_blob.w = _w;
306             top_blob.h = _h;
307             return 0;
308         }
309 
310         // flatten
311         Mat bottom_blob_flattened = bottom_blob;
312         {
313             Option opt_flatten = opt;
314             opt_flatten.blob_allocator = opt.workspace_allocator;
315 
316             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
317             if (bottom_blob_flattened.empty())
318                 return -100;
319         }
320 
321         top_blob.create(_w, _h, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
322         if (top_blob.empty())
323             return -100;
324 
325         int size = top_blob.w * top_blob.h;
326 
327 #if __SSE2__
328 #if __AVX__
329         if (out_elempack == 8)
330         {
331             #pragma omp parallel for num_threads(opt.num_threads)
332             for (int q = 0; q < top_blob.c; q++)
333             {
334                 const float* ptr0 = (const float*)bottom_blob_flattened + size * q * 8;
335                 const float* ptr1 = (const float*)bottom_blob_flattened + size * (q * 8 + 1);
336                 const float* ptr2 = (const float*)bottom_blob_flattened + size * (q * 8 + 2);
337                 const float* ptr3 = (const float*)bottom_blob_flattened + size * (q * 8 + 3);
338                 const float* ptr4 = (const float*)bottom_blob_flattened + size * (q * 8 + 4);
339                 const float* ptr5 = (const float*)bottom_blob_flattened + size * (q * 8 + 5);
340                 const float* ptr6 = (const float*)bottom_blob_flattened + size * (q * 8 + 6);
341                 const float* ptr7 = (const float*)bottom_blob_flattened + size * (q * 8 + 7);
342                 float* outptr = top_blob.channel(q);
343 
344                 int i = 0;
345                 for (; i + 7 < size; i += 8)
346                 {
347                     __m256 _row0 = _mm256_loadu_ps(ptr0);
348                     __m256 _row1 = _mm256_loadu_ps(ptr1);
349                     __m256 _row2 = _mm256_loadu_ps(ptr2);
350                     __m256 _row3 = _mm256_loadu_ps(ptr3);
351                     __m256 _row4 = _mm256_loadu_ps(ptr4);
352                     __m256 _row5 = _mm256_loadu_ps(ptr5);
353                     __m256 _row6 = _mm256_loadu_ps(ptr6);
354                     __m256 _row7 = _mm256_loadu_ps(ptr7);
355 
356                     transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
357 
358                     _mm256_storeu_ps(outptr, _row0);
359                     _mm256_storeu_ps(outptr + 8, _row1);
360                     _mm256_storeu_ps(outptr + 16, _row2);
361                     _mm256_storeu_ps(outptr + 24, _row3);
362                     _mm256_storeu_ps(outptr + 32, _row4);
363                     _mm256_storeu_ps(outptr + 40, _row5);
364                     _mm256_storeu_ps(outptr + 48, _row6);
365                     _mm256_storeu_ps(outptr + 56, _row7);
366 
367                     ptr0 += 8;
368                     ptr1 += 8;
369                     ptr2 += 8;
370                     ptr3 += 8;
371                     ptr4 += 8;
372                     ptr5 += 8;
373                     ptr6 += 8;
374                     ptr7 += 8;
375                     outptr += 64;
376                 }
377                 for (; i < size; i++)
378                 {
379                     outptr[0] = *ptr0++;
380                     outptr[1] = *ptr1++;
381                     outptr[2] = *ptr2++;
382                     outptr[3] = *ptr3++;
383                     outptr[4] = *ptr4++;
384                     outptr[5] = *ptr5++;
385                     outptr[6] = *ptr6++;
386                     outptr[7] = *ptr7++;
387 
388                     outptr += 8;
389                 }
390             }
391         }
392 #endif // __AVX__
393 
394         if (out_elempack == 4)
395         {
396             #pragma omp parallel for num_threads(opt.num_threads)
397             for (int q = 0; q < top_blob.c; q++)
398             {
399                 const float* ptr0 = (const float*)bottom_blob_flattened + size * q * 4;
400                 const float* ptr1 = (const float*)bottom_blob_flattened + size * (q * 4 + 1);
401                 const float* ptr2 = (const float*)bottom_blob_flattened + size * (q * 4 + 2);
402                 const float* ptr3 = (const float*)bottom_blob_flattened + size * (q * 4 + 3);
403                 float* outptr = top_blob.channel(q);
404 
405                 int i = 0;
406                 for (; i + 3 < size; i += 4)
407                 {
408                     __m128 _row0 = _mm_loadu_ps(ptr0);
409                     __m128 _row1 = _mm_loadu_ps(ptr1);
410                     __m128 _row2 = _mm_loadu_ps(ptr2);
411                     __m128 _row3 = _mm_loadu_ps(ptr3);
412 
413                     _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
414 
415                     _mm_storeu_ps(outptr, _row0);
416                     _mm_storeu_ps(outptr + 4, _row1);
417                     _mm_storeu_ps(outptr + 8, _row2);
418                     _mm_storeu_ps(outptr + 12, _row3);
419 
420                     ptr0 += 4;
421                     ptr1 += 4;
422                     ptr2 += 4;
423                     ptr3 += 4;
424                     outptr += 16;
425                 }
426                 for (; i < size; i++)
427                 {
428                     outptr[0] = *ptr0++;
429                     outptr[1] = *ptr1++;
430                     outptr[2] = *ptr2++;
431                     outptr[3] = *ptr3++;
432 
433                     outptr += 4;
434                 }
435             }
436         }
437 #endif // __SSE2__
438 
439         if (out_elempack == 1)
440         {
441             #pragma omp parallel for num_threads(opt.num_threads)
442             for (int q = 0; q < top_blob.c; q++)
443             {
444                 const float* ptr = (const float*)bottom_blob_flattened + size * q;
445                 float* outptr = top_blob.channel(q);
446 
447                 int i = 0;
448 #if __SSE2__
449 #if __AVX__
450                 for (; i + 7 < size; i += 8)
451                 {
452                     __m256 _v = _mm256_loadu_ps(ptr);
453                     _mm256_storeu_ps(outptr, _v);
454                     ptr += 8;
455                     outptr += 8;
456                 }
457 #endif
458                 for (; i + 3 < size; i += 4)
459                 {
460                     __m128 _v = _mm_loadu_ps(ptr);
461                     _mm_storeu_ps(outptr, _v);
462                     ptr += 4;
463                     outptr += 4;
464                 }
465 #endif // __SSE2__
466                 for (; i < size; i++)
467                 {
468                     *outptr++ = *ptr++;
469                 }
470             }
471         }
472     }
473 
474     return 0;
475 }
476 
477 } // namespace ncnn
478