1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "packing_arm.h"
16
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20
21 namespace ncnn {
22
Packing_arm()23 Packing_arm::Packing_arm()
24 {
25 support_packing = true;
26 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
27 support_fp16_storage = true;
28 #endif
29
30 support_bf16_storage = true;
31 }
32
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const33 int Packing_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
34 {
35 int elembits = bottom_blob.elembits();
36
37 if (elembits == 8)
38 return forward_int8(bottom_blob, top_blob, opt);
39
40 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
41 if (opt.use_fp16_storage && elembits == 16)
42 return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
43 #endif
44
45 if (opt.use_bf16_storage && elembits == 16)
46 return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
47
48 if (use_padding)
49 {
50 return Packing::forward(bottom_blob, top_blob, opt);
51 }
52
53 if (elembits != 32)
54 {
55 // non-fp32 type
56 return Packing::forward(bottom_blob, top_blob, opt);
57 }
58
59 size_t elemsize = bottom_blob.elemsize;
60 int elempack = bottom_blob.elempack;
61
62 if (elempack == out_elempack)
63 {
64 top_blob = bottom_blob;
65 return 0;
66 }
67
68 bool pack1to4 = elempack == 1 && out_elempack == 4;
69 bool pack4to1 = elempack == 4 && out_elempack == 1;
70
71 if (!pack1to4 && !pack4to1)
72 {
73 return Packing::forward(bottom_blob, top_blob, opt);
74 }
75
76 int w = bottom_blob.w;
77 int h = bottom_blob.h;
78 int d = bottom_blob.d;
79 int channels = bottom_blob.c;
80 int dims = bottom_blob.dims;
81
82 if (!use_padding)
83 {
84 // identity if use_padding not allowed
85 if (dims == 1 && w * elempack % out_elempack != 0)
86 {
87 top_blob = bottom_blob;
88 return 0;
89 }
90 if (dims == 2 && h * elempack % out_elempack != 0)
91 {
92 top_blob = bottom_blob;
93 return 0;
94 }
95 if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0)
96 {
97 top_blob = bottom_blob;
98 return 0;
99 }
100 }
101
102 if (dims == 1)
103 {
104 top_blob = bottom_blob;
105 top_blob.w = w * elempack / out_elempack;
106 top_blob.cstep = w * elempack / out_elempack;
107 top_blob.elemsize = elemsize / elempack * out_elempack;
108 top_blob.elempack = out_elempack;
109 return 0;
110 }
111
112 if (dims == 2)
113 {
114 int outh = h * elempack / out_elempack;
115 size_t out_elemsize = elemsize / elempack * out_elempack;
116
117 top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator);
118 if (top_blob.empty())
119 return -100;
120
121 if (pack1to4)
122 {
123 #pragma omp parallel for num_threads(opt.num_threads)
124 for (int i = 0; i < outh; i++)
125 {
126 const float* r0 = bottom_blob.row(i * 4);
127 const float* r1 = bottom_blob.row(i * 4 + 1);
128 const float* r2 = bottom_blob.row(i * 4 + 2);
129 const float* r3 = bottom_blob.row(i * 4 + 3);
130
131 float* outptr = top_blob.row(i);
132
133 int j = 0;
134 #if __ARM_NEON
135 for (; j + 3 < w; j += 4)
136 {
137 float32x4x4_t _p;
138 _p.val[0] = vld1q_f32(r0);
139 _p.val[1] = vld1q_f32(r1);
140 _p.val[2] = vld1q_f32(r2);
141 _p.val[3] = vld1q_f32(r3);
142 vst4q_f32(outptr, _p);
143
144 r0 += 4;
145 r1 += 4;
146 r2 += 4;
147 r3 += 4;
148 outptr += 16;
149 }
150 #endif
151 for (; j < w; j++)
152 {
153 outptr[0] = *r0++;
154 outptr[1] = *r1++;
155 outptr[2] = *r2++;
156 outptr[3] = *r3++;
157
158 outptr += 4;
159 }
160 }
161 }
162 if (pack4to1)
163 {
164 #pragma omp parallel for num_threads(opt.num_threads)
165 for (int i = 0; i < h; i++)
166 {
167 const float* r0 = bottom_blob.row(i);
168
169 float* outptr0 = top_blob.row(i * 4);
170 float* outptr1 = top_blob.row(i * 4 + 1);
171 float* outptr2 = top_blob.row(i * 4 + 2);
172 float* outptr3 = top_blob.row(i * 4 + 3);
173
174 int j = 0;
175 #if __ARM_NEON
176 for (; j + 3 < w; j += 4)
177 {
178 float32x4x4_t _p = vld4q_f32(r0);
179 vst1q_f32(outptr0, _p.val[0]);
180 vst1q_f32(outptr1, _p.val[1]);
181 vst1q_f32(outptr2, _p.val[2]);
182 vst1q_f32(outptr3, _p.val[3]);
183
184 r0 += 16;
185 outptr0 += 4;
186 outptr1 += 4;
187 outptr2 += 4;
188 outptr3 += 4;
189 }
190 #endif
191 for (; j < w; j++)
192 {
193 *outptr0++ = r0[0];
194 *outptr1++ = r0[1];
195 *outptr2++ = r0[2];
196 *outptr3++ = r0[3];
197
198 r0 += 4;
199 }
200 }
201 }
202
203 return 0;
204 }
205
206 if (dims == 3 || dims == 4)
207 {
208 int size = w * h * d;
209 int outc = channels * elempack / out_elempack;
210 size_t out_elemsize = elemsize / elempack * out_elempack;
211
212 if (dims == 3)
213 top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator);
214 else // if (dims == 4)
215 top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator);
216 if (top_blob.empty())
217 return -100;
218
219 if (pack1to4)
220 {
221 #pragma omp parallel for num_threads(opt.num_threads)
222 for (int q = 0; q < outc; q++)
223 {
224 const float* r0 = bottom_blob.channel(q * 4);
225 const float* r1 = bottom_blob.channel(q * 4 + 1);
226 const float* r2 = bottom_blob.channel(q * 4 + 2);
227 const float* r3 = bottom_blob.channel(q * 4 + 3);
228
229 float* outptr = top_blob.channel(q);
230
231 int i = 0;
232 #if __ARM_NEON
233 for (; i + 3 < size; i += 4)
234 {
235 float32x4x4_t _p;
236 _p.val[0] = vld1q_f32(r0);
237 _p.val[1] = vld1q_f32(r1);
238 _p.val[2] = vld1q_f32(r2);
239 _p.val[3] = vld1q_f32(r3);
240 vst4q_f32(outptr, _p);
241
242 r0 += 4;
243 r1 += 4;
244 r2 += 4;
245 r3 += 4;
246 outptr += 16;
247 }
248 #endif
249 for (; i < size; i++)
250 {
251 outptr[0] = *r0++;
252 outptr[1] = *r1++;
253 outptr[2] = *r2++;
254 outptr[3] = *r3++;
255
256 outptr += 4;
257 }
258 }
259 }
260 if (pack4to1)
261 {
262 #pragma omp parallel for num_threads(opt.num_threads)
263 for (int q = 0; q < channels; q++)
264 {
265 const float* r0 = bottom_blob.channel(q);
266
267 float* outptr0 = top_blob.channel(q * 4);
268 float* outptr1 = top_blob.channel(q * 4 + 1);
269 float* outptr2 = top_blob.channel(q * 4 + 2);
270 float* outptr3 = top_blob.channel(q * 4 + 3);
271
272 int i = 0;
273 #if __ARM_NEON
274 for (; i + 3 < size; i += 4)
275 {
276 float32x4x4_t _p = vld4q_f32(r0);
277 vst1q_f32(outptr0, _p.val[0]);
278 vst1q_f32(outptr1, _p.val[1]);
279 vst1q_f32(outptr2, _p.val[2]);
280 vst1q_f32(outptr3, _p.val[3]);
281
282 r0 += 16;
283 outptr0 += 4;
284 outptr1 += 4;
285 outptr2 += 4;
286 outptr3 += 4;
287 }
288 #endif
289 for (; i < size; i++)
290 {
291 *outptr0++ = r0[0];
292 *outptr1++ = r0[1];
293 *outptr2++ = r0[2];
294 *outptr3++ = r0[3];
295
296 r0 += 4;
297 }
298 }
299 }
300
301 return 0;
302 }
303
304 return 0;
305 }
306
forward_bf16s_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const307 int Packing_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
308 {
309 if (use_padding)
310 {
311 return Packing::forward(bottom_blob, top_blob, opt);
312 }
313
314 size_t elemsize = bottom_blob.elemsize;
315 int elempack = bottom_blob.elempack;
316
317 if (elempack == out_elempack)
318 {
319 top_blob = bottom_blob;
320 return 0;
321 }
322
323 bool pack1to4 = elempack == 1 && out_elempack == 4;
324 bool pack4to1 = elempack == 4 && out_elempack == 1;
325 bool pack1to8 = elempack == 1 && out_elempack == 8;
326 bool pack8to1 = elempack == 8 && out_elempack == 1;
327 bool pack4to8 = elempack == 4 && out_elempack == 8;
328 bool pack8to4 = elempack == 8 && out_elempack == 4;
329
330 if (!pack1to4 && !pack4to1 && !pack1to8 && !pack8to1 && !pack4to8 && !pack8to4)
331 {
332 return Packing::forward(bottom_blob, top_blob, opt);
333 }
334
335 int w = bottom_blob.w;
336 int h = bottom_blob.h;
337 int d = bottom_blob.d;
338 int channels = bottom_blob.c;
339 int dims = bottom_blob.dims;
340
341 if (!use_padding)
342 {
343 // identity if use_padding not allowed
344 if (dims == 1 && w * elempack % out_elempack != 0)
345 {
346 top_blob = bottom_blob;
347 return 0;
348 }
349 if (dims == 2 && h * elempack % out_elempack != 0)
350 {
351 top_blob = bottom_blob;
352 return 0;
353 }
354 if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0)
355 {
356 top_blob = bottom_blob;
357 return 0;
358 }
359 }
360
361 if (dims == 1)
362 {
363 top_blob = bottom_blob;
364 top_blob.w = w * elempack / out_elempack;
365 top_blob.cstep = w * elempack / out_elempack;
366 top_blob.elemsize = elemsize / elempack * out_elempack;
367 top_blob.elempack = out_elempack;
368 return 0;
369 }
370
371 if (dims == 2)
372 {
373 int outh = h * elempack / out_elempack;
374 size_t out_elemsize = elemsize / elempack * out_elempack;
375
376 top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator);
377 if (top_blob.empty())
378 return -100;
379
380 if (pack1to4)
381 {
382 #pragma omp parallel for num_threads(opt.num_threads)
383 for (int i = 0; i < outh; i++)
384 {
385 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i * 4);
386 const unsigned short* r1 = bottom_blob.row<const unsigned short>(i * 4 + 1);
387 const unsigned short* r2 = bottom_blob.row<const unsigned short>(i * 4 + 2);
388 const unsigned short* r3 = bottom_blob.row<const unsigned short>(i * 4 + 3);
389
390 unsigned short* outptr = top_blob.row<unsigned short>(i);
391
392 int j = 0;
393 #if __ARM_NEON
394 for (; j + 3 < w; j += 4)
395 {
396 uint16x4x4_t _p;
397 _p.val[0] = vld1_u16(r0);
398 _p.val[1] = vld1_u16(r1);
399 _p.val[2] = vld1_u16(r2);
400 _p.val[3] = vld1_u16(r3);
401 vst4_u16(outptr, _p);
402
403 r0 += 4;
404 r1 += 4;
405 r2 += 4;
406 r3 += 4;
407 outptr += 16;
408 }
409 #endif
410 for (; j < w; j++)
411 {
412 outptr[0] = *r0++;
413 outptr[1] = *r1++;
414 outptr[2] = *r2++;
415 outptr[3] = *r3++;
416
417 outptr += 4;
418 }
419 }
420 }
421 if (pack4to1)
422 {
423 #pragma omp parallel for num_threads(opt.num_threads)
424 for (int i = 0; i < h; i++)
425 {
426 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i);
427
428 unsigned short* outptr0 = top_blob.row<unsigned short>(i * 4);
429 unsigned short* outptr1 = top_blob.row<unsigned short>(i * 4 + 1);
430 unsigned short* outptr2 = top_blob.row<unsigned short>(i * 4 + 2);
431 unsigned short* outptr3 = top_blob.row<unsigned short>(i * 4 + 3);
432
433 int j = 0;
434 #if __ARM_NEON
435 for (; j + 3 < w; j += 4)
436 {
437 uint16x4x4_t _p = vld4_u16(r0);
438 vst1_u16(outptr0, _p.val[0]);
439 vst1_u16(outptr1, _p.val[1]);
440 vst1_u16(outptr2, _p.val[2]);
441 vst1_u16(outptr3, _p.val[3]);
442
443 r0 += 16;
444 outptr0 += 4;
445 outptr1 += 4;
446 outptr2 += 4;
447 outptr3 += 4;
448 }
449 #endif
450 for (; j < w; j++)
451 {
452 *outptr0++ = r0[0];
453 *outptr1++ = r0[1];
454 *outptr2++ = r0[2];
455 *outptr3++ = r0[3];
456
457 r0 += 4;
458 }
459 }
460 }
461 if (pack1to8)
462 {
463 #pragma omp parallel for num_threads(opt.num_threads)
464 for (int i = 0; i < outh; i++)
465 {
466 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i * 8);
467 const unsigned short* r1 = bottom_blob.row<const unsigned short>(i * 8 + 1);
468 const unsigned short* r2 = bottom_blob.row<const unsigned short>(i * 8 + 2);
469 const unsigned short* r3 = bottom_blob.row<const unsigned short>(i * 8 + 3);
470 const unsigned short* r4 = bottom_blob.row<const unsigned short>(i * 8 + 4);
471 const unsigned short* r5 = bottom_blob.row<const unsigned short>(i * 8 + 5);
472 const unsigned short* r6 = bottom_blob.row<const unsigned short>(i * 8 + 6);
473 const unsigned short* r7 = bottom_blob.row<const unsigned short>(i * 8 + 7);
474
475 unsigned short* outptr = top_blob.row<unsigned short>(i);
476
477 int j = 0;
478 #if __ARM_NEON
479 for (; j + 7 < w; j += 8)
480 {
481 // transpose 8x8
482 #if __aarch64__
483 asm volatile(
484 "ld1 {v0.8h}, [%0], #16 \n"
485 "ld1 {v1.8h}, [%1], #16 \n"
486 "ld1 {v2.8h}, [%2], #16 \n"
487 "ld1 {v3.8h}, [%3], #16 \n"
488 "ld1 {v4.8h}, [%4], #16 \n"
489 "ld1 {v5.8h}, [%5], #16 \n"
490 "ld1 {v6.8h}, [%6], #16 \n"
491 "ld1 {v7.8h}, [%7], #16 \n"
492
493 "zip1 v16.8h, v0.8h, v4.8h \n"
494 "zip2 v20.8h, v0.8h, v4.8h \n"
495 "zip1 v17.8h, v1.8h, v5.8h \n"
496 "zip2 v21.8h, v1.8h, v5.8h \n"
497 "zip1 v18.8h, v2.8h, v6.8h \n"
498 "zip2 v22.8h, v2.8h, v6.8h \n"
499 "zip1 v19.8h, v3.8h, v7.8h \n"
500 "zip2 v23.8h, v3.8h, v7.8h \n"
501
502 "st4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%8], #64 \n"
503 "st4 {v20.8h, v21.8h, v22.8h, v23.8h}, [%8], #64 \n"
504 : "=r"(r0), // %0
505 "=r"(r1), // %1
506 "=r"(r2), // %2
507 "=r"(r3), // %3
508 "=r"(r4), // %4
509 "=r"(r5), // %5
510 "=r"(r6), // %6
511 "=r"(r7), // %7
512 "=r"(outptr) // %8
513 : "0"(r0),
514 "1"(r1),
515 "2"(r2),
516 "3"(r3),
517 "4"(r4),
518 "5"(r5),
519 "6"(r6),
520 "7"(r7),
521 "8"(outptr)
522 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
523 #else
524 asm volatile(
525 "vld1.u16 {d16-d17}, [%0 : 128]! \n"
526 "vld1.u16 {d18-d19}, [%1 : 128]! \n"
527 "vld1.u16 {d20-d21}, [%2 : 128]! \n"
528 "vld1.u16 {d22-d23}, [%3 : 128]! \n"
529 "vld1.u16 {d24-d25}, [%4 : 128]! \n"
530 "vld1.u16 {d26-d27}, [%5 : 128]! \n"
531 "vld1.u16 {d28-d29}, [%6 : 128]! \n"
532 "vld1.u16 {d30-d31}, [%7 : 128]! \n"
533
534 "vtrn.u16 q8, q9 \n"
535 "vtrn.u16 q10, q11 \n"
536 "vtrn.u16 q12, q13 \n"
537 "vtrn.u16 q14, q15 \n"
538
539 "vtrn.u32 q8, q10 \n"
540 "vtrn.u32 q9, q11 \n"
541 "vtrn.u32 q12, q14 \n"
542 "vtrn.u32 q13, q15 \n"
543
544 "vswp d17, d24 \n"
545 "vswp d19, d26 \n"
546 "vswp d21, d28 \n"
547 "vswp d23, d30 \n"
548
549 "vstm %8!, {d16-d23} \n"
550 "vstm %8!, {d24-d31} \n"
551 : "=r"(r0), // %0
552 "=r"(r1), // %1
553 "=r"(r2), // %2
554 "=r"(r3), // %3
555 "=r"(r4), // %4
556 "=r"(r5), // %5
557 "=r"(r6), // %6
558 "=r"(r7), // %7
559 "=r"(outptr) // %8
560 : "0"(r0),
561 "1"(r1),
562 "2"(r2),
563 "3"(r3),
564 "4"(r4),
565 "5"(r5),
566 "6"(r6),
567 "7"(r7),
568 "8"(outptr)
569 : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
570 #endif
571 }
572 #endif
573 for (; j < w; j++)
574 {
575 outptr[0] = *r0++;
576 outptr[1] = *r1++;
577 outptr[2] = *r2++;
578 outptr[3] = *r3++;
579 outptr[4] = *r4++;
580 outptr[5] = *r5++;
581 outptr[6] = *r6++;
582 outptr[7] = *r7++;
583
584 outptr += 8;
585 }
586 }
587 }
588 if (pack8to1)
589 {
590 #pragma omp parallel for num_threads(opt.num_threads)
591 for (int i = 0; i < h; i++)
592 {
593 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i);
594
595 unsigned short* outptr0 = top_blob.row<unsigned short>(i * 8);
596 unsigned short* outptr1 = top_blob.row<unsigned short>(i * 8 + 1);
597 unsigned short* outptr2 = top_blob.row<unsigned short>(i * 8 + 2);
598 unsigned short* outptr3 = top_blob.row<unsigned short>(i * 8 + 3);
599 unsigned short* outptr4 = top_blob.row<unsigned short>(i * 8 + 4);
600 unsigned short* outptr5 = top_blob.row<unsigned short>(i * 8 + 5);
601 unsigned short* outptr6 = top_blob.row<unsigned short>(i * 8 + 6);
602 unsigned short* outptr7 = top_blob.row<unsigned short>(i * 8 + 7);
603
604 int j = 0;
605 #if __ARM_NEON
606 for (; j + 7 < w; j += 8)
607 {
608 // transpose 8x8
609 #if __aarch64__
610 asm volatile(
611 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
612 "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n"
613
614 "uzp1 v16.8h, v0.8h, v4.8h \n"
615 "uzp2 v20.8h, v0.8h, v4.8h \n"
616 "uzp1 v17.8h, v1.8h, v5.8h \n"
617 "uzp2 v21.8h, v1.8h, v5.8h \n"
618 "uzp1 v18.8h, v2.8h, v6.8h \n"
619 "uzp2 v22.8h, v2.8h, v6.8h \n"
620 "uzp1 v19.8h, v3.8h, v7.8h \n"
621 "uzp2 v23.8h, v3.8h, v7.8h \n"
622
623 "st1 {v16.8h}, [%1], #16 \n"
624 "st1 {v17.8h}, [%2], #16 \n"
625 "st1 {v18.8h}, [%3], #16 \n"
626 "st1 {v19.8h}, [%4], #16 \n"
627 "st1 {v20.8h}, [%5], #16 \n"
628 "st1 {v21.8h}, [%6], #16 \n"
629 "st1 {v22.8h}, [%7], #16 \n"
630 "st1 {v23.8h}, [%8], #16 \n"
631 : "=r"(r0), // %0
632 "=r"(outptr0), // %1
633 "=r"(outptr1), // %2
634 "=r"(outptr2), // %3
635 "=r"(outptr3), // %4
636 "=r"(outptr4), // %5
637 "=r"(outptr5), // %6
638 "=r"(outptr6), // %7
639 "=r"(outptr7) // %8
640 : "0"(r0),
641 "1"(outptr0),
642 "2"(outptr1),
643 "3"(outptr2),
644 "4"(outptr3),
645 "5"(outptr4),
646 "6"(outptr5),
647 "7"(outptr6),
648 "8"(outptr7)
649 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
650 #else
651 asm volatile(
652 "vldm %0!, {d16-d23} \n"
653 "vldm %0!, {d24-d31} \n"
654
655 "vtrn.u16 q8, q9 \n"
656 "vtrn.u16 q10, q11 \n"
657 "vtrn.u16 q12, q13 \n"
658 "vtrn.u16 q14, q15 \n"
659
660 "vtrn.u32 q8, q10 \n"
661 "vtrn.u32 q9, q11 \n"
662 "vtrn.u32 q12, q14 \n"
663 "vtrn.u32 q13, q15 \n"
664
665 "vswp d17, d24 \n"
666 "vswp d19, d26 \n"
667 "vswp d21, d28 \n"
668 "vswp d23, d30 \n"
669
670 "vst1.u16 {d16-d17}, [%1 : 128]! \n"
671 "vst1.u16 {d18-d19}, [%2 : 128]! \n"
672 "vst1.u16 {d20-d21}, [%3 : 128]! \n"
673 "vst1.u16 {d22-d23}, [%4 : 128]! \n"
674 "vst1.u16 {d24-d25}, [%5 : 128]! \n"
675 "vst1.u16 {d26-d27}, [%6 : 128]! \n"
676 "vst1.u16 {d28-d29}, [%7 : 128]! \n"
677 "vst1.u16 {d30-d31}, [%8 : 128]! \n"
678 : "=r"(r0), // %0
679 "=r"(outptr0), // %1
680 "=r"(outptr1), // %2
681 "=r"(outptr2), // %3
682 "=r"(outptr3), // %4
683 "=r"(outptr4), // %5
684 "=r"(outptr5), // %6
685 "=r"(outptr6), // %7
686 "=r"(outptr7) // %8
687 : "0"(r0),
688 "1"(outptr0),
689 "2"(outptr1),
690 "3"(outptr2),
691 "4"(outptr3),
692 "5"(outptr4),
693 "6"(outptr5),
694 "7"(outptr6),
695 "8"(outptr7)
696 : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
697 #endif
698 }
699 #endif
700 for (; j < w; j++)
701 {
702 *outptr0++ = r0[0];
703 *outptr1++ = r0[1];
704 *outptr2++ = r0[2];
705 *outptr3++ = r0[3];
706 *outptr4++ = r0[4];
707 *outptr5++ = r0[5];
708 *outptr6++ = r0[6];
709 *outptr7++ = r0[7];
710
711 r0 += 8;
712 }
713 }
714 }
715 if (pack4to8)
716 {
717 #pragma omp parallel for num_threads(opt.num_threads)
718 for (int i = 0; i < outh; i++)
719 {
720 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i * 2);
721 const unsigned short* r1 = bottom_blob.row<const unsigned short>(i * 2 + 1);
722
723 unsigned short* outptr = top_blob.row<unsigned short>(i);
724
725 int j = 0;
726 #if __ARM_NEON
727 for (; j + 1 < w; j += 2)
728 {
729 #if __aarch64__
730 asm volatile(
731 "ld1 {v0.8h}, [%0], #16 \n"
732 "ld1 {v1.8h}, [%1], #16 \n"
733
734 "zip1 v2.2d, v0.2d, v1.2d \n"
735 "zip2 v3.2d, v0.2d, v1.2d \n"
736
737 "st1 {v2.8h, v3.8h}, [%2], #32\n"
738 : "=r"(r0), // %0
739 "=r"(r1), // %1
740 "=r"(outptr) // %2
741 : "0"(r0),
742 "1"(r1),
743 "2"(outptr)
744 : "memory", "v0", "v1", "v2", "v3");
745 #else
746 asm volatile(
747 "vld1.u16 {d0-d1}, [%0 :128]! \n"
748 "vld1.u16 {d2-d3}, [%1 :128]! \n"
749
750 "vswp d1, d2 \n"
751
752 "vst1.u16 {d0-d3}, [%2 :128]! \n"
753 : "=r"(r0), // %0
754 "=r"(r1), // %1
755 "=r"(outptr) // %2
756 : "0"(r0),
757 "1"(r1),
758 "2"(outptr)
759 : "memory", "q0", "q1");
760 #endif
761 }
762 #endif
763 for (; j < w; j++)
764 {
765 outptr[0] = r0[0];
766 outptr[1] = r0[1];
767 outptr[2] = r0[2];
768 outptr[3] = r0[3];
769 outptr[4] = r1[0];
770 outptr[5] = r1[1];
771 outptr[6] = r1[2];
772 outptr[7] = r1[3];
773
774 r0 += 4;
775 r1 += 4;
776 outptr += 8;
777 }
778 }
779 }
780 if (pack8to4)
781 {
782 #pragma omp parallel for num_threads(opt.num_threads)
783 for (int i = 0; i < h; i++)
784 {
785 const unsigned short* r0 = bottom_blob.row<const unsigned short>(i);
786
787 unsigned short* outptr0 = top_blob.row<unsigned short>(i * 2);
788 unsigned short* outptr1 = top_blob.row<unsigned short>(i * 2 + 1);
789
790 int j = 0;
791 #if __ARM_NEON
792 for (; j + 1 < w; j += 2)
793 {
794 #if __aarch64__
795 asm volatile(
796 "ld1 {v0.8h, v1.8h}, [%0], #32 \n"
797
798 "uzp1 v2.2d, v0.2d, v1.2d \n"
799 "uzp2 v3.2d, v0.2d, v1.2d \n"
800
801 "st1 {v2.8h}, [%1], #16 \n"
802 "st1 {v3.8h}, [%2], #16 \n"
803 : "=r"(r0), // %0
804 "=r"(outptr0), // %1
805 "=r"(outptr1) // %2
806 : "0"(r0),
807 "1"(outptr0),
808 "2"(outptr1)
809 : "memory", "v0", "v1", "v2", "v3");
810 #else
811 asm volatile(
812 "vld1.u16 {d0-d3}, [%0 :128]! \n"
813
814 "vswp d1, d2 \n"
815
816 "vst1.u16 {d0-d1}, [%1 :128]! \n"
817 "vst1.u16 {d2-d3}, [%2 :128]! \n"
818 : "=r"(r0), // %0
819 "=r"(outptr0), // %1
820 "=r"(outptr1) // %2
821 : "0"(r0),
822 "1"(outptr0),
823 "2"(outptr1)
824 : "memory", "q0", "q1");
825 #endif
826 }
827 #endif
828 for (; j < w; j++)
829 {
830 outptr0[0] = r0[0];
831 outptr0[1] = r0[1];
832 outptr0[2] = r0[2];
833 outptr0[3] = r0[3];
834 outptr1[0] = r0[4];
835 outptr1[1] = r0[5];
836 outptr1[2] = r0[6];
837 outptr1[3] = r0[7];
838
839 r0 += 8;
840 outptr0 += 4;
841 outptr1 += 4;
842 }
843 }
844 }
845
846 return 0;
847 }
848
849 if (dims == 3 || dims == 4)
850 {
851 int size = w * h * d;
852 int outc = channels * elempack / out_elempack;
853 size_t out_elemsize = elemsize / elempack * out_elempack;
854
855 if (dims == 3)
856 top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator);
857 else // if (dims == 4)
858 top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator);
859 if (top_blob.empty())
860 return -100;
861
862 if (pack1to4)
863 {
864 #pragma omp parallel for num_threads(opt.num_threads)
865 for (int q = 0; q < outc; q++)
866 {
867 const unsigned short* r0 = bottom_blob.channel(q * 4);
868 const unsigned short* r1 = bottom_blob.channel(q * 4 + 1);
869 const unsigned short* r2 = bottom_blob.channel(q * 4 + 2);
870 const unsigned short* r3 = bottom_blob.channel(q * 4 + 3);
871
872 unsigned short* outptr = top_blob.channel(q);
873
874 int i = 0;
875 #if __ARM_NEON
876 for (; i + 3 < size; i += 4)
877 {
878 uint16x4x4_t _p;
879 _p.val[0] = vld1_u16(r0);
880 _p.val[1] = vld1_u16(r1);
881 _p.val[2] = vld1_u16(r2);
882 _p.val[3] = vld1_u16(r3);
883 vst4_u16(outptr, _p);
884
885 r0 += 4;
886 r1 += 4;
887 r2 += 4;
888 r3 += 4;
889 outptr += 16;
890 }
891 #endif
892 for (; i < size; i++)
893 {
894 outptr[0] = *r0++;
895 outptr[1] = *r1++;
896 outptr[2] = *r2++;
897 outptr[3] = *r3++;
898
899 outptr += 4;
900 }
901 }
902 }
903 if (pack4to1)
904 {
905 #pragma omp parallel for num_threads(opt.num_threads)
906 for (int q = 0; q < channels; q++)
907 {
908 const unsigned short* r0 = bottom_blob.channel(q);
909
910 unsigned short* outptr0 = top_blob.channel(q * 4);
911 unsigned short* outptr1 = top_blob.channel(q * 4 + 1);
912 unsigned short* outptr2 = top_blob.channel(q * 4 + 2);
913 unsigned short* outptr3 = top_blob.channel(q * 4 + 3);
914
915 int i = 0;
916 #if __ARM_NEON
917 for (; i + 3 < size; i += 4)
918 {
919 uint16x4x4_t _p = vld4_u16(r0);
920 vst1_u16(outptr0, _p.val[0]);
921 vst1_u16(outptr1, _p.val[1]);
922 vst1_u16(outptr2, _p.val[2]);
923 vst1_u16(outptr3, _p.val[3]);
924
925 r0 += 16;
926 outptr0 += 4;
927 outptr1 += 4;
928 outptr2 += 4;
929 outptr3 += 4;
930 }
931 #endif
932 for (; i < size; i++)
933 {
934 *outptr0++ = r0[0];
935 *outptr1++ = r0[1];
936 *outptr2++ = r0[2];
937 *outptr3++ = r0[3];
938
939 r0 += 4;
940 }
941 }
942 }
943 if (pack1to8)
944 {
945 #pragma omp parallel for num_threads(opt.num_threads)
946 for (int q = 0; q < outc; q++)
947 {
948 const unsigned short* r0 = bottom_blob.channel(q * 8);
949 const unsigned short* r1 = bottom_blob.channel(q * 8 + 1);
950 const unsigned short* r2 = bottom_blob.channel(q * 8 + 2);
951 const unsigned short* r3 = bottom_blob.channel(q * 8 + 3);
952 const unsigned short* r4 = bottom_blob.channel(q * 8 + 4);
953 const unsigned short* r5 = bottom_blob.channel(q * 8 + 5);
954 const unsigned short* r6 = bottom_blob.channel(q * 8 + 6);
955 const unsigned short* r7 = bottom_blob.channel(q * 8 + 7);
956
957 unsigned short* outptr = top_blob.channel(q);
958
959 int i = 0;
960 #if __ARM_NEON
961 for (; i + 7 < size; i += 8)
962 {
963 // transpose 8x8
964 #if __aarch64__
965 asm volatile(
966 "ld1 {v0.8h}, [%0], #16 \n"
967 "ld1 {v1.8h}, [%1], #16 \n"
968 "ld1 {v2.8h}, [%2], #16 \n"
969 "ld1 {v3.8h}, [%3], #16 \n"
970 "ld1 {v4.8h}, [%4], #16 \n"
971 "ld1 {v5.8h}, [%5], #16 \n"
972 "ld1 {v6.8h}, [%6], #16 \n"
973 "ld1 {v7.8h}, [%7], #16 \n"
974
975 "zip1 v16.8h, v0.8h, v4.8h \n"
976 "zip2 v20.8h, v0.8h, v4.8h \n"
977 "zip1 v17.8h, v1.8h, v5.8h \n"
978 "zip2 v21.8h, v1.8h, v5.8h \n"
979 "zip1 v18.8h, v2.8h, v6.8h \n"
980 "zip2 v22.8h, v2.8h, v6.8h \n"
981 "zip1 v19.8h, v3.8h, v7.8h \n"
982 "zip2 v23.8h, v3.8h, v7.8h \n"
983
984 "st4 {v16.8h, v17.8h, v18.8h, v19.8h}, [%8], #64 \n"
985 "st4 {v20.8h, v21.8h, v22.8h, v23.8h}, [%8], #64 \n"
986 : "=r"(r0), // %0
987 "=r"(r1), // %1
988 "=r"(r2), // %2
989 "=r"(r3), // %3
990 "=r"(r4), // %4
991 "=r"(r5), // %5
992 "=r"(r6), // %6
993 "=r"(r7), // %7
994 "=r"(outptr) // %8
995 : "0"(r0),
996 "1"(r1),
997 "2"(r2),
998 "3"(r3),
999 "4"(r4),
1000 "5"(r5),
1001 "6"(r6),
1002 "7"(r7),
1003 "8"(outptr)
1004 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1005 #else
1006 asm volatile(
1007 "vld1.u16 {d16-d17}, [%0 : 128]! \n"
1008 "vld1.u16 {d18-d19}, [%1 : 128]! \n"
1009 "vld1.u16 {d20-d21}, [%2 : 128]! \n"
1010 "vld1.u16 {d22-d23}, [%3 : 128]! \n"
1011 "vld1.u16 {d24-d25}, [%4 : 128]! \n"
1012 "vld1.u16 {d26-d27}, [%5 : 128]! \n"
1013 "vld1.u16 {d28-d29}, [%6 : 128]! \n"
1014 "vld1.u16 {d30-d31}, [%7 : 128]! \n"
1015
1016 "vtrn.u16 q8, q9 \n"
1017 "vtrn.u16 q10, q11 \n"
1018 "vtrn.u16 q12, q13 \n"
1019 "vtrn.u16 q14, q15 \n"
1020
1021 "vtrn.u32 q8, q10 \n"
1022 "vtrn.u32 q9, q11 \n"
1023 "vtrn.u32 q12, q14 \n"
1024 "vtrn.u32 q13, q15 \n"
1025
1026 "vswp d17, d24 \n"
1027 "vswp d19, d26 \n"
1028 "vswp d21, d28 \n"
1029 "vswp d23, d30 \n"
1030
1031 "vstm %8!, {d16-d23} \n"
1032 "vstm %8!, {d24-d31} \n"
1033 : "=r"(r0), // %0
1034 "=r"(r1), // %1
1035 "=r"(r2), // %2
1036 "=r"(r3), // %3
1037 "=r"(r4), // %4
1038 "=r"(r5), // %5
1039 "=r"(r6), // %6
1040 "=r"(r7), // %7
1041 "=r"(outptr) // %8
1042 : "0"(r0),
1043 "1"(r1),
1044 "2"(r2),
1045 "3"(r3),
1046 "4"(r4),
1047 "5"(r5),
1048 "6"(r6),
1049 "7"(r7),
1050 "8"(outptr)
1051 : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1052 #endif
1053 }
1054 #endif
1055 for (; i < size; i++)
1056 {
1057 outptr[0] = *r0++;
1058 outptr[1] = *r1++;
1059 outptr[2] = *r2++;
1060 outptr[3] = *r3++;
1061 outptr[4] = *r4++;
1062 outptr[5] = *r5++;
1063 outptr[6] = *r6++;
1064 outptr[7] = *r7++;
1065
1066 outptr += 8;
1067 }
1068 }
1069 }
1070 if (pack8to1)
1071 {
1072 #pragma omp parallel for num_threads(opt.num_threads)
1073 for (int q = 0; q < channels; q++)
1074 {
1075 const unsigned short* r0 = bottom_blob.channel(q);
1076
1077 unsigned short* outptr0 = top_blob.channel(q * 8);
1078 unsigned short* outptr1 = top_blob.channel(q * 8 + 1);
1079 unsigned short* outptr2 = top_blob.channel(q * 8 + 2);
1080 unsigned short* outptr3 = top_blob.channel(q * 8 + 3);
1081 unsigned short* outptr4 = top_blob.channel(q * 8 + 4);
1082 unsigned short* outptr5 = top_blob.channel(q * 8 + 5);
1083 unsigned short* outptr6 = top_blob.channel(q * 8 + 6);
1084 unsigned short* outptr7 = top_blob.channel(q * 8 + 7);
1085
1086 int i = 0;
1087 #if __ARM_NEON
1088 for (; i + 7 < size; i += 8)
1089 {
1090 // transpose 8x8
1091 #if __aarch64__
1092 asm volatile(
1093 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
1094 "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0], #64 \n"
1095
1096 "uzp1 v16.8h, v0.8h, v4.8h \n"
1097 "uzp2 v20.8h, v0.8h, v4.8h \n"
1098 "uzp1 v17.8h, v1.8h, v5.8h \n"
1099 "uzp2 v21.8h, v1.8h, v5.8h \n"
1100 "uzp1 v18.8h, v2.8h, v6.8h \n"
1101 "uzp2 v22.8h, v2.8h, v6.8h \n"
1102 "uzp1 v19.8h, v3.8h, v7.8h \n"
1103 "uzp2 v23.8h, v3.8h, v7.8h \n"
1104
1105 "st1 {v16.8h}, [%1], #16 \n"
1106 "st1 {v17.8h}, [%2], #16 \n"
1107 "st1 {v18.8h}, [%3], #16 \n"
1108 "st1 {v19.8h}, [%4], #16 \n"
1109 "st1 {v20.8h}, [%5], #16 \n"
1110 "st1 {v21.8h}, [%6], #16 \n"
1111 "st1 {v22.8h}, [%7], #16 \n"
1112 "st1 {v23.8h}, [%8], #16 \n"
1113 : "=r"(r0), // %0
1114 "=r"(outptr0), // %1
1115 "=r"(outptr1), // %2
1116 "=r"(outptr2), // %3
1117 "=r"(outptr3), // %4
1118 "=r"(outptr4), // %5
1119 "=r"(outptr5), // %6
1120 "=r"(outptr6), // %7
1121 "=r"(outptr7) // %8
1122 : "0"(r0),
1123 "1"(outptr0),
1124 "2"(outptr1),
1125 "3"(outptr2),
1126 "4"(outptr3),
1127 "5"(outptr4),
1128 "6"(outptr5),
1129 "7"(outptr6),
1130 "8"(outptr7)
1131 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1132 #else
1133 asm volatile(
1134 "vldm %0!, {d16-d23} \n"
1135 "vldm %0!, {d24-d31} \n"
1136
1137 "vtrn.u16 q8, q9 \n"
1138 "vtrn.u16 q10, q11 \n"
1139 "vtrn.u16 q12, q13 \n"
1140 "vtrn.u16 q14, q15 \n"
1141
1142 "vtrn.u32 q8, q10 \n"
1143 "vtrn.u32 q9, q11 \n"
1144 "vtrn.u32 q12, q14 \n"
1145 "vtrn.u32 q13, q15 \n"
1146
1147 "vswp d17, d24 \n"
1148 "vswp d19, d26 \n"
1149 "vswp d21, d28 \n"
1150 "vswp d23, d30 \n"
1151
1152 "vst1.u16 {d16-d17}, [%1 : 128]! \n"
1153 "vst1.u16 {d18-d19}, [%2 : 128]! \n"
1154 "vst1.u16 {d20-d21}, [%3 : 128]! \n"
1155 "vst1.u16 {d22-d23}, [%4 : 128]! \n"
1156 "vst1.u16 {d24-d25}, [%5 : 128]! \n"
1157 "vst1.u16 {d26-d27}, [%6 : 128]! \n"
1158 "vst1.u16 {d28-d29}, [%7 : 128]! \n"
1159 "vst1.u16 {d30-d31}, [%8 : 128]! \n"
1160 : "=r"(r0), // %0
1161 "=r"(outptr0), // %1
1162 "=r"(outptr1), // %2
1163 "=r"(outptr2), // %3
1164 "=r"(outptr3), // %4
1165 "=r"(outptr4), // %5
1166 "=r"(outptr5), // %6
1167 "=r"(outptr6), // %7
1168 "=r"(outptr7) // %8
1169 : "0"(r0),
1170 "1"(outptr0),
1171 "2"(outptr1),
1172 "3"(outptr2),
1173 "4"(outptr3),
1174 "5"(outptr4),
1175 "6"(outptr5),
1176 "7"(outptr6),
1177 "8"(outptr7)
1178 : "memory", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1179 #endif
1180 }
1181 #endif
1182 for (; i < size; i++)
1183 {
1184 *outptr0++ = r0[0];
1185 *outptr1++ = r0[1];
1186 *outptr2++ = r0[2];
1187 *outptr3++ = r0[3];
1188 *outptr4++ = r0[4];
1189 *outptr5++ = r0[5];
1190 *outptr6++ = r0[6];
1191 *outptr7++ = r0[7];
1192
1193 r0 += 8;
1194 }
1195 }
1196 }
1197 if (pack4to8)
1198 {
1199 #pragma omp parallel for num_threads(opt.num_threads)
1200 for (int q = 0; q < outc; q++)
1201 {
1202 const unsigned short* r0 = bottom_blob.channel(q * 2);
1203 const unsigned short* r1 = bottom_blob.channel(q * 2 + 1);
1204
1205 unsigned short* outptr = top_blob.channel(q);
1206
1207 int i = 0;
1208 #if __ARM_NEON
1209 for (; i + 1 < size; i += 2)
1210 {
1211 #if __aarch64__
1212 asm volatile(
1213 "ld1 {v0.8h}, [%0], #16 \n"
1214 "ld1 {v1.8h}, [%1], #16 \n"
1215
1216 "zip1 v2.2d, v0.2d, v1.2d \n"
1217 "zip2 v3.2d, v0.2d, v1.2d \n"
1218
1219 "st1 {v2.8h, v3.8h}, [%2], #32\n"
1220 : "=r"(r0), // %0
1221 "=r"(r1), // %1
1222 "=r"(outptr) // %2
1223 : "0"(r0),
1224 "1"(r1),
1225 "2"(outptr)
1226 : "memory", "v0", "v1", "v2", "v3");
1227 #else
1228 asm volatile(
1229 "vld1.u16 {d0-d1}, [%0 :128]! \n"
1230 "vld1.u16 {d2-d3}, [%1 :128]! \n"
1231
1232 "vswp d1, d2 \n"
1233
1234 "vst1.u16 {d0-d3}, [%2 :128]! \n"
1235 : "=r"(r0), // %0
1236 "=r"(r1), // %1
1237 "=r"(outptr) // %2
1238 : "0"(r0),
1239 "1"(r1),
1240 "2"(outptr)
1241 : "memory", "q0", "q1");
1242 #endif
1243 }
1244 #endif
1245 for (; i < size; i++)
1246 {
1247 outptr[0] = r0[0];
1248 outptr[1] = r0[1];
1249 outptr[2] = r0[2];
1250 outptr[3] = r0[3];
1251 outptr[4] = r1[0];
1252 outptr[5] = r1[1];
1253 outptr[6] = r1[2];
1254 outptr[7] = r1[3];
1255
1256 r0 += 4;
1257 r1 += 4;
1258 outptr += 8;
1259 }
1260 }
1261 }
1262 if (pack8to4)
1263 {
1264 #pragma omp parallel for num_threads(opt.num_threads)
1265 for (int q = 0; q < channels; q++)
1266 {
1267 const unsigned short* r0 = bottom_blob.channel(q);
1268
1269 unsigned short* outptr0 = top_blob.channel(q * 2);
1270 unsigned short* outptr1 = top_blob.channel(q * 2 + 1);
1271
1272 int i = 0;
1273 #if __ARM_NEON
1274 for (; i + 1 < size; i += 2)
1275 {
1276 #if __aarch64__
1277 asm volatile(
1278 "ld1 {v0.8h, v1.8h}, [%0], #32 \n"
1279
1280 "uzp1 v2.2d, v0.2d, v1.2d \n"
1281 "uzp2 v3.2d, v0.2d, v1.2d \n"
1282
1283 "st1 {v2.8h}, [%1], #16 \n"
1284 "st1 {v3.8h}, [%2], #16 \n"
1285 : "=r"(r0), // %0
1286 "=r"(outptr0), // %1
1287 "=r"(outptr1) // %2
1288 : "0"(r0),
1289 "1"(outptr0),
1290 "2"(outptr1)
1291 : "memory", "v0", "v1", "v2", "v3");
1292 #else
1293 asm volatile(
1294 "vld1.u16 {d0-d3}, [%0 :128]! \n"
1295
1296 "vswp d1, d2 \n"
1297
1298 "vst1.u16 {d0-d1}, [%1 :128]! \n"
1299 "vst1.u16 {d2-d3}, [%2 :128]! \n"
1300 : "=r"(r0), // %0
1301 "=r"(outptr0), // %1
1302 "=r"(outptr1) // %2
1303 : "0"(r0),
1304 "1"(outptr0),
1305 "2"(outptr1)
1306 : "memory", "q0", "q1");
1307 #endif
1308 }
1309 #endif
1310 for (; i < size; i++)
1311 {
1312 outptr0[0] = r0[0];
1313 outptr0[1] = r0[1];
1314 outptr0[2] = r0[2];
1315 outptr0[3] = r0[3];
1316 outptr1[0] = r0[4];
1317 outptr1[1] = r0[5];
1318 outptr1[2] = r0[6];
1319 outptr1[3] = r0[7];
1320
1321 r0 += 8;
1322 outptr0 += 4;
1323 outptr1 += 4;
1324 }
1325 }
1326 }
1327
1328 return 0;
1329 }
1330
1331 return 0;
1332 }
1333
forward_int8(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const1334 int Packing_arm::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
1335 {
1336 if (use_padding)
1337 {
1338 return Packing::forward(bottom_blob, top_blob, opt);
1339 }
1340
1341 size_t elemsize = bottom_blob.elemsize;
1342 int elempack = bottom_blob.elempack;
1343
1344 if (elempack == out_elempack)
1345 {
1346 top_blob = bottom_blob;
1347 return 0;
1348 }
1349
1350 bool pack1to8 = elempack == 1 && out_elempack == 8;
1351 bool pack8to1 = elempack == 8 && out_elempack == 1;
1352
1353 if (!pack1to8 && !pack8to1)
1354 {
1355 return Packing::forward(bottom_blob, top_blob, opt);
1356 }
1357
1358 int w = bottom_blob.w;
1359 int h = bottom_blob.h;
1360 int d = bottom_blob.d;
1361 int channels = bottom_blob.c;
1362 int dims = bottom_blob.dims;
1363
1364 if (!use_padding)
1365 {
1366 // identity if use_padding not allowed
1367 if (dims == 1 && w * elempack % out_elempack != 0)
1368 {
1369 top_blob = bottom_blob;
1370 return 0;
1371 }
1372 if (dims == 2 && h * elempack % out_elempack != 0)
1373 {
1374 top_blob = bottom_blob;
1375 return 0;
1376 }
1377 if ((dims == 3 || dims == 4) && channels * elempack % out_elempack != 0)
1378 {
1379 top_blob = bottom_blob;
1380 return 0;
1381 }
1382 }
1383
1384 if (dims == 1)
1385 {
1386 top_blob = bottom_blob;
1387 top_blob.w = w * elempack / out_elempack;
1388 top_blob.cstep = w * elempack / out_elempack;
1389 top_blob.elemsize = elemsize / elempack * out_elempack;
1390 top_blob.elempack = out_elempack;
1391 return 0;
1392 }
1393
1394 if (dims == 2)
1395 {
1396 int outh = h * elempack / out_elempack;
1397 size_t out_elemsize = elemsize / elempack * out_elempack;
1398
1399 top_blob.create(w, outh, out_elemsize, out_elempack, opt.blob_allocator);
1400 if (top_blob.empty())
1401 return -100;
1402
1403 if (pack1to8)
1404 {
1405 #pragma omp parallel for num_threads(opt.num_threads)
1406 for (int i = 0; i < outh; i++)
1407 {
1408 const signed char* r0 = bottom_blob.row<const signed char>(i * 8);
1409 const signed char* r1 = bottom_blob.row<const signed char>(i * 8 + 1);
1410 const signed char* r2 = bottom_blob.row<const signed char>(i * 8 + 2);
1411 const signed char* r3 = bottom_blob.row<const signed char>(i * 8 + 3);
1412 const signed char* r4 = bottom_blob.row<const signed char>(i * 8 + 4);
1413 const signed char* r5 = bottom_blob.row<const signed char>(i * 8 + 5);
1414 const signed char* r6 = bottom_blob.row<const signed char>(i * 8 + 6);
1415 const signed char* r7 = bottom_blob.row<const signed char>(i * 8 + 7);
1416
1417 signed char* outptr = top_blob.row<signed char>(i);
1418
1419 int j = 0;
1420 for (; j < w; j++)
1421 {
1422 outptr[0] = *r0++;
1423 outptr[1] = *r1++;
1424 outptr[2] = *r2++;
1425 outptr[3] = *r3++;
1426 outptr[4] = *r4++;
1427 outptr[5] = *r5++;
1428 outptr[6] = *r6++;
1429 outptr[7] = *r7++;
1430
1431 outptr += 8;
1432 }
1433 }
1434 }
1435 if (pack8to1)
1436 {
1437 #pragma omp parallel for num_threads(opt.num_threads)
1438 for (int i = 0; i < h; i++)
1439 {
1440 const signed char* r0 = bottom_blob.row<const signed char>(i);
1441
1442 signed char* outptr0 = top_blob.row<signed char>(i * 8);
1443 signed char* outptr1 = top_blob.row<signed char>(i * 8 + 1);
1444 signed char* outptr2 = top_blob.row<signed char>(i * 8 + 2);
1445 signed char* outptr3 = top_blob.row<signed char>(i * 8 + 3);
1446 signed char* outptr4 = top_blob.row<signed char>(i * 8 + 4);
1447 signed char* outptr5 = top_blob.row<signed char>(i * 8 + 5);
1448 signed char* outptr6 = top_blob.row<signed char>(i * 8 + 6);
1449 signed char* outptr7 = top_blob.row<signed char>(i * 8 + 7);
1450
1451 int j = 0;
1452 for (; j < w; j++)
1453 {
1454 *outptr0++ = r0[0];
1455 *outptr1++ = r0[1];
1456 *outptr2++ = r0[2];
1457 *outptr3++ = r0[3];
1458 *outptr4++ = r0[4];
1459 *outptr5++ = r0[5];
1460 *outptr6++ = r0[6];
1461 *outptr7++ = r0[7];
1462
1463 r0 += 8;
1464 }
1465 }
1466 }
1467
1468 return 0;
1469 }
1470
1471 if (dims == 3 || dims == 4)
1472 {
1473 int size = w * h * d;
1474 int outc = channels * elempack / out_elempack;
1475 size_t out_elemsize = elemsize / elempack * out_elempack;
1476
1477 if (dims == 3)
1478 top_blob.create(w, h, outc, out_elemsize, out_elempack, opt.blob_allocator);
1479 else // if (dims == 4)
1480 top_blob.create(w, h, d, outc, out_elemsize, out_elempack, opt.blob_allocator);
1481 if (top_blob.empty())
1482 return -100;
1483
1484 if (pack1to8)
1485 {
1486 #pragma omp parallel for num_threads(opt.num_threads)
1487 for (int q = 0; q < outc; q++)
1488 {
1489 const signed char* r0 = bottom_blob.channel(q * 8);
1490 const signed char* r1 = bottom_blob.channel(q * 8 + 1);
1491 const signed char* r2 = bottom_blob.channel(q * 8 + 2);
1492 const signed char* r3 = bottom_blob.channel(q * 8 + 3);
1493 const signed char* r4 = bottom_blob.channel(q * 8 + 4);
1494 const signed char* r5 = bottom_blob.channel(q * 8 + 5);
1495 const signed char* r6 = bottom_blob.channel(q * 8 + 6);
1496 const signed char* r7 = bottom_blob.channel(q * 8 + 7);
1497
1498 signed char* outptr = top_blob.channel(q);
1499
1500 int i = 0;
1501 for (; i < size; i++)
1502 {
1503 outptr[0] = *r0++;
1504 outptr[1] = *r1++;
1505 outptr[2] = *r2++;
1506 outptr[3] = *r3++;
1507 outptr[4] = *r4++;
1508 outptr[5] = *r5++;
1509 outptr[6] = *r6++;
1510 outptr[7] = *r7++;
1511
1512 outptr += 8;
1513 }
1514 }
1515 }
1516 if (pack8to1)
1517 {
1518 #pragma omp parallel for num_threads(opt.num_threads)
1519 for (int q = 0; q < channels; q++)
1520 {
1521 const signed char* r0 = bottom_blob.channel(q);
1522
1523 signed char* outptr0 = top_blob.channel(q * 8);
1524 signed char* outptr1 = top_blob.channel(q * 8 + 1);
1525 signed char* outptr2 = top_blob.channel(q * 8 + 2);
1526 signed char* outptr3 = top_blob.channel(q * 8 + 3);
1527 signed char* outptr4 = top_blob.channel(q * 8 + 4);
1528 signed char* outptr5 = top_blob.channel(q * 8 + 5);
1529 signed char* outptr6 = top_blob.channel(q * 8 + 6);
1530 signed char* outptr7 = top_blob.channel(q * 8 + 7);
1531
1532 int i = 0;
1533 for (; i < size; i++)
1534 {
1535 *outptr0++ = r0[0];
1536 *outptr1++ = r0[1];
1537 *outptr2++ = r0[2];
1538 *outptr3++ = r0[3];
1539 *outptr4++ = r0[4];
1540 *outptr5++ = r0[5];
1541 *outptr6++ = r0[6];
1542 *outptr7++ = r0[7];
1543
1544 r0 += 8;
1545 }
1546 }
1547 }
1548
1549 return 0;
1550 }
1551
1552 return 0;
1553 }
1554
1555 } // namespace ncnn
1556