1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "reshape_arm.h"
16 
17 #if __ARM_NEON
18 #include <arm_neon.h>
19 #endif // __ARM_NEON
20 
21 namespace ncnn {
22 
Reshape_arm()23 Reshape_arm::Reshape_arm()
24 {
25 #if __ARM_NEON
26     support_packing = true;
27 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
28     support_fp16_storage = true;
29 #endif
30 #endif // __ARM_NEON
31 
32     support_bf16_storage = true;
33 }
34 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const35 int Reshape_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
36 {
37     int elembits = bottom_blob.elembits();
38 
39 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
40     if (opt.use_fp16_storage && elembits == 16)
41         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
42 #endif
43 
44     if (opt.use_bf16_storage && elembits == 16)
45         return forward_bf16s_fp16s(bottom_blob, top_blob, opt);
46 
47     int elempack = bottom_blob.elempack;
48 
49     if (permute == 1)
50     {
51         // TODO implement permute on-the-fly
52         Option opt_pack = opt;
53         opt_pack.blob_allocator = opt.workspace_allocator;
54 
55         Mat bottom_blob_unpacked;
56         convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
57 
58         Mat top_blob_unpacked;
59         int ret = Reshape::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
60         if (ret != 0)
61             return ret;
62 
63         int out_elempack = 1;
64         if (opt.use_packing_layout)
65         {
66             // resolve dst_elempack
67             int dims = top_blob_unpacked.dims;
68             if (dims == 1) out_elempack = top_blob_unpacked.w % 4 == 0 ? 4 : 1;
69             if (dims == 2) out_elempack = top_blob_unpacked.h % 4 == 0 ? 4 : 1;
70             if (dims == 3) out_elempack = top_blob_unpacked.c % 4 == 0 ? 4 : 1;
71         }
72         convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
73 
74         return 0;
75     }
76 
77     if (ndim == 1)
78     {
79         // flatten
80         flatten(bottom_blob, top_blob, opt);
81         if (top_blob.empty())
82             return -100;
83 
84         return 0;
85     }
86 
87     int dims = bottom_blob.dims;
88     size_t elemsize = bottom_blob.elemsize;
89 
90     int total = bottom_blob.w * bottom_blob.h * bottom_blob.c * elempack;
91 
92     if (ndim == 2)
93     {
94         int _w = w;
95         int _h = h;
96 
97         if (_w == 0)
98             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
99         if (_h == 0)
100             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
101 
102         if (_w == -1)
103             _w = total / _h;
104         if (_h == -1)
105             _h = total / _w;
106 
107         int out_elempack = opt.use_packing_layout && _h % 4 == 0 ? 4 : 1;
108         size_t out_elemsize = elemsize / elempack * out_elempack;
109 
110         if (dims == 2 && bottom_blob.h == _h && elempack == out_elempack)
111         {
112             top_blob = bottom_blob;
113             return 0;
114         }
115 
116         if (out_elempack == 1)
117         {
118             // flatten
119             flatten(bottom_blob, top_blob, opt);
120             if (top_blob.empty())
121                 return -100;
122 
123             top_blob.dims = 2;
124             top_blob.w = _w;
125             top_blob.h = _h;
126             top_blob.cstep = _w * _h;
127             top_blob.elemsize = out_elemsize;
128             top_blob.elempack = out_elempack;
129 
130             return 0;
131         }
132 
133         // flatten
134         Mat bottom_blob_flattened = bottom_blob;
135         {
136             Option opt_flatten = opt;
137             opt_flatten.blob_allocator = opt.workspace_allocator;
138 
139             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
140             if (bottom_blob_flattened.empty())
141                 return -100;
142         }
143 
144         top_blob.create(_w, _h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
145         if (top_blob.empty())
146             return -100;
147 
148         int outw = top_blob.w;
149         int outh = top_blob.h;
150 
151         // assert out_elempack == 4
152 
153         #pragma omp parallel for num_threads(opt.num_threads)
154         for (int i = 0; i < outh; i++)
155         {
156             const float* ptr0 = (const float*)bottom_blob_flattened + outw * i * 4;
157             const float* ptr1 = (const float*)bottom_blob_flattened + outw * (i * 4 + 1);
158             const float* ptr2 = (const float*)bottom_blob_flattened + outw * (i * 4 + 2);
159             const float* ptr3 = (const float*)bottom_blob_flattened + outw * (i * 4 + 3);
160             float* outptr = (float*)top_blob.row(i);
161 
162             int j = 0;
163 #if __ARM_NEON
164             for (; j + 3 < outw; j += 4)
165             {
166                 float32x4x4_t _v4;
167                 _v4.val[0] = vld1q_f32(ptr0);
168                 _v4.val[1] = vld1q_f32(ptr1);
169                 _v4.val[2] = vld1q_f32(ptr2);
170                 _v4.val[3] = vld1q_f32(ptr3);
171 
172                 vst4q_f32(outptr, _v4);
173 
174                 ptr0 += 4;
175                 ptr1 += 4;
176                 ptr2 += 4;
177                 ptr3 += 4;
178                 outptr += 16;
179             }
180 #endif
181             for (; j < outw; j++)
182             {
183                 outptr[0] = *ptr0++;
184                 outptr[1] = *ptr1++;
185                 outptr[2] = *ptr2++;
186                 outptr[3] = *ptr3++;
187 
188                 outptr += 4;
189             }
190         }
191     }
192 
193     if (ndim == 3)
194     {
195         int _w = w;
196         int _h = h;
197         int _c = c;
198 
199         if (_w == 0)
200             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
201         if (_h == 0)
202             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
203         if (_c == 0)
204             _c = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c;
205 
206         if (_w == -1)
207             _w = total / _c / _h;
208         if (_h == -1)
209             _h = total / _c / _w;
210         if (_c == -1)
211             _c = total / _h / _w;
212 
213         int out_elempack = opt.use_packing_layout && _c % 4 == 0 ? 4 : 1;
214         size_t out_elemsize = elemsize / elempack * out_elempack;
215 
216         if (dims == 3 && bottom_blob.c == _c && elempack == out_elempack)
217         {
218             top_blob = bottom_blob;
219             top_blob.w = _w;
220             top_blob.h = _h;
221             return 0;
222         }
223 
224         // flatten
225         Mat bottom_blob_flattened = bottom_blob;
226         {
227             Option opt_flatten = opt;
228             opt_flatten.blob_allocator = opt.workspace_allocator;
229 
230             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
231             if (bottom_blob_flattened.empty())
232                 return -100;
233         }
234 
235         top_blob.create(_w, _h, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
236         if (top_blob.empty())
237             return -100;
238 
239         int size = top_blob.w * top_blob.h;
240 
241         if (out_elempack == 4)
242         {
243             #pragma omp parallel for num_threads(opt.num_threads)
244             for (int q = 0; q < top_blob.c; q++)
245             {
246                 const float* ptr0 = (const float*)bottom_blob_flattened + size * q * 4;
247                 const float* ptr1 = (const float*)bottom_blob_flattened + size * (q * 4 + 1);
248                 const float* ptr2 = (const float*)bottom_blob_flattened + size * (q * 4 + 2);
249                 const float* ptr3 = (const float*)bottom_blob_flattened + size * (q * 4 + 3);
250                 float* outptr = top_blob.channel(q);
251 
252                 int i = 0;
253 #if __ARM_NEON
254                 for (; i + 3 < size; i += 4)
255                 {
256                     float32x4x4_t _v4;
257                     _v4.val[0] = vld1q_f32(ptr0);
258                     _v4.val[1] = vld1q_f32(ptr1);
259                     _v4.val[2] = vld1q_f32(ptr2);
260                     _v4.val[3] = vld1q_f32(ptr3);
261 
262                     vst4q_f32(outptr, _v4);
263 
264                     ptr0 += 4;
265                     ptr1 += 4;
266                     ptr2 += 4;
267                     ptr3 += 4;
268                     outptr += 16;
269                 }
270 #endif
271                 for (; i < size; i++)
272                 {
273                     outptr[0] = *ptr0++;
274                     outptr[1] = *ptr1++;
275                     outptr[2] = *ptr2++;
276                     outptr[3] = *ptr3++;
277 
278                     outptr += 4;
279                 }
280             }
281         }
282 
283         if (out_elempack == 1)
284         {
285             #pragma omp parallel for num_threads(opt.num_threads)
286             for (int q = 0; q < top_blob.c; q++)
287             {
288                 const float* ptr = (const float*)bottom_blob_flattened + size * q;
289                 float* outptr = top_blob.channel(q);
290 
291                 int i = 0;
292 #if __ARM_NEON
293                 for (; i + 3 < size; i += 4)
294                 {
295                     float32x4_t _v = vld1q_f32(ptr);
296                     vst1q_f32(outptr, _v);
297                     ptr += 4;
298                     outptr += 4;
299                 }
300 #endif
301                 for (; i < size; i++)
302                 {
303                     *outptr++ = *ptr++;
304                 }
305             }
306         }
307     }
308 
309     return 0;
310 }
311 
forward_bf16s_fp16s(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const312 int Reshape_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
313 {
314     int elempack = bottom_blob.elempack;
315 
316     if (permute == 1)
317     {
318         // TODO implement permute on-the-fly
319         Option opt_pack = opt;
320         opt_pack.blob_allocator = opt.workspace_allocator;
321 
322         Mat bottom_blob_unpacked;
323         convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
324 
325         Mat bottom_blob_unpacked_fp32;
326         cast_bfloat16_to_float32(bottom_blob_unpacked, bottom_blob_unpacked_fp32, opt_pack);
327 
328         Mat top_blob_unpacked_fp32;
329         int ret = Reshape::forward(bottom_blob_unpacked_fp32, top_blob_unpacked_fp32, opt_pack);
330         if (ret != 0)
331             return ret;
332 
333         Mat top_blob_unpacked;
334         cast_float32_to_bfloat16(top_blob_unpacked_fp32, top_blob_unpacked, opt_pack);
335 
336         int out_elempack = 1;
337         if (opt.use_packing_layout)
338         {
339             // resolve dst_elempack
340             int dims = top_blob_unpacked.dims;
341             if (dims == 1) out_elempack = opt.use_fp16_arithmetic && top_blob_unpacked.w % 8 == 0 ? 8 : top_blob_unpacked.w % 4 == 0 ? 4 : 1;
342             if (dims == 2) out_elempack = opt.use_fp16_arithmetic && top_blob_unpacked.h % 8 == 0 ? 8 : top_blob_unpacked.h % 4 == 0 ? 4 : 1;
343             if (dims == 3) out_elempack = opt.use_fp16_arithmetic && top_blob_unpacked.c % 8 == 0 ? 8 : top_blob_unpacked.c % 4 == 0 ? 4 : 1;
344         }
345         convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
346 
347         return 0;
348     }
349 
350     if (ndim == 1)
351     {
352         // flatten
353         flatten(bottom_blob, top_blob, opt);
354         if (top_blob.empty())
355             return -100;
356 
357         return 0;
358     }
359 
360     int dims = bottom_blob.dims;
361     size_t elemsize = bottom_blob.elemsize;
362 
363     int total = bottom_blob.w * bottom_blob.h * bottom_blob.c * elempack;
364 
365     if (ndim == 2)
366     {
367         int _w = w;
368         int _h = h;
369 
370         if (_w == 0)
371             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
372         if (_h == 0)
373             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
374 
375         if (_w == -1)
376             _w = total / _h;
377         if (_h == -1)
378             _h = total / _w;
379 
380         int out_elempack = 1;
381         if (opt.use_packing_layout)
382         {
383             out_elempack = opt.use_fp16_arithmetic && _h % 8 == 0 ? 8 : _h % 4 == 0 ? 4 : 1;
384         }
385         size_t out_elemsize = elemsize / elempack * out_elempack;
386 
387         if (dims == 2 && bottom_blob.h == _h && elempack == out_elempack)
388         {
389             top_blob = bottom_blob;
390             return 0;
391         }
392 
393         if (out_elempack == 1)
394         {
395             // flatten
396             flatten(bottom_blob, top_blob, opt);
397             if (top_blob.empty())
398                 return -100;
399 
400             top_blob.dims = 2;
401             top_blob.w = _w;
402             top_blob.h = _h;
403             top_blob.cstep = _w * _h;
404             top_blob.elemsize = out_elemsize;
405             top_blob.elempack = out_elempack;
406 
407             return 0;
408         }
409 
410         // flatten
411         Mat bottom_blob_flattened = bottom_blob;
412         {
413             Option opt_flatten = opt;
414             opt_flatten.blob_allocator = opt.workspace_allocator;
415 
416             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
417             if (bottom_blob_flattened.empty())
418                 return -100;
419         }
420 
421         top_blob.create(_w, _h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
422         if (top_blob.empty())
423             return -100;
424 
425         int outw = top_blob.w;
426         int outh = top_blob.h;
427 
428 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
429         if (out_elempack == 8)
430         {
431             #pragma omp parallel for num_threads(opt.num_threads)
432             for (int i = 0; i < outh; i++)
433             {
434                 const __fp16* ptr0 = (const __fp16*)bottom_blob_flattened + outw * i * 8;
435                 const __fp16* ptr1 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 1);
436                 const __fp16* ptr2 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 2);
437                 const __fp16* ptr3 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 3);
438                 const __fp16* ptr4 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 4);
439                 const __fp16* ptr5 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 5);
440                 const __fp16* ptr6 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 6);
441                 const __fp16* ptr7 = (const __fp16*)bottom_blob_flattened + outw * (i * 8 + 7);
442                 __fp16* outptr = top_blob.row<__fp16>(i);
443 
444                 int j = 0;
445                 for (; j + 3 < outw; j += 4)
446                 {
447                     float16x8_t _p01 = vcombine_f16(vld1_f16(ptr0), vld1_f16(ptr1));
448                     float16x8_t _p23 = vcombine_f16(vld1_f16(ptr2), vld1_f16(ptr3));
449                     float16x8_t _p45 = vcombine_f16(vld1_f16(ptr4), vld1_f16(ptr5));
450                     float16x8_t _p67 = vcombine_f16(vld1_f16(ptr6), vld1_f16(ptr7));
451 
452                     float16x8x2_t _p0415 = vzipq_f16(_p01, _p45);
453                     float16x8x2_t _p2637 = vzipq_f16(_p23, _p67);
454 
455                     float16x8x4_t _v4;
456                     _v4.val[0] = _p0415.val[0];
457                     _v4.val[1] = _p0415.val[1];
458                     _v4.val[2] = _p2637.val[0];
459                     _v4.val[3] = _p2637.val[1];
460 
461                     vst4q_f16(outptr, _v4);
462 
463                     ptr0 += 4;
464                     ptr1 += 4;
465                     ptr2 += 4;
466                     ptr3 += 4;
467                     ptr4 += 4;
468                     ptr5 += 4;
469                     ptr6 += 4;
470                     ptr7 += 4;
471                     outptr += 32;
472                 }
473                 for (; j < outw; j++)
474                 {
475                     outptr[0] = *ptr0++;
476                     outptr[1] = *ptr1++;
477                     outptr[2] = *ptr2++;
478                     outptr[3] = *ptr3++;
479                     outptr[4] = *ptr4++;
480                     outptr[5] = *ptr5++;
481                     outptr[6] = *ptr6++;
482                     outptr[7] = *ptr7++;
483 
484                     outptr += 8;
485                 }
486             }
487         }
488 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
489 
490         if (out_elempack == 4)
491         {
492             #pragma omp parallel for num_threads(opt.num_threads)
493             for (int i = 0; i < outh; i++)
494             {
495                 const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + outw * i * 4;
496                 const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 1);
497                 const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 2);
498                 const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + outw * (i * 4 + 3);
499                 unsigned short* outptr = top_blob.row<unsigned short>(i);
500 
501                 int j = 0;
502 #if __ARM_NEON
503                 for (; j + 3 < outw; j += 4)
504                 {
505                     uint16x4x4_t _v4;
506                     _v4.val[0] = vld1_u16(ptr0);
507                     _v4.val[1] = vld1_u16(ptr1);
508                     _v4.val[2] = vld1_u16(ptr2);
509                     _v4.val[3] = vld1_u16(ptr3);
510 
511                     vst4_u16(outptr, _v4);
512 
513                     ptr0 += 4;
514                     ptr1 += 4;
515                     ptr2 += 4;
516                     ptr3 += 4;
517                     outptr += 16;
518                 }
519 #endif
520                 for (; j < outw; j++)
521                 {
522                     outptr[0] = *ptr0++;
523                     outptr[1] = *ptr1++;
524                     outptr[2] = *ptr2++;
525                     outptr[3] = *ptr3++;
526 
527                     outptr += 4;
528                 }
529             }
530         }
531     }
532 
533     if (ndim == 3)
534     {
535         int _w = w;
536         int _h = h;
537         int _c = c;
538 
539         if (_w == 0)
540             _w = dims == 1 ? bottom_blob.w * elempack : bottom_blob.w;
541         if (_h == 0)
542             _h = dims == 2 ? bottom_blob.h * elempack : bottom_blob.h;
543         if (_c == 0)
544             _c = dims == 3 ? bottom_blob.c * elempack : bottom_blob.c;
545 
546         if (_w == -1)
547             _w = total / _c / _h;
548         if (_h == -1)
549             _h = total / _c / _w;
550         if (_c == -1)
551             _c = total / _h / _w;
552 
553         int out_elempack = 1;
554         if (opt.use_packing_layout)
555         {
556             out_elempack = opt.use_fp16_arithmetic && _c % 8 == 0 ? 8 : _c % 4 == 0 ? 4 : 1;
557         }
558         size_t out_elemsize = elemsize / elempack * out_elempack;
559 
560         if (dims == 3 && bottom_blob.c == _c && elempack == out_elempack)
561         {
562             top_blob = bottom_blob;
563             top_blob.w = _w;
564             top_blob.h = _h;
565             return 0;
566         }
567 
568         // flatten
569         Mat bottom_blob_flattened = bottom_blob;
570         {
571             Option opt_flatten = opt;
572             opt_flatten.blob_allocator = opt.workspace_allocator;
573 
574             flatten(bottom_blob, bottom_blob_flattened, opt_flatten);
575             if (bottom_blob_flattened.empty())
576                 return -100;
577         }
578 
579         top_blob.create(_w, _h, _c / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
580         if (top_blob.empty())
581             return -100;
582 
583         int size = top_blob.w * top_blob.h;
584 
585 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
586         if (out_elempack == 8)
587         {
588             #pragma omp parallel for num_threads(opt.num_threads)
589             for (int q = 0; q < top_blob.c; q++)
590             {
591                 const __fp16* ptr0 = (const __fp16*)bottom_blob_flattened + size * q * 8;
592                 const __fp16* ptr1 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 1);
593                 const __fp16* ptr2 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 2);
594                 const __fp16* ptr3 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 3);
595                 const __fp16* ptr4 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 4);
596                 const __fp16* ptr5 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 5);
597                 const __fp16* ptr6 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 6);
598                 const __fp16* ptr7 = (const __fp16*)bottom_blob_flattened + size * (q * 8 + 7);
599                 __fp16* outptr = top_blob.channel(q);
600 
601                 int i = 0;
602                 for (; i + 3 < size; i += 4)
603                 {
604                     float16x8_t _p01 = vcombine_f16(vld1_f16(ptr0), vld1_f16(ptr1));
605                     float16x8_t _p23 = vcombine_f16(vld1_f16(ptr2), vld1_f16(ptr3));
606                     float16x8_t _p45 = vcombine_f16(vld1_f16(ptr4), vld1_f16(ptr5));
607                     float16x8_t _p67 = vcombine_f16(vld1_f16(ptr6), vld1_f16(ptr7));
608 
609                     float16x8x2_t _p0415 = vzipq_f16(_p01, _p45);
610                     float16x8x2_t _p2637 = vzipq_f16(_p23, _p67);
611 
612                     float16x8x4_t _v4;
613                     _v4.val[0] = _p0415.val[0];
614                     _v4.val[1] = _p0415.val[1];
615                     _v4.val[2] = _p2637.val[0];
616                     _v4.val[3] = _p2637.val[1];
617 
618                     vst4q_f16(outptr, _v4);
619 
620                     ptr0 += 4;
621                     ptr1 += 4;
622                     ptr2 += 4;
623                     ptr3 += 4;
624                     ptr4 += 4;
625                     ptr5 += 4;
626                     ptr6 += 4;
627                     ptr7 += 4;
628                     outptr += 32;
629                 }
630                 for (; i < size; i++)
631                 {
632                     outptr[0] = *ptr0++;
633                     outptr[1] = *ptr1++;
634                     outptr[2] = *ptr2++;
635                     outptr[3] = *ptr3++;
636                     outptr[4] = *ptr4++;
637                     outptr[5] = *ptr5++;
638                     outptr[6] = *ptr6++;
639                     outptr[7] = *ptr7++;
640 
641                     outptr += 8;
642                 }
643             }
644         }
645 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
646 
647         if (out_elempack == 4)
648         {
649             #pragma omp parallel for num_threads(opt.num_threads)
650             for (int q = 0; q < top_blob.c; q++)
651             {
652                 const unsigned short* ptr0 = (const unsigned short*)bottom_blob_flattened + size * q * 4;
653                 const unsigned short* ptr1 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 1);
654                 const unsigned short* ptr2 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 2);
655                 const unsigned short* ptr3 = (const unsigned short*)bottom_blob_flattened + size * (q * 4 + 3);
656                 unsigned short* outptr = top_blob.channel(q);
657 
658                 int i = 0;
659 #if __ARM_NEON
660                 for (; i + 3 < size; i += 4)
661                 {
662                     uint16x4x4_t _v4;
663                     _v4.val[0] = vld1_u16(ptr0);
664                     _v4.val[1] = vld1_u16(ptr1);
665                     _v4.val[2] = vld1_u16(ptr2);
666                     _v4.val[3] = vld1_u16(ptr3);
667 
668                     vst4_u16(outptr, _v4);
669 
670                     ptr0 += 4;
671                     ptr1 += 4;
672                     ptr2 += 4;
673                     ptr3 += 4;
674                     outptr += 16;
675                 }
676 #endif
677                 for (; i < size; i++)
678                 {
679                     outptr[0] = *ptr0++;
680                     outptr[1] = *ptr1++;
681                     outptr[2] = *ptr2++;
682                     outptr[3] = *ptr3++;
683 
684                     outptr += 4;
685                 }
686             }
687         }
688 
689         if (out_elempack == 1)
690         {
691             #pragma omp parallel for num_threads(opt.num_threads)
692             for (int q = 0; q < top_blob.c; q++)
693             {
694                 const unsigned short* ptr = (const unsigned short*)bottom_blob_flattened + size * q;
695                 unsigned short* outptr = top_blob.channel(q);
696 
697                 int i = 0;
698 #if __ARM_NEON
699                 for (; i + 3 < size; i += 4)
700                 {
701                     uint16x4_t _v = vld1_u16(ptr);
702                     vst1_u16(outptr, _v);
703                     ptr += 4;
704                     outptr += 4;
705                 }
706 #endif
707                 for (; i < size; i++)
708                 {
709                     *outptr++ = *ptr++;
710                 }
711             }
712         }
713     }
714 
715     return 0;
716 }
717 
718 } // namespace ncnn
719