1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv1x1s1_sgemm_transform_kernel_pack4_fp16sa_neon(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch)15 static void conv1x1s1_sgemm_transform_kernel_pack4_fp16sa_neon(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch)
16 {
17     // interleave
18     // src = inch-outch
19     // dst = 4b-4a-inch/4a-outch/4b
20     kernel_tm_pack4.create(2 * 1, inch / 4, (outch / 4) / 2 + (outch / 4) % 2, (size_t)2u * 16, 16);
21 
22     int q = 0;
23     for (; q + 7 < outch; q += 8)
24     {
25         const float* k0 = (const float*)kernel + (q + 0) * inch;
26         const float* k1 = (const float*)kernel + (q + 1) * inch;
27         const float* k2 = (const float*)kernel + (q + 2) * inch;
28         const float* k3 = (const float*)kernel + (q + 3) * inch;
29         const float* k4 = (const float*)kernel + (q + 4) * inch;
30         const float* k5 = (const float*)kernel + (q + 5) * inch;
31         const float* k6 = (const float*)kernel + (q + 6) * inch;
32         const float* k7 = (const float*)kernel + (q + 7) * inch;
33 
34         __fp16* g0 = kernel_tm_pack4.channel(q / 8);
35 
36         for (int p = 0; p + 3 < inch; p += 4)
37         {
38             g0[0] = (__fp16)k0[0];
39             g0[1] = (__fp16)k1[0];
40             g0[2] = (__fp16)k2[0];
41             g0[3] = (__fp16)k3[0];
42             g0[4] = (__fp16)k4[0];
43             g0[5] = (__fp16)k5[0];
44             g0[6] = (__fp16)k6[0];
45             g0[7] = (__fp16)k7[0];
46 
47             g0[8] = (__fp16)k0[1];
48             g0[9] = (__fp16)k1[1];
49             g0[10] = (__fp16)k2[1];
50             g0[11] = (__fp16)k3[1];
51             g0[12] = (__fp16)k4[1];
52             g0[13] = (__fp16)k5[1];
53             g0[14] = (__fp16)k6[1];
54             g0[15] = (__fp16)k7[1];
55 
56             g0[16] = (__fp16)k0[2];
57             g0[17] = (__fp16)k1[2];
58             g0[18] = (__fp16)k2[2];
59             g0[19] = (__fp16)k3[2];
60             g0[20] = (__fp16)k4[2];
61             g0[21] = (__fp16)k5[2];
62             g0[22] = (__fp16)k6[2];
63             g0[23] = (__fp16)k7[2];
64 
65             g0[24] = (__fp16)k0[3];
66             g0[25] = (__fp16)k1[3];
67             g0[26] = (__fp16)k2[3];
68             g0[27] = (__fp16)k3[3];
69             g0[28] = (__fp16)k4[3];
70             g0[29] = (__fp16)k5[3];
71             g0[30] = (__fp16)k6[3];
72             g0[31] = (__fp16)k7[3];
73 
74             k0 += 4;
75             k1 += 4;
76             k2 += 4;
77             k3 += 4;
78             k4 += 4;
79             k5 += 4;
80             k6 += 4;
81             k7 += 4;
82             g0 += 32;
83         }
84     }
85     for (; q + 3 < outch; q += 4)
86     {
87         const float* k0 = (const float*)kernel + (q + 0) * inch;
88         const float* k1 = (const float*)kernel + (q + 1) * inch;
89         const float* k2 = (const float*)kernel + (q + 2) * inch;
90         const float* k3 = (const float*)kernel + (q + 3) * inch;
91 
92         __fp16* g0 = kernel_tm_pack4.channel(q / 8 + (q % 8) / 4);
93 
94         for (int p = 0; p + 3 < inch; p += 4)
95         {
96             g0[0] = (__fp16)k0[0];
97             g0[1] = (__fp16)k1[0];
98             g0[2] = (__fp16)k2[0];
99             g0[3] = (__fp16)k3[0];
100 
101             g0[4] = (__fp16)k0[1];
102             g0[5] = (__fp16)k1[1];
103             g0[6] = (__fp16)k2[1];
104             g0[7] = (__fp16)k3[1];
105 
106             g0[8] = (__fp16)k0[2];
107             g0[9] = (__fp16)k1[2];
108             g0[10] = (__fp16)k2[2];
109             g0[11] = (__fp16)k3[2];
110 
111             g0[12] = (__fp16)k0[3];
112             g0[13] = (__fp16)k1[3];
113             g0[14] = (__fp16)k2[3];
114             g0[15] = (__fp16)k3[3];
115 
116             k0 += 4;
117             k1 += 4;
118             k2 += 4;
119             k3 += 4;
120             g0 += 16;
121         }
122     }
123 }
124 
conv1x1s1_sgemm_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)125 static void conv1x1s1_sgemm_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
126 {
127     int w = bottom_blob.w;
128     int h = bottom_blob.h;
129     int inch = bottom_blob.c;
130     int outch = top_blob.c;
131 
132     size_t elemsize = bottom_blob.elemsize;
133     int elempack = bottom_blob.elempack;
134 
135     const int size = w * h;
136 
137     const __fp16* bias = _bias;
138 
139     // interleave
140     Mat tmp;
141     if (size >= 8)
142         tmp.create(8, inch, size / 8 + (size % 8) / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
143     else if (size >= 4)
144         tmp.create(4, inch, size / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
145     else // if (size >= 1)
146         tmp.create(1, inch, size, elemsize, elempack, opt.workspace_allocator);
147     {
148         int nn_size;
149         int remain_size_start = 0;
150 
151         nn_size = (size - remain_size_start) >> 3;
152 
153         #pragma omp parallel for num_threads(opt.num_threads)
154         for (int ii = 0; ii < nn_size; ii++)
155         {
156             int i = remain_size_start + ii * 8;
157 
158             const __fp16* img0 = bottom_blob.channel(0);
159             img0 += i * 4;
160 
161             __fp16* tmpptr = tmp.channel(i / 8);
162 
163             for (int q = 0; q < inch; q++)
164             {
165                 // transpose 4x8
166                 asm volatile(
167                     "prfm   pldl1keep, [%0, #512]   \n"
168                     "ld4    {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
169                     "st1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
170                     : "=r"(img0),  // %0
171                     "=r"(tmpptr) // %1
172                     : "0"(img0),
173                     "1"(tmpptr)
174                     : "memory", "v0", "v1", "v2", "v3");
175 
176                 img0 += bottom_blob.cstep * 4;
177             }
178         }
179 
180         remain_size_start += nn_size << 3;
181         nn_size = (size - remain_size_start) >> 2;
182 
183         #pragma omp parallel for num_threads(opt.num_threads)
184         for (int ii = 0; ii < nn_size; ii++)
185         {
186             int i = remain_size_start + ii * 4;
187 
188             const __fp16* img0 = bottom_blob.channel(0);
189             img0 += i * 4;
190 
191             __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
192 
193             for (int q = 0; q < inch; q++)
194             {
195                 // transpose 4x4
196                 asm volatile(
197                     "prfm   pldl1keep, [%0, #256]   \n"
198                     "ld4    {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n"
199                     "st1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
200                     : "=r"(img0),  // %0
201                     "=r"(tmpptr) // %1
202                     : "0"(img0),
203                     "1"(tmpptr)
204                     : "memory", "v0", "v1", "v2", "v3");
205 
206                 img0 += bottom_blob.cstep * 4;
207             }
208         }
209 
210         remain_size_start += nn_size << 2;
211 
212         #pragma omp parallel for num_threads(opt.num_threads)
213         for (int i = remain_size_start; i < size; i++)
214         {
215             const __fp16* img0 = bottom_blob.channel(0);
216             img0 += i * 4;
217 
218             __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
219 
220             for (int q = 0; q < inch; q++)
221             {
222                 asm volatile(
223                     "prfm   pldl1keep, [%0, #64]    \n"
224                     "ld1    {v0.4h}, [%0]           \n"
225                     "st1    {v0.4h}, [%1], #8       \n"
226                     : "=r"(img0),  // %0
227                     "=r"(tmpptr) // %1
228                     : "0"(img0),
229                     "1"(tmpptr)
230                     : "memory", "v0");
231 
232                 img0 += bottom_blob.cstep * 4;
233             }
234         }
235     }
236 
237     int nn_outch = 0;
238     int remain_outch_start = 0;
239 
240     nn_outch = outch >> 1;
241     remain_outch_start = nn_outch << 1;
242 
243     #pragma omp parallel for num_threads(opt.num_threads)
244     for (int pp = 0; pp < nn_outch; pp++)
245     {
246         int p = pp * 2;
247 
248         __fp16* outptr0 = top_blob.channel(p);
249         __fp16* outptr1 = top_blob.channel(p + 1);
250 
251         const __fp16 zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
252         const __fp16* biasptr = bias ? bias + p * 4 : zeros;
253         float16x8_t _bias0 = vld1q_f16(biasptr);
254 
255         int i = 0;
256         for (; i + 7 < size; i += 8)
257         {
258             __fp16* tmpptr = tmp.channel(i / 8);
259             const __fp16* kptr = kernel.channel(pp);
260 
261             int nn = inch; // inch always > 0
262 
263             asm volatile(
264                 "mov    v24.16b, %10.16b            \n"
265                 "mov    v25.16b, %10.16b            \n"
266                 "mov    v26.16b, %10.16b            \n"
267                 "mov    v27.16b, %10.16b            \n"
268                 "mov    v28.16b, %10.16b            \n"
269                 "mov    v29.16b, %10.16b            \n"
270                 "mov    v30.16b, %10.16b            \n"
271                 "mov    v31.16b, %10.16b            \n"
272 
273                 "0:                                 \n"
274 
275                 "prfm   pldl1keep, [%3, #512]       \n"
276                 "ld1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%3], #64 \n" // r01 r23 r45 r67
277 
278                 "prfm   pldl1keep, [%4, #512]       \n"
279                 "ld1    {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
280 
281                 "fmla   v24.8h, v4.8h, v0.h[0]      \n"
282                 "fmla   v25.8h, v4.8h, v0.h[1]      \n"
283                 "fmla   v26.8h, v4.8h, v0.h[2]      \n"
284                 "fmla   v27.8h, v4.8h, v0.h[3]      \n"
285                 "fmla   v28.8h, v4.8h, v0.h[4]      \n"
286                 "fmla   v29.8h, v4.8h, v0.h[5]      \n"
287                 "fmla   v30.8h, v4.8h, v0.h[6]      \n"
288                 "fmla   v31.8h, v4.8h, v0.h[7]      \n"
289 
290                 "fmla   v24.8h, v5.8h, v1.h[0]      \n"
291                 "fmla   v25.8h, v5.8h, v1.h[1]      \n"
292                 "fmla   v26.8h, v5.8h, v1.h[2]      \n"
293                 "fmla   v27.8h, v5.8h, v1.h[3]      \n"
294                 "fmla   v28.8h, v5.8h, v1.h[4]      \n"
295                 "fmla   v29.8h, v5.8h, v1.h[5]      \n"
296                 "fmla   v30.8h, v5.8h, v1.h[6]      \n"
297                 "fmla   v31.8h, v5.8h, v1.h[7]      \n"
298 
299                 "fmla   v24.8h, v6.8h, v2.h[0]      \n"
300                 "fmla   v25.8h, v6.8h, v2.h[1]      \n"
301                 "fmla   v26.8h, v6.8h, v2.h[2]      \n"
302                 "fmla   v27.8h, v6.8h, v2.h[3]      \n"
303                 "fmla   v28.8h, v6.8h, v2.h[4]      \n"
304                 "fmla   v29.8h, v6.8h, v2.h[5]      \n"
305                 "fmla   v30.8h, v6.8h, v2.h[6]      \n"
306                 "fmla   v31.8h, v6.8h, v2.h[7]      \n"
307 
308                 "subs   %w0, %w0, #1                \n"
309 
310                 "fmla   v24.8h, v7.8h, v3.h[0]      \n"
311                 "fmla   v25.8h, v7.8h, v3.h[1]      \n"
312                 "fmla   v26.8h, v7.8h, v3.h[2]      \n"
313                 "fmla   v27.8h, v7.8h, v3.h[3]      \n"
314                 "fmla   v28.8h, v7.8h, v3.h[4]      \n"
315                 "fmla   v29.8h, v7.8h, v3.h[5]      \n"
316                 "fmla   v30.8h, v7.8h, v3.h[6]      \n"
317                 "fmla   v31.8h, v7.8h, v3.h[7]      \n"
318 
319                 "bne    0b                          \n"
320 
321                 "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
322                 "st1    {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
323 
324                 "ext    v24.16b, v24.16b, v24.16b, #8 \n"
325                 "ext    v25.16b, v25.16b, v25.16b, #8 \n"
326                 "ext    v26.16b, v26.16b, v26.16b, #8 \n"
327                 "ext    v27.16b, v27.16b, v27.16b, #8 \n"
328                 "ext    v28.16b, v28.16b, v28.16b, #8 \n"
329                 "ext    v29.16b, v29.16b, v29.16b, #8 \n"
330                 "ext    v30.16b, v30.16b, v30.16b, #8 \n"
331                 "ext    v31.16b, v31.16b, v31.16b, #8 \n"
332 
333                 "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
334                 "st1    {v28.4h, v29.4h, v30.4h, v31.4h}, [%2], #32 \n"
335 
336                 : "=r"(nn),      // %0
337                 "=r"(outptr0), // %1
338                 "=r"(outptr1), // %2
339                 "=r"(tmpptr),  // %3
340                 "=r"(kptr)     // %4
341                 : "0"(nn),
342                 "1"(outptr0),
343                 "2"(outptr1),
344                 "3"(tmpptr),
345                 "4"(kptr),
346                 "w"(_bias0) // %10
347                 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
348         }
349         for (; i + 3 < size; i += 4)
350         {
351             __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
352             const __fp16* kptr = kernel.channel(pp);
353 
354             int nn = inch; // inch always > 0
355 
356             asm volatile(
357                 "mov    v24.16b, %10.16b            \n"
358                 "mov    v25.16b, %10.16b            \n"
359                 "mov    v26.16b, %10.16b            \n"
360                 "mov    v27.16b, %10.16b            \n"
361 
362                 "0:                                 \n"
363 
364                 "prfm   pldl1keep, [%3, #256]       \n"
365                 "ld1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%3], #32 \n" // r01 r23 r45 r67
366 
367                 "prfm   pldl1keep, [%4, #512]       \n"
368                 "ld1    {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
369 
370                 "fmla   v24.8h, v4.8h, v0.h[0]      \n"
371                 "fmla   v25.8h, v4.8h, v0.h[1]      \n"
372                 "fmla   v26.8h, v4.8h, v0.h[2]      \n"
373                 "fmla   v27.8h, v4.8h, v0.h[3]      \n"
374 
375                 "fmla   v24.8h, v5.8h, v1.h[0]      \n"
376                 "fmla   v25.8h, v5.8h, v1.h[1]      \n"
377                 "fmla   v26.8h, v5.8h, v1.h[2]      \n"
378                 "fmla   v27.8h, v5.8h, v1.h[3]      \n"
379 
380                 "fmla   v24.8h, v6.8h, v2.h[0]      \n"
381                 "fmla   v25.8h, v6.8h, v2.h[1]      \n"
382                 "fmla   v26.8h, v6.8h, v2.h[2]      \n"
383                 "fmla   v27.8h, v6.8h, v2.h[3]      \n"
384 
385                 "subs   %w0, %w0, #1                \n"
386 
387                 "fmla   v24.8h, v7.8h, v3.h[0]      \n"
388                 "fmla   v25.8h, v7.8h, v3.h[1]      \n"
389                 "fmla   v26.8h, v7.8h, v3.h[2]      \n"
390                 "fmla   v27.8h, v7.8h, v3.h[3]      \n"
391 
392                 "bne    0b                          \n"
393 
394                 "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
395 
396                 "ext    v24.16b, v24.16b, v24.16b, #8 \n"
397                 "ext    v25.16b, v25.16b, v25.16b, #8 \n"
398                 "ext    v26.16b, v26.16b, v26.16b, #8 \n"
399                 "ext    v27.16b, v27.16b, v27.16b, #8 \n"
400 
401                 "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
402 
403                 : "=r"(nn),      // %0
404                 "=r"(outptr0), // %1
405                 "=r"(outptr1), // %2
406                 "=r"(tmpptr),  // %3
407                 "=r"(kptr)     // %4
408                 : "0"(nn),
409                 "1"(outptr0),
410                 "2"(outptr1),
411                 "3"(tmpptr),
412                 "4"(kptr),
413                 "w"(_bias0) // %10
414                 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
415         }
416         for (; i < size; i++)
417         {
418             __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
419             const __fp16* kptr = kernel.channel(pp);
420 
421             float16x8_t _sum0 = _bias0;
422 
423             for (int q = 0; q < inch; q++)
424             {
425                 float16x4_t _r0 = vld1_f16(tmpptr);
426 
427                 float16x8_t _k0 = vld1q_f16(kptr);
428                 float16x8_t _k1 = vld1q_f16(kptr + 8);
429                 float16x8_t _k2 = vld1q_f16(kptr + 16);
430                 float16x8_t _k3 = vld1q_f16(kptr + 24);
431 
432                 _sum0 = vfmaq_lane_f16(_sum0, _k0, _r0, 0);
433                 _sum0 = vfmaq_lane_f16(_sum0, _k1, _r0, 1);
434                 _sum0 = vfmaq_lane_f16(_sum0, _k2, _r0, 2);
435                 _sum0 = vfmaq_lane_f16(_sum0, _k3, _r0, 3);
436 
437                 kptr += 32;
438                 tmpptr += 4;
439             }
440 
441             vst1_f16(outptr0, vget_low_f16(_sum0));
442             vst1_f16(outptr1, vget_high_f16(_sum0));
443 
444             outptr0 += 4;
445             outptr1 += 4;
446         }
447     }
448 
449     #pragma omp parallel for num_threads(opt.num_threads)
450     for (int p = remain_outch_start; p < outch; p++)
451     {
452         __fp16* outptr0 = top_blob.channel(p);
453 
454         const __fp16 zeros[4] = {0.f, 0.f, 0.f, 0.f};
455         const __fp16* biasptr = bias ? bias + p * 4 : zeros;
456         float16x4_t _bias0 = vld1_f16(biasptr);
457 
458         int i = 0;
459         for (; i + 7 < size; i += 8)
460         {
461             __fp16* tmpptr = tmp.channel(i / 8);
462             const __fp16* kptr = kernel.channel(p / 2 + p % 2);
463 
464             int nn = inch; // inch always > 0
465 
466             asm volatile(
467                 "mov    v24.16b, %8.16b             \n"
468                 "mov    v25.16b, %8.16b             \n"
469                 "mov    v26.16b, %8.16b             \n"
470                 "mov    v27.16b, %8.16b             \n"
471                 "mov    v28.16b, %8.16b             \n"
472                 "mov    v29.16b, %8.16b             \n"
473                 "mov    v30.16b, %8.16b             \n"
474                 "mov    v31.16b, %8.16b             \n"
475 
476                 "0:                                 \n"
477 
478                 "prfm   pldl1keep, [%2, #512]       \n"
479                 "ld1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%2], #64 \n" // r01 r23 r45 r67
480 
481                 "prfm   pldl1keep, [%3, #256]       \n"
482                 "ld1    {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
483 
484                 "fmla   v24.4h, v4.4h, v0.h[0]      \n"
485                 "fmla   v25.4h, v4.4h, v0.h[1]      \n"
486                 "fmla   v26.4h, v4.4h, v0.h[2]      \n"
487                 "fmla   v27.4h, v4.4h, v0.h[3]      \n"
488                 "fmla   v28.4h, v4.4h, v0.h[4]      \n"
489                 "fmla   v29.4h, v4.4h, v0.h[5]      \n"
490                 "fmla   v30.4h, v4.4h, v0.h[6]      \n"
491                 "fmla   v31.4h, v4.4h, v0.h[7]      \n"
492 
493                 "fmla   v24.4h, v5.4h, v1.h[0]      \n"
494                 "fmla   v25.4h, v5.4h, v1.h[1]      \n"
495                 "fmla   v26.4h, v5.4h, v1.h[2]      \n"
496                 "fmla   v27.4h, v5.4h, v1.h[3]      \n"
497                 "fmla   v28.4h, v5.4h, v1.h[4]      \n"
498                 "fmla   v29.4h, v5.4h, v1.h[5]      \n"
499                 "fmla   v30.4h, v5.4h, v1.h[6]      \n"
500                 "fmla   v31.4h, v5.4h, v1.h[7]      \n"
501 
502                 "fmla   v24.4h, v6.4h, v2.h[0]      \n"
503                 "fmla   v25.4h, v6.4h, v2.h[1]      \n"
504                 "fmla   v26.4h, v6.4h, v2.h[2]      \n"
505                 "fmla   v27.4h, v6.4h, v2.h[3]      \n"
506                 "fmla   v28.4h, v6.4h, v2.h[4]      \n"
507                 "fmla   v29.4h, v6.4h, v2.h[5]      \n"
508                 "fmla   v30.4h, v6.4h, v2.h[6]      \n"
509                 "fmla   v31.4h, v6.4h, v2.h[7]      \n"
510 
511                 "subs   %w0, %w0, #1                \n"
512 
513                 "fmla   v24.4h, v7.4h, v3.h[0]      \n"
514                 "fmla   v25.4h, v7.4h, v3.h[1]      \n"
515                 "fmla   v26.4h, v7.4h, v3.h[2]      \n"
516                 "fmla   v27.4h, v7.4h, v3.h[3]      \n"
517                 "fmla   v28.4h, v7.4h, v3.h[4]      \n"
518                 "fmla   v29.4h, v7.4h, v3.h[5]      \n"
519                 "fmla   v30.4h, v7.4h, v3.h[6]      \n"
520                 "fmla   v31.4h, v7.4h, v3.h[7]      \n"
521 
522                 "bne    0b                          \n"
523 
524                 "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
525                 "st1    {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
526 
527                 : "=r"(nn),      // %0
528                 "=r"(outptr0), // %1
529                 "=r"(tmpptr),  // %2
530                 "=r"(kptr)     // %3
531                 : "0"(nn),
532                 "1"(outptr0),
533                 "2"(tmpptr),
534                 "3"(kptr),
535                 "w"(_bias0) // %8
536                 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
537         }
538         for (; i + 3 < size; i += 4)
539         {
540             __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
541             const __fp16* kptr = kernel.channel(p / 2 + p % 2);
542 
543             int nn = inch; // inch always > 0
544 
545             asm volatile(
546                 "mov    v24.16b, %8.16b             \n"
547                 "mov    v25.16b, %8.16b             \n"
548                 "mov    v26.16b, %8.16b             \n"
549                 "mov    v27.16b, %8.16b             \n"
550 
551                 "0:                                 \n"
552 
553                 "prfm   pldl1keep, [%2, #256]       \n"
554                 "ld1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n" // r01 r23 r45 r67
555 
556                 "prfm   pldl1keep, [%3, #256]       \n"
557                 "ld1    {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
558 
559                 "fmla   v24.4h, v4.4h, v0.h[0]      \n"
560                 "fmla   v25.4h, v4.4h, v0.h[1]      \n"
561                 "fmla   v26.4h, v4.4h, v0.h[2]      \n"
562                 "fmla   v27.4h, v4.4h, v0.h[3]      \n"
563 
564                 "fmla   v24.4h, v5.4h, v1.h[0]      \n"
565                 "fmla   v25.4h, v5.4h, v1.h[1]      \n"
566                 "fmla   v26.4h, v5.4h, v1.h[2]      \n"
567                 "fmla   v27.4h, v5.4h, v1.h[3]      \n"
568 
569                 "fmla   v24.4h, v6.4h, v2.h[0]      \n"
570                 "fmla   v25.4h, v6.4h, v2.h[1]      \n"
571                 "fmla   v26.4h, v6.4h, v2.h[2]      \n"
572                 "fmla   v27.4h, v6.4h, v2.h[3]      \n"
573 
574                 "subs   %w0, %w0, #1                \n"
575 
576                 "fmla   v24.4h, v7.4h, v3.h[0]      \n"
577                 "fmla   v25.4h, v7.4h, v3.h[1]      \n"
578                 "fmla   v26.4h, v7.4h, v3.h[2]      \n"
579                 "fmla   v27.4h, v7.4h, v3.h[3]      \n"
580 
581                 "bne    0b                          \n"
582 
583                 "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
584 
585                 : "=r"(nn),      // %0
586                 "=r"(outptr0), // %1
587                 "=r"(tmpptr),  // %2
588                 "=r"(kptr)     // %3
589                 : "0"(nn),
590                 "1"(outptr0),
591                 "2"(tmpptr),
592                 "3"(kptr),
593                 "w"(_bias0) // %8
594                 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
595         }
596         for (; i < size; i++)
597         {
598             __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
599             const __fp16* kptr = kernel.channel(p / 2 + p % 2);
600 
601             float16x4_t _sum0 = _bias0;
602 
603             for (int q = 0; q < inch; q++)
604             {
605                 float16x4_t _r0 = vld1_f16(tmpptr);
606 
607                 float16x4_t _k0 = vld1_f16(kptr);
608                 float16x4_t _k1 = vld1_f16(kptr + 4);
609                 float16x4_t _k2 = vld1_f16(kptr + 8);
610                 float16x4_t _k3 = vld1_f16(kptr + 12);
611 
612                 _sum0 = vfma_lane_f16(_sum0, _k0, _r0, 0);
613                 _sum0 = vfma_lane_f16(_sum0, _k1, _r0, 1);
614                 _sum0 = vfma_lane_f16(_sum0, _k2, _r0, 2);
615                 _sum0 = vfma_lane_f16(_sum0, _k3, _r0, 3);
616 
617                 kptr += 16;
618                 tmpptr += 4;
619             }
620 
621             vst1_f16(outptr0, _sum0);
622 
623             outptr0 += 4;
624         }
625     }
626 
627     //     // NOTE sgemm
628     //     for (; p<outch; p++)
629     //     {
630     //         Mat out0 = top_blob.channel(p);
631     //
632     //         const short bias0 = bias ? bias[p] : 0.f;
633     //
634     //         __fp16* outptr0 = out0;
635     //
636     //         for (int i=0; i<size; i++)
637     //         {
638     //             short sum = bias0;
639     //
640     //             const __fp16* kptr = _kernel.channel(p);
641     //
642     //             for (int q=0; q<inch; q++)
643     //             {
644     //                 const __fp16* img0 = bottom_blob.channel(q);
645     //
646     //                 sum += img0[i] * kptr[0];
647     //                 kptr ++;
648     //             }
649     //
650     //             outptr0[i] = sum;
651     //         }
652     //     }
653 }
654 
conv1x1s2_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)655 static void conv1x1s2_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
656 {
657     int w = bottom_blob.w;
658     int channels = bottom_blob.c;
659     size_t elemsize = bottom_blob.elemsize;
660     int elempack = bottom_blob.elempack;
661 
662     int outw = top_blob.w;
663     int outh = top_blob.h;
664 
665     const int tailstep = (w - 2 * outw + w) * 4;
666 
667     Mat bottom_blob_shrinked;
668     bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
669 
670     #pragma omp parallel for num_threads(opt.num_threads)
671     for (int p = 0; p < channels; p++)
672     {
673         const __fp16* r0 = bottom_blob.channel(p);
674         __fp16* outptr = bottom_blob_shrinked.channel(p);
675 
676         for (int i = 0; i < outh; i++)
677         {
678             int j = 0;
679             for (; j + 3 < outw; j += 4)
680             {
681                 float16x4_t _v0 = vld1_f16(r0);
682                 float16x4_t _v1 = vld1_f16(r0 + 8);
683                 float16x4_t _v2 = vld1_f16(r0 + 16);
684                 float16x4_t _v3 = vld1_f16(r0 + 24);
685                 float16x8_t _v01 = vcombine_f16(_v0, _v1);
686                 float16x8_t _v23 = vcombine_f16(_v2, _v3);
687                 vst1q_f16(outptr, _v01);
688                 vst1q_f16(outptr + 8, _v23);
689 
690                 r0 += 32;
691                 outptr += 16;
692             }
693             for (; j + 1 < outw; j += 2)
694             {
695                 float16x4_t _v0 = vld1_f16(r0);
696                 float16x4_t _v1 = vld1_f16(r0 + 8);
697                 float16x8_t _v = vcombine_f16(_v0, _v1);
698                 vst1q_f16(outptr, _v);
699 
700                 r0 += 16;
701                 outptr += 8;
702             }
703             for (; j < outw; j++)
704             {
705                 float16x4_t _v = vld1_f16(r0);
706                 vst1_f16(outptr, _v);
707 
708                 r0 += 8;
709                 outptr += 4;
710             }
711 
712             r0 += tailstep;
713         }
714     }
715 
716     conv1x1s1_sgemm_pack4_fp16sa_neon(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
717 }
718