1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #if __aarch64__
16 
17 #if 1
18 #include "gemm_symm_int8.h"
conv_im2col_sgemm_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)19 static void conv_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
20 {
21     const int m = outch;
22     const int k = inch * kernel_size;
23     kernel_tm.create(m * k, (size_t)1u);
24     const int8_t* a = _kernel;
25     int8_t* sa = kernel_tm;
26     reorder_a((int8_t*)a, sa, m, k, k);
27 }
28 
conv_im2col_sgemm_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)29 static void conv_im2col_sgemm_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm,
30                                         const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
31 {
32     int w = bottom_blob.w;
33     int inch = bottom_blob.c;
34 
35     int outw = top_blob.w;
36     int outh = top_blob.h;
37     int outch = top_blob.c;
38 
39     // im2col
40     Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, 1UL, opt.workspace_allocator);
41     {
42         const int stride = kernel_h * kernel_w * outw * outh;
43         signed char* ret = (signed char*)bottom_im2col;
44 
45         #pragma omp parallel for num_threads(opt.num_threads)
46         for (int p = 0; p < inch; p++)
47         {
48             const signed char* input = bottom_blob.channel(p);
49             int retID = stride * p;
50             for (int u = 0; u < kernel_h; u++)
51             {
52                 for (int v = 0; v < kernel_w; v++)
53                 {
54                     for (int i = 0; i < outh; i++)
55                     {
56                         for (int j = 0; j < outw; j++)
57                         {
58                             int row = u + i * stride_h;
59                             int col = v + j * stride_w;
60                             int index = row * w + col;
61                             ret[retID] = input[index];
62                             retID++;
63                         }
64                     }
65                 }
66             }
67         }
68     }
69 
70     const int m = outch;
71     const int n = outw * outh;
72     const int k = inch * kernel_w * kernel_h;
73 
74     ncnn::Mat bottom_tm(k * n, (size_t)1u, opt.workspace_allocator);
75     {
76         const int8_t* pData = bottom_im2col;
77         int8_t* pReorder = bottom_tm;
78         reorder_b(pData, pReorder, k, n, n);
79     }
80     // GEMM
81     int32_t* pc = top_blob;
82     const int8_t* pa = kernel_tm;
83     int8_t* pb = bottom_tm;
84     const size_t ldc = top_blob.cstep;
85 
86     int8kernel((void*)pc, pa, pb, m, k, n, ldc, 0, 0, opt);
87 }
88 #else
conv_im2col_sgemm_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)89 static void conv_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
90 {
91     const signed char* kernel = _kernel;
92 
93     // kernel memory packed 4 x 4
94     kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u);
95 
96     int nn_outch = 0;
97     int remain_outch_start = 0;
98 
99     nn_outch = outch >> 2;
100     remain_outch_start = nn_outch << 2;
101 
102     for (int pp = 0; pp < nn_outch; pp++)
103     {
104         int p = pp * 4;
105 
106         const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
107         const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
108         const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
109         const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
110 
111         signed char* ktmp = kernel_tm.channel(p / 4);
112 
113         int q = 0;
114         for (; q + 1 < inch * kernel_size; q += 2)
115         {
116             ktmp[0] = k0[0];
117             ktmp[1] = k0[1];
118             ktmp[2] = k1[0];
119             ktmp[3] = k1[1];
120             ktmp[4] = k2[0];
121             ktmp[5] = k2[1];
122             ktmp[6] = k3[0];
123             ktmp[7] = k3[1];
124 
125             ktmp += 8;
126 
127             k0 += 2;
128             k1 += 2;
129             k2 += 2;
130             k3 += 2;
131         }
132 
133         for (; q < inch * kernel_size; q++)
134         {
135             ktmp[0] = k0[0];
136             ktmp[1] = k1[0];
137             ktmp[2] = k2[0];
138             ktmp[3] = k3[0];
139             ktmp += 4;
140 
141             k0 += 1;
142             k1 += 1;
143             k2 += 1;
144             k3 += 1;
145         }
146     }
147 
148     for (int p = remain_outch_start; p < outch; p++)
149     {
150         const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
151 
152         signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
153 
154         int q = 0;
155         for (; q + 1 < inch * kernel_size; q = q + 2)
156         {
157             ktmp[0] = k0[0];
158             ktmp[1] = k0[1];
159             ktmp += 2;
160             k0 += 2;
161         }
162 
163         for (; q < inch * kernel_size; q++)
164         {
165             ktmp[0] = k0[0];
166             ktmp++;
167             k0++;
168         }
169     }
170 }
171 
conv_im2col_sgemm_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)172 static void conv_im2col_sgemm_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm,
173                                         const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
174 {
175     int w = bottom_blob.w;
176     int inch = bottom_blob.c;
177 
178     int outw = top_blob.w;
179     int outh = top_blob.h;
180     int outch = top_blob.c;
181 
182     // im2row
183     Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator);
184     {
185         int out_stride = kernel_h * kernel_w * inch * outw;
186         signed char* ret = (signed char*)bottom_im2row;
187 
188         // #pragma omp parallel for num_threads(opt.num_threads)
189         for (int i = 0; i < outh; i++)
190         {
191             int retID = out_stride * i;
192             for (int j = 0; j < outw; j++)
193             {
194                 for (int p = 0; p < inch; p++)
195                 {
196                     const signed char* input = bottom_blob.channel(p);
197 
198                     for (int u = 0; u < kernel_h; u++)
199                     {
200                         for (int v = 0; v < kernel_w; v++)
201                         {
202                             int row = u + i * stride_h;
203                             int col = v + j * stride_w;
204                             int index = row * w + col;
205                             ret[retID] = input[index];
206                             retID++;
207                         }
208                     }
209                 }
210             }
211         }
212     }
213 
214     int kernel_size = kernel_w * kernel_h;
215     int out_size = outw * outh;
216 
217     // int M = outch;  // outch
218     int N = outw * outh;                // outsize or out stride
219     int K = kernel_w * kernel_h * inch; // ksize * inch
220 
221     // bottom_im2row memory packed 4 x 4
222     Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator);
223     {
224         int nn_size = out_size >> 2;
225         int remain_size_start = nn_size << 2;
226 
227         #pragma omp parallel for num_threads(opt.num_threads)
228         for (int ii = 0; ii < nn_size; ii++)
229         {
230             int i = ii * 4;
231 
232             const signed char* img0 = bottom_im2row.row<signed char>(i);
233             const signed char* img1 = bottom_im2row.row<signed char>(i + 1);
234             const signed char* img2 = bottom_im2row.row<signed char>(i + 2);
235             const signed char* img3 = bottom_im2row.row<signed char>(i + 3);
236 
237             signed char* tmpptr = bottom_tm.channel(i / 4);
238 
239             int q = 0;
240             for (; q + 1 < inch * kernel_size; q = q + 2)
241             {
242                 tmpptr[0] = img0[0];
243                 tmpptr[1] = img0[1];
244                 tmpptr[2] = img1[0];
245                 tmpptr[3] = img1[1];
246                 tmpptr[4] = img2[0];
247                 tmpptr[5] = img2[1];
248                 tmpptr[6] = img3[0];
249                 tmpptr[7] = img3[1];
250 
251                 tmpptr += 8;
252                 img0 += 2;
253                 img1 += 2;
254                 img2 += 2;
255                 img3 += 2;
256             }
257 
258             for (; q < inch * kernel_size; q++)
259             {
260                 tmpptr[0] = img0[0];
261                 tmpptr[1] = img1[0];
262                 tmpptr[2] = img2[0];
263                 tmpptr[3] = img3[0];
264 
265                 tmpptr += 4;
266                 img0 += 1;
267                 img1 += 1;
268                 img2 += 1;
269                 img3 += 1;
270             }
271         }
272 
273         #pragma omp parallel for num_threads(opt.num_threads)
274         for (int i = remain_size_start; i < out_size; i++)
275         {
276             const signed char* img0 = bottom_im2row.row<signed char>(i);
277 
278             signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4);
279 
280             int q = 0;
281             for (; q + 1 < inch * kernel_size; q = q + 2)
282             {
283                 tmpptr[0] = img0[0];
284                 tmpptr[1] = img0[1];
285 
286                 tmpptr += 2;
287                 img0 += 2;
288             }
289 
290             for (; q < inch * kernel_size; q++)
291             {
292                 tmpptr[0] = img0[0];
293 
294                 tmpptr += 1;
295                 img0 += 1;
296             }
297         }
298     }
299 
300     // 4x4
301     // sgemm(int M, int N, int K, float* A, float* B, float* C)
302     {
303         // int M = outch;  // outch
304         // int N = outw * outh; // outsize or out stride
305         // int L = kernel_w * kernel_h * inch; // ksize * inch
306 
307         int nn_outch = 0;
308         int remain_outch_start = 0;
309 
310         nn_outch = outch >> 2;
311         remain_outch_start = nn_outch << 2;
312 
313         #pragma omp parallel for num_threads(opt.num_threads)
314         for (int pp = 0; pp < nn_outch; pp++)
315         {
316             int i = pp * 4;
317 
318             int* output0 = top_blob.channel(i);
319             int* output1 = top_blob.channel(i + 1);
320             int* output2 = top_blob.channel(i + 2);
321             int* output3 = top_blob.channel(i + 3);
322 
323             int j = 0;
324             for (; j + 3 < N; j = j + 4)
325             {
326                 const signed char* vb = bottom_tm.channel(j / 4);
327                 const signed char* va = kernel_tm.channel(i / 4);
328 
329 #if __ARM_NEON
330                 asm volatile(
331                     "prfm   pldl1keep, [%4, #128]        \n"
332                     "prfm   pldl1keep, [%5, #128]        \n"
333                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
334                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum1
335                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum2
336                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum3
337 
338                     "lsr    w4, %w12, #2                 \n" // r4 = nn = L >> 2
339                     "cmp    w4, #0                       \n"
340                     "beq    1f                           \n"
341 
342                     "0:                                  \n" // for (; k+3<L; k=k+4)
343                     "ld1    {v0.16b}, [%4]               \n" // i0, i1, i2, i3
344                     "ld1    {v4.16b}, [%5]               \n" // k0, k1, k2, k3
345                     "add    %4, %4, #16                  \n"
346                     "add    %5, %5, #16                  \n"
347 
348                     "rev32    v1.8h, v0.8h               \n" // i1, i0, i3, i2
349                     "rev64    v2.4s, v0.4s               \n" // i2, i3, i0, i1
350                     "rev64    v3.8h, v0.8h               \n" // i3, i2, i1, i0
351 
352                     "smull	  v8.8h, v4.8b, v0.8b        \n"
353                     "smull	  v9.8h, v4.8b, v1.8b        \n"
354                     "smull	  v10.8h, v4.8b, v2.8b       \n"
355                     "smull	  v11.8h, v4.8b, v3.8b       \n"
356 
357                     "prfm     pldl1keep, [%4, #128]      \n"
358                     "prfm     pldl1keep, [%5, #128]      \n"
359 
360                     "smlal2	  v8.8h, v4.16b, v0.16b      \n"
361                     "smlal2	  v9.8h, v4.16b, v1.16b      \n"
362                     "smlal2	  v10.8h, v4.16b, v2.16b     \n"
363                     "smlal2	  v11.8h, v4.16b, v3.16b     \n"
364 
365                     "sadalp	  v16.4s, v8.8h              \n" // i0k0, i1k1, i2k2, i3k3
366                     "sadalp	  v17.4s, v9.8h              \n" // i1k0, i0k1, i3k2, i2k3
367                     "sadalp	  v18.4s, v10.8h             \n" // i2k0, i3k1, i0k2, i1k3
368                     "sadalp	  v19.4s, v11.8h             \n" // i3k0, i2k1, i1k2, i0k3
369 
370                     "subs     w4, w4, #1                 \n"
371                     "bne      0b                         \n"
372 
373                     "1:                                  \n" // for (; k+1<L; k=k+2)
374 
375                     // remain loop
376                     "and      w4, %w12, #3               \n" // w4 = remain = K & 3;
377                     "cmp      w4, #0                     \n"
378                     "beq      3f                         \n"
379 
380                     "lsr      w4, w4, #1                 \n" // r4 = nn = L >> 1
381                     "cmp      w4, #0                     \n"
382                     "beq      3f                         \n"
383 
384                     "2:                                  \n" // for (; k+1<L; k=k+2)
385 
386                     "ld1      {v0.8b}, [%4]              \n" // i0, i1, i2, i3
387                     "ld1      {v4.8b}, [%5]              \n" // k0, k1, k2, k3
388                     "add      %4, %4, #8                 \n"
389                     "add      %5, %5, #8                 \n"
390 
391                     "rev32	  v1.4h, v0.4h               \n" // i2, i3, i0, i1
392                     "rev64    v2.2s, v0.2s               \n" // i1, i0, i3, i2
393                     "rev64    v3.4h, v0.4h               \n" // i0, i1, i2, i3
394 
395                     "smull	  v8.8h, v4.8b, v0.8b        \n"
396                     "smull	  v9.8h, v4.8b, v1.8b        \n"
397                     "smull    v10.8h, v4.8b, v2.8b       \n"
398                     "smull	  v11.8h, v4.8b, v3.8b       \n"
399                     "sadalp	  v16.4s, v8.8h              \n"
400                     "sadalp	  v17.4s, v9.8h              \n"
401                     "sadalp	  v18.4s,v10.8h              \n"
402                     "sadalp	  v19.4s,v11.8h              \n"
403 
404                     "subs     w4, w4, #1                 \n"
405                     "bne      2b                         \n"
406 
407                     "3:                                  \n" // realloc
408 
409                     "mov      v20.s[0], v16.s[0]         \n"
410                     "mov      v20.s[1], v17.s[0]         \n"
411                     "mov      v20.s[2], v18.s[0]         \n"
412                     "mov      v20.s[3], v19.s[0]         \n"
413 
414                     "mov      v21.s[0], v17.s[1]         \n"
415                     "mov      v21.s[1], v16.s[1]         \n"
416                     "mov      v21.s[2], v19.s[1]         \n"
417                     "mov      v21.s[3], v18.s[1]         \n"
418 
419                     "mov      v22.s[0], v18.s[2]         \n"
420                     "mov      v22.s[1], v19.s[2]         \n"
421                     "mov      v22.s[2], v16.s[2]         \n"
422                     "mov      v22.s[3], v17.s[2]         \n"
423 
424                     "mov      v23.s[0], v19.s[3]         \n"
425                     "mov      v23.s[1], v18.s[3]         \n"
426                     "mov      v23.s[2], v17.s[3]         \n"
427                     "mov      v23.s[3], v16.s[3]         \n"
428 
429                     "and      w4, %w12, #1               \n" // w4 = remain = K & 1;
430                     "cmp      w4, #0                     \n"
431                     "beq      5f                         \n"
432 
433                     "4:                                  \n"
434                     "ld1      {v0.8b}, [%4]              \n"
435                     "ld1      {v1.8b}, [%5]              \n"
436                     "add      %4, %4, #4                 \n"
437                     "add      %5, %5, #4                 \n"
438 
439                     "sshll    v0.8h, v0.8b, #0           \n" // i0[0], i1[0], i2[0], i3[0]
440                     "sshll    v1.8h, v1.8b, #0           \n" // k0[0], k1[0], k2[0], k3[0]
441 
442                     "smlal    v20.4s, v0.4h, v1.h[0]     \n" // i0k0, i1k0, i2k0, i3k0
443                     "smlal    v21.4s, v0.4h, v1.h[1]     \n" // i0k1, i1k1, i2k1, i3k1
444                     "smlal    v22.4s, v0.4h, v1.h[2]     \n" // i0k2, i1k2, i2k2, i3k2
445                     "smlal    v23.4s, v0.4h, v1.h[3]     \n" // i0k3, i1k3, i2k3, i3k3
446 
447                     "subs     w4, w4, #1                 \n"
448 
449                     "bne      2b                         \n"
450 
451                     "5:                                  \n"
452 
453                     "st1      {v20.4s}, [%0]             \n"
454                     "st1      {v21.4s}, [%1]             \n"
455                     "st1      {v22.4s}, [%2]             \n"
456                     "st1      {v23.4s}, [%3]             \n"
457 
458                     : "=r"(output0), // %0
459                     "=r"(output1), // %1
460                     "=r"(output2), // %2
461                     "=r"(output3), // %3
462                     "=r"(vb),      // %4
463                     "=r"(va)       // %5
464                     : "0"(output0),
465                     "1"(output1),
466                     "2"(output2),
467                     "3"(output3),
468                     "4"(vb),
469                     "5"(va),
470                     "r"(K) // %12
471                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
472 #else
473                 int sum0[4] = {0};
474                 int sum1[4] = {0};
475                 int sum2[4] = {0};
476                 int sum3[4] = {0};
477 
478                 int k = 0;
479 
480                 for (; k + 1 < K; k = k + 2)
481                 {
482                     for (int n = 0; n < 4; n++)
483                     {
484                         sum0[n] += (int)va[0] * vb[2 * n]; // k0
485                         sum0[n] += (int)va[1] * vb[2 * n + 1];
486 
487                         sum1[n] += (int)va[2] * vb[2 * n]; // k1
488                         sum1[n] += (int)va[3] * vb[2 * n + 1];
489 
490                         sum2[n] += (int)va[4] * vb[2 * n]; // k2
491                         sum2[n] += (int)va[5] * vb[2 * n + 1];
492 
493                         sum3[n] += (int)va[6] * vb[2 * n]; // k3
494                         sum3[n] += (int)va[7] * vb[2 * n + 1];
495                     }
496 
497                     va += 8;
498                     vb += 8;
499                 }
500 
501                 for (; k < K; k++)
502                 {
503                     for (int n = 0; n < 4; n++)
504                     {
505                         sum0[n] += (int)va[0] * vb[n];
506                         sum1[n] += (int)va[1] * vb[n];
507                         sum2[n] += (int)va[2] * vb[n];
508                         sum3[n] += (int)va[3] * vb[n];
509                     }
510 
511                     va += 4;
512                     vb += 4;
513                 }
514 
515                 for (int n = 0; n < 4; n++)
516                 {
517                     output0[n] = sum0[n];
518                     output1[n] = sum1[n];
519                     output2[n] = sum2[n];
520                     output3[n] = sum3[n];
521                 }
522 #endif
523                 output0 += 4;
524                 output1 += 4;
525                 output2 += 4;
526                 output3 += 4;
527             }
528 
529             for (; j < N; j++)
530             {
531                 const signed char* vb = bottom_tm.channel(j / 4 + j % 4);
532                 const signed char* va = kernel_tm.channel(i / 4);
533 
534 #if 0 //__ARM_NEON
535                 int32x4_t _sum = vdupq_n_s32(0);
536 
537                 int k=0;
538 
539                 for (; k+3<K; k=k+4)
540                 {
541                     int8x8_t _r0 = vld1_s8(vb);     // i0[0-3]
542                     int8x8x2_t _k = vld2_s8(va);    // k0[0-1], k1[0-1], k2[0-1], k3[0-1];k0[2-3], k1[2-3], k2[2-3], k3[2-3]
543 
544                     int16x8_t _r0_s16 = vmovl_s8(_r0);          // i0[0],i0[1],i0[2],i0[3]
545                     int16x8_t _k02_s16 = vmovl_s8(_k.val[0]);   // k0[0],k1[0],k2[0],k3[0],k0[2],k1[2],k2[2],k3[2]
546                     int16x8_t _k13_s16 = vmovl_s8(_k.val[1]);   // k0[1],k1[1],k2[1],k3[1],k0[3],k1[3],k2[3],k3[3]
547 
548                     _sum = vmlal_lane_s16(_sum, vget_low_s16(_k02_s16), vget_low_s16(_r0_s16), 0);    // i0[0]*k[0-3][0]
549                     _sum = vmlal_lane_s16(_sum, vget_low_s16(_k13_s16), vget_low_s16(_r0_s16), 1);    // i0[1]*k[0-3][1]
550                     _sum = vmlal_lane_s16(_sum, vget_high_s16(_k02_s16), vget_low_s16(_r0_s16), 2);   // i0[2]*k[0-3][2]
551                     _sum = vmlal_lane_s16(_sum, vget_high_s16(_k13_s16), vget_low_s16(_r0_s16), 3);   // i0[3]*k[0-3][3]
552 
553                     va += 16;
554                     vb += 4;
555                 }
556 
557                 for (; k+1<K; k=k+2)
558                 {
559                     int8x8_t _r0 = vld1_s8(vb);     // i0[0-3]
560                     int8x8_t _k = vld1_s8(va);      // k0[0-1], k1[0-1], k2[0-1], k3[0-1]
561 
562                     _r0[2] = _r0[0];
563                     _r0[3] = _r0[1];
564                     _r0[4] = _r0[0];
565                     _r0[5] = _r0[1];
566                     _r0[6] = _r0[0];
567                     _r0[7] = _r0[1];
568 
569                     int16x8_t _tp0 = vmull_s8(_k, _r0);
570                     _sum = vpadalq_s16(_sum, _tp0);
571 
572                     va += 8;
573                     vb += 2;
574                 }
575 
576                 for (; k<K; k++)
577                 {
578                     int8x8_t _r0 = vld1_s8(vb);     // i0[0-3]
579                     int8x8_t _k = vld1_s8(va);      // k[0-3][0]
580 
581                     int16x8_t _tp0 = vmull_s8(_k, _r0);
582 
583                     _sum = vaddw_s16(_sum, vget_low_s16(_tp0));
584 
585                     va += 4;
586                     vb += 1;
587                 }
588 
589                 vst1q_lane_s32(output0, _sum, 0);
590                 vst1q_lane_s32(output1, _sum, 1);
591                 vst1q_lane_s32(output2, _sum, 2);
592                 vst1q_lane_s32(output3, _sum, 3);
593 #else
594                 int sum0 = 0;
595                 int sum1 = 0;
596                 int sum2 = 0;
597                 int sum3 = 0;
598 
599                 int k = 0;
600 
601                 for (; k + 1 < K; k = k + 2)
602                 {
603                     sum0 += (int)va[0] * vb[0];
604                     sum0 += (int)va[1] * vb[1];
605 
606                     sum1 += (int)va[2] * vb[0];
607                     sum1 += (int)va[3] * vb[1];
608 
609                     sum2 += (int)va[4] * vb[0];
610                     sum2 += (int)va[5] * vb[1];
611 
612                     sum3 += (int)va[6] * vb[0];
613                     sum3 += (int)va[7] * vb[1];
614 
615                     va += 8;
616                     vb += 2;
617                 }
618 
619                 for (; k < K; k++)
620                 {
621                     sum0 += (int)va[0] * vb[0];
622                     sum1 += (int)va[1] * vb[0];
623                     sum2 += (int)va[2] * vb[0];
624                     sum3 += (int)va[3] * vb[0];
625 
626                     va += 4;
627                     vb += 1;
628                 }
629 
630                 output0[0] = sum0;
631                 output1[0] = sum1;
632                 output2[0] = sum2;
633                 output3[0] = sum3;
634 #endif
635                 output0++;
636                 output1++;
637                 output2++;
638                 output3++;
639             }
640         }
641 
642         #pragma omp parallel for num_threads(opt.num_threads)
643         for (int i = remain_outch_start; i < outch; i++)
644         {
645             int* output = top_blob.channel(i);
646 
647             int j = 0;
648             for (; j + 3 < N; j = j + 4)
649             {
650                 const signed char* vb = bottom_tm.channel(j / 4);
651                 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
652 
653 #if __ARM_NEON
654                 int32x4_t _sum = vdupq_n_s32(0);
655 
656                 int k = 0;
657                 for (; k + 1 < K; k = k + 2)
658                 {
659                     int8x8_t _r0 = vld1_s8(vb); // i0[0-1], i1[0-1], i2[0-1], i3[0-1]
660                     int8x8_t _k = vld1_s8(va);  // k0[0-1]
661 
662                     _k[2] = _k[0];
663                     _k[3] = _k[1];
664                     _k[4] = _k[0];
665                     _k[5] = _k[1];
666                     _k[6] = _k[0];
667                     _k[7] = _k[1];
668 
669                     int16x8_t _tp0 = vmull_s8(_k, _r0);
670                     _sum = vpadalq_s16(_sum, _tp0);
671 
672                     va += 2;
673                     vb += 8;
674                 }
675 
676                 for (; k < K; k++)
677                 {
678                     int8x8_t _r0 = vld1_s8(vb); // i0[0], i1[0], i2[0], i3[0]
679                     int8x8_t _k = vld1_s8(va);  // k[0][0]
680 
681                     int16x8_t _r0_s16 = vmovl_s8(_r0);
682                     int16x8_t _k_s16 = vmovl_s8(_k);
683 
684                     _sum = vmlal_lane_s16(_sum, vget_low_s16(_r0_s16), vget_low_s16(_k_s16), 0); // i0k0, i1k0, i2k0, i3k0
685 
686                     va += 1;
687                     vb += 4;
688                 }
689 
690                 vst1q_s32(output, _sum);
691 #else
692                 int sum[4] = {0};
693                 int k = 0;
694                 for (; k + 1 < K; k = k + 2)
695                 {
696                     for (int n = 0; n < 4; n++)
697                     {
698                         sum[n] += (int)va[0] * vb[2 * n];
699                         sum[n] += (int)va[1] * vb[2 * n + 1];
700                     }
701                     va += 2;
702                     vb += 8;
703                 }
704 
705                 for (; k < K; k++)
706                 {
707                     for (int n = 0; n < 4; n++)
708                     {
709                         sum[n] += (int)va[0] * vb[n];
710                     }
711                     va += 1;
712                     vb += 4;
713                 }
714 
715                 for (int n = 0; n < 4; n++)
716                 {
717                     output[n] = sum[n];
718                 }
719 #endif
720                 output += 4;
721             }
722 
723             for (; j < N; j++)
724             {
725                 int sum = 0;
726 
727                 const signed char* vb = bottom_tm.channel(j / 4 + j % 4);
728                 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
729 
730                 for (int k = 0; k < K; k++)
731                 {
732                     sum += (int)va[0] * vb[0];
733 
734                     va += 1;
735                     vb += 1;
736                 }
737                 output[0] = sum;
738 
739                 output++;
740             }
741         }
742     }
743 
744     // // sgemm(int M, int N, int K, float* A, float* B, float* C)
745     // {
746     //     for (int i=0; i<M; i++)
747     //     {
748     //         int* output = top_blob.channel(i);
749 
750     //         for (int j=0; j<N; j++)
751     //         {
752     //             int sum = 0;
753 
754     //             signed char* vb = (signed char*)bottom_im2row + K * j;
755     //             const signed char* va = kernel + K * i;
756 
757     //             for (int k=0; k<K; k++)
758     //             {
759     //                 sum += (int)va[0] * vb[0];
760 
761     //                 va += 1;
762     //                 vb += 1;
763     //             }
764     //             output[0] = sum;
765 
766     //             output++;
767     //         }
768     //     }
769     // }
770 }
771 #endif
772 #else
conv_im2col_sgemm_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)773 static void conv_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
774 {
775     const signed char* kernel = _kernel;
776 
777 #if __ARM_NEON && __aarch64__
778     // kernel memory packed 8 x 8
779     kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)1u);
780 #else
781     // kernel memory packed 4 x 8
782     kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u);
783 #endif
784 
785     int nn_outch = 0;
786     int remain_outch_start = 0;
787 
788 #if __ARM_NEON && __aarch64__
789     nn_outch = outch >> 3;
790     remain_outch_start = nn_outch << 3;
791 
792     for (int pp = 0; pp < nn_outch; pp++)
793     {
794         int p = pp * 8;
795 
796         const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
797         const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
798         const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
799         const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
800         const signed char* k4 = kernel + (p + 4) * inch * kernel_size;
801         const signed char* k5 = kernel + (p + 5) * inch * kernel_size;
802         const signed char* k6 = kernel + (p + 6) * inch * kernel_size;
803         const signed char* k7 = kernel + (p + 7) * inch * kernel_size;
804 
805         signed char* ktmp = kernel_tm.channel(p / 8);
806 
807         for (int q = 0; q < inch * kernel_size; q++)
808         {
809             ktmp[0] = k0[0];
810             ktmp[1] = k1[0];
811             ktmp[2] = k2[0];
812             ktmp[3] = k3[0];
813             ktmp[4] = k4[0];
814             ktmp[5] = k5[0];
815             ktmp[6] = k6[0];
816             ktmp[7] = k7[0];
817             ktmp += 8;
818 
819             k0 += 1;
820             k1 += 1;
821             k2 += 1;
822             k3 += 1;
823             k4 += 1;
824             k5 += 1;
825             k6 += 1;
826             k7 += 1;
827         }
828     }
829 #endif
830 
831     nn_outch = (outch - remain_outch_start) >> 2;
832 
833     for (int pp = 0; pp < nn_outch; pp++)
834     {
835         int p = remain_outch_start + pp * 4;
836 
837         const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
838         const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
839         const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
840         const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
841 
842 #if __ARM_NEON && __aarch64__
843         signed char* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4);
844 #else
845         signed char* ktmp = kernel_tm.channel(p / 4);
846 #endif // __ARM_NEON && __aarch64__
847 
848         for (int q = 0; q < inch * kernel_size; q++)
849         {
850             ktmp[0] = k0[0];
851             ktmp[1] = k1[0];
852             ktmp[2] = k2[0];
853             ktmp[3] = k3[0];
854             ktmp += 4;
855 
856             k0 += 1;
857             k1 += 1;
858             k2 += 1;
859             k3 += 1;
860         }
861     }
862 
863     remain_outch_start += nn_outch << 2;
864 
865     for (int p = remain_outch_start; p < outch; p++)
866     {
867         const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
868 
869 #if __ARM_NEON && __aarch64__
870         signed char* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4);
871 #else
872         signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
873 #endif // __ARM_NEON && __aarch64__
874 
875         for (int q = 0; q < inch * kernel_size; q++)
876         {
877             ktmp[0] = k0[0];
878             ktmp++;
879             k0++;
880         }
881     }
882 }
883 
conv_im2col_sgemm_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)884 static void conv_im2col_sgemm_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm,
885                                         const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
886 {
887     int w = bottom_blob.w;
888     int inch = bottom_blob.c;
889 
890     int outw = top_blob.w;
891     int outh = top_blob.h;
892     int outch = top_blob.c;
893 
894     // im2col
895     Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, 1UL, opt.workspace_allocator);
896     {
897         const int stride = kernel_h * kernel_w * outw * outh;
898         signed char* ret = (signed char*)bottom_im2col;
899 
900         #pragma omp parallel for num_threads(opt.num_threads)
901         for (int p = 0; p < inch; p++)
902         {
903             const signed char* input = bottom_blob.channel(p);
904             int retID = stride * p;
905             for (int u = 0; u < kernel_h; u++)
906             {
907                 for (int v = 0; v < kernel_w; v++)
908                 {
909                     for (int i = 0; i < outh; i++)
910                     {
911                         for (int j = 0; j < outw; j++)
912                         {
913                             int row = u + i * stride_h;
914                             int col = v + j * stride_w;
915                             int index = row * w + col;
916                             ret[retID] = input[index];
917                             retID++;
918                         }
919                     }
920                 }
921             }
922         }
923     }
924 
925     int kernel_size = kernel_w * kernel_h;
926     int out_size = outw * outh;
927 
928     // bottom_im2col memory packed 8 x 8
929     Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, (size_t)1u, opt.workspace_allocator);
930     {
931         int nn_size = out_size >> 3;
932         int remain_size_start = nn_size << 3;
933 
934         #pragma omp parallel for num_threads(opt.num_threads)
935         for (int ii = 0; ii < nn_size; ii++)
936         {
937             int i = ii * 8;
938 
939             const signed char* img0 = bottom_im2col.channel(0);
940             img0 += i;
941 
942             signed char* tmpptr = bottom_tm.channel(i / 8);
943 
944             for (int q = 0; q < inch * kernel_size; q++)
945             {
946 #if __ARM_NEON
947 #if __aarch64__
948                 asm volatile(
949                     "prfm    pldl1keep, [%0, #64]    \n"
950                     "ld1     {v0.8b}, [%0]           \n"
951                     "st1     {v0.8b}, [%1]           \n"
952                     : "=r"(img0),  // %0
953                     "=r"(tmpptr) // %1
954                     : "0"(img0),
955                     "1"(tmpptr)
956                     : "cc", "memory", "v0");
957 #else
958                 asm volatile(
959                     "pld        [%0, #64]     \n"
960                     "vld1.s8   {d0}, [%0]     \n"
961                     "vst1.s8   {d0}, [%1]     \n"
962                     : "=r"(img0),  // %0
963                     "=r"(tmpptr) // %1
964                     : "0"(img0),
965                     "1"(tmpptr)
966                     : "cc", "memory", "d0");
967 #endif // __aarch64__
968 #else
969                 tmpptr[0] = img0[0];
970                 tmpptr[1] = img0[1];
971                 tmpptr[2] = img0[2];
972                 tmpptr[3] = img0[3];
973                 tmpptr[4] = img0[4];
974                 tmpptr[5] = img0[5];
975                 tmpptr[6] = img0[6];
976                 tmpptr[7] = img0[7];
977 #endif // __ARM_NEON
978                 tmpptr += 8;
979                 img0 += out_size;
980             }
981         }
982 
983         #pragma omp parallel for num_threads(opt.num_threads)
984         for (int i = remain_size_start; i < out_size; i++)
985         {
986             const signed char* img0 = bottom_im2col.channel(0);
987             img0 += i;
988 
989             signed char* tmpptr = bottom_tm.channel(i / 8 + i % 8);
990 
991             for (int q = 0; q < inch * kernel_size; q++)
992             {
993                 tmpptr[0] = img0[0];
994 
995                 tmpptr += 1;
996                 img0 += out_size;
997             }
998         }
999     }
1000 
1001     // sgemm(int M, int N, int L, float* A, float* B, float* C)
1002     {
1003         //int M = outch;  // outch
1004         int N = outw * outh;                // outsize or out stride
1005         int L = kernel_w * kernel_h * inch; // ksize * inch
1006 
1007         int nn_outch = 0;
1008         int remain_outch_start = 0;
1009 
1010 #if __ARM_NEON && __aarch64__
1011         nn_outch = outch >> 3;
1012         remain_outch_start = nn_outch << 3;
1013 
1014         #pragma omp parallel for num_threads(opt.num_threads)
1015         for (int pp = 0; pp < nn_outch; pp++)
1016         {
1017             int i = pp * 8;
1018 
1019             int* output0 = top_blob.channel(i);
1020             int* output1 = top_blob.channel(i + 1);
1021             int* output2 = top_blob.channel(i + 2);
1022             int* output3 = top_blob.channel(i + 3);
1023             int* output4 = top_blob.channel(i + 4);
1024             int* output5 = top_blob.channel(i + 5);
1025             int* output6 = top_blob.channel(i + 6);
1026             int* output7 = top_blob.channel(i + 7);
1027 
1028             int j = 0;
1029             for (; j + 7 < N; j = j + 8)
1030             {
1031                 signed char* vb = bottom_tm.channel(j / 8);
1032                 const signed char* va = kernel_tm.channel(i / 8);
1033 #if __aarch64__
1034                 asm volatile(
1035                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
1036                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum0n
1037                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum1
1038                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum1n
1039                     "eor    v20.16b, v20.16b, v20.16b    \n" // sum2
1040                     "eor    v21.16b, v21.16b, v21.16b    \n" // sum2n
1041                     "eor    v22.16b, v22.16b, v22.16b    \n" // sum3
1042                     "eor    v23.16b, v23.16b, v23.16b    \n" // sum3n
1043                     "eor    v24.16b, v24.16b, v24.16b    \n" // sum4
1044                     "eor    v25.16b, v25.16b, v25.16b    \n" // sum4n
1045                     "eor    v26.16b, v26.16b, v26.16b    \n" // sum5
1046                     "eor    v27.16b, v27.16b, v27.16b    \n" // sum5n
1047                     "eor    v28.16b, v28.16b, v28.16b    \n" // sum6
1048                     "eor    v29.16b, v29.16b, v29.16b    \n" // sum6n
1049                     "eor    v30.16b, v30.16b, v30.16b    \n" // sum7
1050                     "eor    v31.16b, v31.16b, v31.16b    \n" // sum7n
1051 
1052                     "lsr         w4, %w20, #2            \n" // r4 = nn = L >> 2
1053                     "cmp         w4, #0                  \n"
1054                     "beq         1f                      \n"
1055 
1056                     "0:                                  \n" // for (; k+3<L; k=k+4)
1057 
1058                     "prfm   pldl1keep, [%9, #128]                       \n"
1059                     "ld1    {v0.8b, v1.8b, v2.8b, v3.8b}, [%9], #32     \n"
1060 
1061                     "prfm   pldl1keep, [%8, #128]                       \n"
1062                     "ld1    {v8.8b, v9.8b, v10.8b, v11.8b}, [%8], #32   \n"
1063 
1064                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k70
1065                     "sshll    v1.8h, v1.8b, #0           \n" // k01 - k71
1066                     "sshll    v2.8h, v2.8b, #0           \n" // k02 - k72
1067                     "sshll    v3.8h, v3.8b, #0           \n" // k03 - k73
1068 
1069                     "sshll    v8.8h, v8.8b, #0           \n" // a00 - a70
1070                     "sshll    v9.8h, v9.8b, #0           \n" // a01 - a71
1071                     "sshll    v10.8h, v10.8b, #0         \n" // a02 - a72
1072                     "sshll    v11.8h, v11.8b, #0         \n" // a03 - a73
1073                     // k0
1074                     "smlal    v16.4s, v8.4h, v0.h[0]     \n" // sum0 += (a00-a70) * k00
1075                     "smlal2   v17.4s, v8.8h, v0.h[0]     \n" //
1076                     "smlal    v18.4s, v8.4h, v0.h[1]     \n" // sum1 += (a00-a70) * k10
1077                     "smlal2   v19.4s, v8.8h, v0.h[1]     \n" //
1078                     "smlal    v20.4s, v8.4h, v0.h[2]     \n" // sum2 += (a00-a70) * k20
1079                     "smlal2   v21.4s, v8.8h, v0.h[2]     \n" //
1080                     "smlal    v22.4s, v8.4h, v0.h[3]     \n" // sum3 += (a00-a70) * k30
1081                     "smlal2   v23.4s, v8.8h, v0.h[3]     \n" //
1082                     "smlal    v24.4s, v8.4h, v0.h[4]     \n" // sum4 += (a00-a70) * k40
1083                     "smlal2   v25.4s, v8.8h, v0.h[4]     \n" //
1084                     "smlal    v26.4s, v8.4h, v0.h[5]     \n" // sum5 += (a00-a70) * k50
1085                     "smlal2   v27.4s, v8.8h, v0.h[5]     \n" //
1086                     "smlal    v28.4s, v8.4h, v0.h[6]     \n" // sum6 += (a00-a70) * k60
1087                     "smlal2   v29.4s, v8.8h, v0.h[6]     \n" //
1088                     "smlal    v30.4s, v8.4h, v0.h[7]     \n" // sum7 += (a00-a70) * k70
1089                     "smlal2   v31.4s, v8.8h, v0.h[7]     \n" //
1090                     // k1
1091                     "smlal    v16.4s, v9.4h, v1.h[0]     \n" // sum0 += (a01-a71) * k01
1092                     "smlal2   v17.4s, v9.8h, v1.h[0]     \n" //
1093                     "smlal    v18.4s, v9.4h, v1.h[1]     \n" // sum1 += (a01-a71) * k11
1094                     "smlal2   v19.4s, v9.8h, v1.h[1]     \n" //
1095                     "smlal    v20.4s, v9.4h, v1.h[2]     \n" // sum2 += (a01-a71) * k21
1096                     "smlal2   v21.4s, v9.8h, v1.h[2]     \n" //
1097                     "smlal    v22.4s, v9.4h, v1.h[3]     \n" // sum3 += (a01-a71) * k31
1098                     "smlal2   v23.4s, v9.8h, v1.h[3]     \n" //
1099                     "smlal    v24.4s, v9.4h, v1.h[4]     \n" // sum4 += (a01-a71) * k41
1100                     "smlal2   v25.4s, v9.8h, v1.h[4]     \n" //
1101                     "smlal    v26.4s, v9.4h, v1.h[5]     \n" // sum5 += (a01-a71) * k51
1102                     "smlal2   v27.4s, v9.8h, v1.h[5]     \n" //
1103                     "smlal    v28.4s, v9.4h, v1.h[6]     \n" // sum6 += (a01-a71) * k61
1104                     "smlal2   v29.4s, v9.8h, v1.h[6]     \n" //
1105                     "smlal    v30.4s, v9.4h, v1.h[7]     \n" // sum7 += (a01-a71) * k71
1106                     "smlal2   v31.4s, v9.8h, v1.h[7]     \n" //
1107                     // k2
1108                     "smlal    v16.4s, v10.4h, v2.h[0]    \n" // sum0 += (a02-a72) * k02
1109                     "smlal2   v17.4s, v10.8h, v2.h[0]    \n" //
1110                     "smlal    v18.4s, v10.4h, v2.h[1]    \n" // sum1 += (a02-a72) * k12
1111                     "smlal2   v19.4s, v10.8h, v2.h[1]    \n" //
1112                     "smlal    v20.4s, v10.4h, v2.h[2]    \n" // sum2 += (a02-a72) * k22
1113                     "smlal2   v21.4s, v10.8h, v2.h[2]    \n" //
1114                     "smlal    v22.4s, v10.4h, v2.h[3]    \n" // sum3 += (a02-a72) * k32
1115                     "smlal2   v23.4s, v10.8h, v2.h[3]    \n" //
1116                     "smlal    v24.4s, v10.4h, v2.h[4]    \n" // sum4 += (a02-a72) * k42
1117                     "smlal2   v25.4s, v10.8h, v2.h[4]    \n" //
1118                     "smlal    v26.4s, v10.4h, v2.h[5]    \n" // sum5 += (a02-a72) * k52
1119                     "smlal2   v27.4s, v10.8h, v2.h[5]    \n" //
1120                     "smlal    v28.4s, v10.4h, v2.h[6]    \n" // sum6 += (a02-a72) * k62
1121                     "smlal2   v29.4s, v10.8h, v2.h[6]    \n" //
1122                     "smlal    v30.4s, v10.4h, v2.h[7]    \n" // sum7 += (a02-a72) * k72
1123                     "smlal2   v31.4s, v10.8h, v2.h[7]    \n" //
1124                     // k3
1125                     "smlal    v16.4s, v11.4h, v3.h[0]    \n" // sum0 += (a03-a73) * k03
1126                     "smlal2   v17.4s, v11.8h, v3.h[0]    \n" //
1127                     "smlal    v18.4s, v11.4h, v3.h[1]    \n" // sum1 += (a03-a73) * k13
1128                     "smlal2   v19.4s, v11.8h, v3.h[1]    \n" //
1129                     "smlal    v20.4s, v11.4h, v3.h[2]    \n" // sum2 += (a03-a73) * k23
1130                     "smlal2   v21.4s, v11.8h, v3.h[2]    \n" //
1131                     "smlal    v22.4s, v11.4h, v3.h[3]    \n" // sum3 += (a03-a73) * k33
1132                     "smlal2   v23.4s, v11.8h, v3.h[3]    \n" //
1133                     "smlal    v24.4s, v11.4h, v3.h[4]    \n" // sum4 += (a03-a73) * k43
1134                     "smlal2   v25.4s, v11.8h, v3.h[4]    \n" //
1135                     "smlal    v26.4s, v11.4h, v3.h[5]    \n" // sum5 += (a03-a73) * k53
1136                     "smlal2   v27.4s, v11.8h, v3.h[5]    \n" //
1137                     "smlal    v28.4s, v11.4h, v3.h[6]    \n" // sum6 += (a03-a73) * k63
1138                     "smlal2   v29.4s, v11.8h, v3.h[6]    \n" //
1139                     "smlal    v30.4s, v11.4h, v3.h[7]    \n" // sum7 += (a03-a73) * k73
1140                     "smlal2   v31.4s, v11.8h, v3.h[7]    \n" //
1141 
1142                     "subs   w4, w4, #1                   \n"
1143                     "bne    0b                           \n"
1144 
1145                     "1:                                  \n"
1146 
1147                     // remain loop
1148                     "and    w4, %w20, #3                 \n" // w4 = remain = inch & 3;
1149                     "cmp    w4, #0                       \n"
1150                     "beq    3f                           \n"
1151 
1152                     "2:                                  \n"
1153 
1154                     "prfm   pldl1keep, [%9, #128]        \n"
1155                     "ld1    {v0.8b}, [%9], #8            \n"
1156 
1157                     "prfm   pldl1keep, [%8, #128]        \n"
1158                     "ld1    {v8.8b}, [%8], #8            \n"
1159 
1160                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k70
1161                     "sshll    v8.8h, v8.8b, #0           \n" // a00 - a70
1162 
1163                     // k0
1164                     "smlal    v16.4s, v8.4h, v0.h[0]     \n" // sum0 += (a00-a70) * k00
1165                     "smlal2   v17.4s, v8.8h, v0.h[0]     \n" //
1166                     "smlal    v18.4s, v8.4h, v0.h[1]     \n" // sum1 += (a00-a70) * k10
1167                     "smlal2   v19.4s, v8.8h, v0.h[1]     \n" //
1168                     "smlal    v20.4s, v8.4h, v0.h[2]     \n" // sum2 += (a00-a70) * k20
1169                     "smlal2   v21.4s, v8.8h, v0.h[2]     \n" //
1170                     "smlal    v22.4s, v8.4h, v0.h[3]     \n" // sum3 += (a00-a70) * k30
1171                     "smlal2   v23.4s, v8.8h, v0.h[3]     \n" //
1172                     "smlal    v24.4s, v8.4h, v0.h[4]     \n" // sum4 += (a00-a70) * k40
1173                     "smlal2   v25.4s, v8.8h, v0.h[4]     \n" //
1174                     "smlal    v26.4s, v8.4h, v0.h[5]     \n" // sum5 += (a00-a70) * k50
1175                     "smlal2   v27.4s, v8.8h, v0.h[5]     \n" //
1176                     "smlal    v28.4s, v8.4h, v0.h[6]     \n" // sum6 += (a00-a70) * k60
1177                     "smlal2   v29.4s, v8.8h, v0.h[6]     \n" //
1178                     "smlal    v30.4s, v8.4h, v0.h[7]     \n" // sum7 += (a00-a70) * k70
1179                     "smlal2   v31.4s, v8.8h, v0.h[7]     \n" //
1180 
1181                     "subs   w4, w4, #1                   \n"
1182 
1183                     "bne    2b                           \n"
1184 
1185                     "3:                                  \n"
1186 
1187                     "st1    {v16.4s, v17.4s}, [%0]       \n"
1188                     "st1    {v18.4s, v19.4s}, [%1]       \n"
1189                     "st1    {v20.4s, v21.4s}, [%2]       \n"
1190                     "st1    {v22.4s, v23.4s}, [%3]       \n"
1191                     "st1    {v24.4s, v25.4s}, [%4]       \n"
1192                     "st1    {v26.4s, v27.4s}, [%5]       \n"
1193                     "st1    {v28.4s, v29.4s}, [%6]       \n"
1194                     "st1    {v30.4s, v31.4s}, [%7]       \n"
1195 
1196                     : "=r"(output0), // %0
1197                     "=r"(output1), // %1
1198                     "=r"(output2), // %2
1199                     "=r"(output3), // %3
1200                     "=r"(output4), // %4
1201                     "=r"(output5), // %5
1202                     "=r"(output6), // %6
1203                     "=r"(output7), // %7
1204                     "=r"(vb),      // %8
1205                     "=r"(va)       // %9
1206                     : "0"(output0),
1207                     "1"(output1),
1208                     "2"(output2),
1209                     "3"(output3),
1210                     "4"(output4),
1211                     "5"(output5),
1212                     "6"(output6),
1213                     "7"(output7),
1214                     "8"(vb),
1215                     "9"(va),
1216                     "r"(L) // %20
1217                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1218 #else
1219                 int sum0[8] = {0};
1220                 int sum1[8] = {0};
1221                 int sum2[8] = {0};
1222                 int sum3[8] = {0};
1223                 int sum4[8] = {0};
1224                 int sum5[8] = {0};
1225                 int sum6[8] = {0};
1226                 int sum7[8] = {0};
1227 
1228                 int k = 0;
1229                 for (; k + 7 < L; k = k + 8)
1230                 {
1231                     for (int n = 0; n < 8; n++)
1232                     {
1233                         sum0[n] += (int)va[0] * vb[n];
1234                         sum1[n] += (int)va[1] * vb[n];
1235                         sum2[n] += (int)va[2] * vb[n];
1236                         sum3[n] += (int)va[3] * vb[n];
1237                         sum4[n] += (int)va[4] * vb[n];
1238                         sum5[n] += (int)va[5] * vb[n];
1239                         sum6[n] += (int)va[6] * vb[n];
1240                         sum7[n] += (int)va[7] * vb[n];
1241                         va += 8;
1242 
1243                         sum0[n] += (int)va[0] * vb[n + 8];
1244                         sum1[n] += (int)va[1] * vb[n + 8];
1245                         sum2[n] += (int)va[2] * vb[n + 8];
1246                         sum3[n] += (int)va[3] * vb[n + 8];
1247                         sum4[n] += (int)va[4] * vb[n + 8];
1248                         sum5[n] += (int)va[5] * vb[n + 8];
1249                         sum6[n] += (int)va[6] * vb[n + 8];
1250                         sum7[n] += (int)va[7] * vb[n + 8];
1251                         va += 8;
1252 
1253                         sum0[n] += (int)va[0] * vb[n + 16];
1254                         sum1[n] += (int)va[1] * vb[n + 16];
1255                         sum2[n] += (int)va[2] * vb[n + 16];
1256                         sum3[n] += (int)va[3] * vb[n + 16];
1257                         sum4[n] += (int)va[4] * vb[n + 16];
1258                         sum5[n] += (int)va[5] * vb[n + 16];
1259                         sum6[n] += (int)va[6] * vb[n + 16];
1260                         sum7[n] += (int)va[7] * vb[n + 16];
1261                         va += 8;
1262 
1263                         sum0[n] += (int)va[0] * vb[n + 24];
1264                         sum1[n] += (int)va[1] * vb[n + 24];
1265                         sum2[n] += (int)va[2] * vb[n + 24];
1266                         sum3[n] += (int)va[3] * vb[n + 24];
1267                         sum4[n] += (int)va[4] * vb[n + 24];
1268                         sum5[n] += (int)va[5] * vb[n + 24];
1269                         sum6[n] += (int)va[6] * vb[n + 24];
1270                         sum7[n] += (int)va[7] * vb[n + 24];
1271                         va += 8;
1272 
1273                         sum0[n] += (int)va[0] * vb[n + 32];
1274                         sum1[n] += (int)va[1] * vb[n + 32];
1275                         sum2[n] += (int)va[2] * vb[n + 32];
1276                         sum3[n] += (int)va[3] * vb[n + 32];
1277                         sum4[n] += (int)va[4] * vb[n + 32];
1278                         sum5[n] += (int)va[5] * vb[n + 32];
1279                         sum6[n] += (int)va[6] * vb[n + 32];
1280                         sum7[n] += (int)va[7] * vb[n + 32];
1281                         va += 8;
1282 
1283                         sum0[n] += (int)va[0] * vb[n + 40];
1284                         sum1[n] += (int)va[1] * vb[n + 40];
1285                         sum2[n] += (int)va[2] * vb[n + 40];
1286                         sum3[n] += (int)va[3] * vb[n + 40];
1287                         sum4[n] += (int)va[4] * vb[n + 40];
1288                         sum5[n] += (int)va[5] * vb[n + 40];
1289                         sum6[n] += (int)va[6] * vb[n + 40];
1290                         sum7[n] += (int)va[7] * vb[n + 40];
1291                         va += 8;
1292 
1293                         sum0[n] += (int)va[0] * vb[n + 48];
1294                         sum1[n] += (int)va[1] * vb[n + 48];
1295                         sum2[n] += (int)va[2] * vb[n + 48];
1296                         sum3[n] += (int)va[3] * vb[n + 48];
1297                         sum4[n] += (int)va[4] * vb[n + 48];
1298                         sum5[n] += (int)va[5] * vb[n + 48];
1299                         sum6[n] += (int)va[6] * vb[n + 48];
1300                         sum7[n] += (int)va[7] * vb[n + 48];
1301                         va += 8;
1302 
1303                         sum0[n] += (int)va[0] * vb[n + 56];
1304                         sum1[n] += (int)va[1] * vb[n + 56];
1305                         sum2[n] += (int)va[2] * vb[n + 56];
1306                         sum3[n] += (int)va[3] * vb[n + 56];
1307                         sum4[n] += (int)va[4] * vb[n + 56];
1308                         sum5[n] += (int)va[5] * vb[n + 56];
1309                         sum6[n] += (int)va[6] * vb[n + 56];
1310                         sum7[n] += (int)va[7] * vb[n + 56];
1311                         va -= 56;
1312                     }
1313 
1314                     va += 64;
1315                     vb += 64;
1316                 }
1317 
1318                 for (; k < L; k++)
1319                 {
1320                     for (int n = 0; n < 8; n++)
1321                     {
1322                         sum0[n] += (int)va[0] * vb[n];
1323                         sum1[n] += (int)va[1] * vb[n];
1324                         sum2[n] += (int)va[2] * vb[n];
1325                         sum3[n] += (int)va[3] * vb[n];
1326                         sum4[n] += (int)va[4] * vb[n];
1327                         sum5[n] += (int)va[5] * vb[n];
1328                         sum6[n] += (int)va[6] * vb[n];
1329                         sum7[n] += (int)va[7] * vb[n];
1330                     }
1331 
1332                     va += 8;
1333                     vb += 8;
1334                 }
1335 
1336                 for (int n = 0; n < 8; n++)
1337                 {
1338                     output0[n] = sum0[n];
1339                     output1[n] = sum1[n];
1340                     output2[n] = sum2[n];
1341                     output3[n] = sum3[n];
1342                     output4[n] = sum4[n];
1343                     output5[n] = sum5[n];
1344                     output6[n] = sum6[n];
1345                     output7[n] = sum7[n];
1346                 }
1347 #endif // __aarch64__
1348                 output0 += 8;
1349                 output1 += 8;
1350                 output2 += 8;
1351                 output3 += 8;
1352                 output4 += 8;
1353                 output5 += 8;
1354                 output6 += 8;
1355                 output7 += 8;
1356             }
1357 
1358             for (; j < N; j++)
1359             {
1360                 signed char* vb = bottom_tm.channel(j / 8 + j % 8);
1361                 const signed char* va = kernel_tm.channel(i / 8);
1362 
1363 #if __aarch64__
1364                 asm volatile(
1365                     "eor    v14.16b, v14.16b, v14.16b    \n" // sum0_3
1366                     "eor    v15.16b, v15.16b, v15.16b    \n" // sum4_7
1367                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
1368                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum1
1369                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum2
1370                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum3
1371                     "eor    v20.16b, v20.16b, v20.16b    \n" // sum4
1372                     "eor    v21.16b, v21.16b, v21.16b    \n" // sum5
1373                     "eor    v22.16b, v22.16b, v22.16b    \n" // sum6
1374                     "eor    v23.16b, v23.16b, v23.16b    \n" // sum7
1375 
1376                     "lsr         w4, %w20, #2            \n" // r4 = nn = L >> 2
1377                     "cmp         w4, #0                  \n"
1378                     "beq         1f                      \n"
1379 
1380                     "0:                                  \n" // for (; k+3<L; k=k+4)
1381 
1382                     "prfm   pldl1keep, [%9, #128]                       \n"
1383                     "ld1    {v0.8b, v1.8b, v2.8b, v3.8b}, [%9], #32     \n" // k
1384 
1385                     //"prfm   pldl1keep, [%8, #128]      \n"
1386                     "ld1    {v4.8b}, [%8]                \n" // d
1387                     "add    %8, %8, #4                   \n"
1388 
1389                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k70
1390                     "sshll    v1.8h, v1.8b, #0           \n" // k01 - k71
1391                     "sshll    v2.8h, v2.8b, #0           \n" // k02 - k72
1392                     "sshll    v3.8h, v3.8b, #0           \n" // k03 - k73
1393 
1394                     "sshll    v4.8h, v4.8b, #0           \n" // a00 - a30
1395 
1396                     // k0
1397                     "smlal    v16.4s, v0.4h, v4.h[0]     \n" // sum0 += (k00-k70) * a00
1398                     "smlal2   v17.4s, v0.8h, v4.h[0]     \n" //
1399                     "smlal    v18.4s, v1.4h, v4.h[1]     \n" // sum1 += (k01-k71) * a10
1400                     "smlal2   v19.4s, v1.8h, v4.h[1]     \n" //
1401                     "smlal    v20.4s, v2.4h, v4.h[2]     \n" // sum2 += (k02-k72) * a20
1402                     "smlal2   v21.4s, v2.8h, v4.h[2]     \n" //
1403                     "smlal    v22.4s, v3.4h, v4.h[3]     \n" // sum3 += (k03-k73) * a30
1404                     "smlal2   v23.4s, v3.8h, v4.h[3]     \n" //
1405 
1406                     "subs   w4, w4, #1                   \n"
1407                     "bne    0b                           \n"
1408 
1409                     "add      v16.4s, v16.4s, v18.4s     \n"
1410                     "add      v17.4s, v17.4s, v19.4s     \n"
1411                     "add      v20.4s, v20.4s, v22.4s     \n"
1412                     "add      v21.4s, v21.4s, v23.4s     \n"
1413                     "add      v14.4s, v16.4s, v20.4s     \n"
1414                     "add      v15.4s, v17.4s, v21.4s     \n"
1415 
1416                     "1:                                  \n"
1417 
1418                     // remain loop
1419                     "and    w4, %w20, #3                 \n" // w4 = remain = inch & 3;
1420                     "cmp    w4, #0                       \n"
1421                     "beq    3f                           \n"
1422 
1423                     "2:                                  \n"
1424 
1425                     //"prfm   pldl1keep, [%9, #128]      \n"
1426                     "ld1    {v0.8b}, [%9], #8             \n"
1427                     //"prfm   pldl1keep, [%8, #128]      \n"
1428                     "ld1    {v4.8b}, [%8]                \n"
1429                     "add    %8, %8, #1                   \n"
1430 
1431                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k70
1432                     "sshll    v4.8h, v4.8b, #0           \n" // a00
1433 
1434                     // k0
1435                     "smlal    v14.4s, v0.4h, v4.h[0]     \n" // sum0 += (k00-k70) * a00
1436                     "smlal2   v15.4s, v0.8h, v4.h[0]     \n" //
1437 
1438                     "subs   w4, w4, #1                   \n"
1439 
1440                     "bne    2b                           \n"
1441 
1442                     "3:                                  \n"
1443 
1444                     "st1    {v14.s}[0], [%0]             \n"
1445                     "st1    {v14.s}[1], [%1]             \n"
1446                     "st1    {v14.s}[2], [%2]             \n"
1447                     "st1    {v14.s}[3], [%3]             \n"
1448                     "st1    {v15.s}[0], [%4]             \n"
1449                     "st1    {v15.s}[1], [%5]             \n"
1450                     "st1    {v15.s}[2], [%6]             \n"
1451                     "st1    {v15.s}[3], [%7]             \n"
1452 
1453                     : "=r"(output0), // %0
1454                     "=r"(output1), // %1
1455                     "=r"(output2), // %2
1456                     "=r"(output3), // %3
1457                     "=r"(output4), // %4
1458                     "=r"(output5), // %5
1459                     "=r"(output6), // %6
1460                     "=r"(output7), // %7
1461                     "=r"(vb),      // %8
1462                     "=r"(va)       // %9
1463                     : "0"(output0),
1464                     "1"(output1),
1465                     "2"(output2),
1466                     "3"(output3),
1467                     "4"(output4),
1468                     "5"(output5),
1469                     "6"(output6),
1470                     "7"(output7),
1471                     "8"(vb),
1472                     "9"(va),
1473                     "r"(L) // %20
1474                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1475 #else
1476                 int sum0 = 0;
1477                 int sum1 = 0;
1478                 int sum2 = 0;
1479                 int sum3 = 0;
1480                 int sum4 = 0;
1481                 int sum5 = 0;
1482                 int sum6 = 0;
1483                 int sum7 = 0;
1484 
1485                 for (int k = 0; k < L; k++)
1486                 {
1487                     sum0 += (int)va[0] * vb[0];
1488                     sum1 += (int)va[1] * vb[0];
1489                     sum2 += (int)va[2] * vb[0];
1490                     sum3 += (int)va[3] * vb[0];
1491                     sum4 += (int)va[4] * vb[0];
1492                     sum5 += (int)va[5] * vb[0];
1493                     sum6 += (int)va[6] * vb[0];
1494                     sum7 += (int)va[7] * vb[0];
1495 
1496                     va += 8;
1497                     vb += 1;
1498                 }
1499 
1500                 output0[0] = sum0;
1501                 output1[0] = sum1;
1502                 output2[0] = sum2;
1503                 output3[0] = sum3;
1504                 output4[0] = sum4;
1505                 output5[0] = sum5;
1506                 output6[0] = sum6;
1507                 output7[0] = sum7;
1508 #endif // __aarch64__
1509                 output0++;
1510                 output1++;
1511                 output2++;
1512                 output3++;
1513                 output4++;
1514                 output5++;
1515                 output6++;
1516                 output7++;
1517             }
1518         }
1519 #endif // __ARM_NEON && __aarch64__
1520 
1521         nn_outch = (outch - remain_outch_start) >> 2;
1522 
1523         #pragma omp parallel for num_threads(opt.num_threads)
1524         for (int pp = 0; pp < nn_outch; pp++)
1525         {
1526             int i = remain_outch_start + pp * 4;
1527 
1528             int* output0 = top_blob.channel(i);
1529             int* output1 = top_blob.channel(i + 1);
1530             int* output2 = top_blob.channel(i + 2);
1531             int* output3 = top_blob.channel(i + 3);
1532 
1533             int j = 0;
1534             for (; j + 7 < N; j = j + 8)
1535             {
1536                 signed char* vb = bottom_tm.channel(j / 8);
1537 #if __ARM_NEON && __aarch64__
1538                 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
1539 #else
1540                 const signed char* va = kernel_tm.channel(i / 4);
1541 #endif // __ARM_NEON && __aarch64__
1542 
1543 #if __ARM_NEON
1544 #if __aarch64__
1545                 asm volatile(
1546                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
1547                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum0n
1548                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum1
1549                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum1n
1550                     "eor    v20.16b, v20.16b, v20.16b    \n" // sum2
1551                     "eor    v21.16b, v21.16b, v21.16b    \n" // sum2n
1552                     "eor    v22.16b, v22.16b, v22.16b    \n" // sum3
1553                     "eor    v23.16b, v23.16b, v23.16b    \n" // sum3n
1554 
1555                     "lsr         w4, %w12, #2            \n" // r4 = nn = L >> 2
1556                     "cmp         w4, #0                  \n"
1557                     "beq         1f                      \n"
1558 
1559                     "0:                                  \n" // for (; k+3<L; k=k+4)
1560 
1561                     "prfm   pldl1keep, [%5, #128]        \n"
1562                     "ld1    {v0.8b, v1.8b}, [%5], #16    \n"
1563 
1564                     "prfm   pldl1keep, [%4, #128]                       \n"
1565                     "ld1    {v8.8b, v9.8b, v10.8b, v11.8b}, [%4], #32   \n"
1566 
1567                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k30,k01 - k31
1568                     "sshll    v1.8h, v1.8b, #0           \n" // k02 - k32,k03 - k33
1569 
1570                     "sshll    v8.8h, v8.8b, #0           \n" // a00 - a70
1571                     "sshll    v9.8h, v9.8b, #0           \n" // a01 - a71
1572                     "sshll    v10.8h, v10.8b, #0         \n" // a02 - a72
1573                     "sshll    v11.8h, v11.8b, #0         \n" // a03 - a73
1574 
1575                     // k0
1576                     "smlal    v16.4s, v8.4h, v0.h[0]     \n" // sum0 += (a00-a70) * k00
1577                     "smlal2   v17.4s, v8.8h, v0.h[0]     \n" //
1578                     "smlal    v18.4s, v8.4h, v0.h[1]     \n" // sum1 += (a00-a70) * k10
1579                     "smlal2   v19.4s, v8.8h, v0.h[1]     \n" //
1580                     "smlal    v20.4s, v8.4h, v0.h[2]     \n" // sum2 += (a00-a70) * k20
1581                     "smlal2   v21.4s, v8.8h, v0.h[2]     \n" //
1582                     "smlal    v22.4s, v8.4h, v0.h[3]     \n" // sum3 += (a00-a70) * k30
1583                     "smlal2   v23.4s, v8.8h, v0.h[3]     \n" //
1584                     // k1
1585                     "smlal    v16.4s, v9.4h, v0.h[4]     \n" // sum0 += (a01-a71) * k01
1586                     "smlal2   v17.4s, v9.8h, v0.h[4]     \n" //
1587                     "smlal    v18.4s, v9.4h, v0.h[5]     \n" // sum1 += (a01-a71) * k11
1588                     "smlal2   v19.4s, v9.8h, v0.h[5]     \n" //
1589                     "smlal    v20.4s, v9.4h, v0.h[6]     \n" // sum2 += (a01-a71) * k21
1590                     "smlal2   v21.4s, v9.8h, v0.h[6]     \n" //
1591                     "smlal    v22.4s, v9.4h, v0.h[7]     \n" // sum3 += (a01-a71) * k31
1592                     "smlal2   v23.4s, v9.8h, v0.h[7]     \n" //
1593                     // k2
1594                     "smlal    v16.4s, v10.4h, v1.h[0]    \n" // sum0 += (a02-a72) * k02
1595                     "smlal2   v17.4s, v10.8h, v1.h[0]    \n" //
1596                     "smlal    v18.4s, v10.4h, v1.h[1]    \n" // sum1 += (a02-a72) * k12
1597                     "smlal2   v19.4s, v10.8h, v1.h[1]    \n" //
1598                     "smlal    v20.4s, v10.4h, v1.h[2]    \n" // sum2 += (a02-a72) * k22
1599                     "smlal2   v21.4s, v10.8h, v1.h[2]    \n" //
1600                     "smlal    v22.4s, v10.4h, v1.h[3]    \n" // sum3 += (a02-a72) * k32
1601                     "smlal2   v23.4s, v10.8h, v1.h[3]    \n" //
1602                     // k3
1603                     "smlal    v16.4s, v11.4h, v1.h[4]    \n" // sum0 += (a03-a73) * k03
1604                     "smlal2   v17.4s, v11.8h, v1.h[4]    \n" //
1605                     "smlal    v18.4s, v11.4h, v1.h[5]    \n" // sum1 += (a03-a73) * k13
1606                     "smlal2   v19.4s, v11.8h, v1.h[5]    \n" //
1607                     "smlal    v20.4s, v11.4h, v1.h[6]    \n" // sum2 += (a03-a73) * k23
1608                     "smlal2   v21.4s, v11.8h, v1.h[6]    \n" //
1609                     "smlal    v22.4s, v11.4h, v1.h[7]    \n" // sum3 += (a03-a73) * k33
1610                     "smlal2   v23.4s, v11.8h, v1.h[7]    \n" //
1611 
1612                     "subs   w4, w4, #1                   \n"
1613                     "bne    0b                           \n"
1614 
1615                     "1:                                  \n"
1616 
1617                     // remain loop
1618                     "and    w4, %w12, #3                 \n" // w4 = remain = inch & 3;
1619                     "cmp    w4, #0                       \n"
1620                     "beq    3f                           \n"
1621 
1622                     "2:                                  \n"
1623 
1624                     //"prfm   pldl1keep, [%5, #128]      \n"
1625                     "ld1    {v0.8b}, [%5]                \n"
1626                     //"prfm   pldl1keep, [%4, #128]      \n"
1627                     "ld1    {v8.8b}, [%4], #8            \n"
1628                     "add    %5, %5, #4                   \n"
1629 
1630                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k30
1631                     "sshll    v8.8h, v8.8b, #0           \n" // a00 - a70
1632 
1633                     // k0
1634                     "smlal    v16.4s, v8.4h, v0.h[0]     \n" // sum0 += (a00-a70) * k00
1635                     "smlal2   v17.4s, v8.8h, v0.h[0]     \n" //
1636                     "smlal    v18.4s, v8.4h, v0.h[1]     \n" // sum1 += (a00-a70) * k10
1637                     "smlal2   v19.4s, v8.8h, v0.h[1]     \n" //
1638                     "smlal    v20.4s, v8.4h, v0.h[2]     \n" // sum2 += (a00-a70) * k20
1639                     "smlal2   v21.4s, v8.8h, v0.h[2]     \n" //
1640                     "smlal    v22.4s, v8.4h, v0.h[3]     \n" // sum3 += (a00-a70) * k30
1641                     "smlal2   v23.4s, v8.8h, v0.h[3]     \n" //
1642 
1643                     "subs   w4, w4, #1                   \n"
1644 
1645                     "bne    2b                           \n"
1646 
1647                     "3:                                  \n"
1648 
1649                     "st1    {v16.4s, v17.4s}, [%0]       \n"
1650                     "st1    {v18.4s, v19.4s}, [%1]       \n"
1651                     "st1    {v20.4s, v21.4s}, [%2]       \n"
1652                     "st1    {v22.4s, v23.4s}, [%3]       \n"
1653 
1654                     : "=r"(output0), // %0
1655                     "=r"(output1), // %1
1656                     "=r"(output2), // %2
1657                     "=r"(output3), // %3
1658                     "=r"(vb),      // %4
1659                     "=r"(va)       // %5
1660                     : "0"(output0),
1661                     "1"(output1),
1662                     "2"(output2),
1663                     "3"(output3),
1664                     "4"(vb),
1665                     "5"(va),
1666                     "r"(L) // %12
1667                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1668 #else
1669                 asm volatile(
1670                     // K loop
1671                     "vmov.s32    q8, #0             \n"
1672                     "vmov.s32    q9, #0             \n"
1673                     "vmov.s32    q10, #0            \n"
1674                     "vmov.s32    q11, #0            \n"
1675                     "vmov.s32    q12, #0            \n"
1676                     "vmov.s32    q13, #0            \n"
1677                     "vmov.s32    q14, #0            \n"
1678                     "vmov.s32    q15, #0            \n"
1679 
1680                     "lsr         r4, %12, #3        \n" // r4 = nn = L >> 3
1681                     "cmp         r4, #0             \n"
1682                     "beq         1f                 \n"
1683 
1684                     "0:                             \n" // for(; nn != 0; nn--)
1685                     "pld         [%4, #128]         \n"
1686                     "vld1.s8     {d8-d11}, [%4]!    \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37    a(inch)(data)
1687                     "vmovl.s8    q7, d11            \n" // a30-a37
1688                     "vmovl.s8    q6, d10            \n" // a20-a27
1689                     "vmovl.s8    q5, d9             \n" // a10-a17
1690                     "vmovl.s8    q4, d8             \n" // a00-a07
1691 
1692                     "pld         [%5, #128]         \n"
1693                     "vld1.s8     {d0-d3}, [%5]!     \n" // kptr k00-k30,k01-k31, k02-k32,k03-k33, k04-k34,k05-k35, k06-k36,k07-k37    k(outch)(inch)
1694                     "vmovl.s8    q3, d3             \n" // k06-k36,k07-k37
1695                     "vmovl.s8    q2, d2             \n" // k04-k34,k05-k35
1696                     "vmovl.s8    q1, d1             \n" // k02-k32,k03-k33
1697                     "vmovl.s8    q0, d0             \n" // k00-k30,k01-k31
1698 
1699                     "vmlal.s16   q8, d8, d0[0]      \n" // sum0 = (a00-a07) * k00
1700                     "vmlal.s16   q9, d9, d0[0]      \n"
1701                     "vmlal.s16   q10, d8, d0[1]     \n" // sum1 = (a00-a07) * k10
1702                     "vmlal.s16   q11, d9, d0[1]     \n"
1703                     "vmlal.s16   q12, d8, d0[2]     \n" // sum2 = (a00-a07) * k20
1704                     "vmlal.s16   q13, d9, d0[2]     \n"
1705                     "vmlal.s16   q14, d8, d0[3]     \n" // sum3 = (a00-a07) * k30
1706                     "vmlal.s16   q15, d9, d0[3]     \n"
1707 
1708                     "vmlal.s16   q8, d10, d1[0]     \n" // sum0 += (a10-a17) * k01
1709                     "vmlal.s16   q9, d11, d1[0]     \n"
1710                     "vmlal.s16   q10, d10, d1[1]    \n" // sum1 += (a10-a17) * k11
1711                     "vmlal.s16   q11, d11, d1[1]    \n"
1712                     "vmlal.s16   q12, d10, d1[2]    \n" // sum2 += (a10-a17) * k21
1713                     "vmlal.s16   q13, d11, d1[2]    \n"
1714                     "vmlal.s16   q14, d10, d1[3]    \n" // sum3 += (a10-a17) * k31
1715                     "vmlal.s16   q15, d11, d1[3]    \n"
1716 
1717                     "pld         [%4, #128]         \n"
1718                     "vld1.s8     {d8-d9}, [%4]!     \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37    a(inch)(data)
1719                     "vmovl.s8    q5, d9             \n" // a10-a17
1720                     "vmovl.s8    q4, d8             \n" // a00-a07
1721 
1722                     "vmlal.s16   q8, d12, d2[0]     \n" // sum0 += (a20-a27) * k02
1723                     "vmlal.s16   q9, d13, d2[0]     \n"
1724                     "vmlal.s16   q10, d12, d2[1]    \n" // sum1 += (a20-a27) * k12
1725                     "vmlal.s16   q11, d13, d2[1]    \n"
1726                     "vmlal.s16   q12, d12, d2[2]    \n" // sum2 += (a20-a27) * k22
1727                     "vmlal.s16   q13, d13, d2[2]    \n"
1728                     "vmlal.s16   q14, d12, d2[3]    \n" // sum3 += (a20-a27) * k32
1729                     "vmlal.s16   q15, d13, d2[3]    \n"
1730 
1731                     "vmlal.s16   q8, d14, d3[0]     \n" // sum0 += (a30-a37) * k03
1732                     "vmlal.s16   q9, d15, d3[0]     \n"
1733                     "vmlal.s16   q10, d14, d3[1]    \n" // sum1 += (a30-a37) * k13
1734                     "vmlal.s16   q11, d15, d3[1]    \n"
1735                     "vmlal.s16   q12, d14, d3[2]    \n" // sum2 += (a30-a37) * k23
1736                     "vmlal.s16   q13, d15, d3[2]    \n"
1737                     "vmlal.s16   q14, d14, d3[3]    \n" // sum3 += (a30-a37) * k33
1738                     "vmlal.s16   q15, d15, d3[3]    \n"
1739 
1740                     "pld         [%4, #128]         \n"
1741                     "vld1.s8     {d0-d1}, [%4]!     \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37    a(inch)(data)
1742                     "vmovl.s8    q1, d1             \n" // a10-a17
1743                     "vmovl.s8    q0, d0             \n" // a00-a07
1744 
1745                     "vmlal.s16   q8, d8, d4[0]      \n" // sum0 += (a40-a47) * k04
1746                     "vmlal.s16   q9, d9, d4[0]      \n"
1747                     "vmlal.s16   q10, d8, d4[1]     \n" // sum1 += (a40-a47) * k14
1748                     "vmlal.s16   q11, d9, d4[1]     \n"
1749                     "vmlal.s16   q12, d8, d4[2]     \n" // sum2 += (a40-a47) * k24
1750                     "vmlal.s16   q13, d9, d4[2]     \n"
1751                     "vmlal.s16   q14, d8, d4[3]     \n" // sum3 += (a40-a47) * k34
1752                     "vmlal.s16   q15, d9, d4[3]     \n"
1753 
1754                     "vmlal.s16   q8, d10, d5[0]     \n" // sum0 += (a50-a57) * k05
1755                     "vmlal.s16   q9, d11, d5[0]     \n"
1756                     "vmlal.s16   q10, d10, d5[1]    \n" // sum1 += (a50-a57) * k15
1757                     "vmlal.s16   q11, d11, d5[1]    \n"
1758                     "vmlal.s16   q12, d10, d5[2]    \n" // sum2 += (a50-a57) * k25
1759                     "vmlal.s16   q13, d11, d5[2]    \n"
1760                     "vmlal.s16   q14, d10, d5[3]    \n" // sum3 += (a50-a57) * k35
1761                     "vmlal.s16   q15, d11, d5[3]    \n"
1762 
1763                     "vmlal.s16   q8, d0, d6[0]      \n" // sum0 += (a60-a67) * k06
1764                     "vmlal.s16   q9, d1, d6[0]      \n"
1765                     "vmlal.s16   q10, d0, d6[1]     \n" // sum1 += (a60-a67) * k16
1766                     "vmlal.s16   q11, d1, d6[1]     \n"
1767                     "vmlal.s16   q12, d0, d6[2]     \n" // sum2 += (a60-a67) * k26
1768                     "vmlal.s16   q13, d1, d6[2]     \n"
1769                     "vmlal.s16   q14, d0, d6[3]     \n" // sum3 += (a60-a67) * k36
1770                     "vmlal.s16   q15, d1, d6[3]     \n"
1771 
1772                     "vmlal.s16   q8, d2, d7[0]      \n" // sum0 += (a70-a77) * k07
1773                     "vmlal.s16   q9, d3, d7[0]      \n"
1774                     "vmlal.s16   q10, d2, d7[1]     \n" // sum1 += (a70-a77) * k17
1775                     "vmlal.s16   q11, d3, d7[1]     \n"
1776                     "vmlal.s16   q12, d2, d7[2]     \n" // sum2 += (a70-a77) * k27
1777                     "vmlal.s16   q13, d3, d7[2]     \n"
1778                     "vmlal.s16   q14, d2, d7[3]     \n" // sum3 += (a70-a77) * k37
1779                     "vmlal.s16   q15, d3, d7[3]     \n"
1780 
1781                     "subs        r4, r4, #1         \n"
1782                     "bne         0b                 \n" // end for
1783 
1784                     "1:                             \n"
1785                     // remain loop
1786                     "and         r4, %12, #7        \n" // r4 = remain = inch & 7
1787                     "cmp         r4, #0             \n"
1788                     "beq         3f                 \n"
1789 
1790                     "2:                             \n" // for(; remain != 0; remain--)
1791                     "vld1.s8     {d2}, [%4]!        \n" // tmpr a00-a70    a(inch)(data)
1792                     "vld1.s8     {d0}, [%5]         \n" // kptr k00-k30    k(outch)(inch)
1793                     "vmovl.s8    q1, d2             \n"
1794                     "vmovl.s8    q0, d0             \n"
1795                     "add         %5, #4             \n"
1796 
1797                     "vmlal.s16   q8, d2, d0[0]      \n" // sum0 += (a00-a70) * k00
1798                     "vmlal.s16   q9, d3, d0[0]      \n"
1799                     "vmlal.s16   q10, d2, d0[1]     \n" // sum1 += (a00-a70) * k10
1800                     "vmlal.s16   q11, d3, d0[1]     \n"
1801                     "vmlal.s16   q12, d2, d0[2]     \n" // sum2 += (a00-a70) * k20
1802                     "vmlal.s16   q13, d3, d0[2]     \n"
1803                     "vmlal.s16   q14, d2, d0[3]     \n" // sum3 += (a00-a70) * k30
1804                     "vmlal.s16   q15, d3, d0[3]     \n"
1805 
1806                     "subs        r4, r4, #1         \n"
1807                     "bne         2b                 \n"
1808 
1809                     "3:                             \n" // store the result to memory
1810                     "vst1.s32    {d16-d19}, [%0]    \n"
1811                     "vst1.s32    {d20-d23}, [%1]    \n"
1812                     "vst1.s32    {d24-d27}, [%2]    \n"
1813                     "vst1.s32    {d28-d31}, [%3]    \n"
1814 
1815                     : "=r"(output0), // %0
1816                     "=r"(output1), // %1
1817                     "=r"(output2), // %2
1818                     "=r"(output3), // %3
1819                     "=r"(vb),      // %4
1820                     "=r"(va)       // %5
1821                     : "0"(output0),
1822                     "1"(output1),
1823                     "2"(output2),
1824                     "3"(output3),
1825                     "4"(vb),
1826                     "5"(va),
1827                     "r"(L) // %12
1828                     : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1829 #endif // __aarch64__
1830 #else
1831                 int sum0[8] = {0};
1832                 int sum1[8] = {0};
1833                 int sum2[8] = {0};
1834                 int sum3[8] = {0};
1835 
1836                 int k = 0;
1837                 for (; k + 7 < L; k = k + 8)
1838                 {
1839                     for (int n = 0; n < 8; n++)
1840                     {
1841                         sum0[n] += (int)va[0] * vb[n];
1842                         sum1[n] += (int)va[1] * vb[n];
1843                         sum2[n] += (int)va[2] * vb[n];
1844                         sum3[n] += (int)va[3] * vb[n];
1845                         va += 4;
1846 
1847                         sum0[n] += (int)va[0] * vb[n + 8];
1848                         sum1[n] += (int)va[1] * vb[n + 8];
1849                         sum2[n] += (int)va[2] * vb[n + 8];
1850                         sum3[n] += (int)va[3] * vb[n + 8];
1851                         va += 4;
1852 
1853                         sum0[n] += (int)va[0] * vb[n + 16];
1854                         sum1[n] += (int)va[1] * vb[n + 16];
1855                         sum2[n] += (int)va[2] * vb[n + 16];
1856                         sum3[n] += (int)va[3] * vb[n + 16];
1857                         va += 4;
1858 
1859                         sum0[n] += (int)va[0] * vb[n + 24];
1860                         sum1[n] += (int)va[1] * vb[n + 24];
1861                         sum2[n] += (int)va[2] * vb[n + 24];
1862                         sum3[n] += (int)va[3] * vb[n + 24];
1863                         va += 4;
1864 
1865                         sum0[n] += (int)va[0] * vb[n + 32];
1866                         sum1[n] += (int)va[1] * vb[n + 32];
1867                         sum2[n] += (int)va[2] * vb[n + 32];
1868                         sum3[n] += (int)va[3] * vb[n + 32];
1869                         va += 4;
1870 
1871                         sum0[n] += (int)va[0] * vb[n + 40];
1872                         sum1[n] += (int)va[1] * vb[n + 40];
1873                         sum2[n] += (int)va[2] * vb[n + 40];
1874                         sum3[n] += (int)va[3] * vb[n + 40];
1875                         va += 4;
1876 
1877                         sum0[n] += (int)va[0] * vb[n + 48];
1878                         sum1[n] += (int)va[1] * vb[n + 48];
1879                         sum2[n] += (int)va[2] * vb[n + 48];
1880                         sum3[n] += (int)va[3] * vb[n + 48];
1881                         va += 4;
1882 
1883                         sum0[n] += (int)va[0] * vb[n + 56];
1884                         sum1[n] += (int)va[1] * vb[n + 56];
1885                         sum2[n] += (int)va[2] * vb[n + 56];
1886                         sum3[n] += (int)va[3] * vb[n + 56];
1887                         va -= 28;
1888                     }
1889 
1890                     va += 32;
1891                     vb += 64;
1892                 }
1893 
1894                 for (; k < L; k++)
1895                 {
1896                     for (int n = 0; n < 8; n++)
1897                     {
1898                         sum0[n] += (int)va[0] * vb[n];
1899                         sum1[n] += (int)va[1] * vb[n];
1900                         sum2[n] += (int)va[2] * vb[n];
1901                         sum3[n] += (int)va[3] * vb[n];
1902                     }
1903 
1904                     va += 4;
1905                     vb += 8;
1906                 }
1907 
1908                 for (int n = 0; n < 8; n++)
1909                 {
1910                     output0[n] = sum0[n];
1911                     output1[n] = sum1[n];
1912                     output2[n] = sum2[n];
1913                     output3[n] = sum3[n];
1914                 }
1915 #endif // __ARM_NEON
1916                 output0 += 8;
1917                 output1 += 8;
1918                 output2 += 8;
1919                 output3 += 8;
1920             }
1921 
1922             for (; j < N; j++)
1923             {
1924                 signed char* vb = bottom_tm.channel(j / 8 + j % 8);
1925 #if __ARM_NEON && __aarch64__
1926                 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
1927 #else
1928                 const signed char* va = kernel_tm.channel(i / 4);
1929 #endif // __ARM_NEON && __aarch64__
1930 
1931 #if __ARM_NEON
1932 #if __aarch64__
1933                 asm volatile(
1934                     "eor    v14.16b, v14.16b, v14.16b    \n" // sum0_3
1935                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
1936                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum1
1937                     "eor    v18.16b, v18.16b, v18.16b    \n" // sum2
1938                     "eor    v19.16b, v19.16b, v19.16b    \n" // sum3
1939 
1940                     "lsr         w4, %w12, #2            \n" // r4 = nn = L >> 2
1941                     "cmp         w4, #0                  \n"
1942                     "beq         1f                      \n"
1943 
1944                     "0:                                  \n" // for (; k+3<L; k=k+4)
1945 
1946                     "prfm   pldl1keep, [%5, #128]        \n"
1947                     "ld1    {v0.8b, v1.8b}, [%5], #16    \n" // k
1948 
1949                     //"prfm   pldl1keep, [%4, #128]      \n"
1950                     "ld1    {v4.8b}, [%4]                \n" // d
1951                     "add    %4, %4, #4                   \n"
1952 
1953                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k30,k01 - k31
1954                     "sshll    v1.8h, v1.8b, #0           \n" // k02 - k32,k03 - k33
1955                     "sshll    v4.8h, v4.8b, #0           \n" // a00 - a30
1956 
1957                     "subs   w4, w4, #1                   \n"
1958                     // k0
1959                     "smlal    v16.4s, v0.4h, v4.h[0]     \n" // sum0 += (k00-k30) * a00
1960                     "smlal2   v17.4s, v0.8h, v4.h[0]     \n" // sum1 += (k01-k31) * a10
1961                     "smlal    v18.4s, v1.4h, v4.h[1]     \n" // sum2 += (k02-k32) * a20
1962                     "smlal2   v19.4s, v1.8h, v4.h[1]     \n" // sum3 += (k03-k33) * a30
1963 
1964                     "bne    0b                           \n"
1965 
1966                     "add      v16.4s, v16.4s, v18.4s     \n"
1967                     "add      v17.4s, v17.4s, v19.4s     \n"
1968                     "add      v14.4s, v16.4s, v17.4s     \n"
1969 
1970                     "1:                                  \n"
1971 
1972                     // remain loop
1973                     "and    w4, %w12, #3                 \n" // w4 = remain = inch & 3;
1974                     "cmp    w4, #0                       \n"
1975                     "beq    3f                           \n"
1976 
1977                     "2:                                  \n"
1978 
1979                     //"prfm   pldl1keep, [%5, #128]      \n"
1980                     "ld1    {v0.8b}, [%5]                \n"
1981                     //"prfm   pldl1keep, [4, #128]       \n"
1982                     "ld1    {v4.8b}, [%4]                \n"
1983                     "add    %4, %4, #1                   \n"
1984                     "add    %5, %5, #4                   \n"
1985 
1986                     "subs   w4, w4, #1                   \n"
1987 
1988                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k30
1989                     "sshll    v4.8h, v4.8b, #0           \n" // a00
1990                     // k0
1991                     "smlal    v14.4s, v0.4h, v4.h[0]     \n" // sum0 += (k00-k30) * a00
1992 
1993                     "bne    2b                           \n"
1994 
1995                     "3:                                  \n"
1996 
1997                     "st1    {v14.s}[0], [%0]             \n"
1998                     "st1    {v14.s}[1], [%1]             \n"
1999                     "st1    {v14.s}[2], [%2]             \n"
2000                     "st1    {v14.s}[3], [%3]             \n"
2001 
2002                     : "=r"(output0), // %0
2003                     "=r"(output1), // %1
2004                     "=r"(output2), // %2
2005                     "=r"(output3), // %3
2006                     "=r"(vb),      // %4
2007                     "=r"(va)       // %5
2008                     : "0"(output0),
2009                     "1"(output1),
2010                     "2"(output2),
2011                     "3"(output3),
2012                     "4"(vb),
2013                     "5"(va),
2014                     "r"(L) // %12
2015                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
2016 #else
2017                 asm volatile(
2018                     // inch loop
2019                     "veor        q6, q6, q6        \n"
2020                     "veor        q7, q7, q7        \n"
2021                     "veor        q8, q8, q8        \n"
2022                     "veor        q9, q9, q9        \n"
2023                     "veor        q10, q10, q10     \n"
2024                     "veor        q11, q11, q11     \n"
2025                     "veor        q12, q12, q12     \n"
2026                     "veor        q13, q13, q13     \n"
2027                     "vmov.s32    q14, #0           \n"
2028 
2029                     "lsr         r4, %12, #3       \n" // r4 = nn = L >> 2
2030                     "cmp         r4, #0            \n"
2031                     "beq         1f                \n"
2032 
2033                     "0:                            \n" // for(; nn != 0; nn--)
2034                     "pld         [%4, #128]        \n"
2035                     "vld1.s8     {d0}, [%4]!       \n" // tmpr a00,a10,a20,a30    a(inch)(data)
2036                     "vmovl.s8    q0, d0            \n" // a00-a07
2037 
2038                     "pld         [%5, #128]        \n"
2039                     "vld1.s8     {d2-d5}, [%5]!    \n" // kptr k00-k30,k01-k31, k02-k32,k03-k33, k04-k34,k05-k35, k06-k36,k07-k37    k(outch)(inch)
2040                     "vmovl.s8    q4, d5            \n" // k06-k36,k07-k37
2041                     "vmovl.s8    q3, d4            \n" // k04-k34,k05-k35
2042                     "vmovl.s8    q2, d3            \n" // k02-k32,k03-k33
2043                     "vmovl.s8    q1, d2            \n" // k00-k30,k01-k31
2044 
2045                     "vmlal.s16   q6, d2, d0[0]     \n" // (k00-k30) * a00
2046                     "vmlal.s16   q7, d3, d0[1]     \n" // (k01-k31) * a01
2047                     "vmlal.s16   q8, d4, d0[2]     \n" // (k02-k32) * a02
2048                     "vmlal.s16   q9, d5, d0[3]     \n" // (k03-k33) * a03
2049                     "vmlal.s16   q10, d6, d1[0]    \n" // (k04-k34) * a04
2050                     "vmlal.s16   q11, d7, d1[1]    \n" // (k05-k35) * a05
2051                     "vmlal.s16   q12, d8, d1[2]    \n" // (k06-k36) * a06
2052                     "vmlal.s16   q13, d9, d1[3]    \n" // (k07-k37) * a07
2053 
2054                     "subs        r4, r4, #1        \n"
2055                     "bne         0b                \n" // end for
2056 
2057                     "vadd.s32    q6, q6, q7        \n"
2058                     "vadd.s32    q9, q9, q8        \n"
2059                     "vadd.s32    q11, q11, q10     \n"
2060                     "vadd.s32    q13, q13, q12     \n"
2061 
2062                     "vadd.s32    q9, q9, q6        \n"
2063                     "vadd.s32    q13, q13, q11     \n"
2064                     "vadd.s32    q14, q13, q9      \n"
2065 
2066                     "1:                            \n"
2067                     // remain loop
2068                     "and         r4, %12, #7       \n" // r4 = remain = inch & 3
2069                     "cmp         r4, #0            \n"
2070                     "beq         3f                \n"
2071 
2072                     "2:                            \n" // for(; remain != 0; remain--)
2073                     "vld1.s8     {d2}, [%4]        \n" // tmpr a00        a(inch)(data)
2074                     "vld1.s8     {d0}, [%5]        \n" // kptr k00-k30    k(outch)(inch)
2075                     "vmovl.s8    q1, d2            \n"
2076                     "vmovl.s8    q0, d0            \n"
2077                     "add         %4, #1            \n"
2078                     "add         %5, #4            \n"
2079 
2080                     "vmlal.s16   q14, d0, d2[0]    \n"
2081 
2082                     "subs        r4, r4, #1        \n"
2083                     "bne         2b                \n"
2084 
2085                     "3:                            \n" // store the result to memory
2086                     "vst1.s32    {d28[0]}, [%0]    \n"
2087                     "vst1.s32    {d28[1]}, [%1]    \n"
2088                     "vst1.s32    {d29[0]}, [%2]    \n"
2089                     "vst1.s32    {d29[1]}, [%3]    \n"
2090 
2091                     : "=r"(output0), // %0
2092                     "=r"(output1), // %1
2093                     "=r"(output2), // %2
2094                     "=r"(output3), // %3
2095                     "=r"(vb),      // %4
2096                     "=r"(va)       // %5
2097                     : "0"(output0),
2098                     "1"(output1),
2099                     "2"(output2),
2100                     "3"(output3),
2101                     "4"(vb),
2102                     "5"(va),
2103                     "r"(L) // %12
2104                     : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14");
2105 #endif // __aarch64__
2106 #else
2107                 int sum0 = 0;
2108                 int sum1 = 0;
2109                 int sum2 = 0;
2110                 int sum3 = 0;
2111 
2112                 for (int k = 0; k < L; k++)
2113                 {
2114                     sum0 += (int)va[0] * vb[0];
2115                     sum1 += (int)va[1] * vb[0];
2116                     sum2 += (int)va[2] * vb[0];
2117                     sum3 += (int)va[3] * vb[0];
2118 
2119                     va += 4;
2120                     vb += 1;
2121                 }
2122 
2123                 output0[0] = sum0;
2124                 output1[0] = sum1;
2125                 output2[0] = sum2;
2126                 output3[0] = sum3;
2127 #endif // __ARM_NEON
2128                 output0++;
2129                 output1++;
2130                 output2++;
2131                 output3++;
2132             }
2133         }
2134 
2135         remain_outch_start += nn_outch << 2;
2136 
2137         #pragma omp parallel for num_threads(opt.num_threads)
2138         for (int i = remain_outch_start; i < outch; i++)
2139         {
2140             int* output = top_blob.channel(i);
2141 
2142             int j = 0;
2143             for (; j + 7 < N; j = j + 8)
2144             {
2145                 signed char* vb = bottom_tm.channel(j / 8);
2146 #if __ARM_NEON && __aarch64__
2147                 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
2148 #else
2149                 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
2150 #endif // __ARM_NEON && __aarch64__
2151 
2152 #if __ARM_NEON
2153 #if __aarch64__
2154                 asm volatile(
2155                     "eor    v16.16b, v16.16b, v16.16b    \n" // sum0
2156                     "eor    v17.16b, v17.16b, v17.16b    \n" // sum0n
2157 
2158                     "lsr         w4, %w6, #2             \n" // r4 = nn = L >> 2
2159                     "cmp         w4, #0                  \n"
2160                     "beq         1f                      \n"
2161 
2162                     "0:                                  \n" // for (; k+3<L; k=k+4)
2163 
2164                     "prfm   pldl1keep, [%2, #128]        \n"
2165                     "ld1    {v0.8b}, [%2]                \n"
2166 
2167                     "prfm   pldl1keep, [%1, #128]                       \n"
2168                     "ld1    {v8.8b, v9.8b, v10.8b, v11.8b}, [%1], #32   \n"
2169                     "add    %2, %2, #4                   \n"
2170 
2171                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k03
2172 
2173                     "sshll    v8.8h, v8.8b, #0           \n" // a00 - a70
2174                     "sshll    v9.8h, v9.8b, #0           \n" // a01 - a71
2175                     "sshll    v10.8h, v10.8b, #0         \n" // a02 - a72
2176                     "sshll    v11.8h, v11.8b, #0         \n" // a03 - a73
2177 
2178                     // k0
2179                     "smlal    v16.4s, v8.4h, v0.h[0]     \n" // sum0 += (a00-a70) * k00
2180                     "smlal2   v17.4s, v8.8h, v0.h[0]     \n" //
2181                     // k1
2182                     "smlal    v16.4s, v9.4h, v0.h[1]     \n" // sum0 += (a01-a71) * k01
2183                     "smlal2   v17.4s, v9.8h, v0.h[1]     \n" //
2184                     // k2
2185                     "smlal    v16.4s, v10.4h, v0.h[2]    \n" // sum0 += (a02-a72) * k02
2186                     "smlal2   v17.4s, v10.8h, v0.h[2]    \n" //
2187                     // k3
2188                     "smlal    v16.4s, v11.4h, v0.h[3]    \n" // sum0 += (a03-a73) * k03
2189                     "smlal2   v17.4s, v11.8h, v0.h[3]    \n" //
2190 
2191                     "subs   w4, w4, #1                   \n"
2192                     "bne    0b                           \n"
2193 
2194                     "1:                                  \n"
2195 
2196                     // remain loop
2197                     "and    w4, %w6, #3                 \n" // w4 = remain = inch & 3;
2198                     "cmp    w4, #0                       \n"
2199                     "beq    3f                           \n"
2200 
2201                     "2:                                  \n"
2202 
2203                     //"prfm   pldl1keep, [%2, #128]      \n"
2204                     "ld1    {v0.8b}, [%2]                \n"
2205                     //"prfm   pldl1keep, [%1, #128]      \n"
2206                     "ld1    {v8.8b}, [%1], #8            \n"
2207                     "add    %2, %2, #1                   \n"
2208 
2209                     "sshll    v0.8h, v0.8b, #0           \n" // k00 - k30
2210                     "sshll    v8.8h, v8.8b, #0           \n" // a00 - a70
2211 
2212                     // k0
2213                     "smlal    v16.4s, v8.4h, v0.h[0]     \n" // sum0 += (a00-a70) * k00
2214                     "smlal2   v17.4s, v8.8h, v0.h[0]     \n" //
2215 
2216                     "subs   w4, w4, #1                   \n"
2217 
2218                     "bne    2b                           \n"
2219 
2220                     "3:                                  \n"
2221 
2222                     "st1    {v16.4s, v17.4s}, [%0]       \n"
2223 
2224                     : "=r"(output), // %0
2225                     "=r"(vb),     // %1
2226                     "=r"(va)      // %2
2227                     : "0"(output),
2228                     "1"(vb),
2229                     "2"(va),
2230                     "r"(L) // %6
2231                     : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17");
2232 #else
2233                 asm volatile(
2234                     // inch loop
2235                     "vmov.s32    q6, #0            \n"
2236                     "vmov.s32    q7, #0            \n"
2237 
2238                     "lsr         r4, %6, #3        \n" // r4 = nn = inch >> 3
2239                     "cmp         r4, #0            \n"
2240                     "beq         1f                \n"
2241 
2242                     "0:                            \n" // for(; nn != 0; nn--)
2243                     "pld         [%1, #128]        \n"
2244                     "vld1.s8     {d4-d7}, [%1]!    \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37    a(inch)(data)
2245                     "vmovl.s8    q5, d7            \n" // a30-a37
2246                     "vmovl.s8    q4, d6            \n" // a20-a27
2247                     "vmovl.s8    q3, d5            \n" // a10-a17
2248                     "vmovl.s8    q2, d4            \n" // a00-a07
2249 
2250                     "pld         [%2, #128]        \n"
2251                     "vld1.s8     {d0}, [%2]!       \n" // kptr k00-k07    k(outch)(inch)
2252                     "vmovl.s8    q1, d1            \n" // k04,k05,k06,k07
2253                     "vmovl.s8    q0, d0            \n" // k00,k01,k02,k03
2254 
2255                     "vmlal.s16   q6, d4, d0[0]     \n" // (a00-a07) * k00
2256                     "vmlal.s16   q7, d5, d0[0]     \n"
2257                     "vmlal.s16   q6, d6, d0[1]     \n" // (a10-a17) * k01
2258                     "vmlal.s16   q7, d7, d0[1]     \n"
2259                     "vmlal.s16   q6, d8, d0[2]     \n" // (a20-a27) * k02
2260                     "vmlal.s16   q7, d9, d0[2]     \n"
2261                     "vmlal.s16   q6, d10, d0[3]    \n" // (a30-a37) * k03
2262                     "vmlal.s16   q7, d11, d0[3]    \n"
2263 
2264                     "pld         [%1, #128]        \n"
2265                     "vld1.s8     {d4-d7}, [%1]!    \n" // tmpr a40-a47,a50-a57,a60-a67,a70-a77    a(inch)(data)
2266                     "vmovl.s8    q5, d7            \n" // a70-a77
2267                     "vmovl.s8    q4, d6            \n" // a60-a67
2268                     "vmovl.s8    q3, d5            \n" // a50-a57
2269                     "vmovl.s8    q2, d4            \n" // a40-a47
2270 
2271                     "vmlal.s16   q6, d4, d1[0]     \n" // (a00-a07) * k00
2272                     "vmlal.s16   q7, d5, d1[0]     \n"
2273                     "vmlal.s16   q6, d6, d1[1]     \n" // (a10-a17) * k01
2274                     "vmlal.s16   q7, d7, d1[1]     \n"
2275                     "vmlal.s16   q6, d8, d1[2]     \n" // (a20-a27) * k02
2276                     "vmlal.s16   q7, d9, d1[2]     \n"
2277                     "vmlal.s16   q6, d10, d1[3]    \n" // (a30-a37) * k03
2278                     "vmlal.s16   q7, d11, d1[3]    \n"
2279 
2280                     "subs        r4, r4, #1        \n"
2281                     "bne         0b                \n" // end for
2282 
2283                     "1:                            \n"
2284                     // remain loop
2285                     "and         r4, %6, #7        \n" // r4 = remain = inch & 7
2286                     "cmp         r4, #0            \n"
2287                     "beq         3f                \n"
2288 
2289                     "2:                            \n" // for(; remain != 0; remain--)
2290                     "vld1.s8     {d2}, [%1]!       \n" // tmpr a00-a07    a(inch)(data)
2291                     "vld1.s8     {d0}, [%2]        \n" // kptr k00        k(outch)(inch)
2292                     "vmovl.s8    q1, d2            \n"
2293                     "vmovl.s8    q0, d0            \n"
2294                     "add         %2, #1            \n"
2295 
2296                     "vmlal.s16   q6, d2, d0[0]     \n" // (a00-a07) * k00
2297                     "vmlal.s16   q7, d3, d0[0]     \n"
2298 
2299                     "subs        r4, r4, #1        \n"
2300                     "bne         2b                \n"
2301 
2302                     "3:                            \n" // store the result to memory
2303                     "vst1.s32    {d12-d15}, [%0]   \n"
2304 
2305                     : "=r"(output), // %0
2306                     "=r"(vb),     // %1
2307                     "=r"(va)      // %2
2308                     : "0"(output),
2309                     "1"(vb),
2310                     "2"(va),
2311                     "r"(L) // %6
2312                     : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
2313 #endif // __aarch64__
2314 #else
2315                 int sum[8] = {0};
2316 
2317                 int k = 0;
2318                 for (; k + 7 < L; k = k + 8)
2319                 {
2320                     for (int n = 0; n < 8; n++)
2321                     {
2322                         sum[n] += (int)va[0] * vb[n];
2323                         sum[n] += (int)va[1] * vb[n + 8];
2324                         sum[n] += (int)va[2] * vb[n + 16];
2325                         sum[n] += (int)va[3] * vb[n + 24];
2326                         sum[n] += (int)va[4] * vb[n + 32];
2327                         sum[n] += (int)va[5] * vb[n + 40];
2328                         sum[n] += (int)va[6] * vb[n + 48];
2329                         sum[n] += (int)va[7] * vb[n + 56];
2330                     }
2331 
2332                     va += 8;
2333                     vb += 64;
2334                 }
2335 
2336                 for (; k < L; k++)
2337                 {
2338                     for (int n = 0; n < 8; n++)
2339                     {
2340                         sum[n] += (int)va[0] * vb[n];
2341                     }
2342 
2343                     va += 1;
2344                     vb += 8;
2345                 }
2346 
2347                 for (int n = 0; n < 8; n++)
2348                 {
2349                     output[n] = sum[n];
2350                 }
2351 #endif // __ARM_NEON
2352                 output += 8;
2353             }
2354 
2355             for (; j < N; j++)
2356             {
2357                 int sum = 0;
2358 
2359                 signed char* vb = bottom_tm.channel(j / 8 + j % 8);
2360 #if __ARM_NEON && __aarch64__
2361                 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
2362 #else
2363                 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
2364 #endif // __ARM_NEON && __aarch64__
2365 
2366                 for (int k = 0; k < L; k++)
2367                 {
2368                     sum += (int)va[0] * vb[0];
2369 
2370                     va += 1;
2371                     vb += 1;
2372                 }
2373                 output[0] = sum;
2374 
2375                 output++;
2376             }
2377         }
2378     }
2379 }
2380 #endif
2381