1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv_im2col_sgemm_transform_kernel_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)15 static void conv_im2col_sgemm_transform_kernel_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
16 {
17     const float* kernel = _kernel;
18 
19 #if __ARM_NEON && __aarch64__
20     // kernel memory packed 8 x 8
21     kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4);
22 #else
23     // kernel memory packed 4 x 8
24     kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4);
25 #endif
26 
27     int nn_outch = 0;
28     int remain_outch_start = 0;
29 
30 #if __ARM_NEON && __aarch64__
31     nn_outch = outch >> 3;
32     remain_outch_start = nn_outch << 3;
33 
34     for (int pp = 0; pp < nn_outch; pp++)
35     {
36         int p = pp * 8;
37 
38         const float* k0 = kernel + (p + 0) * inch * kernel_size;
39         const float* k1 = kernel + (p + 1) * inch * kernel_size;
40         const float* k2 = kernel + (p + 2) * inch * kernel_size;
41         const float* k3 = kernel + (p + 3) * inch * kernel_size;
42         const float* k4 = kernel + (p + 4) * inch * kernel_size;
43         const float* k5 = kernel + (p + 5) * inch * kernel_size;
44         const float* k6 = kernel + (p + 6) * inch * kernel_size;
45         const float* k7 = kernel + (p + 7) * inch * kernel_size;
46 
47         float* ktmp = kernel_tm.channel(p / 8);
48 
49         for (int q = 0; q < inch * kernel_size; q++)
50         {
51             ktmp[0] = k0[0];
52             ktmp[1] = k1[0];
53             ktmp[2] = k2[0];
54             ktmp[3] = k3[0];
55             ktmp[4] = k4[0];
56             ktmp[5] = k5[0];
57             ktmp[6] = k6[0];
58             ktmp[7] = k7[0];
59             ktmp += 8;
60 
61             k0 += 1;
62             k1 += 1;
63             k2 += 1;
64             k3 += 1;
65             k4 += 1;
66             k5 += 1;
67             k6 += 1;
68             k7 += 1;
69         }
70     }
71 #endif
72 
73     nn_outch = (outch - remain_outch_start) >> 2;
74 
75     for (int pp = 0; pp < nn_outch; pp++)
76     {
77         int p = remain_outch_start + pp * 4;
78 
79         const float* k0 = kernel + (p + 0) * inch * kernel_size;
80         const float* k1 = kernel + (p + 1) * inch * kernel_size;
81         const float* k2 = kernel + (p + 2) * inch * kernel_size;
82         const float* k3 = kernel + (p + 3) * inch * kernel_size;
83 
84 #if __ARM_NEON && __aarch64__
85         float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4);
86 #else
87         float* ktmp = kernel_tm.channel(p / 4);
88 #endif // __ARM_NEON && __aarch64__
89 
90         for (int q = 0; q < inch * kernel_size; q++)
91         {
92             ktmp[0] = k0[0];
93             ktmp[1] = k1[0];
94             ktmp[2] = k2[0];
95             ktmp[3] = k3[0];
96             ktmp += 4;
97 
98             k0 += 1;
99             k1 += 1;
100             k2 += 1;
101             k3 += 1;
102         }
103     }
104 
105     remain_outch_start += nn_outch << 2;
106 
107     for (int p = remain_outch_start; p < outch; p++)
108     {
109         const float* k0 = kernel + (p + 0) * inch * kernel_size;
110 
111 #if __ARM_NEON && __aarch64__
112         float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4);
113 #else
114         float* ktmp = kernel_tm.channel(p / 4 + p % 4);
115 #endif // __ARM_NEON && __aarch64__
116 
117         for (int q = 0; q < inch * kernel_size; q++)
118         {
119             ktmp[0] = k0[0];
120             ktmp++;
121             k0++;
122         }
123     }
124 }
125 
conv_im2col_sgemm_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)126 static void conv_im2col_sgemm_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias,
127                                    const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
128 {
129     int w = bottom_blob.w;
130     int inch = bottom_blob.c;
131     size_t elemsize = bottom_blob.elemsize;
132 
133     int outw = top_blob.w;
134     int outh = top_blob.h;
135     int outch = top_blob.c;
136 
137     const float* bias = _bias;
138 
139     // im2col
140     Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator);
141     {
142         const int stride = kernel_h * kernel_w * outw * outh;
143         float* ret = (float*)bottom_im2col;
144 
145         #pragma omp parallel for num_threads(opt.num_threads)
146         for (int p = 0; p < inch; p++)
147         {
148             const float* input = bottom_blob.channel(p);
149             int retID = stride * p;
150             for (int u = 0; u < kernel_h; u++)
151             {
152                 for (int v = 0; v < kernel_w; v++)
153                 {
154                     for (int i = 0; i < outh; i++)
155                     {
156                         for (int j = 0; j < outw; j++)
157                         {
158                             int row = u + i * stride_h;
159                             int col = v + j * stride_w;
160                             int index = row * w + col;
161                             ret[retID] = input[index];
162                             retID++;
163                         }
164                     }
165                 }
166             }
167         }
168     }
169 
170     int kernel_size = kernel_w * kernel_h;
171     int out_size = outw * outh;
172 
173     // bottom_im2col memory packed 8 x 8
174     Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, elemsize, opt.workspace_allocator);
175     {
176         int nn_size = out_size >> 3;
177         int remain_size_start = nn_size << 3;
178 
179         #pragma omp parallel for num_threads(opt.num_threads)
180         for (int ii = 0; ii < nn_size; ii++)
181         {
182             int i = ii * 8;
183 
184             const float* img0 = bottom_im2col.channel(0);
185             img0 += i;
186 
187             float* tmpptr = bottom_tm.channel(i / 8);
188 
189             for (int q = 0; q < inch * kernel_size; q++)
190             {
191 #if __ARM_NEON
192 #if __aarch64__
193                 asm volatile(
194                     "prfm    pldl1keep, [%0, #256]   \n"
195                     "ld1     {v0.4s, v1.4s}, [%0]    \n"
196                     "st1     {v0.4s, v1.4s}, [%1]    \n"
197                     : "=r"(img0),  // %0
198                     "=r"(tmpptr) // %1
199                     : "0"(img0),
200                     "1"(tmpptr)
201                     : "cc", "memory", "v0", "v1");
202 #else
203                 asm volatile(
204                     "pld        [%0, #256]          \n"
205                     "vld1.f32   {d0-d3}, [%0]       \n"
206                     "vst1.f32   {d0-d3}, [%1]       \n"
207                     : "=r"(img0),  // %0
208                     "=r"(tmpptr) // %1
209                     : "0"(img0),
210                     "1"(tmpptr)
211                     : "memory", "q0", "q1");
212 #endif // __aarch64__
213 #else
214                 tmpptr[0] = img0[0];
215                 tmpptr[1] = img0[1];
216                 tmpptr[2] = img0[2];
217                 tmpptr[3] = img0[3];
218                 tmpptr[4] = img0[4];
219                 tmpptr[5] = img0[5];
220                 tmpptr[6] = img0[6];
221                 tmpptr[7] = img0[7];
222 #endif // __ARM_NEON
223                 tmpptr += 8;
224                 img0 += out_size;
225             }
226         }
227 
228         #pragma omp parallel for num_threads(opt.num_threads)
229         for (int i = remain_size_start; i < out_size; i++)
230         {
231             const float* img0 = bottom_im2col.channel(0);
232             img0 += i;
233 
234             float* tmpptr = bottom_tm.channel(i / 8 + i % 8);
235 
236             for (int q = 0; q < inch * kernel_size; q++)
237             {
238                 tmpptr[0] = img0[0];
239 
240                 tmpptr += 1;
241                 img0 += out_size;
242             }
243         }
244     }
245 
246     // sgemm(int M, int N, int L, float* A, float* B, float* C)
247     {
248         //int M = outch;                    // outch
249         int N = outw * outh;                // outsize or out stride
250         int L = kernel_w * kernel_h * inch; // ksize * inch
251 
252         int nn_outch = 0;
253         int remain_outch_start = 0;
254 
255 #if __aarch64__
256         nn_outch = outch >> 3;
257         remain_outch_start = nn_outch << 3;
258 
259         #pragma omp parallel for num_threads(opt.num_threads)
260         for (int pp = 0; pp < nn_outch; pp++)
261         {
262             int i = pp * 8;
263 
264             float* output0 = top_blob.channel(i);
265             float* output1 = top_blob.channel(i + 1);
266             float* output2 = top_blob.channel(i + 2);
267             float* output3 = top_blob.channel(i + 3);
268             float* output4 = top_blob.channel(i + 4);
269             float* output5 = top_blob.channel(i + 5);
270             float* output6 = top_blob.channel(i + 6);
271             float* output7 = top_blob.channel(i + 7);
272 
273             const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
274             const float* biasptr = bias ? bias + i : zeros;
275 
276             int j = 0;
277             for (; j + 7 < N; j = j + 8)
278             {
279                 const float* vb = bottom_tm.channel(j / 8);
280                 const float* va = kernel_tm.channel(i / 8);
281 #if __ARM_NEON
282                 asm volatile(
283                     "ld1    {v0.4s, v1.4s}, [%21]   \n"
284                     "dup    v16.4s, v0.s[0]         \n" // sum0
285                     "dup    v17.4s, v0.s[0]         \n"
286                     "dup    v18.4s, v0.s[1]         \n" // sum1
287                     "dup    v19.4s, v0.s[1]         \n"
288                     "dup    v20.4s, v0.s[2]         \n" // sum2
289                     "dup    v21.4s, v0.s[2]         \n"
290                     "dup    v22.4s, v0.s[3]         \n" // sum3
291                     "dup    v23.4s, v0.s[3]         \n"
292                     "dup    v24.4s, v1.s[0]         \n" // sum4
293                     "dup    v25.4s, v1.s[0]         \n"
294                     "dup    v26.4s, v1.s[1]         \n" // sum5
295                     "dup    v27.4s, v1.s[1]         \n"
296                     "dup    v28.4s, v1.s[2]         \n" // sum6
297                     "dup    v29.4s, v1.s[2]         \n"
298                     "dup    v30.4s, v1.s[3]         \n" // sum7
299                     "dup    v31.4s, v1.s[3]         \n"
300 
301                     "lsr         w4, %w20, #2            \n" // r4 = nn = L >> 2
302                     "cmp         w4, #0                  \n"
303                     "beq         1f                      \n"
304 
305                     "0:                                  \n" // for (; k+3<L; k=k+4)
306 
307                     "prfm   pldl1keep, [%9, #512]                       \n"
308                     "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64     \n" // kernel
309                     "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64     \n"
310 
311                     "prfm   pldl1keep, [%8, #512]                       \n"
312                     "ld1    {v8.4s, v9.4s, v10.4s, v11.4s}, [%8], #64   \n" // data
313                     "ld1    {v12.4s, v13.4s, v14.4s, v15.4s}, [%8], #64 \n"
314                     // k0
315                     "fmla    v16.4s, v8.4s, v0.s[0]      \n" // sum0 += (a00-a70) * k00
316                     "fmla    v17.4s, v9.4s, v0.s[0]      \n" //
317                     "fmla    v18.4s, v8.4s, v0.s[1]      \n" // sum1 += (a00-a70) * k10
318                     "fmla    v19.4s, v9.4s, v0.s[1]      \n" //
319                     "fmla    v20.4s, v8.4s, v0.s[2]      \n" // sum2 += (a00-a70) * k20
320                     "fmla    v21.4s, v9.4s, v0.s[2]      \n" //
321                     "fmla    v22.4s, v8.4s, v0.s[3]      \n" // sum3 += (a00-a70) * k30
322                     "fmla    v23.4s, v9.4s, v0.s[3]      \n" //
323                     "fmla    v24.4s, v8.4s, v1.s[0]      \n" // sum4 += (a00-a70) * k40
324                     "fmla    v25.4s, v9.4s, v1.s[0]      \n" //
325                     "fmla    v26.4s, v8.4s, v1.s[1]      \n" // sum5 += (a00-a70) * k50
326                     "fmla    v27.4s, v9.4s, v1.s[1]      \n" //
327                     "fmla    v28.4s, v8.4s, v1.s[2]      \n" // sum6 += (a00-a70) * k60
328                     "fmla    v29.4s, v9.4s, v1.s[2]      \n" //
329                     "fmla    v30.4s, v8.4s, v1.s[3]      \n" // sum7 += (a00-a70) * k70
330                     "fmla    v31.4s, v9.4s, v1.s[3]      \n" //
331                     // k1
332                     "fmla    v16.4s, v10.4s, v2.s[0]     \n" // sum0 += (a01-a71) * k01
333                     "fmla    v17.4s, v11.4s, v2.s[0]     \n" //
334                     "fmla    v18.4s, v10.4s, v2.s[1]     \n" // sum1 += (a01-a71) * k11
335                     "fmla    v19.4s, v11.4s, v2.s[1]     \n" //
336                     "fmla    v20.4s, v10.4s, v2.s[2]     \n" // sum2 += (a01-a71) * k21
337                     "fmla    v21.4s, v11.4s, v2.s[2]     \n" //
338                     "fmla    v22.4s, v10.4s, v2.s[3]     \n" // sum3 += (a01-a71) * k31
339                     "fmla    v23.4s, v11.4s, v2.s[3]     \n" //
340                     "fmla    v24.4s, v10.4s, v3.s[0]     \n" // sum4 += (a01-a71) * k41
341                     "fmla    v25.4s, v11.4s, v3.s[0]     \n" //
342                     "fmla    v26.4s, v10.4s, v3.s[1]     \n" // sum5 += (a01-a71) * k51
343                     "fmla    v27.4s, v11.4s, v3.s[1]     \n" //
344                     "fmla    v28.4s, v10.4s, v3.s[2]     \n" // sum6 += (a01-a71) * k61
345                     "fmla    v29.4s, v11.4s, v3.s[2]     \n" //
346                     "fmla    v30.4s, v10.4s, v3.s[3]     \n" // sum7 += (a01-a71) * k71
347                     "fmla    v31.4s, v11.4s, v3.s[3]     \n" //
348                     // k2
349                     "fmla    v16.4s, v12.4s, v4.s[0]     \n" // sum0 += (a02-a72) * k02
350                     "fmla    v17.4s, v13.4s, v4.s[0]     \n" //
351                     "fmla    v18.4s, v12.4s, v4.s[1]     \n" // sum1 += (a02-a72) * k12
352                     "fmla    v19.4s, v13.4s, v4.s[1]     \n" //
353                     "fmla    v20.4s, v12.4s, v4.s[2]     \n" // sum2 += (a02-a72) * k22
354                     "fmla    v21.4s, v13.4s, v4.s[2]     \n" //
355                     "fmla    v22.4s, v12.4s, v4.s[3]     \n" // sum3 += (a02-a72) * k32
356                     "fmla    v23.4s, v13.4s, v4.s[3]     \n" //
357                     "fmla    v24.4s, v12.4s, v5.s[0]     \n" // sum4 += (a02-a72) * k42
358                     "fmla    v25.4s, v13.4s, v5.s[0]     \n" //
359                     "fmla    v26.4s, v12.4s, v5.s[1]     \n" // sum5 += (a02-a72) * k52
360                     "fmla    v27.4s, v13.4s, v5.s[1]     \n" //
361                     "fmla    v28.4s, v12.4s, v5.s[2]     \n" // sum6 += (a02-a72) * k62
362                     "fmla    v29.4s, v13.4s, v5.s[2]     \n" //
363                     "fmla    v30.4s, v12.4s, v5.s[3]     \n" // sum7 += (a02-a72) * k72
364                     "fmla    v31.4s, v13.4s, v5.s[3]     \n" //
365                     // k3
366                     "fmla    v16.4s, v14.4s, v6.s[0]     \n" // sum0 += (a03-a73) * k03
367                     "fmla    v17.4s, v15.4s, v6.s[0]     \n" //
368                     "fmla    v18.4s, v14.4s, v6.s[1]     \n" // sum1 += (a03-a73) * k13
369                     "fmla    v19.4s, v15.4s, v6.s[1]     \n" //
370                     "fmla    v20.4s, v14.4s, v6.s[2]     \n" // sum2 += (a03-a73) * k23
371                     "fmla    v21.4s, v15.4s, v6.s[2]     \n" //
372                     "fmla    v22.4s, v14.4s, v6.s[3]     \n" // sum3 += (a03-a73) * k33
373                     "fmla    v23.4s, v15.4s, v6.s[3]     \n" //
374                     "fmla    v24.4s, v14.4s, v7.s[0]     \n" // sum4 += (a03-a73) * k43
375                     "fmla    v25.4s, v15.4s, v7.s[0]     \n" //
376                     "fmla    v26.4s, v14.4s, v7.s[1]     \n" // sum5 += (a03-a73) * k53
377                     "fmla    v27.4s, v15.4s, v7.s[1]     \n" //
378                     "fmla    v28.4s, v14.4s, v7.s[2]     \n" // sum6 += (a03-a73) * k63
379                     "fmla    v29.4s, v15.4s, v7.s[2]     \n" //
380                     "fmla    v30.4s, v14.4s, v7.s[3]     \n" // sum7 += (a03-a73) * k73
381                     "fmla    v31.4s, v15.4s, v7.s[3]     \n" //
382 
383                     "subs   w4, w4, #1                   \n"
384                     "bne    0b                           \n"
385 
386                     "1:                                  \n"
387 
388                     // remain loop
389                     "and    w4, %w20, #3                 \n" // w4 = remain = inch & 3;
390                     "cmp    w4, #0                       \n"
391                     "beq    3f                           \n"
392 
393                     "2:                                  \n"
394 
395                     "prfm   pldl1keep, [%9, #256]        \n"
396                     "ld1    {v0.4s, v1.4s}, [%9], #32    \n"
397 
398                     "prfm   pldl1keep, [%8, #256]        \n"
399                     "ld1    {v8.4s, v9.4s}, [%8], #32    \n"
400 
401                     // k0
402                     "fmla    v16.4s, v8.4s, v0.s[0]      \n" // sum0 += (a00-a70) * k00
403                     "fmla    v17.4s, v9.4s, v0.s[0]      \n" //
404                     "fmla    v18.4s, v8.4s, v0.s[1]      \n" // sum1 += (a00-a70) * k10
405                     "fmla    v19.4s, v9.4s, v0.s[1]      \n" //
406                     "fmla    v20.4s, v8.4s, v0.s[2]      \n" // sum2 += (a00-a70) * k20
407                     "fmla    v21.4s, v9.4s, v0.s[2]      \n" //
408                     "fmla    v22.4s, v8.4s, v0.s[3]      \n" // sum3 += (a00-a70) * k30
409                     "fmla    v23.4s, v9.4s, v0.s[3]      \n" //
410                     "fmla    v24.4s, v8.4s, v1.s[0]      \n" // sum4 += (a00-a70) * k40
411                     "fmla    v25.4s, v9.4s, v1.s[0]      \n" //
412                     "fmla    v26.4s, v8.4s, v1.s[1]      \n" // sum5 += (a00-a70) * k50
413                     "fmla    v27.4s, v9.4s, v1.s[1]      \n" //
414                     "fmla    v28.4s, v8.4s, v1.s[2]      \n" // sum6 += (a00-a70) * k60
415                     "fmla    v29.4s, v9.4s, v1.s[2]      \n" //
416                     "fmla    v30.4s, v8.4s, v1.s[3]      \n" // sum7 += (a00-a70) * k70
417                     "fmla    v31.4s, v9.4s, v1.s[3]      \n" //
418 
419                     "subs   w4, w4, #1                   \n"
420 
421                     "bne    2b                           \n"
422 
423                     "3:                                  \n"
424 
425                     "st1    {v16.4s, v17.4s}, [%0]       \n"
426                     "st1    {v18.4s, v19.4s}, [%1]       \n"
427                     "st1    {v20.4s, v21.4s}, [%2]       \n"
428                     "st1    {v22.4s, v23.4s}, [%3]       \n"
429                     "st1    {v24.4s, v25.4s}, [%4]       \n"
430                     "st1    {v26.4s, v27.4s}, [%5]       \n"
431                     "st1    {v28.4s, v29.4s}, [%6]       \n"
432                     "st1    {v30.4s, v31.4s}, [%7]       \n"
433 
434                     : "=r"(output0), // %0
435                     "=r"(output1), // %1
436                     "=r"(output2), // %2
437                     "=r"(output3), // %3
438                     "=r"(output4), // %4
439                     "=r"(output5), // %5
440                     "=r"(output6), // %6
441                     "=r"(output7), // %7
442                     "=r"(vb),      // %8
443                     "=r"(va)       // %9
444                     : "0"(output0),
445                     "1"(output1),
446                     "2"(output2),
447                     "3"(output3),
448                     "4"(output4),
449                     "5"(output5),
450                     "6"(output6),
451                     "7"(output7),
452                     "8"(vb),
453                     "9"(va),
454                     "r"(L),      // %20
455                     "r"(biasptr) // %21
456                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
457 #else
458                 float sum0[8] = {0};
459                 float sum1[8] = {0};
460                 float sum2[8] = {0};
461                 float sum3[8] = {0};
462                 float sum4[8] = {0};
463                 float sum5[8] = {0};
464                 float sum6[8] = {0};
465                 float sum7[8] = {0};
466 
467                 int k = 0;
468                 for (; k + 7 < L; k = k + 8)
469                 {
470                     for (int n = 0; n < 8; n++)
471                     {
472                         sum0[n] += va[0] * vb[n];
473                         sum1[n] += va[1] * vb[n];
474                         sum2[n] += va[2] * vb[n];
475                         sum3[n] += va[3] * vb[n];
476                         sum4[n] += va[4] * vb[n];
477                         sum5[n] += va[5] * vb[n];
478                         sum6[n] += va[6] * vb[n];
479                         sum7[n] += va[7] * vb[n];
480                         va += 8;
481 
482                         sum0[n] += va[0] * vb[n + 8];
483                         sum1[n] += va[1] * vb[n + 8];
484                         sum2[n] += va[2] * vb[n + 8];
485                         sum3[n] += va[3] * vb[n + 8];
486                         sum4[n] += va[4] * vb[n + 8];
487                         sum5[n] += va[5] * vb[n + 8];
488                         sum6[n] += va[6] * vb[n + 8];
489                         sum7[n] += va[7] * vb[n + 8];
490                         va += 8;
491 
492                         sum0[n] += va[0] * vb[n + 16];
493                         sum1[n] += va[1] * vb[n + 16];
494                         sum2[n] += va[2] * vb[n + 16];
495                         sum3[n] += va[3] * vb[n + 16];
496                         sum4[n] += va[4] * vb[n + 16];
497                         sum5[n] += va[5] * vb[n + 16];
498                         sum6[n] += va[6] * vb[n + 16];
499                         sum7[n] += va[7] * vb[n + 16];
500                         va += 8;
501 
502                         sum0[n] += va[0] * vb[n + 24];
503                         sum1[n] += va[1] * vb[n + 24];
504                         sum2[n] += va[2] * vb[n + 24];
505                         sum3[n] += va[3] * vb[n + 24];
506                         sum4[n] += va[4] * vb[n + 24];
507                         sum5[n] += va[5] * vb[n + 24];
508                         sum6[n] += va[6] * vb[n + 24];
509                         sum7[n] += va[7] * vb[n + 24];
510                         va += 8;
511 
512                         sum0[n] += va[0] * vb[n + 32];
513                         sum1[n] += va[1] * vb[n + 32];
514                         sum2[n] += va[2] * vb[n + 32];
515                         sum3[n] += va[3] * vb[n + 32];
516                         sum4[n] += va[4] * vb[n + 32];
517                         sum5[n] += va[5] * vb[n + 32];
518                         sum6[n] += va[6] * vb[n + 32];
519                         sum7[n] += va[7] * vb[n + 32];
520                         va += 8;
521 
522                         sum0[n] += va[0] * vb[n + 40];
523                         sum1[n] += va[1] * vb[n + 40];
524                         sum2[n] += va[2] * vb[n + 40];
525                         sum3[n] += va[3] * vb[n + 40];
526                         sum4[n] += va[4] * vb[n + 40];
527                         sum5[n] += va[5] * vb[n + 40];
528                         sum6[n] += va[6] * vb[n + 40];
529                         sum7[n] += va[7] * vb[n + 40];
530                         va += 8;
531 
532                         sum0[n] += va[0] * vb[n + 48];
533                         sum1[n] += va[1] * vb[n + 48];
534                         sum2[n] += va[2] * vb[n + 48];
535                         sum3[n] += va[3] * vb[n + 48];
536                         sum4[n] += va[4] * vb[n + 48];
537                         sum5[n] += va[5] * vb[n + 48];
538                         sum6[n] += va[6] * vb[n + 48];
539                         sum7[n] += va[7] * vb[n + 48];
540                         va += 8;
541 
542                         sum0[n] += va[0] * vb[n + 56];
543                         sum1[n] += va[1] * vb[n + 56];
544                         sum2[n] += va[2] * vb[n + 56];
545                         sum3[n] += va[3] * vb[n + 56];
546                         sum4[n] += va[4] * vb[n + 56];
547                         sum5[n] += va[5] * vb[n + 56];
548                         sum6[n] += va[6] * vb[n + 56];
549                         sum7[n] += va[7] * vb[n + 56];
550                         va -= 56;
551                     }
552 
553                     va += 64;
554                     vb += 64;
555                 }
556 
557                 for (; k < L; k++)
558                 {
559                     for (int n = 0; n < 8; n++)
560                     {
561                         sum0[n] += va[0] * vb[n];
562                         sum1[n] += va[1] * vb[n];
563                         sum2[n] += va[2] * vb[n];
564                         sum3[n] += va[3] * vb[n];
565                         sum4[n] += va[4] * vb[n];
566                         sum5[n] += va[5] * vb[n];
567                         sum6[n] += va[6] * vb[n];
568                         sum7[n] += va[7] * vb[n];
569                     }
570 
571                     va += 8;
572                     vb += 8;
573                 }
574 
575                 for (int n = 0; n < 8; n++)
576                 {
577                     output0[n] = sum0[n] + biasptr[0];
578                     output1[n] = sum1[n] + biasptr[1];
579                     output2[n] = sum2[n] + biasptr[2];
580                     output3[n] = sum3[n] + biasptr[3];
581                     output4[n] = sum4[n] + biasptr[4];
582                     output5[n] = sum5[n] + biasptr[5];
583                     output6[n] = sum6[n] + biasptr[6];
584                     output7[n] = sum7[n] + biasptr[7];
585                 }
586 #endif // __ARM_NEON
587                 output0 += 8;
588                 output1 += 8;
589                 output2 += 8;
590                 output3 += 8;
591                 output4 += 8;
592                 output5 += 8;
593                 output6 += 8;
594                 output7 += 8;
595             }
596 
597             for (; j < N; j++)
598             {
599                 const float* vb = bottom_tm.channel(j / 8 + j % 8);
600                 const float* va = kernel_tm.channel(i / 8);
601 
602 #if __ARM_NEON
603                 asm volatile(
604                     "ld1    {v14.4s, v15.4s}, [%21]      \n" // sum0_7 inital with bias
605                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
606                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum1
607                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum2
608                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum3
609                     "eor    v20.16b, v20.16b, v20.16b    \n" // sum4
610                     "eor    v21.16b, v21.16b, v21.16b    \n" // sum5
611                     "eor    v22.16b, v22.16b, v22.16b    \n" // sum6
612                     "eor    v23.16b, v23.16b, v23.16b    \n" // sum7
613 
614                     "lsr         w4, %w20, #2            \n" // r4 = nn = L >> 2
615                     "cmp         w4, #0                  \n"
616                     "beq         1f                      \n"
617 
618                     "0:                                  \n" // for (; k+3<L; k=k+4)
619 
620                     "prfm   pldl1keep, [%9, #256]                       \n"
621                     "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64     \n" // k
622                     "ld1    {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64     \n"
623 
624                     "prfm   pldl1keep, [%8, #128]        \n"
625                     "ld1    {v8.4s}, [%8], #16           \n" // d
626 
627                     // k0
628                     "fmla    v16.4s, v0.4s, v8.s[0]      \n" // sum0 += (k00-k70) * a00
629                     "fmla    v17.4s, v1.4s, v8.s[0]      \n" //
630                     "fmla    v18.4s, v2.4s, v8.s[1]      \n" // sum1 += (k01-k71) * a10
631                     "fmla    v19.4s, v3.4s, v8.s[1]      \n" //
632                     "fmla    v20.4s, v4.4s, v8.s[2]      \n" // sum2 += (k02-k72) * a20
633                     "fmla    v21.4s, v5.4s, v8.s[2]      \n" //
634                     "fmla    v22.4s, v6.4s, v8.s[3]      \n" // sum3 += (k03-k73) * a30
635                     "fmla    v23.4s, v7.4s, v8.s[3]      \n" //
636 
637                     "subs   w4, w4, #1                   \n"
638                     "bne    0b                           \n"
639 
640                     "fadd   v16.4s, v16.4s, v18.4s       \n"
641                     "fadd   v17.4s, v17.4s, v19.4s       \n"
642                     "fadd   v20.4s, v20.4s, v22.4s       \n"
643                     "fadd   v21.4s, v21.4s, v23.4s       \n"
644                     "fadd   v16.4s, v16.4s, v20.4s       \n"
645                     "fadd   v17.4s, v17.4s, v21.4s       \n"
646                     "fadd   v14.4s, v14.4s, v16.4s       \n"
647                     "fadd   v15.4s, v15.4s, v17.4s       \n"
648 
649                     "1:                                  \n"
650 
651                     // remain loop
652                     "and    w4, %w20, #3                 \n" // w4 = remain = inch & 3;
653                     "cmp    w4, #0                       \n"
654                     "beq    3f                           \n"
655 
656                     "2:                                  \n"
657 
658                     "prfm   pldl1keep, [%9, #256]        \n"
659                     "ld1    {v0.4s, v1.4s}, [%9], #32    \n"
660                     "prfm   pldl1keep, [%8, #32]         \n"
661                     "ld1r   {v8.4s}, [%8], #4            \n"
662 
663                     // k0
664                     "fmla   v14.4s, v8.4s, v0.4s         \n" // sum0 += (k00-k70) * a00
665                     "fmla   v15.4s, v8.4s, v1.4s         \n" //
666 
667                     "subs   w4, w4, #1                   \n"
668 
669                     "bne    2b                           \n"
670 
671                     "3:                                  \n"
672 
673                     "st1    {v14.s}[0], [%0]             \n"
674                     "st1    {v14.s}[1], [%1]             \n"
675                     "st1    {v14.s}[2], [%2]             \n"
676                     "st1    {v14.s}[3], [%3]             \n"
677                     "st1    {v15.s}[0], [%4]             \n"
678                     "st1    {v15.s}[1], [%5]             \n"
679                     "st1    {v15.s}[2], [%6]             \n"
680                     "st1    {v15.s}[3], [%7]             \n"
681 
682                     : "=r"(output0), // %0
683                     "=r"(output1), // %1
684                     "=r"(output2), // %2
685                     "=r"(output3), // %3
686                     "=r"(output4), // %4
687                     "=r"(output5), // %5
688                     "=r"(output6), // %6
689                     "=r"(output7), // %7
690                     "=r"(vb),      // %8
691                     "=r"(va)       // %9
692                     : "0"(output0),
693                     "1"(output1),
694                     "2"(output2),
695                     "3"(output3),
696                     "4"(output4),
697                     "5"(output5),
698                     "6"(output6),
699                     "7"(output7),
700                     "8"(vb),
701                     "9"(va),
702                     "r"(L),      // %20
703                     "r"(biasptr) // %21
704                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
705 #else
706                 float sum0 = biasptr[0];
707                 float sum1 = biasptr[1];
708                 float sum2 = biasptr[2];
709                 float sum3 = biasptr[3];
710                 float sum4 = biasptr[4];
711                 float sum5 = biasptr[5];
712                 float sum6 = biasptr[6];
713                 float sum7 = biasptr[7];
714 
715                 for (int k = 0; k < L; k++)
716                 {
717                     sum0 += va[0] * vb[0];
718                     sum1 += va[1] * vb[0];
719                     sum2 += va[2] * vb[0];
720                     sum3 += va[3] * vb[0];
721                     sum4 += va[4] * vb[0];
722                     sum5 += va[5] * vb[0];
723                     sum6 += va[6] * vb[0];
724                     sum7 += va[7] * vb[0];
725 
726                     va += 8;
727                     vb += 1;
728                 }
729 
730                 output0[0] = sum0;
731                 output1[0] = sum1;
732                 output2[0] = sum2;
733                 output3[0] = sum3;
734                 output4[0] = sum4;
735                 output5[0] = sum5;
736                 output6[0] = sum6;
737                 output7[0] = sum7;
738 #endif // __ARM_NEON
739                 output0++;
740                 output1++;
741                 output2++;
742                 output3++;
743                 output4++;
744                 output5++;
745                 output6++;
746                 output7++;
747             }
748         }
749 #endif // __aarch64__
750 
751         nn_outch = (outch - remain_outch_start) >> 2;
752 
753         #pragma omp parallel for num_threads(opt.num_threads)
754         for (int pp = 0; pp < nn_outch; pp++)
755         {
756             int i = remain_outch_start + pp * 4;
757 
758             float* output0 = top_blob.channel(i);
759             float* output1 = top_blob.channel(i + 1);
760             float* output2 = top_blob.channel(i + 2);
761             float* output3 = top_blob.channel(i + 3);
762 
763             const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
764             const float* biasptr = bias ? bias + i : zeros;
765 
766             int j = 0;
767             for (; j + 7 < N; j = j + 8)
768             {
769                 const float* vb = bottom_tm.channel(j / 8);
770 #if __ARM_NEON && __aarch64__
771                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
772 #else
773                 const float* va = kernel_tm.channel(i / 4);
774 #endif // __ARM_NEON && __aarch64__
775 
776 #if __ARM_NEON
777 #if __aarch64__
778                 asm volatile(
779                     "ld1    {v0.4s}, [%13]               \n"
780                     "dup    v16.4s, v0.s[0]              \n" // sum0
781                     "dup    v17.4s, v0.s[0]              \n"
782                     "dup    v18.4s, v0.s[1]              \n" // sum1
783                     "dup    v19.4s, v0.s[1]              \n"
784                     "dup    v20.4s, v0.s[2]              \n" // sum2
785                     "dup    v21.4s, v0.s[2]              \n"
786                     "dup    v22.4s, v0.s[3]              \n" // sum3
787                     "dup    v23.4s, v0.s[3]              \n"
788 
789                     "lsr         w4, %w12, #2            \n" // r4 = nn = L >> 2
790                     "cmp         w4, #0                  \n"
791                     "beq         1f                      \n"
792 
793                     "0:                                  \n" // for (; k+3<L; k=k+4)
794 
795                     "prfm   pldl1keep, [%5, #512]                       \n"
796                     "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64     \n" // kernel
797 
798                     "prfm   pldl1keep, [%4, #512]                       \n"
799                     "ld1    {v8.4s, v9.4s, v10.4s, v11.4s}, [%4], #64   \n" // data
800                     "ld1    {v12.4s, v13.4s, v14.4s, v15.4s}, [%4], #64 \n"
801 
802                     "subs   w4, w4, #1                   \n"
803                     // k0
804                     "fmla    v16.4s, v8.4s, v0.s[0]      \n" // sum0 += (a00-a70) * k00
805                     "fmla    v17.4s, v9.4s, v0.s[0]      \n" //
806                     "fmla    v18.4s, v8.4s, v0.s[1]      \n" // sum1 += (a00-a70) * k10
807                     "fmla    v19.4s, v9.4s, v0.s[1]      \n" //
808                     "fmla    v20.4s, v8.4s, v0.s[2]      \n" // sum2 += (a00-a70) * k20
809                     "fmla    v21.4s, v9.4s, v0.s[2]      \n" //
810                     "fmla    v22.4s, v8.4s, v0.s[3]      \n" // sum3 += (a00-a70) * k30
811                     "fmla    v23.4s, v9.4s, v0.s[3]      \n" //
812                     // k1
813                     "fmla    v16.4s, v10.4s, v1.s[0]     \n" // sum0 += (a01-a71) * k01
814                     "fmla    v17.4s, v11.4s, v1.s[0]     \n" //
815                     "fmla    v18.4s, v10.4s, v1.s[1]     \n" // sum1 += (a01-a71) * k11
816                     "fmla    v19.4s, v11.4s, v1.s[1]     \n" //
817                     "fmla    v20.4s, v10.4s, v1.s[2]     \n" // sum2 += (a01-a71) * k21
818                     "fmla    v21.4s, v11.4s, v1.s[2]     \n" //
819                     "fmla    v22.4s, v10.4s, v1.s[3]     \n" // sum3 += (a01-a71) * k31
820                     "fmla    v23.4s, v11.4s, v1.s[3]     \n" //
821                     // k2
822                     "fmla    v16.4s, v12.4s, v2.s[0]     \n" // sum0 += (a02-a72) * k02
823                     "fmla    v17.4s, v13.4s, v2.s[0]     \n" //
824                     "fmla    v18.4s, v12.4s, v2.s[1]     \n" // sum1 += (a02-a72) * k12
825                     "fmla    v19.4s, v13.4s, v2.s[1]     \n" //
826                     "fmla    v20.4s, v12.4s, v2.s[2]     \n" // sum2 += (a02-a72) * k22
827                     "fmla    v21.4s, v13.4s, v2.s[2]     \n" //
828                     "fmla    v22.4s, v12.4s, v2.s[3]     \n" // sum3 += (a02-a72) * k32
829                     "fmla    v23.4s, v13.4s, v2.s[3]     \n" //
830                     // k3
831                     "fmla    v16.4s, v14.4s, v3.s[0]     \n" // sum0 += (a03-a73) * k03
832                     "fmla    v17.4s, v15.4s, v3.s[0]     \n" //
833                     "fmla    v18.4s, v14.4s, v3.s[1]     \n" // sum1 += (a03-a73) * k13
834                     "fmla    v19.4s, v15.4s, v3.s[1]     \n" //
835                     "fmla    v20.4s, v14.4s, v3.s[2]     \n" // sum2 += (a03-a73) * k23
836                     "fmla    v21.4s, v15.4s, v3.s[2]     \n" //
837                     "fmla    v22.4s, v14.4s, v3.s[3]     \n" // sum3 += (a03-a73) * k33
838                     "fmla    v23.4s, v15.4s, v3.s[3]     \n" //
839 
840                     "bne    0b                           \n"
841 
842                     "1:                                  \n"
843 
844                     // remain loop
845                     "and    w4, %w12, #3                 \n" // w4 = remain = inch & 3;
846                     "cmp    w4, #0                       \n"
847                     "beq    3f                           \n"
848 
849                     "2:                                  \n"
850 
851                     "prfm   pldl1keep, [%5, #256]        \n"
852                     "ld1    {v0.4s}, [%5], #16           \n"
853                     "prfm   pldl1keep, [%4, #256]        \n"
854                     "ld1    {v8.4s, v9.4s}, [%4], #32    \n"
855                     // k0
856                     "fmla    v16.4s, v8.4s, v0.s[0]      \n" // sum0 += (a00-a70) * k00
857                     "fmla    v17.4s, v9.4s, v0.s[0]      \n" //
858                     "fmla    v18.4s, v8.4s, v0.s[1]      \n" // sum1 += (a00-a70) * k10
859                     "fmla    v19.4s, v9.4s, v0.s[1]      \n" //
860                     "fmla    v20.4s, v8.4s, v0.s[2]      \n" // sum2 += (a00-a70) * k20
861                     "fmla    v21.4s, v9.4s, v0.s[2]      \n" //
862                     "fmla    v22.4s, v8.4s, v0.s[3]      \n" // sum3 += (a00-a70) * k30
863                     "fmla    v23.4s, v9.4s, v0.s[3]      \n" //
864 
865                     "subs   w4, w4, #1                   \n"
866 
867                     "bne    2b                           \n"
868 
869                     "3:                                  \n"
870 
871                     "st1    {v16.4s, v17.4s}, [%0]       \n"
872                     "st1    {v18.4s, v19.4s}, [%1]       \n"
873                     "st1    {v20.4s, v21.4s}, [%2]       \n"
874                     "st1    {v22.4s, v23.4s}, [%3]       \n"
875 
876                     : "=r"(output0), // %0
877                     "=r"(output1), // %1
878                     "=r"(output2), // %2
879                     "=r"(output3), // %3
880                     "=r"(vb),      // %4
881                     "=r"(va)       // %5
882                     : "0"(output0),
883                     "1"(output1),
884                     "2"(output2),
885                     "3"(output3),
886                     "4"(vb),
887                     "5"(va),
888                     "r"(L),      // %12
889                     "r"(biasptr) // %13
890                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
891 #else
892                 asm volatile(
893                     "vld1.f32   {d0-d1}, [%13]      \n"
894                     "vdup.f32   q8, d0[0]           \n"
895                     "vdup.f32   q9, d0[0]           \n"
896                     "vdup.f32   q10, d0[1]          \n"
897                     "vdup.f32   q11, d0[1]          \n"
898                     "vdup.f32   q12, d1[0]          \n"
899                     "vdup.f32   q13, d1[0]          \n"
900                     "vdup.f32   q14, d1[1]          \n"
901                     "vdup.f32   q15, d1[1]          \n"
902 
903                     "lsr         r4, %12, #2        \n" // r4 = nn = L >> 2
904                     "cmp         r4, #0             \n"
905                     "beq         1f                 \n"
906 
907                     "0:                             \n" // for(; nn != 0; nn--)
908                     "pld        [%5, #512]          \n"
909                     "vldm       %5!, {d0-d7}        \n" // kernel
910                     "pld        [%4, #512]          \n"
911                     "vldm       %4!, {d8-d15}       \n" // data
912 
913                     "vmla.f32   q8, q4, d0[0]       \n" // sum0 = (a00-a07) * k00
914                     "vmla.f32   q9, q5, d0[0]       \n"
915                     "vmla.f32   q10, q4, d0[1]      \n" // sum1 = (a00-a07) * k10
916                     "vmla.f32   q11, q5, d0[1]      \n"
917                     "vmla.f32   q12, q4, d1[0]      \n" // sum2 = (a00-a07) * k20
918                     "vmla.f32   q13, q5, d1[0]      \n"
919                     "vmla.f32   q14, q4, d1[1]      \n" // sum3 = (a00-a07) * k30
920                     "vmla.f32   q15, q5, d1[1]      \n"
921 
922                     "vmla.f32   q8, q6, d2[0]       \n" // sum0 += (a10-a17) * k01
923                     "vmla.f32   q9, q7, d2[0]       \n"
924                     "vmla.f32   q10, q6, d2[1]      \n" // sum1 += (a10-a17) * k11
925                     "vmla.f32   q11, q7, d2[1]      \n"
926                     "vmla.f32   q12, q6, d3[0]      \n" // sum2 += (a10-a17) * k21
927                     "vmla.f32   q13, q7, d3[0]      \n"
928                     "vmla.f32   q14, q6, d3[1]      \n" // sum3 += (a10-a17) * k31
929                     "vmla.f32   q15, q7, d3[1]      \n"
930 
931                     "pld        [%4, #512]          \n"
932                     "vldm       %4!, {d8-d15}       \n" // data
933 
934                     "vmla.f32   q8, q4, d4[0]       \n" // sum0 += (a20-a27) * k02
935                     "vmla.f32   q9, q5, d4[0]       \n"
936                     "vmla.f32   q10, q4, d4[1]      \n" // sum1 += (a20-a27) * k12
937                     "vmla.f32   q11, q5, d4[1]      \n"
938                     "vmla.f32   q12, q4, d5[0]      \n" // sum2 += (a20-a27) * k22
939                     "vmla.f32   q13, q5, d5[0]      \n"
940                     "vmla.f32   q14, q4, d5[1]      \n" // sum3 += (a20-a27) * k32
941                     "vmla.f32   q15, q5, d5[1]      \n"
942 
943                     "vmla.f32   q8, q6, d6[0]       \n" // sum0 += (a30-a37) * k03
944                     "vmla.f32   q9, q7, d6[0]       \n"
945                     "vmla.f32   q10, q6, d6[1]      \n" // sum1 += (a30-a37) * k13
946                     "vmla.f32   q11, q7, d6[1]      \n"
947                     "vmla.f32   q12, q6, d7[0]      \n" // sum2 += (a30-a37) * k23
948                     "vmla.f32   q13, q7, d7[0]      \n"
949                     "vmla.f32   q14, q6, d7[1]      \n" // sum3 += (a30-a37) * k33
950                     "vmla.f32   q15, q7, d7[1]      \n"
951 
952                     "subs        r4, r4, #1         \n"
953                     "bne         0b                 \n" // end for
954 
955                     "1:                             \n"
956                     // remain loop
957                     "and         r4, %12, #3        \n" // r4 = remain = inch & 3
958                     "cmp         r4, #0             \n"
959                     "beq         3f                 \n"
960 
961                     "2:                             \n" // for(; remain != 0; remain--)
962 
963                     "pld        [%5, #128]          \n"
964                     "vld1.f32   {d0-d1}, [%5]!      \n"
965                     "pld        [%4, #256]          \n"
966                     "vld1.f32   {d8-d11}, [%4]!     \n"
967 
968                     "vmla.f32   q8, q4, d0[0]       \n" // sum0 += (a00-a70) * k00
969                     "vmla.f32   q9, q5, d0[0]       \n"
970                     "vmla.f32   q10, q4, d0[1]      \n" // sum1 += (a00-a70) * k10
971                     "vmla.f32   q11, q5, d0[1]      \n"
972                     "vmla.f32   q12, q4, d1[0]      \n" // sum2 += (a00-a70) * k20
973                     "vmla.f32   q13, q5, d1[0]      \n"
974                     "vmla.f32   q14, q4, d1[1]      \n" // sum3 += (a00-a70) * k30
975                     "vmla.f32   q15, q5, d1[1]      \n"
976 
977                     "subs        r4, r4, #1         \n"
978                     "bne         2b                 \n"
979 
980                     "3:                             \n" // store the result to memory
981                     "vst1.f32    {d16-d19}, [%0]    \n"
982                     "vst1.f32    {d20-d23}, [%1]    \n"
983                     "vst1.f32    {d24-d27}, [%2]    \n"
984                     "vst1.f32    {d28-d31}, [%3]    \n"
985 
986                     : "=r"(output0), // %0
987                     "=r"(output1), // %1
988                     "=r"(output2), // %2
989                     "=r"(output3), // %3
990                     "=r"(vb),      // %4
991                     "=r"(va)       // %5
992                     : "0"(output0),
993                     "1"(output1),
994                     "2"(output2),
995                     "3"(output3),
996                     "4"(vb),
997                     "5"(va),
998                     "r"(L),      // %12
999                     "r"(biasptr) // %13
1000                     : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1001 #endif // __aarch64__
1002 #else
1003                 float sum0[8] = {0};
1004                 float sum1[8] = {0};
1005                 float sum2[8] = {0};
1006                 float sum3[8] = {0};
1007 
1008                 int k = 0;
1009                 for (; k + 7 < L; k = k + 8)
1010                 {
1011                     for (int n = 0; n < 8; n++)
1012                     {
1013                         sum0[n] += va[0] * vb[n];
1014                         sum1[n] += va[1] * vb[n];
1015                         sum2[n] += va[2] * vb[n];
1016                         sum3[n] += va[3] * vb[n];
1017                         va += 4;
1018 
1019                         sum0[n] += va[0] * vb[n + 8];
1020                         sum1[n] += va[1] * vb[n + 8];
1021                         sum2[n] += va[2] * vb[n + 8];
1022                         sum3[n] += va[3] * vb[n + 8];
1023                         va += 4;
1024 
1025                         sum0[n] += va[0] * vb[n + 16];
1026                         sum1[n] += va[1] * vb[n + 16];
1027                         sum2[n] += va[2] * vb[n + 16];
1028                         sum3[n] += va[3] * vb[n + 16];
1029                         va += 4;
1030 
1031                         sum0[n] += va[0] * vb[n + 24];
1032                         sum1[n] += va[1] * vb[n + 24];
1033                         sum2[n] += va[2] * vb[n + 24];
1034                         sum3[n] += va[3] * vb[n + 24];
1035                         va += 4;
1036 
1037                         sum0[n] += va[0] * vb[n + 32];
1038                         sum1[n] += va[1] * vb[n + 32];
1039                         sum2[n] += va[2] * vb[n + 32];
1040                         sum3[n] += va[3] * vb[n + 32];
1041                         va += 4;
1042 
1043                         sum0[n] += va[0] * vb[n + 40];
1044                         sum1[n] += va[1] * vb[n + 40];
1045                         sum2[n] += va[2] * vb[n + 40];
1046                         sum3[n] += va[3] * vb[n + 40];
1047                         va += 4;
1048 
1049                         sum0[n] += va[0] * vb[n + 48];
1050                         sum1[n] += va[1] * vb[n + 48];
1051                         sum2[n] += va[2] * vb[n + 48];
1052                         sum3[n] += va[3] * vb[n + 48];
1053                         va += 4;
1054 
1055                         sum0[n] += va[0] * vb[n + 56];
1056                         sum1[n] += va[1] * vb[n + 56];
1057                         sum2[n] += va[2] * vb[n + 56];
1058                         sum3[n] += va[3] * vb[n + 56];
1059                         va -= 28;
1060                     }
1061 
1062                     va += 32;
1063                     vb += 64;
1064                 }
1065 
1066                 for (; k < L; k++)
1067                 {
1068                     for (int n = 0; n < 8; n++)
1069                     {
1070                         sum0[n] += va[0] * vb[n];
1071                         sum1[n] += va[1] * vb[n];
1072                         sum2[n] += va[2] * vb[n];
1073                         sum3[n] += va[3] * vb[n];
1074                     }
1075 
1076                     va += 4;
1077                     vb += 8;
1078                 }
1079 
1080                 for (int n = 0; n < 8; n++)
1081                 {
1082                     output0[n] = sum0[n] + biasptr[0];
1083                     output1[n] = sum1[n] + biasptr[1];
1084                     output2[n] = sum2[n] + biasptr[2];
1085                     output3[n] = sum3[n] + biasptr[3];
1086                 }
1087 #endif // __ARM_NEON
1088                 output0 += 8;
1089                 output1 += 8;
1090                 output2 += 8;
1091                 output3 += 8;
1092             }
1093 
1094             for (; j < N; j++)
1095             {
1096                 float* vb = bottom_tm.channel(j / 8 + j % 8);
1097 #if __ARM_NEON && __aarch64__
1098                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
1099 #else
1100                 const float* va = kernel_tm.channel(i / 4);
1101 #endif // __ARM_NEON && __aarch64__
1102 
1103 #if __ARM_NEON
1104 #if __aarch64__
1105                 asm volatile(
1106                     "ld1    {v14.4s}, [%13]              \n" // sum0_3 inital with bias
1107 
1108                     "lsr         w4, %w12, #2            \n" // r4 = nn = L >> 2
1109                     "cmp         w4, #0                  \n"
1110                     "beq         1f                      \n"
1111 
1112                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
1113                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum1
1114                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum2
1115                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum3
1116 
1117                     "0:                                  \n" // for (; k+3<L; k=k+4)
1118 
1119                     "prfm   pldl1keep, [%5, #256]                       \n"
1120                     "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64     \n" // k
1121 
1122                     "prfm   pldl1keep, [%4, #128]        \n"
1123                     "ld1    {v8.4s}, [%4], #16           \n" // d
1124 
1125                     "subs   w4, w4, #1                   \n"
1126                     "fmla    v16.4s, v0.4s, v8.s[0]      \n" // sum0 += (k00-k30) * a00
1127                     "fmla    v17.4s, v1.4s, v8.s[1]      \n" // sum1 += (k01-k31) * a10
1128                     "fmla    v18.4s, v2.4s, v8.s[2]      \n" // sum2 += (k02-k32) * a20
1129                     "fmla    v19.4s, v3.4s, v8.s[3]      \n" // sum3 += (k03-k33) * a30
1130 
1131                     "bne    0b                           \n"
1132 
1133                     "fadd      v16.4s, v16.4s, v18.4s    \n"
1134                     "fadd      v17.4s, v17.4s, v19.4s    \n"
1135                     "fadd      v14.4s, v14.4s, v16.4s    \n"
1136                     "fadd      v14.4s, v14.4s, v17.4s    \n"
1137 
1138                     "1:                                  \n"
1139 
1140                     // remain loop
1141                     "and    w4, %w12, #3                 \n" // w4 = remain = inch & 3;
1142                     "cmp    w4, #0                       \n"
1143                     "beq    3f                           \n"
1144 
1145                     "2:                                  \n"
1146 
1147                     "prfm   pldl1keep, [%5, #128]        \n"
1148                     "ld1    {v0.4s}, [%5], #16           \n"
1149                     "prfm   pldl1keep, [%4, #32]         \n"
1150                     "ld1r   {v8.4s}, [%4], #4            \n"
1151 
1152                     "subs   w4, w4, #1                   \n"
1153                     // k0
1154                     "fmla   v14.4s, v8.4s, v0.4s         \n" // sum0 += (k00-k30) * a00
1155                     "bne    2b                           \n"
1156 
1157                     "3:                                  \n"
1158 
1159                     "st1    {v14.s}[0], [%0]             \n"
1160                     "st1    {v14.s}[1], [%1]             \n"
1161                     "st1    {v14.s}[2], [%2]             \n"
1162                     "st1    {v14.s}[3], [%3]             \n"
1163 
1164                     : "=r"(output0), // %0
1165                     "=r"(output1), // %1
1166                     "=r"(output2), // %2
1167                     "=r"(output3), // %3
1168                     "=r"(vb),      // %4
1169                     "=r"(va)       // %5
1170                     : "0"(output0),
1171                     "1"(output1),
1172                     "2"(output2),
1173                     "3"(output3),
1174                     "4"(vb),
1175                     "5"(va),
1176                     "r"(L),      // %12
1177                     "r"(biasptr) // %13
1178                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1179 #else
1180                 asm volatile(
1181                     // inch loop
1182                     "vld1.f32   {d24-d25}, [%13]    \n"
1183 
1184                     "lsr         r4, %12, #2        \n" // r4 = nn = L >> 2
1185                     "cmp         r4, #0             \n"
1186                     "beq         1f                 \n"
1187 
1188                     "veor       q8, q8, q8          \n"
1189                     "veor       q9, q9, q9          \n"
1190                     "veor       q10, q10, q10       \n"
1191                     "veor       q11, q11, q11       \n"
1192 
1193                     "0:                             \n" // for(; nn != 0; nn--)
1194                     "pld        [%5, #512]          \n"
1195                     "vldm       %5!, {d0-d7}        \n" // kernel
1196                     "pld        [%4, #128]          \n"
1197                     "vld1.f32   {d8-d9}, [%4]!      \n" // data
1198 
1199                     "vmla.f32   q8, q0, d8[0]       \n" // (k00-k30) * a00
1200                     "vmla.f32   q9, q1, d8[1]       \n" // (k01-k31) * a01
1201                     "vmla.f32   q10, q2, d9[0]      \n" // (k02-k32) * a02
1202                     "vmla.f32   q11, q3, d9[1]      \n" // (k03-k33) * a03
1203 
1204                     "subs        r4, r4, #1         \n"
1205                     "bne         0b                 \n" // end for
1206 
1207                     "vadd.f32   q8, q8, q9          \n"
1208                     "vadd.f32   q10, q10, q11       \n"
1209                     "vadd.f32   q8, q8, q10         \n"
1210                     "vadd.f32   q12, q12, q8        \n"
1211 
1212                     "1:                             \n"
1213                     // remain loop
1214                     "and         r4, %12, #3        \n" // r4 = remain = inch & 3
1215                     "cmp         r4, #0             \n"
1216                     "beq         3f                 \n"
1217 
1218                     "2:                             \n" // for(; remain != 0; remain--)
1219                     "pld        [%5, #128]          \n"
1220                     "vld1.f32   {d0-d1}, [%5]!      \n"
1221                     "pld        [%4, #32]           \n"
1222                     "vld1.f32   {d8[],d9[]}, [%4]!  \n"
1223 
1224                     "subs       r4, r4, #1          \n"
1225 
1226                     "vmla.f32   q12, q0, q4         \n"
1227                     "bne         2b                 \n"
1228 
1229                     "3:                             \n" // store the result to memory
1230                     "vst1.f32    {d24[0]}, [%0]     \n"
1231                     "vst1.f32    {d24[1]}, [%1]     \n"
1232                     "vst1.f32    {d25[0]}, [%2]     \n"
1233                     "vst1.f32    {d25[1]}, [%3]     \n"
1234 
1235                     : "=r"(output0), // %0
1236                     "=r"(output1), // %1
1237                     "=r"(output2), // %2
1238                     "=r"(output3), // %3
1239                     "=r"(vb),      // %4
1240                     "=r"(va)       // %5
1241                     : "0"(output0),
1242                     "1"(output1),
1243                     "2"(output2),
1244                     "3"(output3),
1245                     "4"(vb),
1246                     "5"(va),
1247                     "r"(L),      // %12
1248                     "r"(biasptr) // %13
1249                     : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
1250 #endif // __aarch64__
1251 #else
1252                 float sum0 = biasptr[0];
1253                 float sum1 = biasptr[1];
1254                 float sum2 = biasptr[2];
1255                 float sum3 = biasptr[3];
1256 
1257                 for (int k = 0; k < L; k++)
1258                 {
1259                     sum0 += va[0] * vb[0];
1260                     sum1 += va[1] * vb[0];
1261                     sum2 += va[2] * vb[0];
1262                     sum3 += va[3] * vb[0];
1263 
1264                     va += 4;
1265                     vb += 1;
1266                 }
1267 
1268                 output0[0] = sum0;
1269                 output1[0] = sum1;
1270                 output2[0] = sum2;
1271                 output3[0] = sum3;
1272 #endif // __ARM_NEON
1273                 output0++;
1274                 output1++;
1275                 output2++;
1276                 output3++;
1277             }
1278         }
1279 
1280         remain_outch_start += nn_outch << 2;
1281 
1282         #pragma omp parallel for num_threads(opt.num_threads)
1283         for (int i = remain_outch_start; i < outch; i++)
1284         {
1285             float* output = top_blob.channel(i);
1286 
1287             const float bias0 = bias ? bias[i] : 0.f;
1288 
1289             int j = 0;
1290             for (; j + 7 < N; j = j + 8)
1291             {
1292                 const float* vb = bottom_tm.channel(j / 8);
1293 #if __ARM_NEON && __aarch64__
1294                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
1295 #else
1296                 const float* va = kernel_tm.channel(i / 4 + i % 4);
1297 #endif // __ARM_NEON && __aarch64__
1298 
1299 #if __ARM_NEON
1300 #if __aarch64__
1301                 asm volatile(
1302                     "dup    v16.4s, %w7                  \n" // sum0
1303                     "dup    v17.4s, %w7                  \n" // sum0n
1304 
1305                     "lsr         w4, %w6, #2             \n" // r4 = nn = L >> 2
1306                     "cmp         w4, #0                  \n"
1307                     "beq         1f                      \n"
1308 
1309                     "0:                                  \n" // for (; k+3<L; k=k+4)
1310 
1311                     "prfm   pldl1keep, [%2, #128]        \n"
1312                     "ld1    {v0.4s}, [%2], #16           \n"
1313 
1314                     "prfm   pldl1keep, [%1, #512]        \n"
1315                     "ld1    {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64   \n" // data
1316                     "ld1    {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n"
1317 
1318                     // k0
1319                     "fmla    v16.4s, v8.4s, v0.s[0]      \n" // sum0 += (a00-a70) * k00
1320                     "fmla    v17.4s, v9.4s, v0.s[0]      \n" //
1321                     // k1
1322                     "fmla    v16.4s, v10.4s, v0.s[1]     \n" // sum0 += (a01-a71) * k01
1323                     "fmla    v17.4s, v11.4s, v0.s[1]     \n" //
1324                     // k2
1325                     "fmla    v16.4s, v12.4s, v0.s[2]     \n" // sum0 += (a02-a72) * k02
1326                     "fmla    v17.4s, v13.4s, v0.s[2]     \n" //
1327                     // k3
1328                     "fmla    v16.4s, v14.4s, v0.s[3]     \n" // sum0 += (a03-a73) * k03
1329                     "fmla    v17.4s, v15.4s, v0.s[3]     \n" //
1330 
1331                     "subs   w4, w4, #1                   \n"
1332                     "bne    0b                           \n"
1333 
1334                     "1:                                  \n"
1335 
1336                     // remain loop
1337                     "and    w4, %w6, #3                  \n" // w4 = remain = inch & 3;
1338                     "cmp    w4, #0                       \n"
1339                     "beq    3f                           \n"
1340 
1341                     "2:                                  \n"
1342                     "prfm   pldl1keep, [%2, #32]         \n"
1343                     "ld1r   {v0.4s}, [%2], #4            \n"
1344                     "prfm   pldl1keep, [%1, #256]        \n"
1345                     "ld1    {v8.4s, v9.4s}, [%1], #32    \n"
1346 
1347                     "subs   w4, w4, #1                   \n"
1348                     // k0
1349                     "fmla    v16.4s, v0.4s, v8.4s        \n" // sum0 += (a00-a70) * k00
1350                     "fmla    v17.4s, v0.4s, v9.4s        \n" //
1351 
1352                     "bne    2b                           \n"
1353 
1354                     "3:                                  \n"
1355                     "st1    {v16.4s, v17.4s}, [%0]       \n"
1356 
1357                     : "=r"(output), // %0
1358                     "=r"(vb),     // %1
1359                     "=r"(va)      // %2
1360                     : "0"(output),
1361                     "1"(vb),
1362                     "2"(va),
1363                     "r"(L),    // %6
1364                     "r"(bias0) // %7
1365                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17");
1366 #else
1367                 asm volatile(
1368                     "vdup.f32   q8, %7              \n"
1369                     "vdup.f32   q9, %7              \n"
1370                     // inch loop
1371                     "lsr        r4, %6, #2          \n" // r4 = nn = inch >> 2
1372                     "cmp        r4, #0              \n"
1373                     "beq        1f                  \n"
1374 
1375                     "0:                             \n"
1376 
1377                     "pld        [%1, #512]          \n"
1378                     "vldm       %1!, {d8-d15}       \n"
1379                     "pld        [%2, #128]          \n"
1380                     "vld1.f32   {d0-d1}, [%2]!      \n"
1381 
1382                     "vmla.f32   q8, q4, d0[0]       \n"
1383                     "vmla.f32   q9, q5, d0[0]       \n"
1384 
1385                     "pld        [%1, #512]          \n"
1386                     "vldm       %1!, {d24-d31}      \n"
1387 
1388                     "vmla.f32   q8, q6, d0[1]       \n"
1389                     "vmla.f32   q9, q7, d0[1]       \n"
1390 
1391                     "subs       r4, r4, #1          \n"
1392 
1393                     "vmla.f32   q8, q12, d1[0]      \n"
1394                     "vmla.f32   q9, q13, d1[0]      \n"
1395                     "vmla.f32   q8, q14, d1[1]      \n"
1396                     "vmla.f32   q9, q15, d1[1]      \n"
1397 
1398                     "bne        0b                  \n"
1399 
1400                     "1:                             \n"
1401                     // remain loop
1402                     "and        r4, %6, #3          \n" // r4 = remain = inch & 3;
1403                     "cmp        r4, #0              \n"
1404                     "beq        3f                  \n"
1405 
1406                     "2:                             \n"
1407                     "pld        [%1, #256]          \n"
1408                     "vld1.f32   {d8-d11}, [%1]!     \n"
1409                     "pld        [%2, #32]           \n"
1410                     "vld1.f32   {d0[],d1[]}, [%2]!  \n"
1411 
1412                     "subs       r4, r4, #1          \n"
1413 
1414                     "vmla.f32   q8, q4, q0          \n"
1415                     "vmla.f32   q9, q5, q0          \n"
1416                     "bne        2b                  \n"
1417 
1418                     "3:                             \n"
1419                     "vst1.f32   {d16-d19}, [%0]     \n"
1420 
1421                     : "=r"(output), // %0
1422                     "=r"(vb),     // %1
1423                     "=r"(va)      // %2
1424                     : "0"(output),
1425                     "1"(vb),
1426                     "2"(va),
1427                     "r"(L),    // %6
1428                     "r"(bias0) // %7
1429                     : "cc", "memory", "r4", "q0", "q4", "q5", "q6", "q7", "q8", "q9", "q12", "q13", "q14", "q15");
1430 #endif // __aarch64__
1431 #else
1432                 float sum[8] = {0};
1433 
1434                 int k = 0;
1435                 for (; k + 7 < L; k = k + 8)
1436                 {
1437                     for (int n = 0; n < 8; n++)
1438                     {
1439                         sum[n] += va[0] * vb[n];
1440                         sum[n] += va[1] * vb[n + 8];
1441                         sum[n] += va[2] * vb[n + 16];
1442                         sum[n] += va[3] * vb[n + 24];
1443                         sum[n] += va[4] * vb[n + 32];
1444                         sum[n] += va[5] * vb[n + 40];
1445                         sum[n] += va[6] * vb[n + 48];
1446                         sum[n] += va[7] * vb[n + 56];
1447                     }
1448 
1449                     va += 8;
1450                     vb += 64;
1451                 }
1452 
1453                 for (; k < L; k++)
1454                 {
1455                     for (int n = 0; n < 8; n++)
1456                     {
1457                         sum[n] += va[0] * vb[n];
1458                     }
1459 
1460                     va += 1;
1461                     vb += 8;
1462                 }
1463 
1464                 for (int n = 0; n < 8; n++)
1465                 {
1466                     output[n] = sum[n] + bias0;
1467                 }
1468 #endif // __ARM_NEON
1469                 output += 8;
1470             }
1471 
1472             for (; j < N; j++)
1473             {
1474                 const float* vb = bottom_tm.channel(j / 8 + j % 8);
1475 #if __ARM_NEON && __aarch64__
1476                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
1477 #else
1478                 const float* va = kernel_tm.channel(i / 4 + i % 4);
1479 #endif // __ARM_NEON && __aarch64__
1480 
1481                 int k = 0;
1482 #if __ARM_NEON
1483                 float32x4_t _sum0 = vdupq_n_f32(0.f);
1484 
1485                 for (; k + 3 < L; k += 4)
1486                 {
1487                     float32x4_t _p0 = vld1q_f32(vb);
1488                     vb += 4;
1489 
1490                     float32x4_t _k0 = vld1q_f32(va);
1491                     va += 4;
1492 
1493 #if __aarch64__
1494                     _sum0 = vfmaq_f32(_sum0, _p0, _k0);
1495 #else
1496                     _sum0 = vmlaq_f32(_sum0, _p0, _k0);
1497 #endif
1498                 }
1499 
1500 #if __aarch64__
1501                 float sum0 = bias0 + vaddvq_f32(_sum0);
1502 #else
1503                 float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
1504                 float sum0 = bias0 + vget_lane_f32(vpadd_f32(_ss, _ss), 0);
1505 #endif
1506 #else
1507                 float sum0 = bias0;
1508 #endif // __ARM_NEON
1509                 for (; k < L; k++)
1510                 {
1511                     sum0 += va[0] * vb[0];
1512 
1513                     va += 1;
1514                     vb += 1;
1515                 }
1516                 output[0] = sum0;
1517 
1518                 output++;
1519             }
1520         }
1521     }
1522 }
1523