1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
im2col_sgemm_neon(const Mat & bottom_im2col,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void im2col_sgemm_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17     // Mat bottom_im2col(size, maxk, inch, 4u, 1, opt.workspace_allocator);
18 
19     const int size = bottom_im2col.w;
20     const int maxk = bottom_im2col.h;
21     const int inch = bottom_im2col.c;
22 
23     const int outch = top_blob.c;
24 
25     const float* bias = _bias;
26 
27     // permute
28     Mat tmp;
29 #if __ARM_NEON
30     if (size >= 8)
31         tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + size % 4, 4u, 1, opt.workspace_allocator);
32     else if (size >= 4)
33         tmp.create(4 * maxk, inch, size / 4 + size % 4, 4u, 1, opt.workspace_allocator);
34     else
35         tmp.create(maxk, inch, size, 4u, 1, opt.workspace_allocator);
36     {
37         int nn_size = size >> 3;
38         int remain_size_start = 0;
39 
40         #pragma omp parallel for num_threads(opt.num_threads)
41         for (int ii = 0; ii < nn_size; ii++)
42         {
43             int i = remain_size_start + ii * 8;
44 
45             float* tmpptr = tmp.channel(i / 8);
46 
47             for (int q = 0; q < inch; q++)
48             {
49                 const float* img0 = (const float*)bottom_im2col.channel(q) + i;
50 
51                 for (int k = 0; k < maxk; k++)
52                 {
53                     vst1q_f32(tmpptr, vld1q_f32(img0));
54                     vst1q_f32(tmpptr + 4, vld1q_f32(img0 + 4));
55                     img0 += size;
56                     tmpptr += 8;
57                 }
58             }
59         }
60 
61         remain_size_start += nn_size << 3;
62         nn_size = (size - remain_size_start) >> 2;
63 
64         #pragma omp parallel for num_threads(opt.num_threads)
65         for (int ii = 0; ii < nn_size; ii++)
66         {
67             int i = remain_size_start + ii * 4;
68 
69             float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
70 
71             for (int q = 0; q < inch; q++)
72             {
73                 const float* img0 = (const float*)bottom_im2col.channel(q) + i;
74 
75                 for (int k = 0; k < maxk; k++)
76                 {
77                     vst1q_f32(tmpptr, vld1q_f32(img0));
78                     img0 += size;
79                     tmpptr += 4;
80                 }
81             }
82         }
83 
84         remain_size_start += nn_size << 2;
85 
86         #pragma omp parallel for num_threads(opt.num_threads)
87         for (int i = remain_size_start; i < size; i++)
88         {
89             float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
90 
91             for (int q = 0; q < inch; q++)
92             {
93                 const float* img0 = (const float*)bottom_im2col.channel(q) + i;
94 
95                 for (int k = 0; k < maxk; k++)
96                 {
97                     tmpptr[0] = img0[0];
98                     img0 += size;
99                     tmpptr += 1;
100                 }
101             }
102         }
103     }
104 #else // __ARM_NEON
105     tmp.create(maxk, inch, size, 4u, 1, opt.workspace_allocator);
106     {
107         #pragma omp parallel for num_threads(opt.num_threads)
108         for (int i = 0; i < size; i++)
109         {
110             float* tmpptr = tmp.channel(i);
111 
112             for (int q = 0; q < inch; q++)
113             {
114                 const float* img0 = (const float*)bottom_im2col.channel(q) + i;
115 
116                 for (int k = 0; k < maxk; k++)
117                 {
118                     tmpptr[0] = img0[0];
119                     img0 += size;
120                     tmpptr += 1;
121                 }
122             }
123         }
124     }
125 #endif // __ARM_NEON
126 
127 #if __ARM_NEON
128     int nn_outch = 0;
129     int remain_outch_start = 0;
130 
131 #if __aarch64__
132     nn_outch = outch >> 3;
133     remain_outch_start = nn_outch << 3;
134 
135     #pragma omp parallel for num_threads(opt.num_threads)
136     for (int pp = 0; pp < nn_outch; pp++)
137     {
138         int p = pp * 8;
139 
140         float* outptr0 = top_blob.channel(p);
141         float* outptr1 = top_blob.channel(p + 1);
142         float* outptr2 = top_blob.channel(p + 2);
143         float* outptr3 = top_blob.channel(p + 3);
144         float* outptr4 = top_blob.channel(p + 4);
145         float* outptr5 = top_blob.channel(p + 5);
146         float* outptr6 = top_blob.channel(p + 6);
147         float* outptr7 = top_blob.channel(p + 7);
148 
149         const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
150         const float* biasptr = bias ? bias + p : zeros;
151 
152         int i = 0;
153         for (; i + 7 < size; i += 8)
154         {
155             const float* tmpptr = tmp.channel(i / 8);
156             const float* kptr = kernel.channel(p / 8);
157 
158             int nn = inch * maxk; // inch always > 0
159 
160             asm volatile(
161                 "ld1    {v0.4s, v1.4s}, [%20]   \n"
162                 "dup    v16.4s, v0.s[0]         \n"
163                 "dup    v17.4s, v0.s[0]         \n"
164                 "dup    v18.4s, v0.s[1]         \n"
165                 "dup    v19.4s, v0.s[1]         \n"
166                 "dup    v20.4s, v0.s[2]         \n"
167                 "dup    v21.4s, v0.s[2]         \n"
168                 "dup    v22.4s, v0.s[3]         \n"
169                 "dup    v23.4s, v0.s[3]         \n"
170                 "dup    v24.4s, v1.s[0]         \n"
171                 "dup    v25.4s, v1.s[0]         \n"
172                 "dup    v26.4s, v1.s[1]         \n"
173                 "dup    v27.4s, v1.s[1]         \n"
174                 "dup    v28.4s, v1.s[2]         \n"
175                 "dup    v29.4s, v1.s[2]         \n"
176                 "dup    v30.4s, v1.s[3]         \n"
177                 "dup    v31.4s, v1.s[3]         \n"
178 
179                 // inch loop
180                 "lsr    w4, %w21, #2            \n" // w4 = nn = inch >> 2
181                 "cmp    w4, #0                  \n"
182                 "beq    1f                      \n"
183 
184                 "0:                             \n"
185 
186                 "prfm   pldl1keep, [%8, #512]   \n"
187                 "ld1    {v8.4s, v9.4s, v10.4s, v11.4s}, [%8], #64   \n"
188 
189                 "prfm   pldl1keep, [%9, #512]   \n"
190                 "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64     \n"
191 
192                 "fmla   v16.4s, v8.4s, v0.s[0]  \n"
193                 "fmla   v18.4s, v8.4s, v0.s[1]  \n"
194                 "fmla   v20.4s, v8.4s, v0.s[2]  \n"
195                 "fmla   v22.4s, v8.4s, v0.s[3]  \n"
196 
197                 "fmla   v17.4s, v9.4s, v0.s[0]  \n"
198                 "fmla   v19.4s, v9.4s, v0.s[1]  \n"
199                 "fmla   v21.4s, v9.4s, v0.s[2]  \n"
200                 "fmla   v23.4s, v9.4s, v0.s[3]  \n"
201 
202                 "fmla   v24.4s, v8.4s, v1.s[0]  \n"
203                 "fmla   v26.4s, v8.4s, v1.s[1]  \n"
204                 "fmla   v28.4s, v8.4s, v1.s[2]  \n"
205                 "fmla   v30.4s, v8.4s, v1.s[3]  \n"
206 
207                 "fmla   v25.4s, v9.4s, v1.s[0]  \n"
208                 "fmla   v27.4s, v9.4s, v1.s[1]  \n"
209                 "fmla   v29.4s, v9.4s, v1.s[2]  \n"
210                 "fmla   v31.4s, v9.4s, v1.s[3]  \n"
211 
212                 "prfm   pldl1keep, [%8, #512]   \n"
213                 "ld1    {v12.4s, v13.4s, v14.4s, v15.4s}, [%8], #64 \n"
214 
215                 "fmla   v16.4s, v10.4s, v2.s[0] \n"
216                 "fmla   v18.4s, v10.4s, v2.s[1] \n"
217                 "fmla   v20.4s, v10.4s, v2.s[2] \n"
218                 "fmla   v22.4s, v10.4s, v2.s[3] \n"
219 
220                 "fmla   v17.4s, v11.4s, v2.s[0] \n"
221                 "fmla   v19.4s, v11.4s, v2.s[1] \n"
222                 "fmla   v21.4s, v11.4s, v2.s[2] \n"
223                 "fmla   v23.4s, v11.4s, v2.s[3] \n"
224 
225                 "fmla   v24.4s, v10.4s, v3.s[0] \n"
226                 "fmla   v26.4s, v10.4s, v3.s[1] \n"
227                 "fmla   v28.4s, v10.4s, v3.s[2] \n"
228                 "fmla   v30.4s, v10.4s, v3.s[3] \n"
229 
230                 "fmla   v25.4s, v11.4s, v3.s[0] \n"
231                 "fmla   v27.4s, v11.4s, v3.s[1] \n"
232                 "fmla   v29.4s, v11.4s, v3.s[2] \n"
233                 "fmla   v31.4s, v11.4s, v3.s[3] \n"
234 
235                 "prfm   pldl1keep, [%9, #512]   \n"
236                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64     \n"
237 
238                 "fmla   v16.4s, v12.4s, v4.s[0] \n"
239                 "fmla   v18.4s, v12.4s, v4.s[1] \n"
240                 "fmla   v20.4s, v12.4s, v4.s[2] \n"
241                 "fmla   v22.4s, v12.4s, v4.s[3] \n"
242 
243                 "fmla   v17.4s, v13.4s, v4.s[0] \n"
244                 "fmla   v19.4s, v13.4s, v4.s[1] \n"
245                 "fmla   v21.4s, v13.4s, v4.s[2] \n"
246                 "fmla   v23.4s, v13.4s, v4.s[3] \n"
247 
248                 "fmla   v24.4s, v12.4s, v5.s[0] \n"
249                 "fmla   v26.4s, v12.4s, v5.s[1] \n"
250                 "fmla   v28.4s, v12.4s, v5.s[2] \n"
251                 "fmla   v30.4s, v12.4s, v5.s[3] \n"
252 
253                 "fmla   v25.4s, v13.4s, v5.s[0] \n"
254                 "fmla   v27.4s, v13.4s, v5.s[1] \n"
255                 "fmla   v29.4s, v13.4s, v5.s[2] \n"
256                 "fmla   v31.4s, v13.4s, v5.s[3] \n"
257 
258                 "subs   w4, w4, #1              \n"
259 
260                 "fmla   v16.4s, v14.4s, v6.s[0] \n"
261                 "fmla   v18.4s, v14.4s, v6.s[1] \n"
262                 "fmla   v20.4s, v14.4s, v6.s[2] \n"
263                 "fmla   v22.4s, v14.4s, v6.s[3] \n"
264 
265                 "fmla   v17.4s, v15.4s, v6.s[0] \n"
266                 "fmla   v19.4s, v15.4s, v6.s[1] \n"
267                 "fmla   v21.4s, v15.4s, v6.s[2] \n"
268                 "fmla   v23.4s, v15.4s, v6.s[3] \n"
269 
270                 "fmla   v24.4s, v14.4s, v7.s[0] \n"
271                 "fmla   v26.4s, v14.4s, v7.s[1] \n"
272                 "fmla   v28.4s, v14.4s, v7.s[2] \n"
273                 "fmla   v30.4s, v14.4s, v7.s[3] \n"
274 
275                 "fmla   v25.4s, v15.4s, v7.s[0] \n"
276                 "fmla   v27.4s, v15.4s, v7.s[1] \n"
277                 "fmla   v29.4s, v15.4s, v7.s[2] \n"
278                 "fmla   v31.4s, v15.4s, v7.s[3] \n"
279 
280                 "bne    0b                      \n"
281 
282                 "1:                             \n"
283 
284                 // remain loop
285                 "and    w4, %w21, #3            \n" // w4 = remain = inch & 3;
286                 "cmp    w4, #0                  \n"
287                 "beq    3f                      \n"
288 
289                 "2:                             \n"
290 
291                 "prfm   pldl1keep, [%8, #256]   \n"
292                 "ld1    {v8.4s, v9.4s}, [%8], #32   \n"
293 
294                 "prfm   pldl1keep, [%9, #256]   \n"
295                 "ld1    {v0.4s, v1.4s}, [%9], #32   \n"
296 
297                 "fmla   v16.4s, v8.4s, v0.s[0]  \n"
298                 "fmla   v18.4s, v8.4s, v0.s[1]  \n"
299                 "fmla   v20.4s, v8.4s, v0.s[2]  \n"
300                 "fmla   v22.4s, v8.4s, v0.s[3]  \n"
301 
302                 "fmla   v17.4s, v9.4s, v0.s[0]  \n"
303                 "fmla   v19.4s, v9.4s, v0.s[1]  \n"
304                 "fmla   v21.4s, v9.4s, v0.s[2]  \n"
305                 "fmla   v23.4s, v9.4s, v0.s[3]  \n"
306 
307                 "subs   w4, w4, #1              \n"
308 
309                 "fmla   v24.4s, v8.4s, v1.s[0]  \n"
310                 "fmla   v26.4s, v8.4s, v1.s[1]  \n"
311                 "fmla   v28.4s, v8.4s, v1.s[2]  \n"
312                 "fmla   v30.4s, v8.4s, v1.s[3]  \n"
313 
314                 "fmla   v25.4s, v9.4s, v1.s[0]  \n"
315                 "fmla   v27.4s, v9.4s, v1.s[1]  \n"
316                 "fmla   v29.4s, v9.4s, v1.s[2]  \n"
317                 "fmla   v31.4s, v9.4s, v1.s[3]  \n"
318 
319                 "bne    2b                      \n"
320 
321                 "3:                             \n"
322 
323                 "st1    {v16.4s, v17.4s}, [%0], #32 \n"
324                 "st1    {v18.4s, v19.4s}, [%1], #32 \n"
325                 "st1    {v20.4s, v21.4s}, [%2], #32 \n"
326                 "st1    {v22.4s, v23.4s}, [%3], #32 \n"
327                 "st1    {v24.4s, v25.4s}, [%4], #32 \n"
328                 "st1    {v26.4s, v27.4s}, [%5], #32 \n"
329                 "st1    {v28.4s, v29.4s}, [%6], #32 \n"
330                 "st1    {v30.4s, v31.4s}, [%7], #32 \n"
331 
332                 : "=r"(outptr0), // %0
333                 "=r"(outptr1), // %1
334                 "=r"(outptr2), // %2
335                 "=r"(outptr3), // %3
336                 "=r"(outptr4), // %4
337                 "=r"(outptr5), // %5
338                 "=r"(outptr6), // %6
339                 "=r"(outptr7), // %7
340                 "=r"(tmpptr),  // %8
341                 "=r"(kptr)     // %9
342                 : "0"(outptr0),
343                 "1"(outptr1),
344                 "2"(outptr2),
345                 "3"(outptr3),
346                 "4"(outptr4),
347                 "5"(outptr5),
348                 "6"(outptr6),
349                 "7"(outptr7),
350                 "8"(tmpptr),
351                 "9"(kptr),
352                 "r"(biasptr), // %20
353                 "r"(nn)       // %21
354                 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
355         }
356         for (; i + 3 < size; i += 4)
357         {
358             const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
359             const float* kptr = kernel.channel(p / 8);
360 
361             int nn = inch * maxk; // inch always > 0
362 
363             asm volatile(
364                 "ld1    {v0.4s, v1.4s}, [%20]   \n"
365                 "dup    v16.4s, v0.s[0]         \n"
366                 "dup    v17.4s, v0.s[1]         \n"
367                 "dup    v18.4s, v0.s[2]         \n"
368                 "dup    v19.4s, v0.s[3]         \n"
369                 "dup    v20.4s, v1.s[0]         \n"
370                 "dup    v21.4s, v1.s[1]         \n"
371                 "dup    v22.4s, v1.s[2]         \n"
372                 "dup    v23.4s, v1.s[3]         \n"
373 
374                 // inch loop
375                 "lsr    w4, %w21, #2            \n" // w4 = nn = inch >> 2
376                 "cmp    w4, #0                  \n"
377                 "beq    1f                      \n"
378 
379                 "0:                             \n"
380 
381                 "prfm   pldl1keep, [%8, #512]   \n"
382                 "ld1    {v8.4s, v9.4s, v10.4s, v11.4s}, [%8], #64   \n"
383 
384                 "prfm   pldl1keep, [%9, #512]   \n"
385                 "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64     \n"
386 
387                 "fmla   v16.4s, v8.4s, v0.s[0]  \n"
388                 "fmla   v17.4s, v8.4s, v0.s[1]  \n"
389                 "fmla   v18.4s, v8.4s, v0.s[2]  \n"
390                 "fmla   v19.4s, v8.4s, v0.s[3]  \n"
391                 "fmla   v20.4s, v8.4s, v1.s[0]  \n"
392                 "fmla   v21.4s, v8.4s, v1.s[1]  \n"
393                 "fmla   v22.4s, v8.4s, v1.s[2]  \n"
394                 "fmla   v23.4s, v8.4s, v1.s[3]  \n"
395 
396                 "prfm   pldl1keep, [%9, #512]   \n"
397                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64     \n"
398 
399                 "fmla   v16.4s, v9.4s, v2.s[0]  \n"
400                 "fmla   v17.4s, v9.4s, v2.s[1]  \n"
401                 "fmla   v18.4s, v9.4s, v2.s[2]  \n"
402                 "fmla   v19.4s, v9.4s, v2.s[3]  \n"
403                 "fmla   v20.4s, v9.4s, v3.s[0]  \n"
404                 "fmla   v21.4s, v9.4s, v3.s[1]  \n"
405                 "fmla   v22.4s, v9.4s, v3.s[2]  \n"
406                 "fmla   v23.4s, v9.4s, v3.s[3]  \n"
407 
408                 "subs   w4, w4, #1              \n"
409 
410                 "fmla   v16.4s, v10.4s, v4.s[0] \n"
411                 "fmla   v17.4s, v10.4s, v4.s[1] \n"
412                 "fmla   v18.4s, v10.4s, v4.s[2] \n"
413                 "fmla   v19.4s, v10.4s, v4.s[3] \n"
414                 "fmla   v20.4s, v10.4s, v5.s[0] \n"
415                 "fmla   v21.4s, v10.4s, v5.s[1] \n"
416                 "fmla   v22.4s, v10.4s, v5.s[2] \n"
417                 "fmla   v23.4s, v10.4s, v5.s[3] \n"
418 
419                 "fmla   v16.4s, v11.4s, v6.s[0] \n"
420                 "fmla   v17.4s, v11.4s, v6.s[1] \n"
421                 "fmla   v18.4s, v11.4s, v6.s[2] \n"
422                 "fmla   v19.4s, v11.4s, v6.s[3] \n"
423                 "fmla   v20.4s, v11.4s, v7.s[0] \n"
424                 "fmla   v21.4s, v11.4s, v7.s[1] \n"
425                 "fmla   v22.4s, v11.4s, v7.s[2] \n"
426                 "fmla   v23.4s, v11.4s, v7.s[3] \n"
427 
428                 "bne    0b                      \n"
429 
430                 "1:                             \n"
431 
432                 // remain loop
433                 "and    w4, %w21, #3            \n" // w4 = remain = inch & 3;
434                 "cmp    w4, #0                  \n"
435                 "beq    3f                      \n"
436 
437                 "2:                             \n"
438 
439                 "prfm   pldl1keep, [%8, #128]   \n"
440                 "ld1    {v8.4s}, [%8], #16      \n"
441 
442                 "prfm   pldl1keep, [%9, #256]   \n"
443                 "ld1    {v0.4s, v1.4s}, [%9], #32   \n"
444 
445                 "fmla   v16.4s, v8.4s, v0.s[0]  \n"
446                 "fmla   v17.4s, v8.4s, v0.s[1]  \n"
447                 "fmla   v18.4s, v8.4s, v0.s[2]  \n"
448                 "fmla   v19.4s, v8.4s, v0.s[3]  \n"
449 
450                 "subs   w4, w4, #1              \n"
451 
452                 "fmla   v20.4s, v8.4s, v1.s[0]  \n"
453                 "fmla   v21.4s, v8.4s, v1.s[1]  \n"
454                 "fmla   v22.4s, v8.4s, v1.s[2]  \n"
455                 "fmla   v23.4s, v8.4s, v1.s[3]  \n"
456 
457                 "bne    2b                      \n"
458 
459                 "3:                             \n"
460 
461                 "st1    {v16.4s}, [%0], #16     \n"
462                 "st1    {v17.4s}, [%1], #16     \n"
463                 "st1    {v18.4s}, [%2], #16     \n"
464                 "st1    {v19.4s}, [%3], #16     \n"
465                 "st1    {v20.4s}, [%4], #16     \n"
466                 "st1    {v21.4s}, [%5], #16     \n"
467                 "st1    {v22.4s}, [%6], #16     \n"
468                 "st1    {v23.4s}, [%7], #16     \n"
469 
470                 : "=r"(outptr0), // %0
471                 "=r"(outptr1), // %1
472                 "=r"(outptr2), // %2
473                 "=r"(outptr3), // %3
474                 "=r"(outptr4), // %4
475                 "=r"(outptr5), // %5
476                 "=r"(outptr6), // %6
477                 "=r"(outptr7), // %7
478                 "=r"(tmpptr),  // %8
479                 "=r"(kptr)     // %9
480                 : "0"(outptr0),
481                 "1"(outptr1),
482                 "2"(outptr2),
483                 "3"(outptr3),
484                 "4"(outptr4),
485                 "5"(outptr5),
486                 "6"(outptr6),
487                 "7"(outptr7),
488                 "8"(tmpptr),
489                 "9"(kptr),
490                 "r"(biasptr), // %20
491                 "r"(nn)       // %21
492                 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
493         }
494         for (; i < size; i++)
495         {
496             const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
497             const float* kptr = kernel.channel(p / 8);
498 
499             int nn = inch * maxk; // inch always > 0
500 
501             asm volatile(
502                 "ld1    {v24.4s, v25.4s}, [%20] \n"
503 
504                 // inch loop
505                 "lsr    w4, %w21, #2            \n" // w4 = nn = inch >> 2
506                 "cmp    w4, #0                  \n"
507                 "beq    1f                      \n"
508 
509                 "eor    v16.16b, v16.16b, v16.16b  \n"
510                 "eor    v17.16b, v17.16b, v17.16b  \n"
511                 "eor    v18.16b, v18.16b, v18.16b  \n"
512                 "eor    v19.16b, v19.16b, v19.16b  \n"
513                 "eor    v20.16b, v20.16b, v20.16b  \n"
514                 "eor    v21.16b, v21.16b, v21.16b  \n"
515                 "eor    v22.16b, v22.16b, v22.16b  \n"
516                 "eor    v23.16b, v23.16b, v23.16b  \n"
517 
518                 "0:                             \n"
519 
520                 "prfm   pldl1keep, [%8, #128]   \n"
521                 "ld1    {v8.4s}, [%8], #16      \n"
522 
523                 "prfm   pldl1keep, [%9, #512]   \n"
524                 "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64     \n"
525 
526                 "fmla   v16.4s, v0.4s, v8.s[0]  \n"
527                 "fmla   v17.4s, v1.4s, v8.s[0]  \n"
528                 "fmla   v18.4s, v2.4s, v8.s[1]  \n"
529                 "fmla   v19.4s, v3.4s, v8.s[1]  \n"
530 
531                 "prfm   pldl1keep, [%9, #512]   \n"
532                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64     \n"
533 
534                 "subs   w4, w4, #1              \n"
535 
536                 "fmla   v20.4s, v4.4s, v8.s[2]  \n"
537                 "fmla   v21.4s, v5.4s, v8.s[2]  \n"
538                 "fmla   v22.4s, v6.4s, v8.s[3]  \n"
539                 "fmla   v23.4s, v7.4s, v8.s[3]  \n"
540 
541                 "bne    0b                      \n"
542 
543                 "fadd   v16.4s, v16.4s, v18.4s  \n"
544                 "fadd   v17.4s, v17.4s, v19.4s  \n"
545                 "fadd   v20.4s, v20.4s, v22.4s  \n"
546                 "fadd   v21.4s, v21.4s, v23.4s  \n"
547                 "fadd   v16.4s, v16.4s, v20.4s  \n"
548                 "fadd   v17.4s, v17.4s, v21.4s  \n"
549                 "fadd   v24.4s, v24.4s, v16.4s  \n"
550                 "fadd   v25.4s, v25.4s, v17.4s  \n"
551 
552                 "1:                             \n"
553 
554                 // remain loop
555                 "and    w4, %w21, #3            \n" // w4 = remain = inch & 3;
556                 "cmp    w4, #0                  \n"
557                 "beq    3f                      \n"
558 
559                 "2:                             \n"
560 
561                 "prfm   pldl1keep, [%8, #32]    \n"
562                 "ld1r   {v8.4s}, [%8], #4       \n"
563 
564                 "prfm   pldl1keep, [%9, #256]   \n"
565                 "ld1    {v0.4s, v1.4s}, [%9], #32   \n"
566 
567                 "subs   w4, w4, #1              \n"
568 
569                 "fmla   v24.4s, v8.4s, v0.4s    \n"
570                 "fmla   v25.4s, v8.4s, v1.4s    \n"
571 
572                 "bne    2b                      \n"
573 
574                 "3:                             \n"
575 
576                 "st1    {v24.s}[0],[%0], #4     \n"
577                 "st1    {v24.s}[1],[%1], #4     \n"
578                 "st1    {v24.s}[2],[%2], #4     \n"
579                 "st1    {v24.s}[3],[%3], #4     \n"
580                 "st1    {v25.s}[0],[%4], #4     \n"
581                 "st1    {v25.s}[1],[%5], #4     \n"
582                 "st1    {v25.s}[2],[%6], #4     \n"
583                 "st1    {v25.s}[3],[%7], #4     \n"
584 
585                 : "=r"(outptr0), // %0
586                 "=r"(outptr1), // %1
587                 "=r"(outptr2), // %2
588                 "=r"(outptr3), // %3
589                 "=r"(outptr4), // %4
590                 "=r"(outptr5), // %5
591                 "=r"(outptr6), // %6
592                 "=r"(outptr7), // %7
593                 "=r"(tmpptr),  // %8
594                 "=r"(kptr)     // %9
595                 : "0"(outptr0),
596                 "1"(outptr1),
597                 "2"(outptr2),
598                 "3"(outptr3),
599                 "4"(outptr4),
600                 "5"(outptr5),
601                 "6"(outptr6),
602                 "7"(outptr7),
603                 "8"(tmpptr),
604                 "9"(kptr),
605                 "r"(biasptr), // %20
606                 "r"(nn)       // %21
607                 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25");
608         }
609     }
610 #endif // __aarch64__
611 
612     nn_outch = (outch - remain_outch_start) >> 2;
613 
614     #pragma omp parallel for num_threads(opt.num_threads)
615     for (int pp = 0; pp < nn_outch; pp++)
616     {
617         int p = remain_outch_start + pp * 4;
618 
619         float* outptr0 = top_blob.channel(p);
620         float* outptr1 = top_blob.channel(p + 1);
621         float* outptr2 = top_blob.channel(p + 2);
622         float* outptr3 = top_blob.channel(p + 3);
623 
624         const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
625         const float* biasptr = bias ? bias + p : zeros;
626 
627         int i = 0;
628         for (; i + 7 < size; i += 8)
629         {
630             const float* tmpptr = tmp.channel(i / 8);
631 #if __aarch64__
632             const float* kptr = kernel.channel(p / 8 + (p % 8) / 4);
633 #else
634             const float* kptr = kernel.channel(p / 4);
635 #endif
636 
637             int nn = inch * maxk; // inch always > 0
638 
639 #if __aarch64__
640             asm volatile(
641                 "ld1    {v0.4s}, [%12]          \n"
642                 "dup    v8.4s, v0.s[0]          \n"
643                 "dup    v9.4s, v0.s[0]          \n"
644                 "dup    v10.4s, v0.s[1]         \n"
645                 "dup    v11.4s, v0.s[1]         \n"
646                 "dup    v12.4s, v0.s[2]         \n"
647                 "dup    v13.4s, v0.s[2]         \n"
648                 "dup    v14.4s, v0.s[3]         \n"
649                 "dup    v15.4s, v0.s[3]         \n"
650 
651                 // inch loop
652                 "lsr    w4, %w13, #2            \n" // w4 = nn = inch >> 2
653                 "cmp    w4, #0                  \n"
654                 "beq    1f                      \n"
655 
656                 "0:                             \n"
657 
658                 "prfm   pldl1keep, [%4, #512]   \n"
659                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%4], #64     \n"
660 
661                 "prfm   pldl1keep, [%5, #512]   \n"
662                 "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64     \n"
663 
664                 "fmla   v8.4s, v4.4s, v0.s[0]   \n"
665                 "fmla   v10.4s, v4.4s, v0.s[1]  \n"
666                 "fmla   v12.4s, v4.4s, v0.s[2]  \n"
667                 "fmla   v14.4s, v4.4s, v0.s[3]  \n"
668 
669                 "fmla   v9.4s, v5.4s, v0.s[0]   \n"
670                 "fmla   v11.4s, v5.4s, v0.s[1]  \n"
671                 "fmla   v13.4s, v5.4s, v0.s[2]  \n"
672                 "fmla   v15.4s, v5.4s, v0.s[3]  \n"
673 
674                 "prfm   pldl1keep, [%4, #512]   \n"
675                 "ld1    {v16.4s, v17.4s, v18.4s, v19.4s}, [%4], #64 \n"
676 
677                 "fmla   v8.4s, v6.4s, v1.s[0]   \n"
678                 "fmla   v10.4s, v6.4s, v1.s[1]  \n"
679                 "fmla   v12.4s, v6.4s, v1.s[2]  \n"
680                 "fmla   v14.4s, v6.4s, v1.s[3]  \n"
681 
682                 "fmla   v9.4s, v7.4s, v1.s[0]   \n"
683                 "fmla   v11.4s, v7.4s, v1.s[1]  \n"
684                 "fmla   v13.4s, v7.4s, v1.s[2]  \n"
685                 "fmla   v15.4s, v7.4s, v1.s[3]  \n"
686 
687                 "subs   w4, w4, #1              \n"
688 
689                 "fmla   v8.4s, v16.4s, v2.s[0]  \n"
690                 "fmla   v10.4s, v16.4s, v2.s[1] \n"
691                 "fmla   v12.4s, v16.4s, v2.s[2] \n"
692                 "fmla   v14.4s, v16.4s, v2.s[3] \n"
693 
694                 "fmla   v9.4s, v17.4s, v2.s[0]  \n"
695                 "fmla   v11.4s, v17.4s, v2.s[1] \n"
696                 "fmla   v13.4s, v17.4s, v2.s[2] \n"
697                 "fmla   v15.4s, v17.4s, v2.s[3] \n"
698 
699                 "fmla   v8.4s, v18.4s, v3.s[0]  \n"
700                 "fmla   v10.4s, v18.4s, v3.s[1] \n"
701                 "fmla   v12.4s, v18.4s, v3.s[2] \n"
702                 "fmla   v14.4s, v18.4s, v3.s[3] \n"
703 
704                 "fmla   v9.4s, v19.4s, v3.s[0]  \n"
705                 "fmla   v11.4s, v19.4s, v3.s[1] \n"
706                 "fmla   v13.4s, v19.4s, v3.s[2] \n"
707                 "fmla   v15.4s, v19.4s, v3.s[3] \n"
708 
709                 "bne    0b                      \n"
710 
711                 "1:                             \n"
712 
713                 // remain loop
714                 "and    w4, %w13, #3            \n" // w4 = remain = inch & 3;
715                 "cmp    w4, #0                  \n"
716                 "beq    3f                      \n"
717 
718                 "2:                             \n"
719 
720                 "prfm   pldl1keep, [%4, #256]   \n"
721                 "ld1    {v4.4s, v5.4s}, [%4], #32   \n"
722 
723                 "prfm   pldl1keep, [%5, #128]   \n"
724                 "ld1    {v0.4s}, [%5], #16      \n"
725 
726                 "fmla   v8.4s, v4.4s, v0.s[0]   \n"
727                 "fmla   v10.4s, v4.4s, v0.s[1]  \n"
728                 "fmla   v12.4s, v4.4s, v0.s[2]  \n"
729                 "fmla   v14.4s, v4.4s, v0.s[3]  \n"
730 
731                 "subs   w4, w4, #1              \n"
732 
733                 "fmla   v9.4s, v5.4s, v0.s[0]   \n"
734                 "fmla   v11.4s, v5.4s, v0.s[1]  \n"
735                 "fmla   v13.4s, v5.4s, v0.s[2]  \n"
736                 "fmla   v15.4s, v5.4s, v0.s[3]  \n"
737 
738                 "bne    2b                      \n"
739 
740                 "3:                             \n"
741 
742                 "st1    {v8.4s, v9.4s}, [%0], #32   \n"
743                 "st1    {v10.4s, v11.4s}, [%1], #32 \n"
744                 "st1    {v12.4s, v13.4s}, [%2], #32 \n"
745                 "st1    {v14.4s, v15.4s}, [%3], #32 \n"
746 
747                 : "=r"(outptr0), // %0
748                 "=r"(outptr1), // %1
749                 "=r"(outptr2), // %2
750                 "=r"(outptr3), // %3
751                 "=r"(tmpptr),  // %4
752                 "=r"(kptr)     // %5
753                 : "0"(outptr0),
754                 "1"(outptr1),
755                 "2"(outptr2),
756                 "3"(outptr3),
757                 "4"(tmpptr),
758                 "5"(kptr),
759                 "r"(biasptr), // %12
760                 "r"(nn)       // %13
761                 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
762 #else  // __aarch64__
763             asm volatile(
764                 "vld1.f32   {d0-d1}, [%12]      \n"
765                 "vdup.f32   q8, d0[0]           \n"
766                 "vdup.f32   q9, d0[0]           \n"
767                 "vdup.f32   q10, d0[1]          \n"
768                 "vdup.f32   q11, d0[1]          \n"
769                 "vdup.f32   q12, d1[0]          \n"
770                 "vdup.f32   q13, d1[0]          \n"
771                 "vdup.f32   q14, d1[1]          \n"
772                 "vdup.f32   q15, d1[1]          \n"
773 
774                 // inch loop
775                 "lsr        r4, %13, #2         \n" // r4 = nn = inch >> 2
776                 "cmp        r4, #0              \n"
777                 "beq        1f                  \n"
778 
779                 "0:                             \n"
780 
781                 "pld        [%4, #512]          \n"
782                 "vldm       %4!, {d8-d15}       \n"
783                 //                 "vld1.f32   {d8-d11}, [%4 :128]!    \n"
784                 //                 "vld1.f32   {d12-d15}, [%4 :128]!   \n"
785 
786                 "pld        [%5, #512]          \n"
787                 "vldm       %5!, {d0-d7}       \n"
788                 //                 "vld1.f32   {d0-d3}, [%5 :128]! \n"
789                 //                 "vld1.f32   {d4-d7}, [%5 :128]! \n"
790 
791                 "vmla.f32   q8, q4, d0[0]       \n"
792                 "vmla.f32   q10, q4, d0[1]      \n"
793                 "vmla.f32   q12, q4, d1[0]      \n"
794                 "vmla.f32   q14, q4, d1[1]      \n"
795 
796                 "vmla.f32   q9, q5, d0[0]       \n"
797                 "vmla.f32   q11, q5, d0[1]      \n"
798                 "vmla.f32   q13, q5, d1[0]      \n"
799                 "vmla.f32   q15, q5, d1[1]      \n"
800 
801                 "vmla.f32   q8, q6, d2[0]       \n"
802                 "vmla.f32   q10, q6, d2[1]      \n"
803                 "vmla.f32   q12, q6, d3[0]      \n"
804                 "vmla.f32   q14, q6, d3[1]      \n"
805 
806                 "vmla.f32   q9, q7, d2[0]       \n"
807                 "vmla.f32   q11, q7, d2[1]      \n"
808                 "vmla.f32   q13, q7, d3[0]      \n"
809                 "vmla.f32   q15, q7, d3[1]      \n"
810 
811                 "pld        [%4, #512]          \n"
812                 "vldm       %4!, {d8-d15}       \n"
813                 //                 "vld1.f32   {d8-d11}, [%4 :128]!    \n"
814                 //                 "vld1.f32   {d12-d15}, [%4 :128]!   \n"
815 
816                 "vmla.f32   q8, q4, d4[0]       \n"
817                 "vmla.f32   q10, q4, d4[1]      \n"
818                 "vmla.f32   q12, q4, d5[0]      \n"
819                 "vmla.f32   q14, q4, d5[1]      \n"
820 
821                 "vmla.f32   q9, q5, d4[0]       \n"
822                 "vmla.f32   q11, q5, d4[1]      \n"
823                 "vmla.f32   q13, q5, d5[0]      \n"
824                 "vmla.f32   q15, q5, d5[1]      \n"
825 
826                 "subs       r4, r4, #1          \n"
827 
828                 "vmla.f32   q8, q6, d6[0]       \n"
829                 "vmla.f32   q10, q6, d6[1]      \n"
830                 "vmla.f32   q12, q6, d7[0]      \n"
831                 "vmla.f32   q14, q6, d7[1]      \n"
832 
833                 "vmla.f32   q9, q7, d6[0]       \n"
834                 "vmla.f32   q11, q7, d6[1]      \n"
835                 "vmla.f32   q13, q7, d7[0]      \n"
836                 "vmla.f32   q15, q7, d7[1]      \n"
837 
838                 "bne        0b                  \n"
839 
840                 "1:                             \n"
841 
842                 // remain loop
843                 "and        r4, %13, #3         \n" // r4 = remain = inch & 3;
844                 "cmp        r4, #0              \n"
845                 "beq        3f                  \n"
846 
847                 "2:                             \n"
848 
849                 "pld        [%4, #256]          \n"
850                 "vld1.f32   {d8-d11}, [%4 :128]!    \n"
851 
852                 "pld        [%5, #128]          \n"
853                 "vld1.f32   {d0-d1}, [%5 :128]!     \n"
854 
855                 "vmla.f32   q8, q4, d0[0]       \n"
856                 "vmla.f32   q10, q4, d0[1]      \n"
857                 "vmla.f32   q12, q4, d1[0]      \n"
858                 "vmla.f32   q14, q4, d1[1]      \n"
859 
860                 "subs       r4, r4, #1          \n"
861 
862                 "vmla.f32   q9, q5, d0[0]       \n"
863                 "vmla.f32   q11, q5, d0[1]      \n"
864                 "vmla.f32   q13, q5, d1[0]      \n"
865                 "vmla.f32   q15, q5, d1[1]      \n"
866 
867                 "bne        2b                  \n"
868 
869                 "3:                             \n"
870 
871                 "vst1.f32   {d16-d19}, [%0 :128]!   \n"
872                 "vst1.f32   {d20-d23}, [%1 :128]!   \n"
873                 "vst1.f32   {d24-d27}, [%2 :128]!   \n"
874                 "vst1.f32   {d28-d31}, [%3 :128]!   \n"
875 
876                 : "=r"(outptr0), // %0
877                 "=r"(outptr1), // %1
878                 "=r"(outptr2), // %2
879                 "=r"(outptr3), // %3
880                 "=r"(tmpptr),  // %4
881                 "=r"(kptr)     // %5
882                 : "0"(outptr0),
883                 "1"(outptr1),
884                 "2"(outptr2),
885                 "3"(outptr3),
886                 "4"(tmpptr),
887                 "5"(kptr),
888                 "r"(biasptr), // %12
889                 "r"(nn)       // %13
890                 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
891 #endif // __aarch64__
892         }
893         for (; i + 3 < size; i += 4)
894         {
895             const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
896 #if __aarch64__
897             const float* kptr = kernel.channel(p / 8 + (p % 8) / 4);
898 #else
899             const float* kptr = kernel.channel(p / 4);
900 #endif
901 
902             int nn = inch * maxk; // inch always > 0
903 
904 #if __aarch64__
905             asm volatile(
906                 "ld1    {v0.4s}, [%12]          \n"
907                 "dup    v8.4s, v0.s[0]          \n"
908                 "dup    v9.4s, v0.s[1]          \n"
909                 "dup    v10.4s, v0.s[2]         \n"
910                 "dup    v11.4s, v0.s[3]         \n"
911 
912                 // inch loop
913                 "lsr    w4, %w13, #2            \n" // w4 = nn = inch >> 2
914                 "cmp    w4, #0                  \n"
915                 "beq    1f                      \n"
916 
917                 "0:                             \n"
918 
919                 "prfm   pldl1keep, [%4, #512]   \n"
920                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%4], #64     \n"
921 
922                 "prfm   pldl1keep, [%5, #512]   \n"
923                 "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64     \n"
924 
925                 "fmla   v8.4s, v4.4s, v0.s[0]   \n"
926                 "fmla   v9.4s, v4.4s, v0.s[1]   \n"
927                 "fmla   v10.4s, v4.4s, v0.s[2]  \n"
928                 "fmla   v11.4s, v4.4s, v0.s[3]  \n"
929 
930                 "fmla   v8.4s, v5.4s, v1.s[0]   \n"
931                 "fmla   v9.4s, v5.4s, v1.s[1]   \n"
932                 "fmla   v10.4s, v5.4s, v1.s[2]  \n"
933                 "fmla   v11.4s, v5.4s, v1.s[3]  \n"
934 
935                 "subs   w4, w4, #1              \n"
936 
937                 "fmla   v8.4s, v6.4s, v2.s[0]   \n"
938                 "fmla   v9.4s, v6.4s, v2.s[1]   \n"
939                 "fmla   v10.4s, v6.4s, v2.s[2]  \n"
940                 "fmla   v11.4s, v6.4s, v2.s[3]  \n"
941 
942                 "fmla   v8.4s, v7.4s, v3.s[0]   \n"
943                 "fmla   v9.4s, v7.4s, v3.s[1]   \n"
944                 "fmla   v10.4s, v7.4s, v3.s[2]  \n"
945                 "fmla   v11.4s, v7.4s, v3.s[3]  \n"
946 
947                 "bne    0b                      \n"
948 
949                 "1:                             \n"
950 
951                 // remain loop
952                 "and    w4, %w13, #3            \n" // w4 = remain = inch & 3;
953                 "cmp    w4, #0                  \n"
954                 "beq    3f                      \n"
955 
956                 "2:                             \n"
957 
958                 "prfm   pldl1keep, [%4, #128]   \n"
959                 "ld1    {v4.4s}, [%4], #16      \n"
960 
961                 "prfm   pldl1keep, [%5, #128]   \n"
962                 "ld1    {v0.4s}, [%5], #16      \n"
963 
964                 "subs   w4, w4, #1              \n"
965 
966                 "fmla   v8.4s, v4.4s, v0.s[0]   \n"
967                 "fmla   v9.4s, v4.4s, v0.s[1]   \n"
968                 "fmla   v10.4s, v4.4s, v0.s[2]  \n"
969                 "fmla   v11.4s, v4.4s, v0.s[3]  \n"
970 
971                 "bne    2b                      \n"
972 
973                 "3:                             \n"
974 
975                 "st1    {v8.4s}, [%0], #16      \n"
976                 "st1    {v9.4s}, [%1], #16      \n"
977                 "st1    {v10.4s}, [%2], #16     \n"
978                 "st1    {v11.4s}, [%3], #16     \n"
979 
980                 : "=r"(outptr0), // %0
981                 "=r"(outptr1), // %1
982                 "=r"(outptr2), // %2
983                 "=r"(outptr3), // %3
984                 "=r"(tmpptr),  // %4
985                 "=r"(kptr)     // %5
986                 : "0"(outptr0),
987                 "1"(outptr1),
988                 "2"(outptr2),
989                 "3"(outptr3),
990                 "4"(tmpptr),
991                 "5"(kptr),
992                 "r"(biasptr), // %12
993                 "r"(nn)       // %13
994                 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
995 #else  // __aarch64__
996             asm volatile(
997                 "vld1.f32   {d0-d1}, [%12]      \n"
998                 "vdup.f32   q8, d0[0]           \n"
999                 "vdup.f32   q9, d0[1]           \n"
1000                 "vdup.f32   q10, d1[0]          \n"
1001                 "vdup.f32   q11, d1[1]          \n"
1002 
1003                 // inch loop
1004                 "lsr        r4, %13, #2         \n" // r4 = nn = inch >> 2
1005                 "cmp        r4, #0              \n"
1006                 "beq        1f                  \n"
1007 
1008                 "0:                             \n"
1009 
1010                 "pld        [%4, #512]          \n"
1011                 "vldm       %4!, {d8-d15}       \n"
1012                 //                 "vld1.f32   {d8-d11}, [%4 :128]!    \n"
1013                 //                 "vld1.f32   {d12-d15}, [%4 :128]!   \n"
1014 
1015                 "pld        [%5, #512]          \n"
1016                 "vldm       %5!, {d0-d7}       \n"
1017                 //                 "vld1.f32   {d0-d3}, [%5 :128]! \n"
1018                 //                 "vld1.f32   {d4-d7}, [%5 :128]! \n"
1019 
1020                 "vmla.f32   q8, q4, d0[0]       \n"
1021                 "vmla.f32   q9, q4, d0[1]       \n"
1022                 "vmla.f32   q10, q4, d1[0]      \n"
1023                 "vmla.f32   q11, q4, d1[1]      \n"
1024 
1025                 "vmla.f32   q8, q5, d2[0]       \n"
1026                 "vmla.f32   q9, q5, d2[1]       \n"
1027                 "vmla.f32   q10, q5, d3[0]      \n"
1028                 "vmla.f32   q11, q5, d3[1]      \n"
1029 
1030                 "subs       r4, r4, #1          \n"
1031 
1032                 "vmla.f32   q8, q6, d4[0]       \n"
1033                 "vmla.f32   q9, q6, d4[1]       \n"
1034                 "vmla.f32   q10, q6, d5[0]      \n"
1035                 "vmla.f32   q11, q6, d5[1]      \n"
1036 
1037                 "vmla.f32   q8, q7, d6[0]       \n"
1038                 "vmla.f32   q9, q7, d6[1]       \n"
1039                 "vmla.f32   q10, q7, d7[0]      \n"
1040                 "vmla.f32   q11, q7, d7[1]      \n"
1041 
1042                 "bne        0b                  \n"
1043 
1044                 "1:                             \n"
1045 
1046                 // remain loop
1047                 "and        r4, %13, #3         \n" // r4 = remain = inch & 3;
1048                 "cmp        r4, #0              \n"
1049                 "beq        3f                  \n"
1050 
1051                 "2:                             \n"
1052 
1053                 "pld        [%4, #128]          \n"
1054                 "vld1.f32   {d8-d9}, [%4 :128]! \n"
1055 
1056                 "pld        [%5, #128]          \n"
1057                 "vld1.f32   {d0-d1}, [%5 :128]! \n"
1058 
1059                 "subs       r4, r4, #1          \n"
1060 
1061                 "vmla.f32   q8, q4, d0[0]       \n"
1062                 "vmla.f32   q9, q4, d0[1]       \n"
1063                 "vmla.f32   q10, q4, d1[0]      \n"
1064                 "vmla.f32   q11, q4, d1[1]      \n"
1065 
1066                 "bne        2b                  \n"
1067 
1068                 "3:                             \n"
1069 
1070                 "vst1.f32   {d16-d17}, [%0 :128]!   \n"
1071                 "vst1.f32   {d18-d19}, [%1 :128]!   \n"
1072                 "vst1.f32   {d20-d21}, [%2 :128]!   \n"
1073                 "vst1.f32   {d22-d23}, [%3 :128]!   \n"
1074 
1075                 : "=r"(outptr0), // %0
1076                 "=r"(outptr1), // %1
1077                 "=r"(outptr2), // %2
1078                 "=r"(outptr3), // %3
1079                 "=r"(tmpptr),  // %4
1080                 "=r"(kptr)     // %5
1081                 : "0"(outptr0),
1082                 "1"(outptr1),
1083                 "2"(outptr2),
1084                 "3"(outptr3),
1085                 "4"(tmpptr),
1086                 "5"(kptr),
1087                 "r"(biasptr), // %12
1088                 "r"(nn)       // %13
1089                 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
1090 #endif // __aarch64__
1091         }
1092         for (; i < size; i++)
1093         {
1094             const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
1095 #if __aarch64__
1096             const float* kptr = kernel.channel(p / 8 + (p % 8) / 4);
1097 #else
1098             const float* kptr = kernel.channel(p / 4);
1099 #endif
1100 
1101             int nn = inch * maxk; // inch always > 0
1102 
1103 #if __aarch64__
1104             asm volatile(
1105                 "ld1    {v12.4s}, [%12]         \n"
1106 
1107                 // inch loop
1108                 "lsr    w4, %w13, #2            \n" // w4 = nn = inch >> 2
1109                 "cmp    w4, #0                  \n"
1110                 "beq    1f                      \n"
1111 
1112                 "eor    v8.16b, v8.16b, v8.16b  \n"
1113                 "eor    v9.16b, v9.16b, v9.16b  \n"
1114                 "eor    v10.16b, v10.16b, v10.16b  \n"
1115                 "eor    v11.16b, v11.16b, v11.16b  \n"
1116 
1117                 "0:                             \n"
1118 
1119                 "prfm   pldl1keep, [%4, #128]   \n"
1120                 "ld1    {v4.4s}, [%4], #16      \n"
1121 
1122                 "prfm   pldl1keep, [%5, #512]   \n"
1123                 "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64     \n"
1124 
1125                 "subs   w4, w4, #1              \n"
1126 
1127                 "fmla   v8.4s, v0.4s, v4.s[0]   \n"
1128                 "fmla   v9.4s, v1.4s, v4.s[1]   \n"
1129                 "fmla   v10.4s, v2.4s, v4.s[2]  \n"
1130                 "fmla   v11.4s, v3.4s, v4.s[3]  \n"
1131 
1132                 "bne    0b                      \n"
1133 
1134                 "fadd   v8.4s, v8.4s, v9.4s     \n"
1135                 "fadd   v10.4s, v10.4s, v11.4s  \n"
1136                 "fadd   v8.4s, v8.4s, v10.4s    \n"
1137                 "fadd   v12.4s, v12.4s, v8.4s   \n"
1138 
1139                 "1:                             \n"
1140 
1141                 // remain loop
1142                 "and    w4, %w13, #3            \n" // w4 = remain = inch & 3;
1143                 "cmp    w4, #0                  \n"
1144                 "beq    3f                      \n"
1145 
1146                 "2:                             \n"
1147 
1148                 "prfm   pldl1keep, [%4, #32]    \n"
1149                 "ld1r   {v4.4s}, [%4], #4       \n"
1150 
1151                 "prfm   pldl1keep, [%5, #128]   \n"
1152                 "ld1    {v0.4s}, [%5], #16      \n"
1153 
1154                 "subs   w4, w4, #1              \n"
1155 
1156                 "fmla   v12.4s, v4.4s, v0.4s    \n"
1157 
1158                 "bne    2b                      \n"
1159 
1160                 "3:                             \n"
1161 
1162                 "st1    {v12.s}[0], [%0], #4    \n"
1163                 "st1    {v12.s}[1], [%1], #4    \n"
1164                 "st1    {v12.s}[2], [%2], #4    \n"
1165                 "st1    {v12.s}[3], [%3], #4    \n"
1166 
1167                 : "=r"(outptr0), // %0
1168                 "=r"(outptr1), // %1
1169                 "=r"(outptr2), // %2
1170                 "=r"(outptr3), // %3
1171                 "=r"(tmpptr),  // %4
1172                 "=r"(kptr)     // %5
1173                 : "0"(outptr0),
1174                 "1"(outptr1),
1175                 "2"(outptr2),
1176                 "3"(outptr3),
1177                 "4"(tmpptr),
1178                 "5"(kptr),
1179                 "r"(biasptr), // %12
1180                 "r"(nn)       // %13
1181                 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v8", "v9", "v10", "v11", "v12");
1182 #else  // __aarch64__
1183             asm volatile(
1184                 "vld1.f32   {d24-d25}, [%12]    \n"
1185 
1186                 // inch loop
1187                 "lsr        r4, %13, #2         \n" // r4 = nn = inch >> 2
1188                 "cmp        r4, #0              \n"
1189                 "beq        1f                  \n"
1190 
1191                 "veor       q8, q8, q8          \n"
1192                 "veor       q9, q9, q9          \n"
1193                 "veor       q10, q10, q10       \n"
1194                 "veor       q11, q11, q11       \n"
1195 
1196                 "0:                             \n"
1197 
1198                 "pld        [%4, #128]          \n"
1199                 "vld1.f32   {d8-d9}, [%4 :128]! \n"
1200 
1201                 "pld        [%5, #512]          \n"
1202                 "vldm       %5!, {d0-d7}       \n"
1203                 //                 "vld1.f32   {d0-d3}, [%5 :128]! \n"
1204                 //                 "vld1.f32   {d4-d7}, [%5 :128]! \n"
1205 
1206                 "subs       r4, r4, #1          \n"
1207 
1208                 "vmla.f32   q8, q0, d8[0]       \n"
1209                 "vmla.f32   q9, q1, d8[1]       \n"
1210                 "vmla.f32   q10, q2, d9[0]      \n"
1211                 "vmla.f32   q11, q3, d9[1]      \n"
1212 
1213                 "bne        0b                  \n"
1214 
1215                 "vadd.f32   q8, q8, q9          \n"
1216                 "vadd.f32   q10, q10, q11       \n"
1217                 "vadd.f32   q8, q8, q10         \n"
1218                 "vadd.f32   q12, q12, q8        \n"
1219 
1220                 "1:                             \n"
1221 
1222                 // remain loop
1223                 "and        r4, %13, #3         \n" // r4 = remain = inch & 3;
1224                 "cmp        r4, #0              \n"
1225                 "beq        3f                  \n"
1226 
1227                 "2:                             \n"
1228 
1229                 "pld        [%4, #32]           \n"
1230                 "vld1.f32   {d8[],d9[]}, [%4]!  \n"
1231 
1232                 "pld        [%5, #128]          \n"
1233                 "vld1.f32   {d0-d1}, [%5 :128]! \n"
1234 
1235                 "subs       r4, r4, #1          \n"
1236 
1237                 "vmla.f32   q12, q4, q0         \n"
1238 
1239                 "bne        2b                  \n"
1240 
1241                 "3:                             \n"
1242 
1243                 "vst1.f32   {d24[0]}, [%0]!     \n"
1244                 "vst1.f32   {d24[1]}, [%1]!     \n"
1245                 "vst1.f32   {d25[0]}, [%2]!     \n"
1246                 "vst1.f32   {d25[1]}, [%3]!     \n"
1247 
1248                 : "=r"(outptr0), // %0
1249                 "=r"(outptr1), // %1
1250                 "=r"(outptr2), // %2
1251                 "=r"(outptr3), // %3
1252                 "=r"(tmpptr),  // %4
1253                 "=r"(kptr)     // %5
1254                 : "0"(outptr0),
1255                 "1"(outptr1),
1256                 "2"(outptr2),
1257                 "3"(outptr3),
1258                 "4"(tmpptr),
1259                 "5"(kptr),
1260                 "r"(biasptr), // %12
1261                 "r"(nn)       // %13
1262                 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q8", "q9", "q10", "q11", "q12");
1263 #endif // __aarch64__
1264         }
1265     }
1266 
1267     remain_outch_start += nn_outch << 2;
1268 
1269     #pragma omp parallel for num_threads(opt.num_threads)
1270     for (int p = remain_outch_start; p < outch; p++)
1271     {
1272         float* outptr0 = top_blob.channel(p);
1273 
1274         const float bias0 = bias ? bias[p] : 0.f;
1275 
1276         int i = 0;
1277         for (; i + 7 < size; i += 8)
1278         {
1279             const float* tmpptr = tmp.channel(i / 8);
1280 #if __aarch64__
1281             const float* kptr = kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1282 #else
1283             const float* kptr = kernel.channel(p / 4 + p % 4);
1284 #endif
1285 
1286             int nn = inch * maxk; // inch always > 0
1287 
1288 #if __aarch64__
1289             asm volatile(
1290                 "dup    v8.4s, %w6              \n"
1291                 "dup    v9.4s, %w6              \n"
1292 
1293                 // inch loop
1294                 "lsr    w4, %w7, #2             \n" // w4 = nn = inch >> 2
1295                 "cmp    w4, #0                  \n"
1296                 "beq    1f                      \n"
1297 
1298                 "0:                             \n"
1299 
1300                 "prfm   pldl1keep, [%1, #512]   \n"
1301                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%1], #64     \n"
1302 
1303                 "prfm   pldl1keep, [%2, #128]   \n"
1304                 "ld1    {v0.4s}, [%2], #16      \n"
1305 
1306                 "fmla   v8.4s, v4.4s, v0.s[0]   \n"
1307                 "fmla   v9.4s, v5.4s, v0.s[0]   \n"
1308 
1309                 "prfm   pldl1keep, [%1, #512]   \n"
1310                 "ld1    {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n"
1311 
1312                 "fmla   v8.4s, v6.4s, v0.s[1]   \n"
1313                 "fmla   v9.4s, v7.4s, v0.s[1]   \n"
1314 
1315                 "subs   w4, w4, #1              \n"
1316 
1317                 "fmla   v8.4s, v12.4s, v0.s[2]  \n"
1318                 "fmla   v9.4s, v13.4s, v0.s[2]  \n"
1319 
1320                 "fmla   v8.4s, v14.4s, v0.s[3]  \n"
1321                 "fmla   v9.4s, v15.4s, v0.s[3]  \n"
1322 
1323                 "bne    0b                      \n"
1324 
1325                 "1:                             \n"
1326 
1327                 // remain loop
1328                 "and    w4, %w7, #3             \n" // w4 = remain = inch & 3;
1329                 "cmp    w4, #0                  \n"
1330                 "beq    3f                      \n"
1331 
1332                 "2:                             \n"
1333 
1334                 "prfm   pldl1keep, [%1, #256]   \n"
1335                 "ld1    {v4.4s, v5.4s}, [%1], #32   \n"
1336 
1337                 "prfm   pldl1keep, [%2, #32]    \n"
1338                 "ld1r   {v0.4s}, [%2], #4       \n"
1339 
1340                 "subs   w4, w4, #1              \n"
1341 
1342                 "fmla   v8.4s, v4.4s, v0.4s     \n"
1343                 "fmla   v9.4s, v5.4s, v0.4s     \n"
1344 
1345                 "bne    2b                      \n"
1346 
1347                 "3:                             \n"
1348 
1349                 "st1    {v8.4s, v9.4s}, [%0], #32   \n"
1350 
1351                 : "=r"(outptr0), // %0
1352                 "=r"(tmpptr),  // %1
1353                 "=r"(kptr)     // %2
1354                 : "0"(outptr0),
1355                 "1"(tmpptr),
1356                 "2"(kptr),
1357                 "r"(bias0), // %6
1358                 "r"(nn)     // %7
1359                 : "cc", "memory", "x4", "v0", "v4", "v5", "v6", "v7", "v8", "v9", "v12", "v13", "v14", "v15");
1360 #else  // __aarch64__
1361             asm volatile(
1362                 "vdup.f32   q8, %6              \n"
1363                 "vdup.f32   q9, %6              \n"
1364 
1365                 // inch loop
1366                 "lsr        r4, %7, #2          \n" // r4 = nn = inch >> 2
1367                 "cmp        r4, #0              \n"
1368                 "beq        1f                  \n"
1369 
1370                 "0:                             \n"
1371 
1372                 "pld        [%1, #512]          \n"
1373                 "vldm       %1!, {d8-d15}       \n"
1374                 //                 "vld1.f32   {d8-d11}, [%1 :128]!    \n"
1375                 //                 "vld1.f32   {d12-d15}, [%1 :128]!   \n"
1376 
1377                 "pld        [%2, #128]          \n"
1378                 "vld1.f32   {d0-d1}, [%2 :128]! \n"
1379 
1380                 "vmla.f32   q8, q4, d0[0]       \n"
1381                 "vmla.f32   q9, q5, d0[0]       \n"
1382 
1383                 "pld        [%1, #512]          \n"
1384                 "vldm       %1!, {d24-d31}      \n"
1385                 //                 "vld1.f32   {d24-d27}, [%1 :128]!   \n"
1386                 //                 "vld1.f32   {d28-d31}, [%1 :128]!   \n"
1387 
1388                 "vmla.f32   q8, q6, d0[1]       \n"
1389                 "vmla.f32   q9, q7, d0[1]       \n"
1390 
1391                 "subs       r4, r4, #1          \n"
1392 
1393                 "vmla.f32   q8, q12, d1[0]      \n"
1394                 "vmla.f32   q9, q13, d1[0]      \n"
1395 
1396                 "vmla.f32   q8, q14, d1[1]      \n"
1397                 "vmla.f32   q9, q15, d1[1]      \n"
1398 
1399                 "bne        0b                  \n"
1400 
1401                 "1:                             \n"
1402 
1403                 // remain loop
1404                 "and        r4, %7, #3          \n" // r4 = remain = inch & 3;
1405                 "cmp        r4, #0              \n"
1406                 "beq        3f                  \n"
1407 
1408                 "2:                             \n"
1409 
1410                 "pld        [%1, #256]          \n"
1411                 "vld1.f32   {d8-d11}, [%1 :128]!    \n"
1412 
1413                 "pld        [%2, #32]           \n"
1414                 "vld1.f32   {d0[],d1[]}, [%2]!  \n"
1415 
1416                 "subs       r4, r4, #1          \n"
1417 
1418                 "vmla.f32   q8, q4, q0          \n"
1419                 "vmla.f32   q9, q5, q0          \n"
1420 
1421                 "bne        2b                  \n"
1422 
1423                 "3:                             \n"
1424 
1425                 "vst1.f32   {d16-d19}, [%0 :128]!   \n"
1426 
1427                 : "=r"(outptr0), // %0
1428                 "=r"(tmpptr),  // %1
1429                 "=r"(kptr)     // %2
1430                 : "0"(outptr0),
1431                 "1"(tmpptr),
1432                 "2"(kptr),
1433                 "r"(bias0), // %6
1434                 "r"(nn)     // %7
1435                 : "cc", "memory", "r4", "q0", "q4", "q5", "q6", "q7", "q8", "q9", "q12", "q13", "q14", "q15");
1436 #endif // __aarch64__
1437         }
1438         for (; i + 3 < size; i += 4)
1439         {
1440             const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
1441 #if __aarch64__
1442             const float* kptr = kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1443 #else
1444             const float* kptr = kernel.channel(p / 4 + p % 4);
1445 #endif
1446 
1447             int nn = inch * maxk; // inch always > 0
1448 
1449 #if __aarch64__
1450             asm volatile(
1451                 "dup    v8.4s, %w6              \n"
1452 
1453                 // inch loop
1454                 "lsr    w4, %w7, #2             \n" // w4 = nn = inch >> 2
1455                 "cmp    w4, #0                  \n"
1456                 "beq    1f                      \n"
1457 
1458                 "0:                             \n"
1459 
1460                 "prfm   pldl1keep, [%1, #512]   \n"
1461                 "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%1], #64     \n"
1462 
1463                 "prfm   pldl1keep, [%2, #128]   \n"
1464                 "ld1    {v0.4s}, [%2], #16      \n"
1465 
1466                 "subs   w4, w4, #1              \n"
1467 
1468                 "fmla   v8.4s, v4.4s, v0.s[0]   \n"
1469                 "fmla   v8.4s, v5.4s, v0.s[1]   \n"
1470                 "fmla   v8.4s, v6.4s, v0.s[2]   \n"
1471                 "fmla   v8.4s, v7.4s, v0.s[3]   \n"
1472 
1473                 "bne    0b                      \n"
1474 
1475                 "1:                             \n"
1476 
1477                 // remain loop
1478                 "and    w4, %w7, #3             \n" // w4 = remain = inch & 3;
1479                 "cmp    w4, #0                  \n"
1480                 "beq    3f                      \n"
1481 
1482                 "2:                             \n"
1483 
1484                 "prfm   pldl1keep, [%1, #128]   \n"
1485                 "ld1    {v4.4s}, [%1], #16      \n"
1486 
1487                 "prfm   pldl1keep, [%2, #32]    \n"
1488                 "ld1r   {v0.4s}, [%2], #4       \n"
1489 
1490                 "subs   w4, w4, #1              \n"
1491 
1492                 "fmla   v8.4s, v4.4s, v0.4s     \n"
1493 
1494                 "bne    2b                      \n"
1495 
1496                 "3:                             \n"
1497 
1498                 "st1    {v8.4s}, [%0], #16      \n"
1499 
1500                 : "=r"(outptr0), // %0
1501                 "=r"(tmpptr),  // %1
1502                 "=r"(kptr)     // %2
1503                 : "0"(outptr0),
1504                 "1"(tmpptr),
1505                 "2"(kptr),
1506                 "r"(bias0), // %6
1507                 "r"(nn)     // %7
1508                 : "cc", "memory", "x4", "v0", "v4", "v5", "v6", "v7", "v8");
1509 #else  // __aarch64__
1510             asm volatile(
1511                 "vdup.f32   q8, %6              \n"
1512 
1513                 // inch loop
1514                 "lsr        r4, %7, #2          \n" // r4 = nn = inch >> 2
1515                 "cmp        r4, #0              \n"
1516                 "beq        1f                  \n"
1517 
1518                 "0:                             \n"
1519 
1520                 "pld        [%1, #512]          \n"
1521                 "vldm       %1!, {d8-d15}       \n"
1522                 //                 "vld1.f32   {d8-d11}, [%1 :128]!    \n"
1523                 //                 "vld1.f32   {d12-d15}, [%1 :128]!   \n"
1524 
1525                 "pld        [%2, #128]          \n"
1526                 "vld1.f32   {d0-d1}, [%2]!      \n"
1527 
1528                 "subs       r4, r4, #1          \n"
1529 
1530                 "vmla.f32   q8, q4, d0[0]       \n"
1531                 "vmla.f32   q8, q5, d0[1]       \n"
1532                 "vmla.f32   q8, q6, d1[0]       \n"
1533                 "vmla.f32   q8, q7, d1[1]       \n"
1534 
1535                 "bne        0b                  \n"
1536 
1537                 "1:                             \n"
1538 
1539                 // remain loop
1540                 "and        r4, %7, #3          \n" // r4 = remain = inch & 3;
1541                 "cmp        r4, #0              \n"
1542                 "beq        3f                  \n"
1543 
1544                 "2:                             \n"
1545 
1546                 "pld        [%1, #128]          \n"
1547                 "vld1.f32   {d8-d9}, [%1 :128]! \n"
1548 
1549                 "pld        [%2, #32]           \n"
1550                 "vld1.f32   {d0[],d1[]}, [%2]!  \n"
1551 
1552                 "subs       r4, r4, #1          \n"
1553 
1554                 "vmla.f32   q8, q4, q0          \n"
1555 
1556                 "bne        2b                  \n"
1557 
1558                 "3:                             \n"
1559 
1560                 "vst1.f32   {d16-d17}, [%0 :128]!   \n"
1561 
1562                 : "=r"(outptr0), // %0
1563                 "=r"(tmpptr),  // %1
1564                 "=r"(kptr)     // %2
1565                 : "0"(outptr0),
1566                 "1"(tmpptr),
1567                 "2"(kptr),
1568                 "r"(bias0), // %6
1569                 "r"(nn)     // %7
1570                 : "cc", "memory", "r4", "q0", "q4", "q5", "q6", "q7", "q8");
1571 #endif // __aarch64__
1572         }
1573         for (; i < size; i++)
1574         {
1575             const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
1576 #if __aarch64__
1577             const float* kptr = kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1578 #else
1579             const float* kptr = kernel.channel(p / 4 + p % 4);
1580 #endif
1581 
1582             int nn = inch * maxk; // inch always > 0
1583 
1584             float32x4_t _sum0 = vdupq_n_f32(0.f);
1585 
1586             int q = 0;
1587             for (; q + 3 < nn; q += 4)
1588             {
1589                 float32x4_t _p0 = vld1q_f32(tmpptr);
1590                 tmpptr += 4;
1591 
1592                 float32x4_t _k0 = vld1q_f32(kptr);
1593                 kptr += 4;
1594 
1595 #if __aarch64__
1596                 _sum0 = vfmaq_f32(_sum0, _p0, _k0);
1597 #else
1598                 _sum0 = vmlaq_f32(_sum0, _p0, _k0);
1599 #endif
1600             }
1601 
1602 #if __aarch64__
1603             float sum0 = bias0 + vaddvq_f32(_sum0);
1604 #else
1605             float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
1606             float sum0 = bias0 + vget_lane_f32(vpadd_f32(_ss, _ss), 0);
1607 #endif
1608 
1609             for (; q < nn; q++)
1610             {
1611                 sum0 += tmpptr[0] * kptr[0];
1612                 tmpptr++;
1613                 kptr++;
1614             }
1615 
1616             outptr0[0] = sum0;
1617 
1618             outptr0++;
1619         }
1620     }
1621 #else // __ARM_NEON
1622     #pragma omp parallel for num_threads(opt.num_threads)
1623     for (int p = 0; p < outch; p++)
1624     {
1625         float* outptr0 = top_blob.channel(p);
1626 
1627         const float bias0 = bias ? bias[p] : 0.f;
1628 
1629         for (int i = 0; i < size; i++)
1630         {
1631             const float* tmpptr = tmp.channel(i);
1632             const float* kptr = kernel.channel(p);
1633 
1634             int nn = inch * maxk; // inch always > 0
1635 
1636             float sum0 = bias0;
1637 
1638             for (int q = 0; q < nn; q++)
1639             {
1640                 sum0 += tmpptr[0] * kptr[0];
1641                 tmpptr++;
1642                 kptr++;
1643             }
1644 
1645             outptr0[0] = sum0;
1646 
1647             outptr0++;
1648         }
1649     }
1650 #endif // __ARM_NEON
1651 }
1652 
convolution_im2col_sgemm_transform_kernel_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_w,int kernel_h)1653 static void convolution_im2col_sgemm_transform_kernel_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h)
1654 {
1655     const int maxk = kernel_w * kernel_h;
1656 
1657     // interleave
1658     // src = maxk-inch-outch
1659     // dst = 4b-4a-maxk-inch/4a-outch/4b
1660     Mat kernel = _kernel.reshape(maxk, inch, outch);
1661 #if __ARM_NEON
1662 #if __aarch64__
1663     kernel_tm.create(32 * maxk, inch / 4 + inch % 4, outch / 8 + (outch % 8) / 4 + outch % 4);
1664 #else
1665     kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4 + outch % 4);
1666 #endif
1667 
1668     int q = 0;
1669 #if __aarch64__
1670     for (; q + 7 < outch; q += 8)
1671     {
1672         const Mat k0 = kernel.channel(q);
1673         const Mat k1 = kernel.channel(q + 1);
1674         const Mat k2 = kernel.channel(q + 2);
1675         const Mat k3 = kernel.channel(q + 3);
1676         const Mat k4 = kernel.channel(q + 4);
1677         const Mat k5 = kernel.channel(q + 5);
1678         const Mat k6 = kernel.channel(q + 6);
1679         const Mat k7 = kernel.channel(q + 7);
1680 
1681         float* g00 = kernel_tm.channel(q / 8);
1682 
1683         for (int p = 0; p < inch; p++)
1684         {
1685             const float* k00 = k0.row(p);
1686             const float* k10 = k1.row(p);
1687             const float* k20 = k2.row(p);
1688             const float* k30 = k3.row(p);
1689             const float* k40 = k4.row(p);
1690             const float* k50 = k5.row(p);
1691             const float* k60 = k6.row(p);
1692             const float* k70 = k7.row(p);
1693 
1694             for (int k = 0; k < maxk; k++)
1695             {
1696                 g00[0] = k00[k];
1697                 g00[1] = k10[k];
1698                 g00[2] = k20[k];
1699                 g00[3] = k30[k];
1700                 g00[4] = k40[k];
1701                 g00[5] = k50[k];
1702                 g00[6] = k60[k];
1703                 g00[7] = k70[k];
1704 
1705                 g00 += 8;
1706             }
1707         }
1708     }
1709 #endif // __aarch64__
1710     for (; q + 3 < outch; q += 4)
1711     {
1712         const Mat k0 = kernel.channel(q);
1713         const Mat k1 = kernel.channel(q + 1);
1714         const Mat k2 = kernel.channel(q + 2);
1715         const Mat k3 = kernel.channel(q + 3);
1716 
1717 #if __aarch64__
1718         float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4);
1719 #else
1720         float* g00 = kernel_tm.channel(q / 4);
1721 #endif
1722 
1723         for (int p = 0; p < inch; p++)
1724         {
1725             const float* k00 = k0.row(p);
1726             const float* k10 = k1.row(p);
1727             const float* k20 = k2.row(p);
1728             const float* k30 = k3.row(p);
1729 
1730             for (int k = 0; k < maxk; k++)
1731             {
1732                 g00[0] = k00[k];
1733                 g00[1] = k10[k];
1734                 g00[2] = k20[k];
1735                 g00[3] = k30[k];
1736 
1737                 g00 += 4;
1738             }
1739         }
1740     }
1741     for (; q < outch; q++)
1742     {
1743         const Mat k0 = kernel.channel(q);
1744 
1745 #if __aarch64__
1746         float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + q % 4);
1747 #else
1748         float* g00 = kernel_tm.channel(q / 4 + q % 4);
1749 #endif
1750 
1751         for (int p = 0; p < inch; p++)
1752         {
1753             const float* k00 = k0.row(p);
1754 
1755             for (int k = 0; k < maxk; k++)
1756             {
1757                 g00[0] = k00[k];
1758 
1759                 g00 += 1;
1760             }
1761         }
1762     }
1763 #else
1764     kernel_tm = kernel;
1765 #endif // __ARM_NEON
1766 }
1767 
convolution_im2col_sgemm_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)1768 static void convolution_im2col_sgemm_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
1769 {
1770     int w = bottom_blob.w;
1771     int inch = bottom_blob.c;
1772 
1773     int outw = top_blob.w;
1774     int outh = top_blob.h;
1775     const int size = outw * outh;
1776 
1777     const int maxk = kernel_w * kernel_h;
1778 
1779     // im2col
1780     Mat bottom_im2col(size, maxk, inch, 4u, 1, opt.workspace_allocator);
1781     {
1782         const int gap = w * stride_h - outw * stride_w;
1783 
1784         #pragma omp parallel for num_threads(opt.num_threads)
1785         for (int p = 0; p < inch; p++)
1786         {
1787             const Mat img = bottom_blob.channel(p);
1788             float* ptr = bottom_im2col.channel(p);
1789 
1790             for (int u = 0; u < kernel_h; u++)
1791             {
1792                 for (int v = 0; v < kernel_w; v++)
1793                 {
1794                     const float* sptr = img.row<const float>(dilation_h * u) + dilation_w * v;
1795 
1796                     for (int i = 0; i < outh; i++)
1797                     {
1798                         int j = 0;
1799                         for (; j + 3 < outw; j += 4)
1800                         {
1801                             ptr[0] = sptr[0];
1802                             ptr[1] = sptr[stride_w];
1803                             ptr[2] = sptr[stride_w * 2];
1804                             ptr[3] = sptr[stride_w * 3];
1805 
1806                             sptr += stride_w * 4;
1807                             ptr += 4;
1808                         }
1809                         for (; j + 1 < outw; j += 2)
1810                         {
1811                             ptr[0] = sptr[0];
1812                             ptr[1] = sptr[stride_w];
1813 
1814                             sptr += stride_w * 2;
1815                             ptr += 2;
1816                         }
1817                         for (; j < outw; j++)
1818                         {
1819                             ptr[0] = sptr[0];
1820 
1821                             sptr += stride_w;
1822                             ptr += 1;
1823                         }
1824 
1825                         sptr += gap;
1826                     }
1827                 }
1828             }
1829         }
1830     }
1831 
1832     im2col_sgemm_neon(bottom_im2col, top_blob, kernel, _bias, opt);
1833 }
1834