1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "reshape_x86.h"
16
17 #if __SSE2__
18 #include <emmintrin.h>
19 #endif // __SSE2__
20
21 #include "x86_usability.h"
22
23 namespace ncnn {
24
Reshape_x86()25 Reshape_x86::Reshape_x86()
26 {
27 #if __SSE2__
28 support_packing = true;
29 #endif // __SSE2__
30 }
31
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const32 int Reshape_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
33 {
34 int elempack = bottom_blob.elempack;
35
36 if (permute == 1)
37 {
38 // TODO implement permute on-the-fly
39 Option opt_pack = opt;
40 opt_pack.blob_allocator = opt.workspace_allocator;
41
42 Mat bottom_blob_unpacked;
43 convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
44
45 Mat top_blob_unpacked;
46 int ret = Reshape::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
47 if (ret != 0)
48 return ret;
49
50 int out_elempack = 1;
51 #if __SSE2__
52 if (opt.use_packing_layout)
53 {
54 // resolve dst_elempack
55 int dims = top_blob_unpacked.dims;
56 #if __AVX__
57 if (dims == 1) out_elempack = top_blob_unpacked.w % 8 == 0 ? 8 : top_blob_unpacked.w % 4 == 0 ? 4 : 1;
58 if (dims == 2) out_elempack = top_blob_unpacked.h % 8 == 0 ? 8 : top_blob_unpacked.h % 4 == 0 ? 4 : 1;
59 if (dims == 3) out_elempack = top_blob_unpacked.c % 8 == 0 ? 8 : top_blob_unpacked.c % 4 == 0 ? 4 : 1;
60 #else
61 if (dims == 1) out_elempack = top_blob_unpacked.w % 4 == 0 ? 4 : 1;
62 if (dims == 2) out_elempack = top_blob_unpacked.h % 4 == 0 ? 4 : 1;
63 if (dims == 3) out_elempack = top_blob_unpacked.c % 4 == 0 ? 4 : 1;
64 #endif
65 }
66 #endif // __SSE2__
67 convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
68
69 return 0;
70 }
71
72 if (ndim == 1)
73 {
74 // flatten
75 flatten(bottom_blob, top_blob, opt);
76 if (top_blob.empty())
77 return -100;
78
79 return 0;
80 }
81
82 int dims = bottom_blob.dims;
83 size_t elemsize = bottom_blob.elemsize;
84
85 int total = bottom_blob.w * bottom_blob.h * bottom_blob.c * elempack;
86
87 if (ndim == 2)
88 {
89 int _w = w;
90 int _h = h;
91
92 if (_w == 0)
93 _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
94 if (_h == 0)
95 _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
96
97 if (_w == -1)
98 _w = total / _h;
99 if (_h == -1)
100 _h = total / _w;
101
102 int out_elempack = 1;
103 #if __SSE2__
104 if (opt.use_packing_layout)
105 {
106 #if __AVX__
107 out_elempack = _h % 8 == 0 ? 8 : _h % 4 == 0 ? 4 : 1;
108 #else
109 out_elempack = _h % 4 == 0 ? 4 : 1;
110 #endif
111 }
112 #endif // __SSE2__
113 size_t out_elemsize = elemsize / elempack * out_elempack;
114
115 if (dims == 2 && bottom_blob.h == _h && elempack == out_elempack)
116 {
117 top_blob = bottom_blob;
118 return 0;
119 }
120
121 if (out_elempack == 1)
122 {
123 // flatten
124 flatten(bottom_blob, top_blob, opt);
125 if (top_blob.empty())
126 return -100;
127
128 top_blob.dims = 2;
129 top_blob.w = _w;
130 top_blob.h = _h;
131 top_blob.cstep = (size_t)_w * _h;
132 top_blob.elemsize = out_elemsize;
133 top_blob.elempack = out_elempack;
134
135 return 0;
136 }
137
138 // flatten
139 Mat bottom_blob_flattened = bottom_blob;
140 {
141 Option opt_flatten = opt;
142 opt_flatten.blob_allocator = opt.workspace_allocator;
143
144 flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
145 if (bottom_blob_flattened.empty())
146 return -100;
147 }
148
149 top_blob.create(_w, _h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
150 if (top_blob.empty())
151 return -100;
152
153 int outw = top_blob.w;
154 int outh = top_blob.h;
155
156 #if __SSE2__
157 #if __AVX__
158 if (out_elempack == 8)
159 {
160 #pragma omp parallel for num_threads(opt.num_threads)
161 for (int i = 0; i < outh; i++)
162 {
163 const float* ptr0 = (const float*)bottom_blob_flattened + outw * i * 8;
164 const float* ptr1 = (const float*)bottom_blob_flattened + outw * (i * 8 + 1);
165 const float* ptr2 = (const float*)bottom_blob_flattened + outw * (i * 8 + 2);
166 const float* ptr3 = (const float*)bottom_blob_flattened + outw * (i * 8 + 3);
167 const float* ptr4 = (const float*)bottom_blob_flattened + outw * (i * 8 + 4);
168 const float* ptr5 = (const float*)bottom_blob_flattened + outw * (i * 8 + 5);
169 const float* ptr6 = (const float*)bottom_blob_flattened + outw * (i * 8 + 6);
170 const float* ptr7 = (const float*)bottom_blob_flattened + outw * (i * 8 + 7);
171 float* outptr = top_blob.row(i);
172
173 int j = 0;
174 for (; j + 7 < outw; j += 8)
175 {
176 __m256 _row0 = _mm256_loadu_ps(ptr0);
177 __m256 _row1 = _mm256_loadu_ps(ptr1);
178 __m256 _row2 = _mm256_loadu_ps(ptr2);
179 __m256 _row3 = _mm256_loadu_ps(ptr3);
180 __m256 _row4 = _mm256_loadu_ps(ptr4);
181 __m256 _row5 = _mm256_loadu_ps(ptr5);
182 __m256 _row6 = _mm256_loadu_ps(ptr6);
183 __m256 _row7 = _mm256_loadu_ps(ptr7);
184
185 transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
186
187 _mm256_storeu_ps(outptr, _row0);
188 _mm256_storeu_ps(outptr + 8, _row1);
189 _mm256_storeu_ps(outptr + 16, _row2);
190 _mm256_storeu_ps(outptr + 24, _row3);
191 _mm256_storeu_ps(outptr + 32, _row4);
192 _mm256_storeu_ps(outptr + 40, _row5);
193 _mm256_storeu_ps(outptr + 48, _row6);
194 _mm256_storeu_ps(outptr + 56, _row7);
195
196 ptr0 += 8;
197 ptr1 += 8;
198 ptr2 += 8;
199 ptr3 += 8;
200 ptr4 += 8;
201 ptr5 += 8;
202 ptr6 += 8;
203 ptr7 += 8;
204 outptr += 64;
205 }
206 for (; j < outw; j++)
207 {
208 outptr[0] = *ptr0++;
209 outptr[1] = *ptr1++;
210 outptr[2] = *ptr2++;
211 outptr[3] = *ptr3++;
212 outptr[4] = *ptr4++;
213 outptr[5] = *ptr5++;
214 outptr[6] = *ptr6++;
215 outptr[7] = *ptr7++;
216
217 outptr += 8;
218 }
219 }
220 }
221 #endif // __AVX__
222
223 if (out_elempack == 4)
224 {
225 #pragma omp parallel for num_threads(opt.num_threads)
226 for (int i = 0; i < outh; i++)
227 {
228 const float* ptr0 = (const float*)bottom_blob_flattened + outw * i * 4;
229 const float* ptr1 = (const float*)bottom_blob_flattened + outw * (i * 4 + 1);
230 const float* ptr2 = (const float*)bottom_blob_flattened + outw * (i * 4 + 2);
231 const float* ptr3 = (const float*)bottom_blob_flattened + outw * (i * 4 + 3);
232 float* outptr = top_blob.row(i);
233
234 int j = 0;
235 for (; j + 3 < outw; j += 4)
236 {
237 __m128 _row0 = _mm_loadu_ps(ptr0);
238 __m128 _row1 = _mm_loadu_ps(ptr1);
239 __m128 _row2 = _mm_loadu_ps(ptr2);
240 __m128 _row3 = _mm_loadu_ps(ptr3);
241
242 _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
243
244 _mm_storeu_ps(outptr, _row0);
245 _mm_storeu_ps(outptr + 4, _row1);
246 _mm_storeu_ps(outptr + 8, _row2);
247 _mm_storeu_ps(outptr + 12, _row3);
248
249 ptr0 += 4;
250 ptr1 += 4;
251 ptr2 += 4;
252 ptr3 += 4;
253 outptr += 16;
254 }
255 for (; j < outw; j++)
256 {
257 outptr[0] = *ptr0++;
258 outptr[1] = *ptr1++;
259 outptr[2] = *ptr2++;
260 outptr[3] = *ptr3++;
261
262 outptr += 4;
263 }
264 }
265 }
266 #endif // __SSE2__
267 }
268
269 if (ndim == 3)
270 {
271 int _w = w;
272 int _h = h;
273 int _c = c;
274
275 if (_w == 0)
276 _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
277 if (_h == 0)
278 _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
279 if (_c == 0)
280 _c = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c;
281
282 if (_w == -1)
283 _w = total / _c / _h;
284 if (_h == -1)
285 _h = total / _c / _w;
286 if (_c == -1)
287 _c = total / _h / _w;
288
289 int out_elempack = 1;
290 #if __SSE2__
291 if (opt.use_packing_layout)
292 {
293 #if __AVX__
294 out_elempack = _c % 8 == 0 ? 8 : _c % 4 == 0 ? 4 : 1;
295 #else
296 out_elempack = _c % 4 == 0 ? 4 : 1;
297 #endif
298 }
299 #endif // __SSE2__
300 size_t out_elemsize = elemsize / elempack * out_elempack;
301
302 if (dims == 3 && bottom_blob.c == _c && elempack == out_elempack)
303 {
304 top_blob = bottom_blob;
305 top_blob.w = _w;
306 top_blob.h = _h;
307 return 0;
308 }
309
310 // flatten
311 Mat bottom_blob_flattened = bottom_blob;
312 {
313 Option opt_flatten = opt;
314 opt_flatten.blob_allocator = opt.workspace_allocator;
315
316 flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
317 if (bottom_blob_flattened.empty())
318 return -100;
319 }
320
321 top_blob.create(_w, _h, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
322 if (top_blob.empty())
323 return -100;
324
325 int size = top_blob.w * top_blob.h;
326
327 #if __SSE2__
328 #if __AVX__
329 if (out_elempack == 8)
330 {
331 #pragma omp parallel for num_threads(opt.num_threads)
332 for (int q = 0; q < top_blob.c; q++)
333 {
334 const float* ptr0 = (const float*)bottom_blob_flattened + size * q * 8;
335 const float* ptr1 = (const float*)bottom_blob_flattened + size * (q * 8 + 1);
336 const float* ptr2 = (const float*)bottom_blob_flattened + size * (q * 8 + 2);
337 const float* ptr3 = (const float*)bottom_blob_flattened + size * (q * 8 + 3);
338 const float* ptr4 = (const float*)bottom_blob_flattened + size * (q * 8 + 4);
339 const float* ptr5 = (const float*)bottom_blob_flattened + size * (q * 8 + 5);
340 const float* ptr6 = (const float*)bottom_blob_flattened + size * (q * 8 + 6);
341 const float* ptr7 = (const float*)bottom_blob_flattened + size * (q * 8 + 7);
342 float* outptr = top_blob.channel(q);
343
344 int i = 0;
345 for (; i + 7 < size; i += 8)
346 {
347 __m256 _row0 = _mm256_loadu_ps(ptr0);
348 __m256 _row1 = _mm256_loadu_ps(ptr1);
349 __m256 _row2 = _mm256_loadu_ps(ptr2);
350 __m256 _row3 = _mm256_loadu_ps(ptr3);
351 __m256 _row4 = _mm256_loadu_ps(ptr4);
352 __m256 _row5 = _mm256_loadu_ps(ptr5);
353 __m256 _row6 = _mm256_loadu_ps(ptr6);
354 __m256 _row7 = _mm256_loadu_ps(ptr7);
355
356 transpose8_ps(_row0, _row1, _row2, _row3, _row4, _row5, _row6, _row7);
357
358 _mm256_storeu_ps(outptr, _row0);
359 _mm256_storeu_ps(outptr + 8, _row1);
360 _mm256_storeu_ps(outptr + 16, _row2);
361 _mm256_storeu_ps(outptr + 24, _row3);
362 _mm256_storeu_ps(outptr + 32, _row4);
363 _mm256_storeu_ps(outptr + 40, _row5);
364 _mm256_storeu_ps(outptr + 48, _row6);
365 _mm256_storeu_ps(outptr + 56, _row7);
366
367 ptr0 += 8;
368 ptr1 += 8;
369 ptr2 += 8;
370 ptr3 += 8;
371 ptr4 += 8;
372 ptr5 += 8;
373 ptr6 += 8;
374 ptr7 += 8;
375 outptr += 64;
376 }
377 for (; i < size; i++)
378 {
379 outptr[0] = *ptr0++;
380 outptr[1] = *ptr1++;
381 outptr[2] = *ptr2++;
382 outptr[3] = *ptr3++;
383 outptr[4] = *ptr4++;
384 outptr[5] = *ptr5++;
385 outptr[6] = *ptr6++;
386 outptr[7] = *ptr7++;
387
388 outptr += 8;
389 }
390 }
391 }
392 #endif // __AVX__
393
394 if (out_elempack == 4)
395 {
396 #pragma omp parallel for num_threads(opt.num_threads)
397 for (int q = 0; q < top_blob.c; q++)
398 {
399 const float* ptr0 = (const float*)bottom_blob_flattened + size * q * 4;
400 const float* ptr1 = (const float*)bottom_blob_flattened + size * (q * 4 + 1);
401 const float* ptr2 = (const float*)bottom_blob_flattened + size * (q * 4 + 2);
402 const float* ptr3 = (const float*)bottom_blob_flattened + size * (q * 4 + 3);
403 float* outptr = top_blob.channel(q);
404
405 int i = 0;
406 for (; i + 3 < size; i += 4)
407 {
408 __m128 _row0 = _mm_loadu_ps(ptr0);
409 __m128 _row1 = _mm_loadu_ps(ptr1);
410 __m128 _row2 = _mm_loadu_ps(ptr2);
411 __m128 _row3 = _mm_loadu_ps(ptr3);
412
413 _MM_TRANSPOSE4_PS(_row0, _row1, _row2, _row3);
414
415 _mm_storeu_ps(outptr, _row0);
416 _mm_storeu_ps(outptr + 4, _row1);
417 _mm_storeu_ps(outptr + 8, _row2);
418 _mm_storeu_ps(outptr + 12, _row3);
419
420 ptr0 += 4;
421 ptr1 += 4;
422 ptr2 += 4;
423 ptr3 += 4;
424 outptr += 16;
425 }
426 for (; i < size; i++)
427 {
428 outptr[0] = *ptr0++;
429 outptr[1] = *ptr1++;
430 outptr[2] = *ptr2++;
431 outptr[3] = *ptr3++;
432
433 outptr += 4;
434 }
435 }
436 }
437 #endif // __SSE2__
438
439 if (out_elempack == 1)
440 {
441 #pragma omp parallel for num_threads(opt.num_threads)
442 for (int q = 0; q < top_blob.c; q++)
443 {
444 const float* ptr = (const float*)bottom_blob_flattened + size * q;
445 float* outptr = top_blob.channel(q);
446
447 int i = 0;
448 #if __SSE2__
449 #if __AVX__
450 for (; i + 7 < size; i += 8)
451 {
452 __m256 _v = _mm256_loadu_ps(ptr);
453 _mm256_storeu_ps(outptr, _v);
454 ptr += 8;
455 outptr += 8;
456 }
457 #endif
458 for (; i + 3 < size; i += 4)
459 {
460 __m128 _v = _mm_loadu_ps(ptr);
461 _mm_storeu_ps(outptr, _v);
462 ptr += 4;
463 outptr += 4;
464 }
465 #endif // __SSE2__
466 for (; i < size; i++)
467 {
468 *outptr++ = *ptr++;
469 }
470 }
471 }
472 }
473
474 return 0;
475 }
476
477 } // namespace ncnn
478