1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "shufflechannel_arm.h"
16
17 #include "layer_type.h"
18
19 #if __ARM_NEON
20 #include <arm_neon.h>
21 #endif // __ARM_NEON
22
23 namespace ncnn {
24
ShuffleChannel_arm()25 ShuffleChannel_arm::ShuffleChannel_arm()
26 {
27 #if __ARM_NEON
28 support_packing = true;
29 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
30 support_fp16_storage = true;
31 #endif
32 #endif // __ARM_NEON
33
34 #if NCNN_BF16
35 support_bf16_storage = true;
36 #endif
37 }
38
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const39 int ShuffleChannel_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
40 {
41 int elembits = bottom_blob.elembits();
42
43 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
44 if (opt.use_fp16_storage && elembits == 16)
45 return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
46 #endif
47
48 #if NCNN_BF16
49 if (opt.use_bf16_storage && elembits == 16)
50 return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
51 #endif
52
53 int channels = bottom_blob.c;
54 int elempack = bottom_blob.elempack;
55
56 int _group = reverse ? channels * elempack / group : group;
57
58 if (_group == 1)
59 {
60 top_blob = bottom_blob;
61 return 0;
62 }
63
64 #if __ARM_NEON
65 if (elempack == 4)
66 {
67 if (_group == 2 && channels % _group != 0)
68 {
69 int w = bottom_blob.w;
70 int h = bottom_blob.h;
71 int size = w * h;
72 size_t elemsize = bottom_blob.elemsize;
73
74 top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
75 if (top_blob.empty())
76 return -100;
77
78 int channels_per_group = channels / _group;
79
80 // TODO unroll me
81 for (int q = 0; q < channels_per_group; q++)
82 {
83 const float* ptr0 = bottom_blob.channel(q);
84 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
85 const float* ptr2 = bottom_blob.channel(channels_per_group + q + 1);
86 float* outptr0 = top_blob.channel(q * 2);
87 float* outptr1 = top_blob.channel(q * 2 + 1);
88
89 for (int i = 0; i < size; i++)
90 {
91 float32x4_t _p0 = vld1q_f32(ptr0);
92 float32x4_t _p1 = vld1q_f32(ptr1);
93 float32x4_t _p2 = vld1q_f32(ptr2);
94
95 float32x4_t _p12 = vextq_f32(_p1, _p2, 2);
96
97 float32x4x2_t _p01 = vzipq_f32(_p0, _p12);
98
99 vst1q_f32(outptr0, _p01.val[0]);
100 vst1q_f32(outptr1, _p01.val[1]);
101
102 ptr0 += 4;
103 ptr1 += 4;
104 ptr2 += 4;
105 outptr0 += 4;
106 outptr1 += 4;
107 }
108 }
109
110 // handle the last channel
111 {
112 const float* ptr0 = bottom_blob.channel(channels_per_group);
113 const float* ptr1 = bottom_blob.channel(channels_per_group + channels_per_group);
114 float* outptr0 = top_blob.channel(channels_per_group * 2);
115
116 ptr1 += 2;
117
118 for (int i = 0; i < size; i++)
119 {
120 float32x4_t _p0 = vld1q_f32(ptr0);
121 float32x4_t _p1 = vld1q_f32(ptr1);
122
123 float32x4x2_t _p01 = vzipq_f32(_p0, _p1);
124
125 vst1q_f32(outptr0, _p01.val[0]);
126
127 ptr0 += 4;
128 ptr1 += 4;
129 outptr0 += 4;
130 }
131 }
132
133 return 0;
134 }
135
136 if (_group > 4 || channels % _group != 0)
137 {
138 // slow path for too large group or shuffle inside elempack
139 Option opt_pack = opt;
140 opt_pack.blob_allocator = opt.workspace_allocator;
141
142 Mat bottom_blob_unpacked;
143 convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
144
145 Mat top_blob_unpacked;
146 int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
147 if (ret != 0)
148 return ret;
149
150 convert_packing(top_blob_unpacked, top_blob, elempack, opt);
151
152 return 0;
153 }
154
155 int w = bottom_blob.w;
156 int h = bottom_blob.h;
157 int size = w * h;
158 size_t elemsize = bottom_blob.elemsize;
159
160 top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
161 if (top_blob.empty())
162 return -100;
163
164 int channels_per_group = channels / _group;
165
166 if (_group == 2)
167 {
168 for (int q = 0; q < channels_per_group; q++)
169 {
170 const float* ptr0 = bottom_blob.channel(q);
171 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
172 float* outptr0 = top_blob.channel(q * 2);
173 float* outptr1 = top_blob.channel(q * 2 + 1);
174
175 for (int i = 0; i < size; i++)
176 {
177 float32x4_t _p0 = vld1q_f32(ptr0);
178 float32x4_t _p1 = vld1q_f32(ptr1);
179
180 float32x4x2_t _p01 = vzipq_f32(_p0, _p1);
181
182 vst1q_f32(outptr0, _p01.val[0]);
183 vst1q_f32(outptr1, _p01.val[1]);
184
185 ptr0 += 4;
186 ptr1 += 4;
187 outptr0 += 4;
188 outptr1 += 4;
189 }
190 }
191 }
192
193 if (_group == 3)
194 {
195 for (int q = 0; q < channels_per_group; q++)
196 {
197 const float* ptr0 = bottom_blob.channel(q);
198 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
199 const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
200 float* outptr0 = top_blob.channel(q * 3);
201 float* outptr1 = top_blob.channel(q * 3 + 1);
202 float* outptr2 = top_blob.channel(q * 3 + 2);
203
204 for (int i = 0; i < size; i++)
205 {
206 float32x4_t _p0 = vld1q_f32(ptr0);
207 float32x4_t _p1 = vld1q_f32(ptr1);
208 float32x4_t _p2 = vld1q_f32(ptr2);
209
210 float32x4x2_t _p01 = vzipq_f32(_p0, _p1);
211 float32x4x2_t _p12 = vzipq_f32(_p1, _p2);
212
213 float32x4_t _0415 = _p01.val[0];
214 float32x4_t _2637 = _p01.val[1];
215 float32x4_t _4859 = _p12.val[0];
216 float32x4_t _6x7y = _p12.val[1];
217
218 float32x2_t _15 = vget_high_f32(_0415);
219 float32x2_t _37 = vget_high_f32(_2637);
220 float32x2_t _48 = vget_low_f32(_4859);
221 float32x2_t _6x = vget_low_f32(_6x7y);
222
223 float32x2_t _81 = vext_f32(_48, _15, 1);
224 float32x2_t _x3 = vext_f32(_6x, _37, 1);
225
226 float32x4_t _0481 = vcombine_f32(vget_low_f32(_0415), _81);
227 float32x4_t _5926 = vextq_f32(_4859, _2637, 2);
228 float32x4_t _x37y = vcombine_f32(_x3, vget_high_f32(_6x7y));
229
230 vst1q_f32(outptr0, _0481);
231 vst1q_f32(outptr1, _5926);
232 vst1q_f32(outptr2, _x37y);
233
234 ptr0 += 4;
235 ptr1 += 4;
236 ptr2 += 4;
237 outptr0 += 4;
238 outptr1 += 4;
239 outptr2 += 4;
240 }
241 }
242 }
243
244 if (_group == 4)
245 {
246 for (int q = 0; q < channels_per_group; q++)
247 {
248 const float* ptr0 = bottom_blob.channel(q);
249 const float* ptr1 = bottom_blob.channel(channels_per_group + q);
250 const float* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
251 const float* ptr3 = bottom_blob.channel(channels_per_group * 3 + q);
252 float* outptr0 = top_blob.channel(q * 4);
253 float* outptr1 = top_blob.channel(q * 4 + 1);
254 float* outptr2 = top_blob.channel(q * 4 + 2);
255 float* outptr3 = top_blob.channel(q * 4 + 3);
256
257 for (int i = 0; i < size; i++)
258 {
259 float32x4_t _p0 = vld1q_f32(ptr0);
260 float32x4_t _p1 = vld1q_f32(ptr1);
261 float32x4_t _p2 = vld1q_f32(ptr2);
262 float32x4_t _p3 = vld1q_f32(ptr3);
263
264 // transpose 4x4
265 float32x4x2_t _p01 = vtrnq_f32(_p0, _p1);
266 float32x4x2_t _p23 = vtrnq_f32(_p2, _p3);
267 _p0 = vcombine_f32(vget_low_f32(_p01.val[0]), vget_low_f32(_p23.val[0]));
268 _p1 = vcombine_f32(vget_low_f32(_p01.val[1]), vget_low_f32(_p23.val[1]));
269 _p2 = vcombine_f32(vget_high_f32(_p01.val[0]), vget_high_f32(_p23.val[0]));
270 _p3 = vcombine_f32(vget_high_f32(_p01.val[1]), vget_high_f32(_p23.val[1]));
271
272 vst1q_f32(outptr0, _p0);
273 vst1q_f32(outptr1, _p1);
274 vst1q_f32(outptr2, _p2);
275 vst1q_f32(outptr3, _p3);
276
277 ptr0 += 4;
278 ptr1 += 4;
279 ptr2 += 4;
280 ptr3 += 4;
281 outptr0 += 4;
282 outptr1 += 4;
283 outptr2 += 4;
284 outptr3 += 4;
285 }
286 }
287 }
288
289 return 0;
290 }
291 #endif // __ARM_NEON
292
293 return ShuffleChannel::forward(bottom_blob, top_blob, opt);
294 }
295
forward_bf16s_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const296 int ShuffleChannel_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
297 {
298 int channels = bottom_blob.c;
299 int elempack = bottom_blob.elempack;
300
301 int _group = reverse ? channels * elempack / group : group;
302
303 if (_group == 1)
304 {
305 top_blob = bottom_blob;
306 return 0;
307 }
308
309 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
310 if (elempack == 8)
311 {
312 if (_group == 2 && channels % _group != 0)
313 {
314 int w = bottom_blob.w;
315 int h = bottom_blob.h;
316 int size = w * h;
317 size_t elemsize = bottom_blob.elemsize;
318
319 top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
320 if (top_blob.empty())
321 return -100;
322
323 int channels_per_group = channels / _group;
324
325 // TODO unroll me
326 for (int q = 0; q < channels_per_group; q++)
327 {
328 const __fp16* ptr0 = bottom_blob.channel(q);
329 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
330 const __fp16* ptr2 = bottom_blob.channel(channels_per_group + q + 1);
331 __fp16* outptr0 = top_blob.channel(q * 2);
332 __fp16* outptr1 = top_blob.channel(q * 2 + 1);
333
334 for (int i = 0; i < size; i++)
335 {
336 float16x8_t _p0 = vld1q_f16(ptr0);
337 float16x8_t _p1 = vld1q_f16(ptr1);
338 float16x8_t _p2 = vld1q_f16(ptr2);
339
340 float16x8_t _p12 = vextq_f16(_p1, _p2, 4);
341
342 float16x8x2_t _p01 = vzipq_f16(_p0, _p12);
343
344 vst1q_f16(outptr0, _p01.val[0]);
345 vst1q_f16(outptr1, _p01.val[1]);
346
347 ptr0 += 8;
348 ptr1 += 8;
349 ptr2 += 8;
350 outptr0 += 8;
351 outptr1 += 8;
352 }
353 }
354
355 // handle the last channel
356 {
357 const __fp16* ptr0 = bottom_blob.channel(channels_per_group);
358 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + channels_per_group);
359 __fp16* outptr0 = top_blob.channel(channels_per_group * 2);
360
361 ptr1 += 4;
362
363 for (int i = 0; i < size; i++)
364 {
365 float16x4_t _p0 = vld1_f16(ptr0);
366 float16x4_t _p1 = vld1_f16(ptr1);
367
368 float16x4x2_t _p01 = vzip_f16(_p0, _p1);
369
370 vst1_f16(outptr0, _p01.val[0]);
371 vst1_f16(outptr0 + 4, _p01.val[1]);
372
373 ptr0 += 8;
374 ptr1 += 8;
375 outptr0 += 8;
376 }
377 }
378
379 return 0;
380 }
381
382 if (_group > 4 || channels % _group != 0)
383 {
384 // slow path for too large group or shuffle inside elempack
385 Option opt_pack = opt;
386 opt_pack.blob_allocator = opt.workspace_allocator;
387
388 Mat bottom_blob_unpacked;
389 convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
390
391 Mat top_blob_unpacked;
392 int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
393 if (ret != 0)
394 return ret;
395
396 convert_packing(top_blob_unpacked, top_blob, elempack, opt);
397
398 return 0;
399 }
400
401 int w = bottom_blob.w;
402 int h = bottom_blob.h;
403 int size = w * h;
404 size_t elemsize = bottom_blob.elemsize;
405
406 top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
407 if (top_blob.empty())
408 return -100;
409
410 int channels_per_group = channels / _group;
411
412 if (_group == 2)
413 {
414 for (int q = 0; q < channels_per_group; q++)
415 {
416 const __fp16* ptr0 = bottom_blob.channel(q);
417 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
418 __fp16* outptr0 = top_blob.channel(q * 2);
419 __fp16* outptr1 = top_blob.channel(q * 2 + 1);
420
421 for (int i = 0; i < size; i++)
422 {
423 float16x8_t _p0 = vld1q_f16(ptr0);
424 float16x8_t _p1 = vld1q_f16(ptr1);
425
426 float16x8x2_t _p01 = vzipq_f16(_p0, _p1);
427
428 vst1q_f16(outptr0, _p01.val[0]);
429 vst1q_f16(outptr1, _p01.val[1]);
430
431 ptr0 += 8;
432 ptr1 += 8;
433 outptr0 += 8;
434 outptr1 += 8;
435 }
436 }
437 }
438
439 if (_group == 3)
440 {
441 for (int q = 0; q < channels_per_group; q++)
442 {
443 const __fp16* ptr0 = bottom_blob.channel(q);
444 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
445 const __fp16* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
446 __fp16* outptr0 = top_blob.channel(q * 3);
447 __fp16* outptr1 = top_blob.channel(q * 3 + 1);
448 __fp16* outptr2 = top_blob.channel(q * 3 + 2);
449
450 for (int i = 0; i < size; i++)
451 {
452 float16x8_t _p0 = vld1q_f16(ptr0);
453 float16x8_t _p1 = vld1q_f16(ptr1);
454 float16x8_t _p2 = vld1q_f16(ptr2);
455
456 // TODO figure out a faster way
457
458 // 01234567 08g19h2a
459 // 89abcdef -> i3bj4ck5
460 // ghijklmn dl6em7fn
461
462 float16x8x3_t _p012;
463 _p012.val[0] = _p0;
464 _p012.val[1] = _p1;
465 _p012.val[2] = _p2;
466
467 __fp16 tmp[24];
468 vst3q_f16(&tmp[0], _p012);
469
470 _p0 = vld1q_f16(&tmp[0]);
471 _p1 = vld1q_f16(&tmp[8]);
472 _p2 = vld1q_f16(&tmp[16]);
473
474 vst1q_f16(outptr0, _p0);
475 vst1q_f16(outptr1, _p1);
476 vst1q_f16(outptr2, _p2);
477
478 ptr0 += 8;
479 ptr1 += 8;
480 ptr2 += 8;
481 outptr0 += 8;
482 outptr1 += 8;
483 outptr2 += 8;
484 }
485 }
486 }
487
488 if (_group == 4)
489 {
490 for (int q = 0; q < channels_per_group; q++)
491 {
492 const __fp16* ptr0 = bottom_blob.channel(q);
493 const __fp16* ptr1 = bottom_blob.channel(channels_per_group + q);
494 const __fp16* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
495 const __fp16* ptr3 = bottom_blob.channel(channels_per_group * 3 + q);
496 __fp16* outptr0 = top_blob.channel(q * 4);
497 __fp16* outptr1 = top_blob.channel(q * 4 + 1);
498 __fp16* outptr2 = top_blob.channel(q * 4 + 2);
499 __fp16* outptr3 = top_blob.channel(q * 4 + 3);
500
501 for (int i = 0; i < size; i++)
502 {
503 float16x8_t _p0 = vld1q_f16(ptr0);
504 float16x8_t _p1 = vld1q_f16(ptr1);
505 float16x8_t _p2 = vld1q_f16(ptr2);
506 float16x8_t _p3 = vld1q_f16(ptr3);
507
508 // transpose 4x4
509 float16x8x2_t _p01 = vtrnq_f16(_p0, _p1);
510 float16x8x2_t _p23 = vtrnq_f16(_p2, _p3);
511 uint32x4x2_t _p02 = vtrnq_u32(vreinterpretq_u32_f16(_p01.val[0]), vreinterpretq_u32_f16(_p23.val[0]));
512 uint32x4x2_t _p13 = vtrnq_u32(vreinterpretq_u32_f16(_p01.val[1]), vreinterpretq_u32_f16(_p23.val[1]));
513 _p0 = vreinterpretq_f16_u32(_p02.val[0]);
514 _p1 = vreinterpretq_f16_u32(_p13.val[0]);
515 _p2 = vreinterpretq_f16_u32(_p02.val[1]);
516 _p3 = vreinterpretq_f16_u32(_p13.val[1]);
517
518 vst1q_f16(outptr0, vcombine_f16(vget_low_f16(_p0), vget_low_f16(_p1)));
519 vst1q_f16(outptr1, vcombine_f16(vget_low_f16(_p2), vget_low_f16(_p3)));
520 vst1q_f16(outptr2, vcombine_f16(vget_high_f16(_p0), vget_high_f16(_p1)));
521 vst1q_f16(outptr3, vcombine_f16(vget_high_f16(_p2), vget_high_f16(_p3)));
522
523 ptr0 += 8;
524 ptr1 += 8;
525 ptr2 += 8;
526 ptr3 += 8;
527 outptr0 += 8;
528 outptr1 += 8;
529 outptr2 += 8;
530 outptr3 += 8;
531 }
532 }
533 }
534
535 return 0;
536 }
537 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
538
539 #if __ARM_NEON
540 if (elempack == 4)
541 {
542 if (_group == 2 && channels % _group != 0)
543 {
544 int w = bottom_blob.w;
545 int h = bottom_blob.h;
546 int size = w * h;
547 size_t elemsize = bottom_blob.elemsize;
548
549 top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
550 if (top_blob.empty())
551 return -100;
552
553 int channels_per_group = channels / _group;
554
555 // TODO unroll me
556 for (int q = 0; q < channels_per_group; q++)
557 {
558 const unsigned short* ptr0 = bottom_blob.channel(q);
559 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
560 const unsigned short* ptr2 = bottom_blob.channel(channels_per_group + q + 1);
561 unsigned short* outptr0 = top_blob.channel(q * 2);
562 unsigned short* outptr1 = top_blob.channel(q * 2 + 1);
563
564 for (int i = 0; i < size; i++)
565 {
566 uint16x4_t _p0 = vld1_u16(ptr0);
567 uint16x4_t _p1 = vld1_u16(ptr1);
568 uint16x4_t _p2 = vld1_u16(ptr2);
569
570 uint16x4_t _p12 = vext_u16(_p1, _p2, 2);
571
572 uint16x4x2_t _p01 = vzip_u16(_p0, _p12);
573
574 vst1_u16(outptr0, _p01.val[0]);
575 vst1_u16(outptr1, _p01.val[1]);
576
577 ptr0 += 4;
578 ptr1 += 4;
579 ptr2 += 4;
580 outptr0 += 4;
581 outptr1 += 4;
582 }
583 }
584
585 // handle the last channel
586 {
587 const unsigned short* ptr0 = bottom_blob.channel(channels_per_group);
588 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + channels_per_group);
589 unsigned short* outptr0 = top_blob.channel(channels_per_group * 2);
590
591 ptr1 += 2;
592
593 for (int i = 0; i < size; i++)
594 {
595 uint16x4_t _p0 = vld1_u16(ptr0);
596 uint16x4_t _p1 = vld1_u16(ptr1);
597
598 uint16x4x2_t _p01 = vzip_u16(_p0, _p1);
599
600 vst1_u16(outptr0, _p01.val[0]);
601
602 ptr0 += 4;
603 ptr1 += 4;
604 outptr0 += 4;
605 }
606 }
607
608 return 0;
609 }
610
611 if (_group > 4 || channels % _group != 0)
612 {
613 // slow path for too large group or shuffle inside elempack
614 Option opt_pack = opt;
615 opt_pack.blob_allocator = opt.workspace_allocator;
616
617 Mat bottom_blob_unpacked;
618 convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
619
620 Mat top_blob_unpacked;
621 int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
622 if (ret != 0)
623 return ret;
624
625 convert_packing(top_blob_unpacked, top_blob, elempack, opt);
626
627 return 0;
628 }
629
630 int w = bottom_blob.w;
631 int h = bottom_blob.h;
632 int size = w * h;
633 size_t elemsize = bottom_blob.elemsize;
634
635 top_blob.create(w, h, channels, elemsize, elempack, opt.blob_allocator);
636 if (top_blob.empty())
637 return -100;
638
639 int channels_per_group = channels / _group;
640
641 if (_group == 2)
642 {
643 for (int q = 0; q < channels_per_group; q++)
644 {
645 const unsigned short* ptr0 = bottom_blob.channel(q);
646 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
647 unsigned short* outptr0 = top_blob.channel(q * 2);
648 unsigned short* outptr1 = top_blob.channel(q * 2 + 1);
649
650 for (int i = 0; i < size; i++)
651 {
652 uint16x4_t _p0 = vld1_u16(ptr0);
653 uint16x4_t _p1 = vld1_u16(ptr1);
654
655 uint16x4x2_t _p01 = vzip_u16(_p0, _p1);
656
657 vst1_u16(outptr0, _p01.val[0]);
658 vst1_u16(outptr1, _p01.val[1]);
659
660 ptr0 += 4;
661 ptr1 += 4;
662 outptr0 += 4;
663 outptr1 += 4;
664 }
665 }
666 }
667
668 if (_group == 3)
669 {
670 for (int q = 0; q < channels_per_group; q++)
671 {
672 const unsigned short* ptr0 = bottom_blob.channel(q);
673 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
674 const unsigned short* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
675 unsigned short* outptr0 = top_blob.channel(q * 3);
676 unsigned short* outptr1 = top_blob.channel(q * 3 + 1);
677 unsigned short* outptr2 = top_blob.channel(q * 3 + 2);
678
679 for (int i = 0; i < size; i++)
680 {
681 uint16x4_t _p0 = vld1_u16(ptr0);
682 uint16x4_t _p1 = vld1_u16(ptr1);
683 uint16x4_t _p2 = vld1_u16(ptr2);
684
685 // TODO figure out a faster way
686 uint16x4x2_t _p01 = vzip_u16(_p0, _p1);
687 uint16x4x2_t _p12 = vzip_u16(_p1, _p2);
688
689 uint32x2_t _0415 = vreinterpret_u32_u16(_p01.val[0]);
690 uint16x4_t _2637 = _p01.val[1];
691 uint16x4_t _4859 = _p12.val[0];
692 uint32x2_t _6x7y = vreinterpret_u32_u16(_p12.val[1]);
693
694 uint16x4_t _98yx = vrev32_u16(_p2);
695 uint16x4x2_t _90y281x3 = vtrn_u16(_98yx, _p0);
696
697 uint32x2_t _81x3 = vreinterpret_u32_u16(_90y281x3.val[1]);
698
699 uint32x2x2_t _048115x3 = vtrn_u32(_0415, _81x3);
700 uint32x2x2_t _816xx37y = vtrn_u32(_81x3, _6x7y);
701
702 uint16x4_t _0481 = vreinterpret_u16_u32(_048115x3.val[0]);
703 uint16x4_t _5926 = vext_u16(_4859, _2637, 2);
704 uint16x4_t _x37y = vreinterpret_u16_u32(_816xx37y.val[1]);
705
706 vst1_u16(outptr0, _0481);
707 vst1_u16(outptr1, _5926);
708 vst1_u16(outptr2, _x37y);
709
710 ptr0 += 4;
711 ptr1 += 4;
712 ptr2 += 4;
713 outptr0 += 4;
714 outptr1 += 4;
715 outptr2 += 4;
716 }
717 }
718 }
719
720 if (_group == 4)
721 {
722 for (int q = 0; q < channels_per_group; q++)
723 {
724 const unsigned short* ptr0 = bottom_blob.channel(q);
725 const unsigned short* ptr1 = bottom_blob.channel(channels_per_group + q);
726 const unsigned short* ptr2 = bottom_blob.channel(channels_per_group * 2 + q);
727 const unsigned short* ptr3 = bottom_blob.channel(channels_per_group * 3 + q);
728 unsigned short* outptr0 = top_blob.channel(q * 4);
729 unsigned short* outptr1 = top_blob.channel(q * 4 + 1);
730 unsigned short* outptr2 = top_blob.channel(q * 4 + 2);
731 unsigned short* outptr3 = top_blob.channel(q * 4 + 3);
732
733 for (int i = 0; i < size; i++)
734 {
735 uint16x4_t _p0 = vld1_u16(ptr0);
736 uint16x4_t _p1 = vld1_u16(ptr1);
737 uint16x4_t _p2 = vld1_u16(ptr2);
738 uint16x4_t _p3 = vld1_u16(ptr3);
739
740 // transpose 4x4
741 uint16x4x2_t _p01 = vtrn_u16(_p0, _p1);
742 uint16x4x2_t _p23 = vtrn_u16(_p2, _p3);
743 uint32x2x2_t _p02 = vtrn_u32(vreinterpret_u32_u16(_p01.val[0]), vreinterpret_u32_u16(_p23.val[0]));
744 uint32x2x2_t _p13 = vtrn_u32(vreinterpret_u32_u16(_p01.val[1]), vreinterpret_u32_u16(_p23.val[1]));
745 _p0 = vreinterpret_u16_u32(_p02.val[0]);
746 _p1 = vreinterpret_u16_u32(_p13.val[0]);
747 _p2 = vreinterpret_u16_u32(_p02.val[1]);
748 _p3 = vreinterpret_u16_u32(_p13.val[1]);
749
750 vst1_u16(outptr0, _p0);
751 vst1_u16(outptr1, _p1);
752 vst1_u16(outptr2, _p2);
753 vst1_u16(outptr3, _p3);
754
755 ptr0 += 4;
756 ptr1 += 4;
757 ptr2 += 4;
758 ptr3 += 4;
759 outptr0 += 4;
760 outptr1 += 4;
761 outptr2 += 4;
762 outptr3 += 4;
763 }
764 }
765 }
766
767 return 0;
768 }
769 #endif // __ARM_NEON
770
771 return ShuffleChannel::forward(bottom_blob, top_blob, opt);
772 }
773
774 } // namespace ncnn
775