1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "packing_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 namespace ncnn {
22 
Packing_arm()23 Packing_arm::Packing_arm()
24 {
25     support_packing = true;
26 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
27     support_fp16_storage = true;
28 #endif
29 
30     support_bf16_storage = true;
31 }
32 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const33 int Packing_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
34 {
35     int elembits = bottom_blob.elembits();
36 
37     if (elembits == 8)
38         return forward_int8(bottom_blob, top_blob, opt);
39 
40 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
41     if (opt.use_fp16_storage && elembits == 16)
42         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
43 #endif
44 
45     if (opt.use_bf16_storage && elembits == 16)
46         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
47 
48     if (use_padding)
49     {
50         return Packing::forward(bottom_blob, top_blob, opt);
51     }
52 
53     if (elembits != 32)
54     {
55         // non-fp32 type
56         return Packing::forward(bottom_blob, top_blob, opt);
57     }
58 
59     size_t elemsize = bottom_blob.elemsize;
60     int elempack = bottom_blob.elempack;
61 
62     if (elempack == out_elempack)
63     {
64         top_blob = bottom_blob;
65         return 0;
66     }
67 
68     bool pack1to4 = elempack == 1 && out_elempack == 4;
69     bool pack4to1 = elempack == 4 && out_elempack == 1;
70 
71     if (!pack1to4 && !pack4to1)
72     {
73         return Packing::forward(bottom_blob, top_blob, opt);
74     }
75 
76     int w = bottom_blob.w;
77     int h = bottom_blob.h;
78     int d = bottom_blob.d;
79     int channels = bottom_blob.c;
80     int dims = bottom_blob.dims;
81 
82     if (!use_padding)
83     {
84         // identity if use_padding not allowed
85         if (dims == 1 && w * elempack % out_elempack != 0)
86         {
87             top_blob = bottom_blob;
88             return 0;
89         }
90         if (dims == 2 && h * elempack % out_elempack != 0)
91         {
92             top_blob = bottom_blob;
93             return 0;
94         }
95         if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0)
96         {
97             top_blob = bottom_blob;
98             return 0;
99         }
100     }
101 
102     if (dims == 1)
103     {
104         top_blob = bottom_blob;
105         top_blob.w = w * elempack / out_elempack;
106         top_blob.cstep = w * elempack / out_elempack;
107         top_blob.elemsize = elemsize / elempack * out_elempack;
108         top_blob.elempack = out_elempack;
109         return 0;
110     }
111 
112     if (dims == 2)
113     {
114         int outh = h * elempack / out_elempack;
115         size_t out_elemsize = elemsize / elempack * out_elempack;
116 
117         top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator);
118         if (top_blob.empty())
119             return -100;
120 
121         if (pack1to4)
122         {
123             #pragma omp parallel for num_threads(opt.num_threads)
124             for (int i = 0; i < outh; i++)
125             {
126                 const float* r0 = bottom_blob.row(i * 4);
127                 const float* r1 = bottom_blob.row(i * 4 + 1);
128                 const float* r2 = bottom_blob.row(i * 4 + 2);
129                 const float* r3 = bottom_blob.row(i * 4 + 3);
130 
131                 float* outptr = top_blob.row(i);
132 
133                 int j = 0;
134 #if __ARM_NEON
135                 for (; j + 3 < w; j += 4)
136                 {
137                     float32x4x4_t _p;
138                     _p.val[0] = vld1q_f32(r0);
139                     _p.val[1] = vld1q_f32(r1);
140                     _p.val[2] = vld1q_f32(r2);
141                     _p.val[3] = vld1q_f32(r3);
142                     vst4q_f32(outptr, _p);
143 
144                     r0 += 4;
145                     r1 += 4;
146                     r2 += 4;
147                     r3 += 4;
148                     outptr += 16;
149                 }
150 #endif
151                 for (; j < w; j++)
152                 {
153                     outptr[0] = *r0++;
154                     outptr[1] = *r1++;
155                     outptr[2] = *r2++;
156                     outptr[3] = *r3++;
157 
158                     outptr += 4;
159                 }
160             }
161         }
162         if (pack4to1)
163         {
164             #pragma omp parallel for num_threads(opt.num_threads)
165             for (int i = 0; i < h; i++)
166             {
167                 const float* r0 = bottom_blob.row(i);
168 
169                 float* outptr0 = top_blob.row(i * 4);
170                 float* outptr1 = top_blob.row(i * 4 + 1);
171                 float* outptr2 = top_blob.row(i * 4 + 2);
172                 float* outptr3 = top_blob.row(i * 4 + 3);
173 
174                 int j = 0;
175 #if __ARM_NEON
176                 for (; j + 3 < w; j += 4)
177                 {
178                     float32x4x4_t _p = vld4q_f32(r0);
179                     vst1q_f32(outptr0, _p.val[0]);
180                     vst1q_f32(outptr1, _p.val[1]);
181                     vst1q_f32(outptr2, _p.val[2]);
182                     vst1q_f32(outptr3, _p.val[3]);
183 
184                     r0 += 16;
185                     outptr0 += 4;
186                     outptr1 += 4;
187                     outptr2 += 4;
188                     outptr3 += 4;
189                 }
190 #endif
191                 for (; j < w; j++)
192                 {
193                     *outptr0++ = r0[0];
194                     *outptr1++ = r0[1];
195                     *outptr2++ = r0[2];
196                     *outptr3++ = r0[3];
197 
198                     r0 += 4;
199                 }
200             }
201         }
202 
203         return 0;
204     }
205 
206     if (dims == 3 || dims == 4)
207     {
208         int size = w * h * d;
209         int outc = channels * elempack / out_elempack;
210         size_t out_elemsize = elemsize / elempack * out_elempack;
211 
212         if (dims == 3)
213             top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator);
214         else // if (dims == 4)
215             top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator);
216         if (top_blob.empty())
217             return -100;
218 
219         if (pack1to4)
220         {
221             #pragma omp parallel for num_threads(opt.num_threads)
222             for (int q = 0; q < outc; q++)
223             {
224                 const float* r0 = bottom_blob.channel(q * 4);
225                 const float* r1 = bottom_blob.channel(q * 4 + 1);
226                 const float* r2 = bottom_blob.channel(q * 4 + 2);
227                 const float* r3 = bottom_blob.channel(q * 4 + 3);
228 
229                 float* outptr = top_blob.channel(q);
230 
231                 int i = 0;
232 #if __ARM_NEON
233                 for (; i + 3 < size; i += 4)
234                 {
235                     float32x4x4_t _p;
236                     _p.val[0] = vld1q_f32(r0);
237                     _p.val[1] = vld1q_f32(r1);
238                     _p.val[2] = vld1q_f32(r2);
239                     _p.val[3] = vld1q_f32(r3);
240                     vst4q_f32(outptr, _p);
241 
242                     r0 += 4;
243                     r1 += 4;
244                     r2 += 4;
245                     r3 += 4;
246                     outptr += 16;
247                 }
248 #endif
249                 for (; i < size; i++)
250                 {
251                     outptr[0] = *r0++;
252                     outptr[1] = *r1++;
253                     outptr[2] = *r2++;
254                     outptr[3] = *r3++;
255 
256                     outptr += 4;
257                 }
258             }
259         }
260         if (pack4to1)
261         {
262             #pragma omp parallel for num_threads(opt.num_threads)
263             for (int q = 0; q < channels; q++)
264             {
265                 const float* r0 = bottom_blob.channel(q);
266 
267                 float* outptr0 = top_blob.channel(q * 4);
268                 float* outptr1 = top_blob.channel(q * 4 + 1);
269                 float* outptr2 = top_blob.channel(q * 4 + 2);
270                 float* outptr3 = top_blob.channel(q * 4 + 3);
271 
272                 int i = 0;
273 #if __ARM_NEON
274                 for (; i + 3 < size; i += 4)
275                 {
276                     float32x4x4_t _p = vld4q_f32(r0);
277                     vst1q_f32(outptr0, _p.val[0]);
278                     vst1q_f32(outptr1, _p.val[1]);
279                     vst1q_f32(outptr2, _p.val[2]);
280                     vst1q_f32(outptr3, _p.val[3]);
281 
282                     r0 += 16;
283                     outptr0 += 4;
284                     outptr1 += 4;
285                     outptr2 += 4;
286                     outptr3 += 4;
287                 }
288 #endif
289                 for (; i < size; i++)
290                 {
291                     *outptr0++ = r0[0];
292                     *outptr1++ = r0[1];
293                     *outptr2++ = r0[2];
294                     *outptr3++ = r0[3];
295 
296                     r0 += 4;
297                 }
298             }
299         }
300 
301         return 0;
302     }
303 
304     return 0;
305 }
306 
forward_bf16s_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const307 int Packing_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
308 {
309     if (use_padding)
310     {
311         return Packing::forward(bottom_blob, top_blob, opt);
312     }
313 
314     size_t elemsize = bottom_blob.elemsize;
315     int elempack = bottom_blob.elempack;
316 
317     if (elempack == out_elempack)
318     {
319         top_blob = bottom_blob;
320         return 0;
321     }
322 
323     bool pack1to4 = elempack == 1 && out_elempack == 4;
324     bool pack4to1 = elempack == 4 && out_elempack == 1;
325     bool pack1to8 = elempack == 1 && out_elempack == 8;
326     bool pack8to1 = elempack == 8 && out_elempack == 1;
327     bool pack4to8 = elempack == 4 && out_elempack == 8;
328     bool pack8to4 = elempack == 8 && out_elempack == 4;
329 
330     if (!pack1to4 && !pack4to1 && !pack1to8 && !pack8to1 && !pack4to8 && !pack8to4)
331     {
332         return Packing::forward(bottom_blob, top_blob, opt);
333     }
334 
335     int w = bottom_blob.w;
336     int h = bottom_blob.h;
337     int d = bottom_blob.d;
338     int channels = bottom_blob.c;
339     int dims = bottom_blob.dims;
340 
341     if (!use_padding)
342     {
343         // identity if use_padding not allowed
344         if (dims == 1 && w * elempack % out_elempack != 0)
345         {
346             top_blob = bottom_blob;
347             return 0;
348         }
349         if (dims == 2 && h * elempack % out_elempack != 0)
350         {
351             top_blob = bottom_blob;
352             return 0;
353         }
354         if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0)
355         {
356             top_blob = bottom_blob;
357             return 0;
358         }
359     }
360 
361     if (dims == 1)
362     {
363         top_blob = bottom_blob;
364         top_blob.w = w * elempack / out_elempack;
365         top_blob.cstep = w * elempack / out_elempack;
366         top_blob.elemsize = elemsize / elempack * out_elempack;
367         top_blob.elempack = out_elempack;
368         return 0;
369     }
370 
371     if (dims == 2)
372     {
373         int outh = h * elempack / out_elempack;
374         size_t out_elemsize = elemsize / elempack * out_elempack;
375 
376         top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator);
377         if (top_blob.empty())
378             return -100;
379 
380         if (pack1to4)
381         {
382             #pragma omp parallel for num_threads(opt.num_threads)
383             for (int i = 0; i < outh; i++)
384             {
385                 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i * 4);
386                 const unsigned short* r1 = bottom_blob.row<const unsigned short>(i * 4 + 1);
387                 const unsigned short* r2 = bottom_blob.row<const unsigned short>(i * 4 + 2);
388                 const unsigned short* r3 = bottom_blob.row<const unsigned short>(i * 4 + 3);
389 
390                 unsigned short* outptr = top_blob.row<unsigned short>(i);
391 
392                 int j = 0;
393 #if __ARM_NEON
394                 for (; j + 3 < w; j += 4)
395                 {
396                     uint16x4x4_t _p;
397                     _p.val[0] = vld1_u16(r0);
398                     _p.val[1] = vld1_u16(r1);
399                     _p.val[2] = vld1_u16(r2);
400                     _p.val[3] = vld1_u16(r3);
401                     vst4_u16(outptr, _p);
402 
403                     r0 += 4;
404                     r1 += 4;
405                     r2 += 4;
406                     r3 += 4;
407                     outptr += 16;
408                 }
409 #endif
410                 for (; j < w; j++)
411                 {
412                     outptr[0] = *r0++;
413                     outptr[1] = *r1++;
414                     outptr[2] = *r2++;
415                     outptr[3] = *r3++;
416 
417                     outptr += 4;
418                 }
419             }
420         }
421         if (pack4to1)
422         {
423             #pragma omp parallel for num_threads(opt.num_threads)
424             for (int i = 0; i < h; i++)
425             {
426                 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i);
427 
428                 unsigned short* outptr0 = top_blob.row<unsigned short>(i * 4);
429                 unsigned short* outptr1 = top_blob.row<unsigned short>(i * 4 + 1);
430                 unsigned short* outptr2 = top_blob.row<unsigned short>(i * 4 + 2);
431                 unsigned short* outptr3 = top_blob.row<unsigned short>(i * 4 + 3);
432 
433                 int j = 0;
434 #if __ARM_NEON
435                 for (; j + 3 < w; j += 4)
436                 {
437                     uint16x4x4_t _p = vld4_u16(r0);
438                     vst1_u16(outptr0, _p.val[0]);
439                     vst1_u16(outptr1, _p.val[1]);
440                     vst1_u16(outptr2, _p.val[2]);
441                     vst1_u16(outptr3, _p.val[3]);
442 
443                     r0 += 16;
444                     outptr0 += 4;
445                     outptr1 += 4;
446                     outptr2 += 4;
447                     outptr3 += 4;
448                 }
449 #endif
450                 for (; j < w; j++)
451                 {
452                     *outptr0++ = r0[0];
453                     *outptr1++ = r0[1];
454                     *outptr2++ = r0[2];
455                     *outptr3++ = r0[3];
456 
457                     r0 += 4;
458                 }
459             }
460         }
461         if (pack1to8)
462         {
463             #pragma omp parallel for num_threads(opt.num_threads)
464             for (int i = 0; i < outh; i++)
465             {
466                 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i * 8);
467                 const unsigned short* r1 = bottom_blob.row<const unsigned short>(i * 8 + 1);
468                 const unsigned short* r2 = bottom_blob.row<const unsigned short>(i * 8 + 2);
469                 const unsigned short* r3 = bottom_blob.row<const unsigned short>(i * 8 + 3);
470                 const unsigned short* r4 = bottom_blob.row<const unsigned short>(i * 8 + 4);
471                 const unsigned short* r5 = bottom_blob.row<const unsigned short>(i * 8 + 5);
472                 const unsigned short* r6 = bottom_blob.row<const unsigned short>(i * 8 + 6);
473                 const unsigned short* r7 = bottom_blob.row<const unsigned short>(i * 8 + 7);
474 
475                 unsigned short* outptr = top_blob.row<unsigned short>(i);
476 
477                 int j = 0;
478 #if __ARM_NEON
479                 for (; j + 7 < w; j += 8)
480                 {
481                     // transpose 8x8
482 #if __aarch64__
483                     asm volatile(
484                         "ld1    {v0.8h}, [%0], #16      \n"
485                         "ld1    {v1.8h}, [%1], #16      \n"
486                         "ld1    {v2.8h}, [%2], #16      \n"
487                         "ld1    {v3.8h}, [%3], #16      \n"
488                         "ld1    {v4.8h}, [%4], #16      \n"
489                         "ld1    {v5.8h}, [%5], #16      \n"
490                         "ld1    {v6.8h}, [%6], #16      \n"
491                         "ld1    {v7.8h}, [%7], #16      \n"
492 
493                         "zip1   v16.8h, v0.8h, v4.8h    \n"
494                         "zip2   v20.8h, v0.8h, v4.8h    \n"
495                         "zip1   v17.8h, v1.8h, v5.8h    \n"
496                         "zip2   v21.8h, v1.8h, v5.8h    \n"
497                         "zip1   v18.8h, v2.8h, v6.8h    \n"
498                         "zip2   v22.8h, v2.8h, v6.8h    \n"
499                         "zip1   v19.8h, v3.8h, v7.8h    \n"
500                         "zip2   v23.8h, v3.8h, v7.8h    \n"
501 
502                         "st4    {v16.8h, v17.8h, v18.8h, v19.8h}, [%8], #64 \n"
503                         "st4    {v20.8h, v21.8h, v22.8h, v23.8h}, [%8], #64 \n"
504                         : "=r"(r0),    // %0
505                         "=r"(r1),    // %1
506                         "=r"(r2),    // %2
507                         "=r"(r3),    // %3
508                         "=r"(r4),    // %4
509                         "=r"(r5),    // %5
510                         "=r"(r6),    // %6
511                         "=r"(r7),    // %7
512                         "=r"(outptr) // %8
513                         : "0"(r0),
514                         "1"(r1),
515                         "2"(r2),
516                         "3"(r3),
517                         "4"(r4),
518                         "5"(r5),
519                         "6"(r6),
520                         "7"(r7),
521                         "8"(outptr)
522                         : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
523 #else
524                     asm volatile(
525                         "vld1.u16   {d16-d17}, [%0 : 128]! \n"
526                         "vld1.u16   {d18-d19}, [%1 : 128]! \n"
527                         "vld1.u16   {d20-d21}, [%2 : 128]! \n"
528                         "vld1.u16   {d22-d23}, [%3 : 128]! \n"
529                         "vld1.u16   {d24-d25}, [%4 : 128]! \n"
530                         "vld1.u16   {d26-d27}, [%5 : 128]! \n"
531                         "vld1.u16   {d28-d29}, [%6 : 128]! \n"
532                         "vld1.u16   {d30-d31}, [%7 : 128]! \n"
533 
534                         "vtrn.u16   q8, q9              \n"
535                         "vtrn.u16   q10, q11            \n"
536                         "vtrn.u16   q12, q13            \n"
537                         "vtrn.u16   q14, q15            \n"
538 
539                         "vtrn.u32   q8, q10             \n"
540                         "vtrn.u32   q9, q11             \n"
541                         "vtrn.u32   q12, q14            \n"
542                         "vtrn.u32   q13, q15            \n"
543 
544                         "vswp       d17, d24            \n"
545                         "vswp       d19, d26            \n"
546                         "vswp       d21, d28            \n"
547                         "vswp       d23, d30            \n"
548 
549                         "vstm       %8!, {d16-d23}      \n"
550                         "vstm       %8!, {d24-d31}      \n"
551                         : "=r"(r0),    // %0
552                         "=r"(r1),    // %1
553                         "=r"(r2),    // %2
554                         "=r"(r3),    // %3
555                         "=r"(r4),    // %4
556                         "=r"(r5),    // %5
557                         "=r"(r6),    // %6
558                         "=r"(r7),    // %7
559                         "=r"(outptr) // %8
560                         : "0"(r0),
561                         "1"(r1),
562                         "2"(r2),
563                         "3"(r3),
564                         "4"(r4),
565                         "5"(r5),
566                         "6"(r6),
567                         "7"(r7),
568                         "8"(outptr)
569                         : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
570 #endif
571                 }
572 #endif
573                 for (; j < w; j++)
574                 {
575                     outptr[0] = *r0++;
576                     outptr[1] = *r1++;
577                     outptr[2] = *r2++;
578                     outptr[3] = *r3++;
579                     outptr[4] = *r4++;
580                     outptr[5] = *r5++;
581                     outptr[6] = *r6++;
582                     outptr[7] = *r7++;
583 
584                     outptr += 8;
585                 }
586             }
587         }
588         if (pack8to1)
589         {
590             #pragma omp parallel for num_threads(opt.num_threads)
591             for (int i = 0; i < h; i++)
592             {
593                 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i);
594 
595                 unsigned short* outptr0 = top_blob.row<unsigned short>(i * 8);
596                 unsigned short* outptr1 = top_blob.row<unsigned short>(i * 8 + 1);
597                 unsigned short* outptr2 = top_blob.row<unsigned short>(i * 8 + 2);
598                 unsigned short* outptr3 = top_blob.row<unsigned short>(i * 8 + 3);
599                 unsigned short* outptr4 = top_blob.row<unsigned short>(i * 8 + 4);
600                 unsigned short* outptr5 = top_blob.row<unsigned short>(i * 8 + 5);
601                 unsigned short* outptr6 = top_blob.row<unsigned short>(i * 8 + 6);
602                 unsigned short* outptr7 = top_blob.row<unsigned short>(i * 8 + 7);
603 
604                 int j = 0;
605 #if __ARM_NEON
606                 for (; j + 7 < w; j += 8)
607                 {
608                     // transpose 8x8
609 #if __aarch64__
610                     asm volatile(
611                         "ld4    {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
612                         "ld4    {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n"
613 
614                         "uzp1   v16.8h, v0.8h, v4.8h    \n"
615                         "uzp2   v20.8h, v0.8h, v4.8h    \n"
616                         "uzp1   v17.8h, v1.8h, v5.8h    \n"
617                         "uzp2   v21.8h, v1.8h, v5.8h    \n"
618                         "uzp1   v18.8h, v2.8h, v6.8h    \n"
619                         "uzp2   v22.8h, v2.8h, v6.8h    \n"
620                         "uzp1   v19.8h, v3.8h, v7.8h    \n"
621                         "uzp2   v23.8h, v3.8h, v7.8h    \n"
622 
623                         "st1    {v16.8h}, [%1], #16      \n"
624                         "st1    {v17.8h}, [%2], #16      \n"
625                         "st1    {v18.8h}, [%3], #16      \n"
626                         "st1    {v19.8h}, [%4], #16      \n"
627                         "st1    {v20.8h}, [%5], #16      \n"
628                         "st1    {v21.8h}, [%6], #16      \n"
629                         "st1    {v22.8h}, [%7], #16      \n"
630                         "st1    {v23.8h}, [%8], #16      \n"
631                         : "=r"(r0),      // %0
632                         "=r"(outptr0), // %1
633                         "=r"(outptr1), // %2
634                         "=r"(outptr2), // %3
635                         "=r"(outptr3), // %4
636                         "=r"(outptr4), // %5
637                         "=r"(outptr5), // %6
638                         "=r"(outptr6), // %7
639                         "=r"(outptr7)  // %8
640                         : "0"(r0),
641                         "1"(outptr0),
642                         "2"(outptr1),
643                         "3"(outptr2),
644                         "4"(outptr3),
645                         "5"(outptr4),
646                         "6"(outptr5),
647                         "7"(outptr6),
648                         "8"(outptr7)
649                         : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
650 #else
651                     asm volatile(
652                         "vldm       %0!, {d16-d23}      \n"
653                         "vldm       %0!, {d24-d31}      \n"
654 
655                         "vtrn.u16   q8, q9              \n"
656                         "vtrn.u16   q10, q11            \n"
657                         "vtrn.u16   q12, q13            \n"
658                         "vtrn.u16   q14, q15            \n"
659 
660                         "vtrn.u32   q8, q10             \n"
661                         "vtrn.u32   q9, q11             \n"
662                         "vtrn.u32   q12, q14            \n"
663                         "vtrn.u32   q13, q15            \n"
664 
665                         "vswp       d17, d24            \n"
666                         "vswp       d19, d26            \n"
667                         "vswp       d21, d28            \n"
668                         "vswp       d23, d30            \n"
669 
670                         "vst1.u16   {d16-d17}, [%1 : 128]! \n"
671                         "vst1.u16   {d18-d19}, [%2 : 128]! \n"
672                         "vst1.u16   {d20-d21}, [%3 : 128]! \n"
673                         "vst1.u16   {d22-d23}, [%4 : 128]! \n"
674                         "vst1.u16   {d24-d25}, [%5 : 128]! \n"
675                         "vst1.u16   {d26-d27}, [%6 : 128]! \n"
676                         "vst1.u16   {d28-d29}, [%7 : 128]! \n"
677                         "vst1.u16   {d30-d31}, [%8 : 128]! \n"
678                         : "=r"(r0),      // %0
679                         "=r"(outptr0), // %1
680                         "=r"(outptr1), // %2
681                         "=r"(outptr2), // %3
682                         "=r"(outptr3), // %4
683                         "=r"(outptr4), // %5
684                         "=r"(outptr5), // %6
685                         "=r"(outptr6), // %7
686                         "=r"(outptr7)  // %8
687                         : "0"(r0),
688                         "1"(outptr0),
689                         "2"(outptr1),
690                         "3"(outptr2),
691                         "4"(outptr3),
692                         "5"(outptr4),
693                         "6"(outptr5),
694                         "7"(outptr6),
695                         "8"(outptr7)
696                         : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
697 #endif
698                 }
699 #endif
700                 for (; j < w; j++)
701                 {
702                     *outptr0++ = r0[0];
703                     *outptr1++ = r0[1];
704                     *outptr2++ = r0[2];
705                     *outptr3++ = r0[3];
706                     *outptr4++ = r0[4];
707                     *outptr5++ = r0[5];
708                     *outptr6++ = r0[6];
709                     *outptr7++ = r0[7];
710 
711                     r0 += 8;
712                 }
713             }
714         }
715         if (pack4to8)
716         {
717             #pragma omp parallel for num_threads(opt.num_threads)
718             for (int i = 0; i < outh; i++)
719             {
720                 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i * 2);
721                 const unsigned short* r1 = bottom_blob.row<const unsigned short>(i * 2 + 1);
722 
723                 unsigned short* outptr = top_blob.row<unsigned short>(i);
724 
725                 int j = 0;
726 #if __ARM_NEON
727                 for (; j + 1 < w; j += 2)
728                 {
729 #if __aarch64__
730                     asm volatile(
731                         "ld1    {v0.8h}, [%0], #16      \n"
732                         "ld1    {v1.8h}, [%1], #16      \n"
733 
734                         "zip1   v2.2d, v0.2d, v1.2d     \n"
735                         "zip2   v3.2d, v0.2d, v1.2d     \n"
736 
737                         "st1    {v2.8h, v3.8h}, [%2], #32\n"
738                         : "=r"(r0),    // %0
739                         "=r"(r1),    // %1
740                         "=r"(outptr) // %2
741                         : "0"(r0),
742                         "1"(r1),
743                         "2"(outptr)
744                         : "memory", "v0", "v1", "v2", "v3");
745 #else
746                     asm volatile(
747                         "vld1.u16   {d0-d1}, [%0 :128]! \n"
748                         "vld1.u16   {d2-d3}, [%1 :128]! \n"
749 
750                         "vswp       d1, d2              \n"
751 
752                         "vst1.u16   {d0-d3}, [%2 :128]! \n"
753                         : "=r"(r0),    // %0
754                         "=r"(r1),    // %1
755                         "=r"(outptr) // %2
756                         : "0"(r0),
757                         "1"(r1),
758                         "2"(outptr)
759                         : "memory", "q0", "q1");
760 #endif
761                 }
762 #endif
763                 for (; j < w; j++)
764                 {
765                     outptr[0] = r0[0];
766                     outptr[1] = r0[1];
767                     outptr[2] = r0[2];
768                     outptr[3] = r0[3];
769                     outptr[4] = r1[0];
770                     outptr[5] = r1[1];
771                     outptr[6] = r1[2];
772                     outptr[7] = r1[3];
773 
774                     r0 += 4;
775                     r1 += 4;
776                     outptr += 8;
777                 }
778             }
779         }
780         if (pack8to4)
781         {
782             #pragma omp parallel for num_threads(opt.num_threads)
783             for (int i = 0; i < h; i++)
784             {
785                 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i);
786 
787                 unsigned short* outptr0 = top_blob.row<unsigned short>(i * 2);
788                 unsigned short* outptr1 = top_blob.row<unsigned short>(i * 2 + 1);
789 
790                 int j = 0;
791 #if __ARM_NEON
792                 for (; j + 1 < w; j += 2)
793                 {
794 #if __aarch64__
795                     asm volatile(
796                         "ld1    {v0.8h, v1.8h}, [%0], #32 \n"
797 
798                         "uzp1   v2.2d, v0.2d, v1.2d     \n"
799                         "uzp2   v3.2d, v0.2d, v1.2d     \n"
800 
801                         "st1    {v2.8h}, [%1], #16      \n"
802                         "st1    {v3.8h}, [%2], #16      \n"
803                         : "=r"(r0),      // %0
804                         "=r"(outptr0), // %1
805                         "=r"(outptr1)  // %2
806                         : "0"(r0),
807                         "1"(outptr0),
808                         "2"(outptr1)
809                         : "memory", "v0", "v1", "v2", "v3");
810 #else
811                     asm volatile(
812                         "vld1.u16   {d0-d3}, [%0 :128]! \n"
813 
814                         "vswp       d1, d2              \n"
815 
816                         "vst1.u16   {d0-d1}, [%1 :128]! \n"
817                         "vst1.u16   {d2-d3}, [%2 :128]! \n"
818                         : "=r"(r0),      // %0
819                         "=r"(outptr0), // %1
820                         "=r"(outptr1)  // %2
821                         : "0"(r0),
822                         "1"(outptr0),
823                         "2"(outptr1)
824                         : "memory", "q0", "q1");
825 #endif
826                 }
827 #endif
828                 for (; j < w; j++)
829                 {
830                     outptr0[0] = r0[0];
831                     outptr0[1] = r0[1];
832                     outptr0[2] = r0[2];
833                     outptr0[3] = r0[3];
834                     outptr1[0] = r0[4];
835                     outptr1[1] = r0[5];
836                     outptr1[2] = r0[6];
837                     outptr1[3] = r0[7];
838 
839                     r0 += 8;
840                     outptr0 += 4;
841                     outptr1 += 4;
842                 }
843             }
844         }
845 
846         return 0;
847     }
848 
849     if (dims == 3 || dims == 4)
850     {
851         int size = w * h * d;
852         int outc = channels * elempack / out_elempack;
853         size_t out_elemsize = elemsize / elempack * out_elempack;
854 
855         if (dims == 3)
856             top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator);
857         else // if (dims == 4)
858             top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator);
859         if (top_blob.empty())
860             return -100;
861 
862         if (pack1to4)
863         {
864             #pragma omp parallel for num_threads(opt.num_threads)
865             for (int q = 0; q < outc; q++)
866             {
867                 const unsigned short* r0 = bottom_blob.channel(q * 4);
868                 const unsigned short* r1 = bottom_blob.channel(q * 4 + 1);
869                 const unsigned short* r2 = bottom_blob.channel(q * 4 + 2);
870                 const unsigned short* r3 = bottom_blob.channel(q * 4 + 3);
871 
872                 unsigned short* outptr = top_blob.channel(q);
873 
874                 int i = 0;
875 #if __ARM_NEON
876                 for (; i + 3 < size; i += 4)
877                 {
878                     uint16x4x4_t _p;
879                     _p.val[0] = vld1_u16(r0);
880                     _p.val[1] = vld1_u16(r1);
881                     _p.val[2] = vld1_u16(r2);
882                     _p.val[3] = vld1_u16(r3);
883                     vst4_u16(outptr, _p);
884 
885                     r0 += 4;
886                     r1 += 4;
887                     r2 += 4;
888                     r3 += 4;
889                     outptr += 16;
890                 }
891 #endif
892                 for (; i < size; i++)
893                 {
894                     outptr[0] = *r0++;
895                     outptr[1] = *r1++;
896                     outptr[2] = *r2++;
897                     outptr[3] = *r3++;
898 
899                     outptr += 4;
900                 }
901             }
902         }
903         if (pack4to1)
904         {
905             #pragma omp parallel for num_threads(opt.num_threads)
906             for (int q = 0; q < channels; q++)
907             {
908                 const unsigned short* r0 = bottom_blob.channel(q);
909 
910                 unsigned short* outptr0 = top_blob.channel(q * 4);
911                 unsigned short* outptr1 = top_blob.channel(q * 4 + 1);
912                 unsigned short* outptr2 = top_blob.channel(q * 4 + 2);
913                 unsigned short* outptr3 = top_blob.channel(q * 4 + 3);
914 
915                 int i = 0;
916 #if __ARM_NEON
917                 for (; i + 3 < size; i += 4)
918                 {
919                     uint16x4x4_t _p = vld4_u16(r0);
920                     vst1_u16(outptr0, _p.val[0]);
921                     vst1_u16(outptr1, _p.val[1]);
922                     vst1_u16(outptr2, _p.val[2]);
923                     vst1_u16(outptr3, _p.val[3]);
924 
925                     r0 += 16;
926                     outptr0 += 4;
927                     outptr1 += 4;
928                     outptr2 += 4;
929                     outptr3 += 4;
930                 }
931 #endif
932                 for (; i < size; i++)
933                 {
934                     *outptr0++ = r0[0];
935                     *outptr1++ = r0[1];
936                     *outptr2++ = r0[2];
937                     *outptr3++ = r0[3];
938 
939                     r0 += 4;
940                 }
941             }
942         }
943         if (pack1to8)
944         {
945             #pragma omp parallel for num_threads(opt.num_threads)
946             for (int q = 0; q < outc; q++)
947             {
948                 const unsigned short* r0 = bottom_blob.channel(q * 8);
949                 const unsigned short* r1 = bottom_blob.channel(q * 8 + 1);
950                 const unsigned short* r2 = bottom_blob.channel(q * 8 + 2);
951                 const unsigned short* r3 = bottom_blob.channel(q * 8 + 3);
952                 const unsigned short* r4 = bottom_blob.channel(q * 8 + 4);
953                 const unsigned short* r5 = bottom_blob.channel(q * 8 + 5);
954                 const unsigned short* r6 = bottom_blob.channel(q * 8 + 6);
955                 const unsigned short* r7 = bottom_blob.channel(q * 8 + 7);
956 
957                 unsigned short* outptr = top_blob.channel(q);
958 
959                 int i = 0;
960 #if __ARM_NEON
961                 for (; i + 7 < size; i += 8)
962                 {
963                     // transpose 8x8
964 #if __aarch64__
965                     asm volatile(
966                         "ld1    {v0.8h}, [%0], #16      \n"
967                         "ld1    {v1.8h}, [%1], #16      \n"
968                         "ld1    {v2.8h}, [%2], #16      \n"
969                         "ld1    {v3.8h}, [%3], #16      \n"
970                         "ld1    {v4.8h}, [%4], #16      \n"
971                         "ld1    {v5.8h}, [%5], #16      \n"
972                         "ld1    {v6.8h}, [%6], #16      \n"
973                         "ld1    {v7.8h}, [%7], #16      \n"
974 
975                         "zip1   v16.8h, v0.8h, v4.8h    \n"
976                         "zip2   v20.8h, v0.8h, v4.8h    \n"
977                         "zip1   v17.8h, v1.8h, v5.8h    \n"
978                         "zip2   v21.8h, v1.8h, v5.8h    \n"
979                         "zip1   v18.8h, v2.8h, v6.8h    \n"
980                         "zip2   v22.8h, v2.8h, v6.8h    \n"
981                         "zip1   v19.8h, v3.8h, v7.8h    \n"
982                         "zip2   v23.8h, v3.8h, v7.8h    \n"
983 
984                         "st4    {v16.8h, v17.8h, v18.8h, v19.8h}, [%8], #64 \n"
985                         "st4    {v20.8h, v21.8h, v22.8h, v23.8h}, [%8], #64 \n"
986                         : "=r"(r0),    // %0
987                         "=r"(r1),    // %1
988                         "=r"(r2),    // %2
989                         "=r"(r3),    // %3
990                         "=r"(r4),    // %4
991                         "=r"(r5),    // %5
992                         "=r"(r6),    // %6
993                         "=r"(r7),    // %7
994                         "=r"(outptr) // %8
995                         : "0"(r0),
996                         "1"(r1),
997                         "2"(r2),
998                         "3"(r3),
999                         "4"(r4),
1000                         "5"(r5),
1001                         "6"(r6),
1002                         "7"(r7),
1003                         "8"(outptr)
1004                         : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1005 #else
1006                     asm volatile(
1007                         "vld1.u16   {d16-d17}, [%0 : 128]! \n"
1008                         "vld1.u16   {d18-d19}, [%1 : 128]! \n"
1009                         "vld1.u16   {d20-d21}, [%2 : 128]! \n"
1010                         "vld1.u16   {d22-d23}, [%3 : 128]! \n"
1011                         "vld1.u16   {d24-d25}, [%4 : 128]! \n"
1012                         "vld1.u16   {d26-d27}, [%5 : 128]! \n"
1013                         "vld1.u16   {d28-d29}, [%6 : 128]! \n"
1014                         "vld1.u16   {d30-d31}, [%7 : 128]! \n"
1015 
1016                         "vtrn.u16   q8, q9              \n"
1017                         "vtrn.u16   q10, q11            \n"
1018                         "vtrn.u16   q12, q13            \n"
1019                         "vtrn.u16   q14, q15            \n"
1020 
1021                         "vtrn.u32   q8, q10             \n"
1022                         "vtrn.u32   q9, q11             \n"
1023                         "vtrn.u32   q12, q14            \n"
1024                         "vtrn.u32   q13, q15            \n"
1025 
1026                         "vswp       d17, d24            \n"
1027                         "vswp       d19, d26            \n"
1028                         "vswp       d21, d28            \n"
1029                         "vswp       d23, d30            \n"
1030 
1031                         "vstm       %8!, {d16-d23}      \n"
1032                         "vstm       %8!, {d24-d31}      \n"
1033                         : "=r"(r0),    // %0
1034                         "=r"(r1),    // %1
1035                         "=r"(r2),    // %2
1036                         "=r"(r3),    // %3
1037                         "=r"(r4),    // %4
1038                         "=r"(r5),    // %5
1039                         "=r"(r6),    // %6
1040                         "=r"(r7),    // %7
1041                         "=r"(outptr) // %8
1042                         : "0"(r0),
1043                         "1"(r1),
1044                         "2"(r2),
1045                         "3"(r3),
1046                         "4"(r4),
1047                         "5"(r5),
1048                         "6"(r6),
1049                         "7"(r7),
1050                         "8"(outptr)
1051                         : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1052 #endif
1053                 }
1054 #endif
1055                 for (; i < size; i++)
1056                 {
1057                     outptr[0] = *r0++;
1058                     outptr[1] = *r1++;
1059                     outptr[2] = *r2++;
1060                     outptr[3] = *r3++;
1061                     outptr[4] = *r4++;
1062                     outptr[5] = *r5++;
1063                     outptr[6] = *r6++;
1064                     outptr[7] = *r7++;
1065 
1066                     outptr += 8;
1067                 }
1068             }
1069         }
1070         if (pack8to1)
1071         {
1072             #pragma omp parallel for num_threads(opt.num_threads)
1073             for (int q = 0; q < channels; q++)
1074             {
1075                 const unsigned short* r0 = bottom_blob.channel(q);
1076 
1077                 unsigned short* outptr0 = top_blob.channel(q * 8);
1078                 unsigned short* outptr1 = top_blob.channel(q * 8 + 1);
1079                 unsigned short* outptr2 = top_blob.channel(q * 8 + 2);
1080                 unsigned short* outptr3 = top_blob.channel(q * 8 + 3);
1081                 unsigned short* outptr4 = top_blob.channel(q * 8 + 4);
1082                 unsigned short* outptr5 = top_blob.channel(q * 8 + 5);
1083                 unsigned short* outptr6 = top_blob.channel(q * 8 + 6);
1084                 unsigned short* outptr7 = top_blob.channel(q * 8 + 7);
1085 
1086                 int i = 0;
1087 #if __ARM_NEON
1088                 for (; i + 7 < size; i += 8)
1089                 {
1090                     // transpose 8x8
1091 #if __aarch64__
1092                     asm volatile(
1093                         "ld4    {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
1094                         "ld4    {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n"
1095 
1096                         "uzp1   v16.8h, v0.8h, v4.8h    \n"
1097                         "uzp2   v20.8h, v0.8h, v4.8h    \n"
1098                         "uzp1   v17.8h, v1.8h, v5.8h    \n"
1099                         "uzp2   v21.8h, v1.8h, v5.8h    \n"
1100                         "uzp1   v18.8h, v2.8h, v6.8h    \n"
1101                         "uzp2   v22.8h, v2.8h, v6.8h    \n"
1102                         "uzp1   v19.8h, v3.8h, v7.8h    \n"
1103                         "uzp2   v23.8h, v3.8h, v7.8h    \n"
1104 
1105                         "st1    {v16.8h}, [%1], #16      \n"
1106                         "st1    {v17.8h}, [%2], #16      \n"
1107                         "st1    {v18.8h}, [%3], #16      \n"
1108                         "st1    {v19.8h}, [%4], #16      \n"
1109                         "st1    {v20.8h}, [%5], #16      \n"
1110                         "st1    {v21.8h}, [%6], #16      \n"
1111                         "st1    {v22.8h}, [%7], #16      \n"
1112                         "st1    {v23.8h}, [%8], #16      \n"
1113                         : "=r"(r0),      // %0
1114                         "=r"(outptr0), // %1
1115                         "=r"(outptr1), // %2
1116                         "=r"(outptr2), // %3
1117                         "=r"(outptr3), // %4
1118                         "=r"(outptr4), // %5
1119                         "=r"(outptr5), // %6
1120                         "=r"(outptr6), // %7
1121                         "=r"(outptr7)  // %8
1122                         : "0"(r0),
1123                         "1"(outptr0),
1124                         "2"(outptr1),
1125                         "3"(outptr2),
1126                         "4"(outptr3),
1127                         "5"(outptr4),
1128                         "6"(outptr5),
1129                         "7"(outptr6),
1130                         "8"(outptr7)
1131                         : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1132 #else
1133                     asm volatile(
1134                         "vldm       %0!, {d16-d23}      \n"
1135                         "vldm       %0!, {d24-d31}      \n"
1136 
1137                         "vtrn.u16   q8, q9              \n"
1138                         "vtrn.u16   q10, q11            \n"
1139                         "vtrn.u16   q12, q13            \n"
1140                         "vtrn.u16   q14, q15            \n"
1141 
1142                         "vtrn.u32   q8, q10             \n"
1143                         "vtrn.u32   q9, q11             \n"
1144                         "vtrn.u32   q12, q14            \n"
1145                         "vtrn.u32   q13, q15            \n"
1146 
1147                         "vswp       d17, d24            \n"
1148                         "vswp       d19, d26            \n"
1149                         "vswp       d21, d28            \n"
1150                         "vswp       d23, d30            \n"
1151 
1152                         "vst1.u16   {d16-d17}, [%1 : 128]! \n"
1153                         "vst1.u16   {d18-d19}, [%2 : 128]! \n"
1154                         "vst1.u16   {d20-d21}, [%3 : 128]! \n"
1155                         "vst1.u16   {d22-d23}, [%4 : 128]! \n"
1156                         "vst1.u16   {d24-d25}, [%5 : 128]! \n"
1157                         "vst1.u16   {d26-d27}, [%6 : 128]! \n"
1158                         "vst1.u16   {d28-d29}, [%7 : 128]! \n"
1159                         "vst1.u16   {d30-d31}, [%8 : 128]! \n"
1160                         : "=r"(r0),      // %0
1161                         "=r"(outptr0), // %1
1162                         "=r"(outptr1), // %2
1163                         "=r"(outptr2), // %3
1164                         "=r"(outptr3), // %4
1165                         "=r"(outptr4), // %5
1166                         "=r"(outptr5), // %6
1167                         "=r"(outptr6), // %7
1168                         "=r"(outptr7)  // %8
1169                         : "0"(r0),
1170                         "1"(outptr0),
1171                         "2"(outptr1),
1172                         "3"(outptr2),
1173                         "4"(outptr3),
1174                         "5"(outptr4),
1175                         "6"(outptr5),
1176                         "7"(outptr6),
1177                         "8"(outptr7)
1178                         : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1179 #endif
1180                 }
1181 #endif
1182                 for (; i < size; i++)
1183                 {
1184                     *outptr0++ = r0[0];
1185                     *outptr1++ = r0[1];
1186                     *outptr2++ = r0[2];
1187                     *outptr3++ = r0[3];
1188                     *outptr4++ = r0[4];
1189                     *outptr5++ = r0[5];
1190                     *outptr6++ = r0[6];
1191                     *outptr7++ = r0[7];
1192 
1193                     r0 += 8;
1194                 }
1195             }
1196         }
1197         if (pack4to8)
1198         {
1199             #pragma omp parallel for num_threads(opt.num_threads)
1200             for (int q = 0; q < outc; q++)
1201             {
1202                 const unsigned short* r0 = bottom_blob.channel(q * 2);
1203                 const unsigned short* r1 = bottom_blob.channel(q * 2 + 1);
1204 
1205                 unsigned short* outptr = top_blob.channel(q);
1206 
1207                 int i = 0;
1208 #if __ARM_NEON
1209                 for (; i + 1 < size; i += 2)
1210                 {
1211 #if __aarch64__
1212                     asm volatile(
1213                         "ld1    {v0.8h}, [%0], #16      \n"
1214                         "ld1    {v1.8h}, [%1], #16      \n"
1215 
1216                         "zip1   v2.2d, v0.2d, v1.2d     \n"
1217                         "zip2   v3.2d, v0.2d, v1.2d     \n"
1218 
1219                         "st1    {v2.8h, v3.8h}, [%2], #32\n"
1220                         : "=r"(r0),    // %0
1221                         "=r"(r1),    // %1
1222                         "=r"(outptr) // %2
1223                         : "0"(r0),
1224                         "1"(r1),
1225                         "2"(outptr)
1226                         : "memory", "v0", "v1", "v2", "v3");
1227 #else
1228                     asm volatile(
1229                         "vld1.u16   {d0-d1}, [%0 :128]! \n"
1230                         "vld1.u16   {d2-d3}, [%1 :128]! \n"
1231 
1232                         "vswp       d1, d2              \n"
1233 
1234                         "vst1.u16   {d0-d3}, [%2 :128]! \n"
1235                         : "=r"(r0),    // %0
1236                         "=r"(r1),    // %1
1237                         "=r"(outptr) // %2
1238                         : "0"(r0),
1239                         "1"(r1),
1240                         "2"(outptr)
1241                         : "memory", "q0", "q1");
1242 #endif
1243                 }
1244 #endif
1245                 for (; i < size; i++)
1246                 {
1247                     outptr[0] = r0[0];
1248                     outptr[1] = r0[1];
1249                     outptr[2] = r0[2];
1250                     outptr[3] = r0[3];
1251                     outptr[4] = r1[0];
1252                     outptr[5] = r1[1];
1253                     outptr[6] = r1[2];
1254                     outptr[7] = r1[3];
1255 
1256                     r0 += 4;
1257                     r1 += 4;
1258                     outptr += 8;
1259                 }
1260             }
1261         }
1262         if (pack8to4)
1263         {
1264             #pragma omp parallel for num_threads(opt.num_threads)
1265             for (int q = 0; q < channels; q++)
1266             {
1267                 const unsigned short* r0 = bottom_blob.channel(q);
1268 
1269                 unsigned short* outptr0 = top_blob.channel(q * 2);
1270                 unsigned short* outptr1 = top_blob.channel(q * 2 + 1);
1271 
1272                 int i = 0;
1273 #if __ARM_NEON
1274                 for (; i + 1 < size; i += 2)
1275                 {
1276 #if __aarch64__
1277                     asm volatile(
1278                         "ld1    {v0.8h, v1.8h}, [%0], #32 \n"
1279 
1280                         "uzp1   v2.2d, v0.2d, v1.2d     \n"
1281                         "uzp2   v3.2d, v0.2d, v1.2d     \n"
1282 
1283                         "st1    {v2.8h}, [%1], #16      \n"
1284                         "st1    {v3.8h}, [%2], #16      \n"
1285                         : "=r"(r0),      // %0
1286                         "=r"(outptr0), // %1
1287                         "=r"(outptr1)  // %2
1288                         : "0"(r0),
1289                         "1"(outptr0),
1290                         "2"(outptr1)
1291                         : "memory", "v0", "v1", "v2", "v3");
1292 #else
1293                     asm volatile(
1294                         "vld1.u16   {d0-d3}, [%0 :128]! \n"
1295 
1296                         "vswp       d1, d2              \n"
1297 
1298                         "vst1.u16   {d0-d1}, [%1 :128]! \n"
1299                         "vst1.u16   {d2-d3}, [%2 :128]! \n"
1300                         : "=r"(r0),      // %0
1301                         "=r"(outptr0), // %1
1302                         "=r"(outptr1)  // %2
1303                         : "0"(r0),
1304                         "1"(outptr0),
1305                         "2"(outptr1)
1306                         : "memory", "q0", "q1");
1307 #endif
1308                 }
1309 #endif
1310                 for (; i < size; i++)
1311                 {
1312                     outptr0[0] = r0[0];
1313                     outptr0[1] = r0[1];
1314                     outptr0[2] = r0[2];
1315                     outptr0[3] = r0[3];
1316                     outptr1[0] = r0[4];
1317                     outptr1[1] = r0[5];
1318                     outptr1[2] = r0[6];
1319                     outptr1[3] = r0[7];
1320 
1321                     r0 += 8;
1322                     outptr0 += 4;
1323                     outptr1 += 4;
1324                 }
1325             }
1326         }
1327 
1328         return 0;
1329     }
1330 
1331     return 0;
1332 }
1333 
forward_int8(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1334 int Packing_arm::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1335 {
1336     if (use_padding)
1337     {
1338         return Packing::forward(bottom_blob, top_blob, opt);
1339     }
1340 
1341     size_t elemsize = bottom_blob.elemsize;
1342     int elempack = bottom_blob.elempack;
1343 
1344     if (elempack == out_elempack)
1345     {
1346         top_blob = bottom_blob;
1347         return 0;
1348     }
1349 
1350     bool pack1to8 = elempack == 1 && out_elempack == 8;
1351     bool pack8to1 = elempack == 8 && out_elempack == 1;
1352 
1353     if (!pack1to8 && !pack8to1)
1354     {
1355         return Packing::forward(bottom_blob, top_blob, opt);
1356     }
1357 
1358     int w = bottom_blob.w;
1359     int h = bottom_blob.h;
1360     int d = bottom_blob.d;
1361     int channels = bottom_blob.c;
1362     int dims = bottom_blob.dims;
1363 
1364     if (!use_padding)
1365     {
1366         // identity if use_padding not allowed
1367         if (dims == 1 && w * elempack % out_elempack != 0)
1368         {
1369             top_blob = bottom_blob;
1370             return 0;
1371         }
1372         if (dims == 2 && h * elempack % out_elempack != 0)
1373         {
1374             top_blob = bottom_blob;
1375             return 0;
1376         }
1377         if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0)
1378         {
1379             top_blob = bottom_blob;
1380             return 0;
1381         }
1382     }
1383 
1384     if (dims == 1)
1385     {
1386         top_blob = bottom_blob;
1387         top_blob.w = w * elempack / out_elempack;
1388         top_blob.cstep = w * elempack / out_elempack;
1389         top_blob.elemsize = elemsize / elempack * out_elempack;
1390         top_blob.elempack = out_elempack;
1391         return 0;
1392     }
1393 
1394     if (dims == 2)
1395     {
1396         int outh = h * elempack / out_elempack;
1397         size_t out_elemsize = elemsize / elempack * out_elempack;
1398 
1399         top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator);
1400         if (top_blob.empty())
1401             return -100;
1402 
1403         if (pack1to8)
1404         {
1405             #pragma omp parallel for num_threads(opt.num_threads)
1406             for (int i = 0; i < outh; i++)
1407             {
1408                 const signed char* r0 = bottom_blob.row<const signed char>(i * 8);
1409                 const signed char* r1 = bottom_blob.row<const signed char>(i * 8 + 1);
1410                 const signed char* r2 = bottom_blob.row<const signed char>(i * 8 + 2);
1411                 const signed char* r3 = bottom_blob.row<const signed char>(i * 8 + 3);
1412                 const signed char* r4 = bottom_blob.row<const signed char>(i * 8 + 4);
1413                 const signed char* r5 = bottom_blob.row<const signed char>(i * 8 + 5);
1414                 const signed char* r6 = bottom_blob.row<const signed char>(i * 8 + 6);
1415                 const signed char* r7 = bottom_blob.row<const signed char>(i * 8 + 7);
1416 
1417                 signed char* outptr = top_blob.row<signed char>(i);
1418 
1419                 int j = 0;
1420                 for (; j < w; j++)
1421                 {
1422                     outptr[0] = *r0++;
1423                     outptr[1] = *r1++;
1424                     outptr[2] = *r2++;
1425                     outptr[3] = *r3++;
1426                     outptr[4] = *r4++;
1427                     outptr[5] = *r5++;
1428                     outptr[6] = *r6++;
1429                     outptr[7] = *r7++;
1430 
1431                     outptr += 8;
1432                 }
1433             }
1434         }
1435         if (pack8to1)
1436         {
1437             #pragma omp parallel for num_threads(opt.num_threads)
1438             for (int i = 0; i < h; i++)
1439             {
1440                 const signed char* r0 = bottom_blob.row<const signed char>(i);
1441 
1442                 signed char* outptr0 = top_blob.row<signed char>(i * 8);
1443                 signed char* outptr1 = top_blob.row<signed char>(i * 8 + 1);
1444                 signed char* outptr2 = top_blob.row<signed char>(i * 8 + 2);
1445                 signed char* outptr3 = top_blob.row<signed char>(i * 8 + 3);
1446                 signed char* outptr4 = top_blob.row<signed char>(i * 8 + 4);
1447                 signed char* outptr5 = top_blob.row<signed char>(i * 8 + 5);
1448                 signed char* outptr6 = top_blob.row<signed char>(i * 8 + 6);
1449                 signed char* outptr7 = top_blob.row<signed char>(i * 8 + 7);
1450 
1451                 int j = 0;
1452                 for (; j < w; j++)
1453                 {
1454                     *outptr0++ = r0[0];
1455                     *outptr1++ = r0[1];
1456                     *outptr2++ = r0[2];
1457                     *outptr3++ = r0[3];
1458                     *outptr4++ = r0[4];
1459                     *outptr5++ = r0[5];
1460                     *outptr6++ = r0[6];
1461                     *outptr7++ = r0[7];
1462 
1463                     r0 += 8;
1464                 }
1465             }
1466         }
1467 
1468         return 0;
1469     }
1470 
1471     if (dims == 3 || dims == 4)
1472     {
1473         int size = w * h * d;
1474         int outc = channels * elempack / out_elempack;
1475         size_t out_elemsize = elemsize / elempack * out_elempack;
1476 
1477         if (dims == 3)
1478             top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator);
1479         else // if (dims == 4)
1480             top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator);
1481         if (top_blob.empty())
1482             return -100;
1483 
1484         if (pack1to8)
1485         {
1486             #pragma omp parallel for num_threads(opt.num_threads)
1487             for (int q = 0; q < outc; q++)
1488             {
1489                 const signed char* r0 = bottom_blob.channel(q * 8);
1490                 const signed char* r1 = bottom_blob.channel(q * 8 + 1);
1491                 const signed char* r2 = bottom_blob.channel(q * 8 + 2);
1492                 const signed char* r3 = bottom_blob.channel(q * 8 + 3);
1493                 const signed char* r4 = bottom_blob.channel(q * 8 + 4);
1494                 const signed char* r5 = bottom_blob.channel(q * 8 + 5);
1495                 const signed char* r6 = bottom_blob.channel(q * 8 + 6);
1496                 const signed char* r7 = bottom_blob.channel(q * 8 + 7);
1497 
1498                 signed char* outptr = top_blob.channel(q);
1499 
1500                 int i = 0;
1501                 for (; i < size; i++)
1502                 {
1503                     outptr[0] = *r0++;
1504                     outptr[1] = *r1++;
1505                     outptr[2] = *r2++;
1506                     outptr[3] = *r3++;
1507                     outptr[4] = *r4++;
1508                     outptr[5] = *r5++;
1509                     outptr[6] = *r6++;
1510                     outptr[7] = *r7++;
1511 
1512                     outptr += 8;
1513                 }
1514             }
1515         }
1516         if (pack8to1)
1517         {
1518             #pragma omp parallel for num_threads(opt.num_threads)
1519             for (int q = 0; q < channels; q++)
1520             {
1521                 const signed char* r0 = bottom_blob.channel(q);
1522 
1523                 signed char* outptr0 = top_blob.channel(q * 8);
1524                 signed char* outptr1 = top_blob.channel(q * 8 + 1);
1525                 signed char* outptr2 = top_blob.channel(q * 8 + 2);
1526                 signed char* outptr3 = top_blob.channel(q * 8 + 3);
1527                 signed char* outptr4 = top_blob.channel(q * 8 + 4);
1528                 signed char* outptr5 = top_blob.channel(q * 8 + 5);
1529                 signed char* outptr6 = top_blob.channel(q * 8 + 6);
1530                 signed char* outptr7 = top_blob.channel(q * 8 + 7);
1531 
1532                 int i = 0;
1533                 for (; i < size; i++)
1534                 {
1535                     *outptr0++ = r0[0];
1536                     *outptr1++ = r0[1];
1537                     *outptr2++ = r0[2];
1538                     *outptr3++ = r0[3];
1539                     *outptr4++ = r0[4];
1540                     *outptr5++ = r0[5];
1541                     *outptr6++ = r0[6];
1542                     *outptr7++ = r0[7];
1543 
1544                     r0 += 8;
1545                 }
1546             }
1547         }
1548 
1549         return 0;
1550     }
1551 
1552     return 0;
1553 }
1554 
1555 } // namespace ncnn
1556