1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "reshape_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 namespace ncnn {
22 
Reshape_arm()23 Reshape_arm::Reshape_arm()
24 {
25 #if __ARM_NEON
26     support_packing = true;
27 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
28     support_fp16_storage = true;
29 #endif
30 #endif // __ARM_NEON
31 
32 #if NCNN_BF16
33     support_bf16_storage = true;
34 #endif
35 }
36 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const37 int Reshape_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
38 {
39     int elembits = bottom_blob.elembits();
40 
41 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
42     if (opt.use_fp16_storage && elembits == 16)
43         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
44 #endif
45 
46 #if NCNN_BF16
47     if (opt.use_bf16_storage && elembits == 16)
48         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
49 #endif
50 
51     int elempack = bottom_blob.elempack;
52 
53     if (permute == 1)
54     {
55         // TODO implement permute on-the-fly
56         Option opt_pack = opt;
57         opt_pack.blob_allocator = opt.workspace_allocator;
58 
59         Mat bottom_blob_unpacked;
60         convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
61 
62         Mat top_blob_unpacked;
63         int ret = Reshape::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
64         if (ret != 0)
65             return ret;
66 
67         int out_elempack = 1;
68         if (opt.use_packing_layout)
69         {
70             // resolve dst_elempack
71             int dims = top_blob_unpacked.dims;
72             if (dims == 1) out_elempack = top_blob_unpacked.w % 4 == 0 ? 4 : 1;
73             if (dims == 2) out_elempack = top_blob_unpacked.h % 4 == 0 ? 4 : 1;
74             if (dims == 3 || dims == 4) out_elempack = top_blob_unpacked.c % 4 == 0 ? 4 : 1;
75         }
76         convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
77 
78         return 0;
79     }
80 
81     if (ndim == 1)
82     {
83         // flatten
84         flatten(bottom_blob, top_blob, opt);
85         if (top_blob.empty())
86             return -100;
87 
88         return 0;
89     }
90 
91     int dims = bottom_blob.dims;
92     size_t elemsize = bottom_blob.elemsize;
93 
94     int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * elempack;
95 
96     if (ndim == 2)
97     {
98         int _w = w;
99         int _h = h;
100 
101         if (_w == 0)
102             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
103         if (_h == 0)
104             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
105 
106         if (_w == -1)
107             _w = total / _h;
108         if (_h == -1)
109             _h = total / _w;
110 
111         int out_elempack = opt.use_packing_layout && _h % 4 == 0 ? 4 : 1;
112         size_t out_elemsize = elemsize / elempack * out_elempack;
113 
114         if (dims == 2 && bottom_blob.h * elempack == _h && elempack == out_elempack)
115         {
116             top_blob = bottom_blob;
117             return 0;
118         }
119 
120         if (out_elempack == 1)
121         {
122             // flatten
123             flatten(bottom_blob, top_blob, opt);
124             if (top_blob.empty())
125                 return -100;
126 
127             top_blob.dims = 2;
128             top_blob.w = _w;
129             top_blob.h = _h;
130             top_blob.cstep = _w * _h;
131             top_blob.elemsize = out_elemsize;
132             top_blob.elempack = out_elempack;
133 
134             return 0;
135         }
136 
137         // flatten
138         Mat bottom_blob_flattened = bottom_blob;
139         {
140             Option opt_flatten = opt;
141             opt_flatten.blob_allocator = opt.workspace_allocator;
142 
143             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
144             if (bottom_blob_flattened.empty())
145                 return -100;
146         }
147 
148         top_blob.create(_w, _h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
149         if (top_blob.empty())
150             return -100;
151 
152         int outw = top_blob.w;
153         int outh = top_blob.h;
154 
155         // assert out_elempack == 4
156 
157         #pragma omp parallel for num_threads(opt.num_threads)
158         for (int i = 0; i < outh; i++)
159         {
160             const float* ptr0 = (const float*)bottom_blob_flattened + outw * i * 4;
161             const float* ptr1 = (const float*)bottom_blob_flattened + outw * (i * 4 + 1);
162             const float* ptr2 = (const float*)bottom_blob_flattened + outw * (i * 4 + 2);
163             const float* ptr3 = (const float*)bottom_blob_flattened + outw * (i * 4 + 3);
164             float* outptr = (float*)top_blob.row(i);
165 
166             int j = 0;
167 #if __ARM_NEON
168             for (; j + 3 < outw; j += 4)
169             {
170                 float32x4x4_t _v4;
171                 _v4.val[0] = vld1q_f32(ptr0);
172                 _v4.val[1] = vld1q_f32(ptr1);
173                 _v4.val[2] = vld1q_f32(ptr2);
174                 _v4.val[3] = vld1q_f32(ptr3);
175 
176                 vst4q_f32(outptr, _v4);
177 
178                 ptr0 += 4;
179                 ptr1 += 4;
180                 ptr2 += 4;
181                 ptr3 += 4;
182                 outptr += 16;
183             }
184 #endif
185             for (; j < outw; j++)
186             {
187                 outptr[0] = *ptr0++;
188                 outptr[1] = *ptr1++;
189                 outptr[2] = *ptr2++;
190                 outptr[3] = *ptr3++;
191 
192                 outptr += 4;
193             }
194         }
195     }
196 
197     if (ndim == 3 || ndim == 4)
198     {
199         int _w = w;
200         int _h = h;
201         int _d = d;
202         int _c = c;
203 
204         if (ndim == 3)
205         {
206             if (_w == 0)
207                 _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
208             if (_h == 0)
209                 _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
210             if (_c == 0)
211                 _c = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c;
212 
213             if (_w == -1)
214                 _w = total / _c / _h;
215             if (_h == -1)
216                 _h = total / _c / _w;
217             if (_c == -1)
218                 _c = total / _h / _w;
219         }
220         else // if (ndim == 4)
221         {
222             if (_w == 0)
223                 _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
224             if (_h == 0)
225                 _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
226             if (_d == 0)
227                 _d = bottom_blob.d;
228             if (_c == 0)
229                 _c = (dims == 3 || dims == 4) ? bottom_blob.c * elempack : bottom_blob.c;
230 
231             if (_w == -1)
232                 _w = total / _c / _d / _h;
233             if (_h == -1)
234                 _h = total / _c / _d / _w;
235             if (_d == -1)
236                 _d = total / _c / _h / _w;
237             if (_c == -1)
238                 _c = total / _d / _h / _w;
239         }
240 
241         int out_elempack = opt.use_packing_layout && _c % 4 == 0 ? 4 : 1;
242         size_t out_elemsize = elemsize / elempack * out_elempack;
243 
244         if (dims == 3 && bottom_blob.c * elempack == _c && elempack == out_elempack)
245         {
246             top_blob = bottom_blob;
247             top_blob.w = _w;
248             top_blob.h = _h;
249             return 0;
250         }
251         if (dims == 4 && bottom_blob.c * elempack == _c && elempack == out_elempack)
252         {
253             top_blob = bottom_blob;
254             top_blob.w = _w;
255             top_blob.h = _h;
256             top_blob.d = _d;
257             return 0;
258         }
259 
260         // flatten
261         Mat bottom_blob_flattened = bottom_blob;
262         {
263             Option opt_flatten = opt;
264             opt_flatten.blob_allocator = opt.workspace_allocator;
265 
266             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
267             if (bottom_blob_flattened.empty())
268                 return -100;
269         }
270 
271         if (ndim == 3)
272         {
273             top_blob.create(_w, _h, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
274         }
275         else // if (ndim == 4)
276         {
277             top_blob.create(_w, _h, _d, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
278         }
279         if (top_blob.empty())
280             return -100;
281 
282         int size = top_blob.w * top_blob.h * top_blob.d;
283 
284         if (out_elempack == 4)
285         {
286             #pragma omp parallel for num_threads(opt.num_threads)
287             for (int q = 0; q < top_blob.c; q++)
288             {
289                 const float* ptr0 = (const float*)bottom_blob_flattened + size * q * 4;
290                 const float* ptr1 = (const float*)bottom_blob_flattened + size * (q * 4 + 1);
291                 const float* ptr2 = (const float*)bottom_blob_flattened + size * (q * 4 + 2);
292                 const float* ptr3 = (const float*)bottom_blob_flattened + size * (q * 4 + 3);
293                 float* outptr = top_blob.channel(q);
294 
295                 int i = 0;
296 #if __ARM_NEON
297                 for (; i + 3 < size; i += 4)
298                 {
299                     float32x4x4_t _v4;
300                     _v4.val[0] = vld1q_f32(ptr0);
301                     _v4.val[1] = vld1q_f32(ptr1);
302                     _v4.val[2] = vld1q_f32(ptr2);
303                     _v4.val[3] = vld1q_f32(ptr3);
304 
305                     vst4q_f32(outptr, _v4);
306 
307                     ptr0 += 4;
308                     ptr1 += 4;
309                     ptr2 += 4;
310                     ptr3 += 4;
311                     outptr += 16;
312                 }
313 #endif
314                 for (; i < size; i++)
315                 {
316                     outptr[0] = *ptr0++;
317                     outptr[1] = *ptr1++;
318                     outptr[2] = *ptr2++;
319                     outptr[3] = *ptr3++;
320 
321                     outptr += 4;
322                 }
323             }
324         }
325 
326         if (out_elempack == 1)
327         {
328             #pragma omp parallel for num_threads(opt.num_threads)
329             for (int q = 0; q < top_blob.c; q++)
330             {
331                 const float* ptr = (const float*)bottom_blob_flattened + size * q;
332                 float* outptr = top_blob.channel(q);
333 
334                 int i = 0;
335 #if __ARM_NEON
336                 for (; i + 3 < size; i += 4)
337                 {
338                     float32x4_t _v = vld1q_f32(ptr);
339                     vst1q_f32(outptr, _v);
340                     ptr += 4;
341                     outptr += 4;
342                 }
343 #endif
344                 for (; i < size; i++)
345                 {
346                     *outptr++ = *ptr++;
347                 }
348             }
349         }
350     }
351 
352     return 0;
353 }
354 
forward_bf16s_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const355 int Reshape_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
356 {
357     int elempack = bottom_blob.elempack;
358 
359     if (permute == 1)
360     {
361         // TODO implement permute on-the-fly
362         Option opt_pack = opt;
363         opt_pack.blob_allocator = opt.workspace_allocator;
364 
365         Mat bottom_blob_unpacked;
366         convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
367 
368         Mat bottom_blob_unpacked_fp32;
369         cast_bfloat16_to_float32(bottom_blob_unpacked, bottom_blob_unpacked_fp32, opt_pack);
370 
371         Mat top_blob_unpacked_fp32;
372         int ret = Reshape::forward(bottom_blob_unpacked_fp32, top_blob_unpacked_fp32, opt_pack);
373         if (ret != 0)
374             return ret;
375 
376         Mat top_blob_unpacked;
377         cast_float32_to_bfloat16(top_blob_unpacked_fp32, top_blob_unpacked, opt_pack);
378 
379         int out_elempack = 1;
380         if (opt.use_packing_layout)
381         {
382             // resolve dst_elempack
383             int dims = top_blob_unpacked.dims;
384             if (dims == 1) out_elempack = opt.use_fp16_arithmetic && top_blob_unpacked.w % 8 == 0 ? 8 : top_blob_unpacked.w % 4 == 0 ? 4 : 1;
385             if (dims == 2) out_elempack = opt.use_fp16_arithmetic && top_blob_unpacked.h % 8 == 0 ? 8 : top_blob_unpacked.h % 4 == 0 ? 4 : 1;
386             if (dims == 3 || dims == 4) out_elempack = opt.use_fp16_arithmetic && top_blob_unpacked.c % 8 == 0 ? 8 : top_blob_unpacked.c % 4 == 0 ? 4 : 1;
387         }
388         convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
389 
390         return 0;
391     }
392 
393     if (ndim == 1)
394     {
395         // flatten
396         flatten(bottom_blob, top_blob, opt);
397         if (top_blob.empty())
398             return -100;
399 
400         return 0;
401     }
402 
403     int dims = bottom_blob.dims;
404     size_t elemsize = bottom_blob.elemsize;
405 
406     int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c * elempack;
407 
408     if (ndim == 2)
409     {
410         int _w = w;
411         int _h = h;
412 
413         if (_w == 0)
414             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
415         if (_h == 0)
416             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
417 
418         if (_w == -1)
419             _w = total / _h;
420         if (_h == -1)
421             _h = total / _w;
422 
423         int out_elempack = 1;
424         if (opt.use_packing_layout)
425         {
426             out_elempack = opt.use_fp16_arithmetic && _h % 8 == 0 ? 8 : _h % 4 == 0 ? 4 : 1;
427         }
428         size_t out_elemsize = elemsize / elempack * out_elempack;
429 
430         if (dims == 2 && bottom_blob.h * elempack == _h && elempack == out_elempack)
431         {
432             top_blob = bottom_blob;
433             return 0;
434         }
435 
436         if (out_elempack == 1)
437         {
438             // flatten
439             flatten(bottom_blob, top_blob, opt);
440             if (top_blob.empty())
441                 return -100;
442 
443             top_blob.dims = 2;
444             top_blob.w = _w;
445             top_blob.h = _h;
446             top_blob.cstep = _w * _h;
447             top_blob.elemsize = out_elemsize;
448             top_blob.elempack = out_elempack;
449 
450             return 0;
451         }
452 
453         // flatten
454         Mat bottom_blob_flattened = bottom_blob;
455         {
456             Option opt_flatten = opt;
457             opt_flatten.blob_allocator = opt.workspace_allocator;
458 
459             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
460             if (bottom_blob_flattened.empty())
461                 return -100;
462         }
463 
464         top_blob.create(_w, _h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
465         if (top_blob.empty())
466             return -100;
467 
468         int outw = top_blob.w;
469         int outh = top_blob.h;
470 
471 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
472         if (out_elempack == 8)
473         {
474             #pragma omp parallel for num_threads(opt.num_threads)
475             for (int i = 0; i < outh; i++)
476             {
477                 const __fp16* ptr0 = (const __fp16*)bottom_blob_flattened + outw * i * 8;
478                 const __fp16* ptr1 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 1);
479                 const __fp16* ptr2 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 2);
480                 const __fp16* ptr3 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 3);
481                 const __fp16* ptr4 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 4);
482                 const __fp16* ptr5 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 5);
483                 const __fp16* ptr6 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 6);
484                 const __fp16* ptr7 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 7);
485                 __fp16* outptr = top_blob.row<__fp16>(i);
486 
487                 int j = 0;
488                 for (; j + 3 < outw; j += 4)
489                 {
490                     float16x8_t _p01 = vcombine_f16(vld1_f16(ptr0), vld1_f16(ptr1));
491                     float16x8_t _p23 = vcombine_f16(vld1_f16(ptr2), vld1_f16(ptr3));
492                     float16x8_t _p45 = vcombine_f16(vld1_f16(ptr4), vld1_f16(ptr5));
493                     float16x8_t _p67 = vcombine_f16(vld1_f16(ptr6), vld1_f16(ptr7));
494 
495                     float16x8x2_t _p0415 = vzipq_f16(_p01, _p45);
496                     float16x8x2_t _p2637 = vzipq_f16(_p23, _p67);
497 
498                     float16x8x4_t _v4;
499                     _v4.val[0] = _p0415.val[0];
500                     _v4.val[1] = _p0415.val[1];
501                     _v4.val[2] = _p2637.val[0];
502                     _v4.val[3] = _p2637.val[1];
503 
504                     vst4q_f16(outptr, _v4);
505 
506                     ptr0 += 4;
507                     ptr1 += 4;
508                     ptr2 += 4;
509                     ptr3 += 4;
510                     ptr4 += 4;
511                     ptr5 += 4;
512                     ptr6 += 4;
513                     ptr7 += 4;
514                     outptr += 32;
515                 }
516                 for (; j < outw; j++)
517                 {
518                     outptr[0] = *ptr0++;
519                     outptr[1] = *ptr1++;
520                     outptr[2] = *ptr2++;
521                     outptr[3] = *ptr3++;
522                     outptr[4] = *ptr4++;
523                     outptr[5] = *ptr5++;
524                     outptr[6] = *ptr6++;
525                     outptr[7] = *ptr7++;
526 
527                     outptr += 8;
528                 }
529             }
530         }
531 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
532 
533         if (out_elempack == 4)
534         {
535             #pragma omp parallel for num_threads(opt.num_threads)
536             for (int i = 0; i < outh; i++)
537             {
538                 const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + outw * i * 4;
539                 const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 1);
540                 const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 2);
541                 const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 3);
542                 unsigned short* outptr = top_blob.row<unsigned short>(i);
543 
544                 int j = 0;
545 #if __ARM_NEON
546                 for (; j + 3 < outw; j += 4)
547                 {
548                     uint16x4x4_t _v4;
549                     _v4.val[0] = vld1_u16(ptr0);
550                     _v4.val[1] = vld1_u16(ptr1);
551                     _v4.val[2] = vld1_u16(ptr2);
552                     _v4.val[3] = vld1_u16(ptr3);
553 
554                     vst4_u16(outptr, _v4);
555 
556                     ptr0 += 4;
557                     ptr1 += 4;
558                     ptr2 += 4;
559                     ptr3 += 4;
560                     outptr += 16;
561                 }
562 #endif
563                 for (; j < outw; j++)
564                 {
565                     outptr[0] = *ptr0++;
566                     outptr[1] = *ptr1++;
567                     outptr[2] = *ptr2++;
568                     outptr[3] = *ptr3++;
569 
570                     outptr += 4;
571                 }
572             }
573         }
574     }
575 
576     if (ndim == 3 || ndim == 4)
577     {
578         int _w = w;
579         int _h = h;
580         int _d = d;
581         int _c = c;
582 
583         if (ndim == 3)
584         {
585             if (_w == 0)
586                 _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
587             if (_h == 0)
588                 _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
589             if (_c == 0)
590                 _c = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c;
591 
592             if (_w == -1)
593                 _w = total / _c / _h;
594             if (_h == -1)
595                 _h = total / _c / _w;
596             if (_c == -1)
597                 _c = total / _h / _w;
598         }
599         else // if (ndim == 4)
600         {
601             if (_w == 0)
602                 _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
603             if (_h == 0)
604                 _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
605             if (_d == 0)
606                 _d = bottom_blob.d;
607             if (_c == 0)
608                 _c = (dims == 3 || dims == 4) ? bottom_blob.c * elempack : bottom_blob.c;
609 
610             if (_w == -1)
611                 _w = total / _c / _d / _h;
612             if (_h == -1)
613                 _h = total / _c / _d / _w;
614             if (_d == -1)
615                 _d = total / _c / _h / _w;
616             if (_c == -1)
617                 _c = total / _d / _h / _w;
618         }
619 
620         int out_elempack = 1;
621         if (opt.use_packing_layout)
622         {
623             out_elempack = opt.use_fp16_arithmetic && _c % 8 == 0 ? 8 : _c % 4 == 0 ? 4 : 1;
624         }
625         size_t out_elemsize = elemsize / elempack * out_elempack;
626 
627         if (dims == 3 && bottom_blob.c * elempack == _c && elempack == out_elempack)
628         {
629             top_blob = bottom_blob;
630             top_blob.w = _w;
631             top_blob.h = _h;
632             return 0;
633         }
634         if (dims == 4 && bottom_blob.c * elempack == _c && elempack == out_elempack)
635         {
636             top_blob = bottom_blob;
637             top_blob.w = _w;
638             top_blob.h = _h;
639             top_blob.d = _d;
640             return 0;
641         }
642 
643         // flatten
644         Mat bottom_blob_flattened = bottom_blob;
645         {
646             Option opt_flatten = opt;
647             opt_flatten.blob_allocator = opt.workspace_allocator;
648 
649             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
650             if (bottom_blob_flattened.empty())
651                 return -100;
652         }
653 
654         if (ndim == 3)
655         {
656             top_blob.create(_w, _h, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
657         }
658         else // if (ndim == 4)
659         {
660             top_blob.create(_w, _h, _d, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
661         }
662         if (top_blob.empty())
663             return -100;
664 
665         int size = top_blob.w * top_blob.h * top_blob.d;
666 
667 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
668         if (out_elempack == 8)
669         {
670             #pragma omp parallel for num_threads(opt.num_threads)
671             for (int q = 0; q < top_blob.c; q++)
672             {
673                 const __fp16* ptr0 = (const __fp16*)bottom_blob_flattened + size * q * 8;
674                 const __fp16* ptr1 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 1);
675                 const __fp16* ptr2 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 2);
676                 const __fp16* ptr3 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 3);
677                 const __fp16* ptr4 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 4);
678                 const __fp16* ptr5 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 5);
679                 const __fp16* ptr6 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 6);
680                 const __fp16* ptr7 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 7);
681                 __fp16* outptr = top_blob.channel(q);
682 
683                 int i = 0;
684                 for (; i + 3 < size; i += 4)
685                 {
686                     float16x8_t _p01 = vcombine_f16(vld1_f16(ptr0), vld1_f16(ptr1));
687                     float16x8_t _p23 = vcombine_f16(vld1_f16(ptr2), vld1_f16(ptr3));
688                     float16x8_t _p45 = vcombine_f16(vld1_f16(ptr4), vld1_f16(ptr5));
689                     float16x8_t _p67 = vcombine_f16(vld1_f16(ptr6), vld1_f16(ptr7));
690 
691                     float16x8x2_t _p0415 = vzipq_f16(_p01, _p45);
692                     float16x8x2_t _p2637 = vzipq_f16(_p23, _p67);
693 
694                     float16x8x4_t _v4;
695                     _v4.val[0] = _p0415.val[0];
696                     _v4.val[1] = _p0415.val[1];
697                     _v4.val[2] = _p2637.val[0];
698                     _v4.val[3] = _p2637.val[1];
699 
700                     vst4q_f16(outptr, _v4);
701 
702                     ptr0 += 4;
703                     ptr1 += 4;
704                     ptr2 += 4;
705                     ptr3 += 4;
706                     ptr4 += 4;
707                     ptr5 += 4;
708                     ptr6 += 4;
709                     ptr7 += 4;
710                     outptr += 32;
711                 }
712                 for (; i < size; i++)
713                 {
714                     outptr[0] = *ptr0++;
715                     outptr[1] = *ptr1++;
716                     outptr[2] = *ptr2++;
717                     outptr[3] = *ptr3++;
718                     outptr[4] = *ptr4++;
719                     outptr[5] = *ptr5++;
720                     outptr[6] = *ptr6++;
721                     outptr[7] = *ptr7++;
722 
723                     outptr += 8;
724                 }
725             }
726         }
727 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
728 
729         if (out_elempack == 4)
730         {
731             #pragma omp parallel for num_threads(opt.num_threads)
732             for (int q = 0; q < top_blob.c; q++)
733             {
734                 const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + size * q * 4;
735                 const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 1);
736                 const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 2);
737                 const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 3);
738                 unsigned short* outptr = top_blob.channel(q);
739 
740                 int i = 0;
741 #if __ARM_NEON
742                 for (; i + 3 < size; i += 4)
743                 {
744                     uint16x4x4_t _v4;
745                     _v4.val[0] = vld1_u16(ptr0);
746                     _v4.val[1] = vld1_u16(ptr1);
747                     _v4.val[2] = vld1_u16(ptr2);
748                     _v4.val[3] = vld1_u16(ptr3);
749 
750                     vst4_u16(outptr, _v4);
751 
752                     ptr0 += 4;
753                     ptr1 += 4;
754                     ptr2 += 4;
755                     ptr3 += 4;
756                     outptr += 16;
757                 }
758 #endif
759                 for (; i < size; i++)
760                 {
761                     outptr[0] = *ptr0++;
762                     outptr[1] = *ptr1++;
763                     outptr[2] = *ptr2++;
764                     outptr[3] = *ptr3++;
765 
766                     outptr += 4;
767                 }
768             }
769         }
770 
771         if (out_elempack == 1)
772         {
773             #pragma omp parallel for num_threads(opt.num_threads)
774             for (int q = 0; q < top_blob.c; q++)
775             {
776                 const unsigned short* ptr = (const unsigned short*)bottom_blob_flattened + size * q;
777                 unsigned short* outptr = top_blob.channel(q);
778 
779                 int i = 0;
780 #if __ARM_NEON
781                 for (; i + 3 < size; i += 4)
782                 {
783                     uint16x4_t _v = vld1_u16(ptr);
784                     vst1_u16(outptr, _v);
785                     ptr += 4;
786                     outptr += 4;
787                 }
788 #endif
789                 for (; i < size; i++)
790                 {
791                     *outptr++ = *ptr++;
792                 }
793             }
794         }
795     }
796 
797     return 0;
798 }
799 
800 } // namespace ncnn
801