1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "shufflechannel_arm.h"
16 
17 #include "layer_type.h"
18 
19 #if __ARM_NEON
20 #include <arm_neon.h>
21 #endif // __ARM_NEON
22 
23 namespace ncnn {
24 
ShuffleChannel_arm()25 ShuffleChannel_arm::ShuffleChannel_arm()
26 {
27 #if __ARM_NEON
28     support_packing = true;
29 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
30     support_fp16_storage = true;
31 #endif
32 #endif // __ARM_NEON
33 
34 #if NCNN_BF16
35     support_bf16_storage = true;
36 #endif
37 }
38 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const39 int ShuffleChannel_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
40 {
41     int elembits = bottom_blob.elembits();
42 
43 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
44     if (opt.use_fp16_storage && elembits == 16)
45         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
46 #endif
47 
48 #if NCNN_BF16
49     if (opt.use_bf16_storage && elembits == 16)
50         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
51 #endif
52 
53     int channels = bottom_blob.c;
54     int elempack = bottom_blob.elempack;
55 
56     int _group = reverse ? channels * elempack / group : group;
57 
58     if (_group == 1)
59     {
60         top_blob = bottom_blob;
61         return 0;
62     }
63 
64 #if __ARM_NEON
65     if (elempack == 4)
66     {
67         if (_group == 2 && channels % _group != 0)
68         {
69             int w = bottom_blob.w;
70             int h = bottom_blob.h;
71             int size = w * h;
72             size_t elemsize = bottom_blob.elemsize;
73 
74             top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
75             if (top_blob.empty())
76                 return -100;
77 
78             int channels_per_group = channels / _group;
79 
80             // TODO unroll me
81             for (int q = 0; q < channels_per_group; q++)
82             {
83                 const float* ptr0 = bottom_blob.channel(q);
84                 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
85                 const float* ptr2 = bottom_blob.channel(channels_per_group + q + 1);
86                 float* outptr0 = top_blob.channel(q * 2);
87                 float* outptr1 = top_blob.channel(q * 2 + 1);
88 
89                 for (int i = 0; i < size; i++)
90                 {
91                     float32x4_t _p0 = vld1q_f32(ptr0);
92                     float32x4_t _p1 = vld1q_f32(ptr1);
93                     float32x4_t _p2 = vld1q_f32(ptr2);
94 
95                     float32x4_t _p12 = vextq_f32(_p1, _p2, 2);
96 
97                     float32x4x2_t _p01 = vzipq_f32(_p0, _p12);
98 
99                     vst1q_f32(outptr0, _p01.val[0]);
100                     vst1q_f32(outptr1, _p01.val[1]);
101 
102                     ptr0 += 4;
103                     ptr1 += 4;
104                     ptr2 += 4;
105                     outptr0 += 4;
106                     outptr1 += 4;
107                 }
108             }
109 
110             // handle the last channel
111             {
112                 const float* ptr0 = bottom_blob.channel(channels_per_group);
113                 const float* ptr1 = bottom_blob.channel(channels_per_group + channels_per_group);
114                 float* outptr0 = top_blob.channel(channels_per_group * 2);
115 
116                 ptr1 += 2;
117 
118                 for (int i = 0; i < size; i++)
119                 {
120                     float32x4_t _p0 = vld1q_f32(ptr0);
121                     float32x4_t _p1 = vld1q_f32(ptr1);
122 
123                     float32x4x2_t _p01 = vzipq_f32(_p0, _p1);
124 
125                     vst1q_f32(outptr0, _p01.val[0]);
126 
127                     ptr0 += 4;
128                     ptr1 += 4;
129                     outptr0 += 4;
130                 }
131             }
132 
133             return 0;
134         }
135 
136         if (_group > 4 || channels % _group != 0)
137         {
138             // slow path for too large group or shuffle inside elempack
139             Option opt_pack = opt;
140             opt_pack.blob_allocator = opt.workspace_allocator;
141 
142             Mat bottom_blob_unpacked;
143             convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
144 
145             Mat top_blob_unpacked;
146             int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
147             if (ret != 0)
148                 return ret;
149 
150             convert_packing(top_blob_unpacked, top_blob, elempack, opt);
151 
152             return 0;
153         }
154 
155         int w = bottom_blob.w;
156         int h = bottom_blob.h;
157         int size = w * h;
158         size_t elemsize = bottom_blob.elemsize;
159 
160         top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
161         if (top_blob.empty())
162             return -100;
163 
164         int channels_per_group = channels / _group;
165 
166         if (_group == 2)
167         {
168             for (int q = 0; q < channels_per_group; q++)
169             {
170                 const float* ptr0 = bottom_blob.channel(q);
171                 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
172                 float* outptr0 = top_blob.channel(q * 2);
173                 float* outptr1 = top_blob.channel(q * 2 + 1);
174 
175                 for (int i = 0; i < size; i++)
176                 {
177                     float32x4_t _p0 = vld1q_f32(ptr0);
178                     float32x4_t _p1 = vld1q_f32(ptr1);
179 
180                     float32x4x2_t _p01 = vzipq_f32(_p0, _p1);
181 
182                     vst1q_f32(outptr0, _p01.val[0]);
183                     vst1q_f32(outptr1, _p01.val[1]);
184 
185                     ptr0 += 4;
186                     ptr1 += 4;
187                     outptr0 += 4;
188                     outptr1 += 4;
189                 }
190             }
191         }
192 
193         if (_group == 3)
194         {
195             for (int q = 0; q < channels_per_group; q++)
196             {
197                 const float* ptr0 = bottom_blob.channel(q);
198                 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
199                 const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
200                 float* outptr0 = top_blob.channel(q * 3);
201                 float* outptr1 = top_blob.channel(q * 3 + 1);
202                 float* outptr2 = top_blob.channel(q * 3 + 2);
203 
204                 for (int i = 0; i < size; i++)
205                 {
206                     float32x4_t _p0 = vld1q_f32(ptr0);
207                     float32x4_t _p1 = vld1q_f32(ptr1);
208                     float32x4_t _p2 = vld1q_f32(ptr2);
209 
210                     float32x4x2_t _p01 = vzipq_f32(_p0, _p1);
211                     float32x4x2_t _p12 = vzipq_f32(_p1, _p2);
212 
213                     float32x4_t _0415 = _p01.val[0];
214                     float32x4_t _2637 = _p01.val[1];
215                     float32x4_t _4859 = _p12.val[0];
216                     float32x4_t _6x7y = _p12.val[1];
217 
218                     float32x2_t _15 = vget_high_f32(_0415);
219                     float32x2_t _37 = vget_high_f32(_2637);
220                     float32x2_t _48 = vget_low_f32(_4859);
221                     float32x2_t _6x = vget_low_f32(_6x7y);
222 
223                     float32x2_t _81 = vext_f32(_48, _15, 1);
224                     float32x2_t _x3 = vext_f32(_6x, _37, 1);
225 
226                     float32x4_t _0481 = vcombine_f32(vget_low_f32(_0415), _81);
227                     float32x4_t _5926 = vextq_f32(_4859, _2637, 2);
228                     float32x4_t _x37y = vcombine_f32(_x3, vget_high_f32(_6x7y));
229 
230                     vst1q_f32(outptr0, _0481);
231                     vst1q_f32(outptr1, _5926);
232                     vst1q_f32(outptr2, _x37y);
233 
234                     ptr0 += 4;
235                     ptr1 += 4;
236                     ptr2 += 4;
237                     outptr0 += 4;
238                     outptr1 += 4;
239                     outptr2 += 4;
240                 }
241             }
242         }
243 
244         if (_group == 4)
245         {
246             for (int q = 0; q < channels_per_group; q++)
247             {
248                 const float* ptr0 = bottom_blob.channel(q);
249                 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
250                 const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
251                 const float* ptr3 = bottom_blob.channel(channels_per_group * 3 + q);
252                 float* outptr0 = top_blob.channel(q * 4);
253                 float* outptr1 = top_blob.channel(q * 4 + 1);
254                 float* outptr2 = top_blob.channel(q * 4 + 2);
255                 float* outptr3 = top_blob.channel(q * 4 + 3);
256 
257                 for (int i = 0; i < size; i++)
258                 {
259                     float32x4_t _p0 = vld1q_f32(ptr0);
260                     float32x4_t _p1 = vld1q_f32(ptr1);
261                     float32x4_t _p2 = vld1q_f32(ptr2);
262                     float32x4_t _p3 = vld1q_f32(ptr3);
263 
264                     // transpose 4x4
265                     float32x4x2_t _p01 = vtrnq_f32(_p0, _p1);
266                     float32x4x2_t _p23 = vtrnq_f32(_p2, _p3);
267                     _p0 = vcombine_f32(vget_low_f32(_p01.val[0]), vget_low_f32(_p23.val[0]));
268                     _p1 = vcombine_f32(vget_low_f32(_p01.val[1]), vget_low_f32(_p23.val[1]));
269                     _p2 = vcombine_f32(vget_high_f32(_p01.val[0]), vget_high_f32(_p23.val[0]));
270                     _p3 = vcombine_f32(vget_high_f32(_p01.val[1]), vget_high_f32(_p23.val[1]));
271 
272                     vst1q_f32(outptr0, _p0);
273                     vst1q_f32(outptr1, _p1);
274                     vst1q_f32(outptr2, _p2);
275                     vst1q_f32(outptr3, _p3);
276 
277                     ptr0 += 4;
278                     ptr1 += 4;
279                     ptr2 += 4;
280                     ptr3 += 4;
281                     outptr0 += 4;
282                     outptr1 += 4;
283                     outptr2 += 4;
284                     outptr3 += 4;
285                 }
286             }
287         }
288 
289         return 0;
290     }
291 #endif // __ARM_NEON
292 
293     return ShuffleChannel::forward(bottom_blob, top_blob, opt);
294 }
295 
forward_bf16s_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const296 int ShuffleChannel_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
297 {
298     int channels = bottom_blob.c;
299     int elempack = bottom_blob.elempack;
300 
301     int _group = reverse ? channels * elempack / group : group;
302 
303     if (_group == 1)
304     {
305         top_blob = bottom_blob;
306         return 0;
307     }
308 
309 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
310     if (elempack == 8)
311     {
312         if (_group == 2 && channels % _group != 0)
313         {
314             int w = bottom_blob.w;
315             int h = bottom_blob.h;
316             int size = w * h;
317             size_t elemsize = bottom_blob.elemsize;
318 
319             top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
320             if (top_blob.empty())
321                 return -100;
322 
323             int channels_per_group = channels / _group;
324 
325             // TODO unroll me
326             for (int q = 0; q < channels_per_group; q++)
327             {
328                 const __fp16* ptr0 = bottom_blob.channel(q);
329                 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
330                 const __fp16* ptr2 = bottom_blob.channel(channels_per_group + q + 1);
331                 __fp16* outptr0 = top_blob.channel(q * 2);
332                 __fp16* outptr1 = top_blob.channel(q * 2 + 1);
333 
334                 for (int i = 0; i < size; i++)
335                 {
336                     float16x8_t _p0 = vld1q_f16(ptr0);
337                     float16x8_t _p1 = vld1q_f16(ptr1);
338                     float16x8_t _p2 = vld1q_f16(ptr2);
339 
340                     float16x8_t _p12 = vextq_f16(_p1, _p2, 4);
341 
342                     float16x8x2_t _p01 = vzipq_f16(_p0, _p12);
343 
344                     vst1q_f16(outptr0, _p01.val[0]);
345                     vst1q_f16(outptr1, _p01.val[1]);
346 
347                     ptr0 += 8;
348                     ptr1 += 8;
349                     ptr2 += 8;
350                     outptr0 += 8;
351                     outptr1 += 8;
352                 }
353             }
354 
355             // handle the last channel
356             {
357                 const __fp16* ptr0 = bottom_blob.channel(channels_per_group);
358                 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + channels_per_group);
359                 __fp16* outptr0 = top_blob.channel(channels_per_group * 2);
360 
361                 ptr1 += 4;
362 
363                 for (int i = 0; i < size; i++)
364                 {
365                     float16x4_t _p0 = vld1_f16(ptr0);
366                     float16x4_t _p1 = vld1_f16(ptr1);
367 
368                     float16x4x2_t _p01 = vzip_f16(_p0, _p1);
369 
370                     vst1_f16(outptr0, _p01.val[0]);
371                     vst1_f16(outptr0 + 4, _p01.val[1]);
372 
373                     ptr0 += 8;
374                     ptr1 += 8;
375                     outptr0 += 8;
376                 }
377             }
378 
379             return 0;
380         }
381 
382         if (_group > 4 || channels % _group != 0)
383         {
384             // slow path for too large group or shuffle inside elempack
385             Option opt_pack = opt;
386             opt_pack.blob_allocator = opt.workspace_allocator;
387 
388             Mat bottom_blob_unpacked;
389             convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
390 
391             Mat top_blob_unpacked;
392             int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
393             if (ret != 0)
394                 return ret;
395 
396             convert_packing(top_blob_unpacked, top_blob, elempack, opt);
397 
398             return 0;
399         }
400 
401         int w = bottom_blob.w;
402         int h = bottom_blob.h;
403         int size = w * h;
404         size_t elemsize = bottom_blob.elemsize;
405 
406         top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
407         if (top_blob.empty())
408             return -100;
409 
410         int channels_per_group = channels / _group;
411 
412         if (_group == 2)
413         {
414             for (int q = 0; q < channels_per_group; q++)
415             {
416                 const __fp16* ptr0 = bottom_blob.channel(q);
417                 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
418                 __fp16* outptr0 = top_blob.channel(q * 2);
419                 __fp16* outptr1 = top_blob.channel(q * 2 + 1);
420 
421                 for (int i = 0; i < size; i++)
422                 {
423                     float16x8_t _p0 = vld1q_f16(ptr0);
424                     float16x8_t _p1 = vld1q_f16(ptr1);
425 
426                     float16x8x2_t _p01 = vzipq_f16(_p0, _p1);
427 
428                     vst1q_f16(outptr0, _p01.val[0]);
429                     vst1q_f16(outptr1, _p01.val[1]);
430 
431                     ptr0 += 8;
432                     ptr1 += 8;
433                     outptr0 += 8;
434                     outptr1 += 8;
435                 }
436             }
437         }
438 
439         if (_group == 3)
440         {
441             for (int q = 0; q < channels_per_group; q++)
442             {
443                 const __fp16* ptr0 = bottom_blob.channel(q);
444                 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
445                 const __fp16* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
446                 __fp16* outptr0 = top_blob.channel(q * 3);
447                 __fp16* outptr1 = top_blob.channel(q * 3 + 1);
448                 __fp16* outptr2 = top_blob.channel(q * 3 + 2);
449 
450                 for (int i = 0; i < size; i++)
451                 {
452                     float16x8_t _p0 = vld1q_f16(ptr0);
453                     float16x8_t _p1 = vld1q_f16(ptr1);
454                     float16x8_t _p2 = vld1q_f16(ptr2);
455 
456                     // TODO figure out a faster way
457 
458                     // 01234567        08g19h2a
459                     // 89abcdef   ->   i3bj4ck5
460                     // ghijklmn        dl6em7fn
461 
462                     float16x8x3_t _p012;
463                     _p012.val[0] = _p0;
464                     _p012.val[1] = _p1;
465                     _p012.val[2] = _p2;
466 
467                     __fp16 tmp[24];
468                     vst3q_f16(&tmp[0], _p012);
469 
470                     _p0 = vld1q_f16(&tmp[0]);
471                     _p1 = vld1q_f16(&tmp[8]);
472                     _p2 = vld1q_f16(&tmp[16]);
473 
474                     vst1q_f16(outptr0, _p0);
475                     vst1q_f16(outptr1, _p1);
476                     vst1q_f16(outptr2, _p2);
477 
478                     ptr0 += 8;
479                     ptr1 += 8;
480                     ptr2 += 8;
481                     outptr0 += 8;
482                     outptr1 += 8;
483                     outptr2 += 8;
484                 }
485             }
486         }
487 
488         if (_group == 4)
489         {
490             for (int q = 0; q < channels_per_group; q++)
491             {
492                 const __fp16* ptr0 = bottom_blob.channel(q);
493                 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
494                 const __fp16* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
495                 const __fp16* ptr3 = bottom_blob.channel(channels_per_group * 3 + q);
496                 __fp16* outptr0 = top_blob.channel(q * 4);
497                 __fp16* outptr1 = top_blob.channel(q * 4 + 1);
498                 __fp16* outptr2 = top_blob.channel(q * 4 + 2);
499                 __fp16* outptr3 = top_blob.channel(q * 4 + 3);
500 
501                 for (int i = 0; i < size; i++)
502                 {
503                     float16x8_t _p0 = vld1q_f16(ptr0);
504                     float16x8_t _p1 = vld1q_f16(ptr1);
505                     float16x8_t _p2 = vld1q_f16(ptr2);
506                     float16x8_t _p3 = vld1q_f16(ptr3);
507 
508                     // transpose 4x4
509                     float16x8x2_t _p01 = vtrnq_f16(_p0, _p1);
510                     float16x8x2_t _p23 = vtrnq_f16(_p2, _p3);
511                     uint32x4x2_t _p02 = vtrnq_u32(vreinterpretq_u32_f16(_p01.val[0]), vreinterpretq_u32_f16(_p23.val[0]));
512                     uint32x4x2_t _p13 = vtrnq_u32(vreinterpretq_u32_f16(_p01.val[1]), vreinterpretq_u32_f16(_p23.val[1]));
513                     _p0 = vreinterpretq_f16_u32(_p02.val[0]);
514                     _p1 = vreinterpretq_f16_u32(_p13.val[0]);
515                     _p2 = vreinterpretq_f16_u32(_p02.val[1]);
516                     _p3 = vreinterpretq_f16_u32(_p13.val[1]);
517 
518                     vst1q_f16(outptr0, vcombine_f16(vget_low_f16(_p0), vget_low_f16(_p1)));
519                     vst1q_f16(outptr1, vcombine_f16(vget_low_f16(_p2), vget_low_f16(_p3)));
520                     vst1q_f16(outptr2, vcombine_f16(vget_high_f16(_p0), vget_high_f16(_p1)));
521                     vst1q_f16(outptr3, vcombine_f16(vget_high_f16(_p2), vget_high_f16(_p3)));
522 
523                     ptr0 += 8;
524                     ptr1 += 8;
525                     ptr2 += 8;
526                     ptr3 += 8;
527                     outptr0 += 8;
528                     outptr1 += 8;
529                     outptr2 += 8;
530                     outptr3 += 8;
531                 }
532             }
533         }
534 
535         return 0;
536     }
537 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
538 
539 #if __ARM_NEON
540     if (elempack == 4)
541     {
542         if (_group == 2 && channels % _group != 0)
543         {
544             int w = bottom_blob.w;
545             int h = bottom_blob.h;
546             int size = w * h;
547             size_t elemsize = bottom_blob.elemsize;
548 
549             top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
550             if (top_blob.empty())
551                 return -100;
552 
553             int channels_per_group = channels / _group;
554 
555             // TODO unroll me
556             for (int q = 0; q < channels_per_group; q++)
557             {
558                 const unsigned short* ptr0 = bottom_blob.channel(q);
559                 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
560                 const unsigned short* ptr2 = bottom_blob.channel(channels_per_group + q + 1);
561                 unsigned short* outptr0 = top_blob.channel(q * 2);
562                 unsigned short* outptr1 = top_blob.channel(q * 2 + 1);
563 
564                 for (int i = 0; i < size; i++)
565                 {
566                     uint16x4_t _p0 = vld1_u16(ptr0);
567                     uint16x4_t _p1 = vld1_u16(ptr1);
568                     uint16x4_t _p2 = vld1_u16(ptr2);
569 
570                     uint16x4_t _p12 = vext_u16(_p1, _p2, 2);
571 
572                     uint16x4x2_t _p01 = vzip_u16(_p0, _p12);
573 
574                     vst1_u16(outptr0, _p01.val[0]);
575                     vst1_u16(outptr1, _p01.val[1]);
576 
577                     ptr0 += 4;
578                     ptr1 += 4;
579                     ptr2 += 4;
580                     outptr0 += 4;
581                     outptr1 += 4;
582                 }
583             }
584 
585             // handle the last channel
586             {
587                 const unsigned short* ptr0 = bottom_blob.channel(channels_per_group);
588                 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + channels_per_group);
589                 unsigned short* outptr0 = top_blob.channel(channels_per_group * 2);
590 
591                 ptr1 += 2;
592 
593                 for (int i = 0; i < size; i++)
594                 {
595                     uint16x4_t _p0 = vld1_u16(ptr0);
596                     uint16x4_t _p1 = vld1_u16(ptr1);
597 
598                     uint16x4x2_t _p01 = vzip_u16(_p0, _p1);
599 
600                     vst1_u16(outptr0, _p01.val[0]);
601 
602                     ptr0 += 4;
603                     ptr1 += 4;
604                     outptr0 += 4;
605                 }
606             }
607 
608             return 0;
609         }
610 
611         if (_group > 4 || channels % _group != 0)
612         {
613             // slow path for too large group or shuffle inside elempack
614             Option opt_pack = opt;
615             opt_pack.blob_allocator = opt.workspace_allocator;
616 
617             Mat bottom_blob_unpacked;
618             convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
619 
620             Mat top_blob_unpacked;
621             int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
622             if (ret != 0)
623                 return ret;
624 
625             convert_packing(top_blob_unpacked, top_blob, elempack, opt);
626 
627             return 0;
628         }
629 
630         int w = bottom_blob.w;
631         int h = bottom_blob.h;
632         int size = w * h;
633         size_t elemsize = bottom_blob.elemsize;
634 
635         top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
636         if (top_blob.empty())
637             return -100;
638 
639         int channels_per_group = channels / _group;
640 
641         if (_group == 2)
642         {
643             for (int q = 0; q < channels_per_group; q++)
644             {
645                 const unsigned short* ptr0 = bottom_blob.channel(q);
646                 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
647                 unsigned short* outptr0 = top_blob.channel(q * 2);
648                 unsigned short* outptr1 = top_blob.channel(q * 2 + 1);
649 
650                 for (int i = 0; i < size; i++)
651                 {
652                     uint16x4_t _p0 = vld1_u16(ptr0);
653                     uint16x4_t _p1 = vld1_u16(ptr1);
654 
655                     uint16x4x2_t _p01 = vzip_u16(_p0, _p1);
656 
657                     vst1_u16(outptr0, _p01.val[0]);
658                     vst1_u16(outptr1, _p01.val[1]);
659 
660                     ptr0 += 4;
661                     ptr1 += 4;
662                     outptr0 += 4;
663                     outptr1 += 4;
664                 }
665             }
666         }
667 
668         if (_group == 3)
669         {
670             for (int q = 0; q < channels_per_group; q++)
671             {
672                 const unsigned short* ptr0 = bottom_blob.channel(q);
673                 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
674                 const unsigned short* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
675                 unsigned short* outptr0 = top_blob.channel(q * 3);
676                 unsigned short* outptr1 = top_blob.channel(q * 3 + 1);
677                 unsigned short* outptr2 = top_blob.channel(q * 3 + 2);
678 
679                 for (int i = 0; i < size; i++)
680                 {
681                     uint16x4_t _p0 = vld1_u16(ptr0);
682                     uint16x4_t _p1 = vld1_u16(ptr1);
683                     uint16x4_t _p2 = vld1_u16(ptr2);
684 
685                     // TODO figure out a faster way
686                     uint16x4x2_t _p01 = vzip_u16(_p0, _p1);
687                     uint16x4x2_t _p12 = vzip_u16(_p1, _p2);
688 
689                     uint32x2_t _0415 = vreinterpret_u32_u16(_p01.val[0]);
690                     uint16x4_t _2637 = _p01.val[1];
691                     uint16x4_t _4859 = _p12.val[0];
692                     uint32x2_t _6x7y = vreinterpret_u32_u16(_p12.val[1]);
693 
694                     uint16x4_t _98yx = vrev32_u16(_p2);
695                     uint16x4x2_t _90y281x3 = vtrn_u16(_98yx, _p0);
696 
697                     uint32x2_t _81x3 = vreinterpret_u32_u16(_90y281x3.val[1]);
698 
699                     uint32x2x2_t _048115x3 = vtrn_u32(_0415, _81x3);
700                     uint32x2x2_t _816xx37y = vtrn_u32(_81x3, _6x7y);
701 
702                     uint16x4_t _0481 = vreinterpret_u16_u32(_048115x3.val[0]);
703                     uint16x4_t _5926 = vext_u16(_4859, _2637, 2);
704                     uint16x4_t _x37y = vreinterpret_u16_u32(_816xx37y.val[1]);
705 
706                     vst1_u16(outptr0, _0481);
707                     vst1_u16(outptr1, _5926);
708                     vst1_u16(outptr2, _x37y);
709 
710                     ptr0 += 4;
711                     ptr1 += 4;
712                     ptr2 += 4;
713                     outptr0 += 4;
714                     outptr1 += 4;
715                     outptr2 += 4;
716                 }
717             }
718         }
719 
720         if (_group == 4)
721         {
722             for (int q = 0; q < channels_per_group; q++)
723             {
724                 const unsigned short* ptr0 = bottom_blob.channel(q);
725                 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
726                 const unsigned short* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
727                 const unsigned short* ptr3 = bottom_blob.channel(channels_per_group * 3 + q);
728                 unsigned short* outptr0 = top_blob.channel(q * 4);
729                 unsigned short* outptr1 = top_blob.channel(q * 4 + 1);
730                 unsigned short* outptr2 = top_blob.channel(q * 4 + 2);
731                 unsigned short* outptr3 = top_blob.channel(q * 4 + 3);
732 
733                 for (int i = 0; i < size; i++)
734                 {
735                     uint16x4_t _p0 = vld1_u16(ptr0);
736                     uint16x4_t _p1 = vld1_u16(ptr1);
737                     uint16x4_t _p2 = vld1_u16(ptr2);
738                     uint16x4_t _p3 = vld1_u16(ptr3);
739 
740                     // transpose 4x4
741                     uint16x4x2_t _p01 = vtrn_u16(_p0, _p1);
742                     uint16x4x2_t _p23 = vtrn_u16(_p2, _p3);
743                     uint32x2x2_t _p02 = vtrn_u32(vreinterpret_u32_u16(_p01.val[0]), vreinterpret_u32_u16(_p23.val[0]));
744                     uint32x2x2_t _p13 = vtrn_u32(vreinterpret_u32_u16(_p01.val[1]), vreinterpret_u32_u16(_p23.val[1]));
745                     _p0 = vreinterpret_u16_u32(_p02.val[0]);
746                     _p1 = vreinterpret_u16_u32(_p13.val[0]);
747                     _p2 = vreinterpret_u16_u32(_p02.val[1]);
748                     _p3 = vreinterpret_u16_u32(_p13.val[1]);
749 
750                     vst1_u16(outptr0, _p0);
751                     vst1_u16(outptr1, _p1);
752                     vst1_u16(outptr2, _p2);
753                     vst1_u16(outptr3, _p3);
754 
755                     ptr0 += 4;
756                     ptr1 += 4;
757                     ptr2 += 4;
758                     ptr3 += 4;
759                     outptr0 += 4;
760                     outptr1 += 4;
761                     outptr2 += 4;
762                     outptr3 += 4;
763                 }
764             }
765         }
766 
767         return 0;
768     }
769 #endif // __ARM_NEON
770 
771     return ShuffleChannel::forward(bottom_blob, top_blob, opt);
772 }
773 
774 } // namespace ncnn
775