1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // author:BUG1989 (https://github.com/BUG1989/) Long-term support.
4 // author:FuGuangping (https://github.com/fu1899) Implemented the first version of INT8 quantization on ARMv7.
5 //
6 // Copyright (C) 2019 BUG1989. All rights reserved.
7 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
8 //
9 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
10 // in compliance with the License. You may obtain a copy of the License at
11 //
12 // https://opensource.org/licenses/BSD-3-Clause
13 //
14 // Unless required by applicable law or agreed to in writing, software distributed
15 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
16 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
17 // specific language governing permissions and limitations under the License.
18 
conv3x3s1_winograd23_transform_kernel_int8_neon(const Mat & kernel,std::vector<Mat> & kernel_tm2,int inch,int outch)19 static void conv3x3s1_winograd23_transform_kernel_int8_neon(const Mat& kernel, std::vector<Mat>& kernel_tm2, int inch, int outch)
20 {
21     Mat kernel_tm(4 * 4, inch, outch, 2ul);
22 
23     // G
24     const short ktm[4][3] = {
25         {2, 0, 0},
26         {1, 1, 1},
27         {1, -1, 1},
28         {0, 0, 2}
29     };
30 
31     #pragma omp parallel for
32     for (int p = 0; p < outch; p++)
33     {
34         for (int q = 0; q < inch; q++)
35         {
36             const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9;
37             short* kernel_tm0 = kernel_tm.channel(p).row<short>(q);
38 
39             // transform kernel
40             const signed char* k0 = kernel0;
41             const signed char* k1 = kernel0 + 3;
42             const signed char* k2 = kernel0 + 6;
43 
44             // h
45             short tmp[4][3];
46             for (int i = 0; i < 4; i++)
47             {
48                 tmp[i][0] = (short)k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
49                 tmp[i][1] = (short)k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
50                 tmp[i][2] = (short)k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
51             }
52 
53             // U
54             for (int j = 0; j < 4; j++)
55             {
56                 short* tmpp = &tmp[j][0];
57 
58                 for (int i = 0; i < 4; i++)
59                 {
60                     kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
61                 }
62             }
63         }
64     }
65 
66     for (int r = 0; r < 4; r++)
67     {
68         Mat kernel_tm_test(4 * 8, inch, outch / 8 + (outch % 8) / 4 + outch % 4, 2u);
69 
70         int p = 0;
71         for (; p + 7 < outch; p += 8)
72         {
73             const short* kernel0 = (const short*)kernel_tm + (p + 0) * inch * 16;
74             const short* kernel1 = (const short*)kernel_tm + (p + 1) * inch * 16;
75             const short* kernel2 = (const short*)kernel_tm + (p + 2) * inch * 16;
76             const short* kernel3 = (const short*)kernel_tm + (p + 3) * inch * 16;
77             const short* kernel4 = (const short*)kernel_tm + (p + 4) * inch * 16;
78             const short* kernel5 = (const short*)kernel_tm + (p + 5) * inch * 16;
79             const short* kernel6 = (const short*)kernel_tm + (p + 6) * inch * 16;
80             const short* kernel7 = (const short*)kernel_tm + (p + 7) * inch * 16;
81 
82             short* ktmp = kernel_tm_test.channel(p / 8);
83 
84             for (int q = 0; q < inch; q++)
85             {
86                 ktmp[0] = kernel0[r * 4 + 0];
87                 ktmp[1] = kernel0[r * 4 + 1];
88                 ktmp[2] = kernel0[r * 4 + 2];
89                 ktmp[3] = kernel0[r * 4 + 3];
90 
91                 ktmp[4] = kernel1[r * 4 + 0];
92                 ktmp[5] = kernel1[r * 4 + 1];
93                 ktmp[6] = kernel1[r * 4 + 2];
94                 ktmp[7] = kernel1[r * 4 + 3];
95 
96                 ktmp[8] = kernel2[r * 4 + 0];
97                 ktmp[9] = kernel2[r * 4 + 1];
98                 ktmp[10] = kernel2[r * 4 + 2];
99                 ktmp[11] = kernel2[r * 4 + 3];
100 
101                 ktmp[12] = kernel3[r * 4 + 0];
102                 ktmp[13] = kernel3[r * 4 + 1];
103                 ktmp[14] = kernel3[r * 4 + 2];
104                 ktmp[15] = kernel3[r * 4 + 3];
105 
106                 ktmp[16] = kernel4[r * 4 + 0];
107                 ktmp[17] = kernel4[r * 4 + 1];
108                 ktmp[18] = kernel4[r * 4 + 2];
109                 ktmp[19] = kernel4[r * 4 + 3];
110 
111                 ktmp[20] = kernel5[r * 4 + 0];
112                 ktmp[21] = kernel5[r * 4 + 1];
113                 ktmp[22] = kernel5[r * 4 + 2];
114                 ktmp[23] = kernel5[r * 4 + 3];
115 
116                 ktmp[24] = kernel6[r * 4 + 0];
117                 ktmp[25] = kernel6[r * 4 + 1];
118                 ktmp[26] = kernel6[r * 4 + 2];
119                 ktmp[27] = kernel6[r * 4 + 3];
120 
121                 ktmp[28] = kernel7[r * 4 + 0];
122                 ktmp[29] = kernel7[r * 4 + 1];
123                 ktmp[30] = kernel7[r * 4 + 2];
124                 ktmp[31] = kernel7[r * 4 + 3];
125 
126                 ktmp += 32;
127                 kernel0 += 16;
128                 kernel1 += 16;
129                 kernel2 += 16;
130                 kernel3 += 16;
131                 kernel4 += 16;
132                 kernel5 += 16;
133                 kernel6 += 16;
134                 kernel7 += 16;
135             }
136         }
137 
138         for (; p + 3 < outch; p += 4)
139         {
140             const short* kernel0 = (const short*)kernel_tm + (p + 0) * inch * 16;
141             const short* kernel1 = (const short*)kernel_tm + (p + 1) * inch * 16;
142             const short* kernel2 = (const short*)kernel_tm + (p + 2) * inch * 16;
143             const short* kernel3 = (const short*)kernel_tm + (p + 3) * inch * 16;
144 
145             short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4);
146 
147             for (int q = 0; q < inch; q++)
148             {
149                 ktmp[0] = kernel0[r * 4 + 0];
150                 ktmp[1] = kernel0[r * 4 + 1];
151                 ktmp[2] = kernel0[r * 4 + 2];
152                 ktmp[3] = kernel0[r * 4 + 3];
153 
154                 ktmp[4] = kernel1[r * 4 + 0];
155                 ktmp[5] = kernel1[r * 4 + 1];
156                 ktmp[6] = kernel1[r * 4 + 2];
157                 ktmp[7] = kernel1[r * 4 + 3];
158 
159                 ktmp[8] = kernel2[r * 4 + 0];
160                 ktmp[9] = kernel2[r * 4 + 1];
161                 ktmp[10] = kernel2[r * 4 + 2];
162                 ktmp[11] = kernel2[r * 4 + 3];
163 
164                 ktmp[12] = kernel3[r * 4 + 0];
165                 ktmp[13] = kernel3[r * 4 + 1];
166                 ktmp[14] = kernel3[r * 4 + 2];
167                 ktmp[15] = kernel3[r * 4 + 3];
168 
169                 ktmp += 16;
170                 kernel0 += 16;
171                 kernel1 += 16;
172                 kernel2 += 16;
173                 kernel3 += 16;
174             }
175         }
176 
177         for (; p < outch; p++)
178         {
179             const short* kernel0 = (const short*)kernel_tm + p * inch * 16;
180 
181             short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4 + p % 4);
182 
183             for (int q = 0; q < inch; q++)
184             {
185                 ktmp[0] = kernel0[r * 4 + 0];
186                 ktmp[1] = kernel0[r * 4 + 1];
187                 ktmp[2] = kernel0[r * 4 + 2];
188                 ktmp[3] = kernel0[r * 4 + 3];
189 
190                 ktmp += 4;
191                 kernel0 += 16;
192             }
193         }
194         kernel_tm2.push_back(kernel_tm_test);
195     }
196 }
197 
conv3x3s1_winograd23_int8_neon(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Option & opt)198 static void conv3x3s1_winograd23_int8_neon(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Option& opt)
199 {
200     int w = bottom_blob.w;
201     int h = bottom_blob.h;
202     int inch = bottom_blob.c;
203 
204     int outw = top_blob.w;
205     int outh = top_blob.h;
206     int outch = top_blob.c;
207 
208     // pad to 2n+2, winograd F(2,3)
209     Mat bottom_blob_bordered = bottom_blob;
210 
211     outw = (outw + 1) / 2 * 2;
212     outh = (outh + 1) / 2 * 2;
213 
214     w = outw + 2;
215     h = outh + 2;
216     Option opt_b = opt;
217     opt_b.blob_allocator = opt.workspace_allocator;
218     copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
219 
220     // BEGIN transform input
221     Mat bottom_blob_tm;
222     {
223         int w_tm = outw / 2 * 4;
224         int h_tm = outh / 2 * 4;
225 
226         int nColBlocks = h_tm / 4; // may be the block num in FeatherCNN
227         int nRowBlocks = w_tm / 4;
228 
229         const int tiles = nColBlocks * nRowBlocks;
230 
231         bottom_blob_tm.create(4, inch, tiles * 4, 2u, opt.workspace_allocator);
232 
233         // BT
234         // const float itm[4][4] = {
235         //     {1.0f,  0.0f, -1.0f,  0.0f},
236         //     {0.0f,  1.0f,  1.00f, 0.0f},
237         //     {0.0f, -1.0f,  1.00f, 0.0f},
238         //     {0.0f, -1.0f,  0.00f, 1.0f}
239         // };
240 
241         #pragma omp parallel for num_threads(opt.num_threads)
242         for (int q = 0; q < inch; q++)
243         {
244             const signed char* img = bottom_blob_bordered.channel(q);
245 
246             for (int j = 0; j < nColBlocks; j++)
247             {
248                 const signed char* r0 = img + w * j * 2;
249                 const signed char* r1 = r0 + w;
250                 const signed char* r2 = r1 + w;
251                 const signed char* r3 = r2 + w;
252 
253                 for (int i = 0; i < nRowBlocks; i++)
254                 {
255                     short* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row<short>(q);
256                     short* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row<short>(q);
257                     short* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row<short>(q);
258                     short* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row<short>(q);
259 #if __ARM_NEON
260 #if __aarch64__
261                     asm volatile(
262                         // load
263                         "prfm   pldl1keep, [%0, #64]    \n"
264                         "ld1    {v0.8b}, [%0]           \n"
265                         "prfm   pldl1keep, [%1, #64]    \n"
266                         "ld1    {v1.8b}, [%1]           \n"
267                         "prfm   pldl1keep, [%2, #64]    \n"
268                         "ld1    {v2.8b}, [%2]           \n"
269                         "prfm   pldl1keep, [%3, #64]    \n"
270                         "ld1    {v3.8b}, [%3]           \n"
271                         // w = B_t * d, trans int8 to int16
272                         "ssubl    v4.8h, v0.8b, v2.8b   \n" // d4
273                         "saddl    v5.8h, v1.8b, v2.8b   \n" // d6
274                         "ssubl    v6.8h, v2.8b, v1.8b   \n" // d8
275                         "ssubl    v7.8h, v3.8b, v1.8b   \n" // d10
276                         // transpose w to w_t
277                         "trn1   v8.4h, v4.4h, v5.4h    \n"
278                         "trn2   v9.4h, v4.4h, v5.4h    \n"
279                         "trn1   v10.4h, v6.4h, v7.4h    \n"
280                         "trn2   v11.4h, v6.4h, v7.4h    \n"
281 
282                         "trn1   v0.2s, v8.2s, v10.2s    \n"
283                         "trn2   v2.2s, v8.2s, v10.2s    \n"
284                         "trn1   v1.2s, v9.2s, v11.2s    \n"
285                         "trn2   v3.2s, v9.2s, v11.2s    \n"
286                         // U = B_t * d_t
287                         "sub    v4.4h, v0.4h, v2.4h   \n"
288                         "add    v5.4h, v1.4h, v2.4h   \n"
289                         "sub    v6.4h, v2.4h, v1.4h   \n"
290                         "sub    v7.4h, v3.4h, v1.4h   \n"
291                         // save
292                         "st1    {v4.4h}, [%4]   \n"
293                         "st1    {v5.4h}, [%5]   \n"
294                         "st1    {v6.4h}, [%6]   \n"
295                         "st1    {v7.4h}, [%7]   \n"
296                         : "=r"(r0),      // %0
297                         "=r"(r1),      // %1
298                         "=r"(r2),      // %2
299                         "=r"(r3),      // %3
300                         "=r"(out_tm0), // %4
301                         "=r"(out_tm1), // %5
302                         "=r"(out_tm2), // %6
303                         "=r"(out_tm3)  // %7
304                         : "0"(r0),
305                         "1"(r1),
306                         "2"(r2),
307                         "3"(r3),
308                         "4"(out_tm0),
309                         "5"(out_tm1),
310                         "6"(out_tm2),
311                         "7"(out_tm3)
312                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
313 #else
314                     asm volatile(
315                         // load
316                         "pld         [%0, #64]     \n"
317                         "vld1.s8     {d0}, [%0]    \n"
318                         "pld         [%1, #64]     \n"
319                         "vld1.s8     {d1}, [%1]    \n"
320                         "pld         [%2, #64]     \n"
321                         "vld1.s8     {d2}, [%2]    \n"
322                         "pld         [%3, #64]     \n"
323                         "vld1.s8     {d3}, [%3]    \n"
324                         // w = B_t * d, trans int8 to int16
325                         "vsubl.s8    q2, d0, d2    \n" // d4
326                         "vaddl.s8    q3, d1, d2    \n" // d6
327                         "vsubl.s8    q4, d2, d1    \n" // d8
328                         "vsubl.s8    q5, d3, d1    \n" // d10
329                         // transpose w to w_t
330                         "vtrn.s16    d4, d6        \n"
331                         "vtrn.s16    d8, d10       \n"
332                         "vtrn.s32    d4, d8        \n"
333                         "vtrn.s32    d6, d10       \n"
334                         // U = B_t * d_t
335                         "vsub.s16    d11, d4, d8   \n"
336                         "vadd.s16    d12, d6, d8   \n"
337                         "vsub.s16    d13, d8, d6   \n"
338                         "vsub.s16    d14, d10, d6  \n"
339                         // save
340                         "vst1.s32    {d11}, [%4]   \n"
341                         "vst1.s32    {d12}, [%5]   \n"
342                         "vst1.s32    {d13}, [%6]   \n"
343                         "vst1.s32    {d14}, [%7]   \n"
344                         : "=r"(r0),      // %0
345                         "=r"(r1),      // %1
346                         "=r"(r2),      // %2
347                         "=r"(r3),      // %3
348                         "=r"(out_tm0), // %4
349                         "=r"(out_tm1), // %5
350                         "=r"(out_tm2), // %6
351                         "=r"(out_tm3)  // %7
352                         : "0"(r0),
353                         "1"(r1),
354                         "2"(r2),
355                         "3"(r3),
356                         "4"(out_tm0),
357                         "5"(out_tm1),
358                         "6"(out_tm2),
359                         "7"(out_tm3)
360                         : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
361 #endif // __aarch64__
362 #else
363                     short d0[4], d1[4], d2[4], d3[4];
364                     short w0[4], w1[4], w2[4], w3[4];
365                     short t0[4], t1[4], t2[4], t3[4];
366                     // load
367                     for (int n = 0; n < 4; n++)
368                     {
369                         d0[n] = r0[n];
370                         d1[n] = r1[n];
371                         d2[n] = r2[n];
372                         d3[n] = r3[n];
373                     }
374                     // w = B_t * d
375                     for (int n = 0; n < 4; n++)
376                     {
377                         w0[n] = d0[n] - d2[n];
378                         w1[n] = d1[n] + d2[n];
379                         w2[n] = d2[n] - d1[n];
380                         w3[n] = d3[n] - d1[n];
381                     }
382                     // transpose d to d_t
383                     {
384                         t0[0] = w0[0];
385                         t1[0] = w0[1];
386                         t2[0] = w0[2];
387                         t3[0] = w0[3];
388                         t0[1] = w1[0];
389                         t1[1] = w1[1];
390                         t2[1] = w1[2];
391                         t3[1] = w1[3];
392                         t0[2] = w2[0];
393                         t1[2] = w2[1];
394                         t2[2] = w2[2];
395                         t3[2] = w2[3];
396                         t0[3] = w3[0];
397                         t1[3] = w3[1];
398                         t2[3] = w3[2];
399                         t3[3] = w3[3];
400                     }
401                     // U = B_t * d_t
402                     for (int n = 0; n < 4; n++)
403                     {
404                         d0[n] = t0[n] - t2[n];
405                         d1[n] = t1[n] + t2[n];
406                         d2[n] = t2[n] - t1[n];
407                         d3[n] = t3[n] - t1[n];
408                     }
409                     // save to out_tm
410                     for (int n = 0; n < 4; n++)
411                     {
412                         out_tm0[n] = d0[n];
413                         out_tm1[n] = d1[n];
414                         out_tm2[n] = d2[n];
415                         out_tm3[n] = d3[n];
416                     }
417 #endif
418                     r0 += 2;
419                     r1 += 2;
420                     r2 += 2;
421                     r3 += 2;
422                 }
423             }
424         }
425     }
426     bottom_blob_bordered = Mat();
427 
428     // BEGIN dot
429     Mat top_blob_tm;
430     {
431         int w_tm = outw / 2 * 4;
432         int h_tm = outh / 2 * 4;
433 
434         int nColBlocks = h_tm / 4; // may be the block num in FeatherCNN
435         int nRowBlocks = w_tm / 4;
436 
437         const int tiles = nColBlocks * nRowBlocks;
438 
439         top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator);
440 
441         #pragma omp parallel for num_threads(opt.num_threads)
442         for (int r = 0; r < 4; r++)
443         {
444             int nn_outch = 0;
445             int remain_outch_start = 0;
446 
447             nn_outch = outch >> 3;
448             remain_outch_start = nn_outch << 3;
449 
450             for (int pp = 0; pp < nn_outch; pp++)
451             {
452                 int p = pp * 8;
453 
454                 int* output0_tm = top_blob_tm.channel(p);
455                 int* output1_tm = top_blob_tm.channel(p + 1);
456                 int* output2_tm = top_blob_tm.channel(p + 2);
457                 int* output3_tm = top_blob_tm.channel(p + 3);
458                 int* output4_tm = top_blob_tm.channel(p + 4);
459                 int* output5_tm = top_blob_tm.channel(p + 5);
460                 int* output6_tm = top_blob_tm.channel(p + 6);
461                 int* output7_tm = top_blob_tm.channel(p + 7);
462 
463                 output0_tm = output0_tm + r * 4;
464                 output1_tm = output1_tm + r * 4;
465                 output2_tm = output2_tm + r * 4;
466                 output3_tm = output3_tm + r * 4;
467                 output4_tm = output4_tm + r * 4;
468                 output5_tm = output5_tm + r * 4;
469                 output6_tm = output6_tm + r * 4;
470                 output7_tm = output7_tm + r * 4;
471 
472                 for (int i = 0; i < tiles; i++)
473                 {
474                     const short* kptr = kernel_tm_test[r].channel(p / 8);
475                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
476 #if __ARM_NEON
477 #if __aarch64__
478                     asm volatile(
479                         // inch loop
480                         "eor    v0.16b, v0.16b, v0.16b    \n"
481                         "eor    v1.16b, v1.16b, v1.16b    \n"
482                         "eor    v2.16b, v2.16b, v2.16b    \n"
483                         "eor    v3.16b, v3.16b, v3.16b    \n"
484                         "eor    v4.16b, v4.16b, v4.16b    \n"
485                         "eor    v5.16b, v5.16b, v5.16b    \n"
486                         "eor    v6.16b, v6.16b, v6.16b    \n"
487                         "eor    v7.16b, v7.16b, v7.16b    \n"
488                         "mov    w4, %w20                  \n"
489 
490                         "0:                               \n" // for (int q=0; q<inch; q++)
491                         "prfm    pldl1keep, [%9, #128]    \n" // _r0 = vld1_s16(r0);  // input inch0
492                         "ld1     {v8.4h}, [%8]            \n"
493                         "ld1     {v9.4h, v10.4h}, [%9]    \n" // _k0 = vld1q_s16(kptr);
494                         "add     %9, %9, #16              \n"
495                         "ld1     {v11.4h, v12.4h}, [%9]   \n" // _k0n = vld1q_s16(kptr+8);
496                         "add     %9, %9, #16              \n"
497                         "ld1     {v13.4h, v14.4h}, [%9]   \n" // _k1 = vld1q_s16(kptr+16);
498                         "add     %9, %9, #16              \n"
499                         "ld1     {v15.4h, v16.4h}, [%9]   \n" // _k1n = vld1q_s16(kptr+24);
500                         "add     %8, %8, #8               \n"
501                         "add     %9, %9, #16              \n"
502 
503                         "subs    w4, w4, #1               \n"
504 
505                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
506                         "smlal   v1.4s, v8.4h, v10.4h     \n" // sum1 += (a00-a03) * (k10-k13)
507                         "smlal   v2.4s, v8.4h, v11.4h     \n" // sum2 += (a00-a03) * (k20-k23)
508                         "smlal   v3.4s, v8.4h, v12.4h     \n" // sum3 += (a00-a03) * (k30-k33)
509                         "smlal   v4.4s, v8.4h, v13.4h     \n" // sum4 += (a00-a03) * (k40-k43)
510                         "smlal   v5.4s, v8.4h, v14.4h     \n" // sum5 += (a00-a03) * (k50-k53)
511                         "smlal   v6.4s, v8.4h, v15.4h     \n" // sum6 += (a00-a03) * (k60-k63)
512                         "smlal   v7.4s, v8.4h, v16.4h     \n" // sum7 += (a00-a03) * (k70-k73)
513 
514                         "bne     0b                       \n" // end for
515 
516                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
517                         "st1     {v1.4s}, [%1]            \n" //
518                         "st1     {v2.4s}, [%2]            \n" //
519                         "st1     {v3.4s}, [%3]            \n" //
520                         "st1     {v4.4s}, [%4]            \n" //
521                         "st1     {v5.4s}, [%5]            \n" //
522                         "st1     {v6.4s}, [%6]            \n" //
523                         "st1     {v7.4s}, [%7]            \n" //
524 
525                         : "=r"(output0_tm), // %0
526                         "=r"(output1_tm), // %1
527                         "=r"(output2_tm), // %2
528                         "=r"(output3_tm), // %3
529                         "=r"(output4_tm), // %4
530                         "=r"(output5_tm), // %5
531                         "=r"(output6_tm), // %6
532                         "=r"(output7_tm), // %7
533                         "=r"(r0),         // %8
534                         "=r"(kptr)        // %9
535                         : "0"(output0_tm),
536                         "1"(output1_tm),
537                         "2"(output2_tm),
538                         "3"(output3_tm),
539                         "4"(output4_tm),
540                         "5"(output5_tm),
541                         "6"(output6_tm),
542                         "7"(output7_tm),
543                         "8"(r0),
544                         "9"(kptr),
545                         "r"(inch) // %20
546                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16");
547 #else
548                     asm volatile(
549                         // inch loop
550                         "vmov.s32    q0, #0           \n"
551                         "vmov.s32    q1, #0           \n"
552                         "vmov.s32    q2, #0           \n"
553                         "vmov.s32    q3, #0           \n"
554                         "vmov.s32    q4, #0           \n"
555                         "vmov.s32    q5, #0           \n"
556                         "vmov.s32    q6, #0           \n"
557                         "vmov.s32    q7, #0           \n"
558                         "mov         r4, %20          \n"
559 
560                         "0:                           \n" // for (int q=0; q<inch; q++)
561                         "vld1.s16    {d16}, [%8]!     \n" // _r0 = vld1_s16(r0);  // input inch0
562                         "vld1.s16    {d18-d19}, [%9]  \n" // _k0 = vld1q_s16(kptr);
563                         "add         %9, #16          \n"
564                         "vld1.s16    {d20-d21}, [%9]  \n" // _k0n = vld1q_s16(kptr+8);
565                         "add         %9, #16          \n"
566                         "vld1.s16    {d22-d23}, [%9]  \n" // _k1 = vld1q_s16(kptr+16);
567                         "add         %9, #16          \n"
568                         "vld1.s16    {d24-d25}, [%9]  \n" // _k1n = vld1q_s16(kptr+24);
569                         "add         %9, #16          \n"
570 
571                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
572                         "vmlal.s16   q1, d16, d19     \n" // sum1 += (a00-a03) * (k10-k13)
573                         "vmlal.s16   q2, d16, d20     \n" // sum2 += (a00-a03) * (k20-k23)
574                         "vmlal.s16   q3, d16, d21     \n" // sum3 += (a00-a03) * (k30-k33)
575                         "vmlal.s16   q4, d16, d22     \n" // sum4 += (a00-a03) * (k40-k43)
576                         "vmlal.s16   q5, d16, d23     \n" // sum5 += (a00-a03) * (k50-k53)
577                         "vmlal.s16   q6, d16, d24     \n" // sum6 += (a00-a03) * (k60-k63)
578                         "vmlal.s16   q7, d16, d25     \n" // sum7 += (a00-a03) * (k70-k73)
579 
580                         "subs        r4, r4, #1       \n"
581                         "bne         0b               \n" // end for
582 
583                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
584                         "vst1.s32    {d2-d3}, [%1]    \n"
585                         "vst1.s32    {d4-d5}, [%2]    \n"
586                         "vst1.s32    {d6-d7}, [%3]    \n"
587                         "vst1.s32    {d8-d9}, [%4]    \n"
588                         "vst1.s32    {d10-d11}, [%5]  \n"
589                         "vst1.s32    {d12-d13}, [%6]  \n"
590                         "vst1.s32    {d14-d15}, [%7]  \n"
591 
592                         : "=r"(output0_tm), // %0
593                         "=r"(output1_tm), // %1
594                         "=r"(output2_tm), // %2
595                         "=r"(output3_tm), // %3
596                         "=r"(output4_tm), // %4
597                         "=r"(output5_tm), // %5
598                         "=r"(output6_tm), // %6
599                         "=r"(output7_tm), // %7
600                         "=r"(r0),         // %8
601                         "=r"(kptr)        // %9
602                         : "0"(output0_tm),
603                         "1"(output1_tm),
604                         "2"(output2_tm),
605                         "3"(output3_tm),
606                         "4"(output4_tm),
607                         "5"(output5_tm),
608                         "6"(output6_tm),
609                         "7"(output7_tm),
610                         "8"(r0),
611                         "9"(kptr),
612                         "r"(inch) // %20
613                         : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
614 #endif // __aarch64__
615 #else
616                     int sum0[4] = {0};
617                     int sum1[4] = {0};
618                     int sum2[4] = {0};
619                     int sum3[4] = {0};
620                     int sum4[4] = {0};
621                     int sum5[4] = {0};
622                     int sum6[4] = {0};
623                     int sum7[4] = {0};
624 
625                     for (int q = 0; q < inch; q++)
626                     {
627                         for (int n = 0; n < 4; n++)
628                         {
629                             sum0[n] += (int)r0[n] * kptr[n];
630                             sum1[n] += (int)r0[n] * kptr[n + 4];
631                             sum2[n] += (int)r0[n] * kptr[n + 8];
632                             sum3[n] += (int)r0[n] * kptr[n + 12];
633                             sum4[n] += (int)r0[n] * kptr[n + 16];
634                             sum5[n] += (int)r0[n] * kptr[n + 20];
635                             sum6[n] += (int)r0[n] * kptr[n + 24];
636                             sum7[n] += (int)r0[n] * kptr[n + 28];
637                         }
638                         kptr += 32;
639                         r0 += 4;
640                     }
641 
642                     for (int n = 0; n < 4; n++)
643                     {
644                         output0_tm[n] = sum0[n];
645                         output1_tm[n] = sum1[n];
646                         output2_tm[n] = sum2[n];
647                         output3_tm[n] = sum3[n];
648                         output4_tm[n] = sum4[n];
649                         output5_tm[n] = sum5[n];
650                         output6_tm[n] = sum6[n];
651                         output7_tm[n] = sum7[n];
652                     }
653 #endif // __ARM_NEON
654                     output0_tm += 16;
655                     output1_tm += 16;
656                     output2_tm += 16;
657                     output3_tm += 16;
658                     output4_tm += 16;
659                     output5_tm += 16;
660                     output6_tm += 16;
661                     output7_tm += 16;
662                 }
663             }
664 
665             nn_outch = (outch - remain_outch_start) >> 2;
666 
667             for (int pp = 0; pp < nn_outch; pp++)
668             {
669                 int p = remain_outch_start + pp * 4;
670 
671                 int* output0_tm = top_blob_tm.channel(p);
672                 int* output1_tm = top_blob_tm.channel(p + 1);
673                 int* output2_tm = top_blob_tm.channel(p + 2);
674                 int* output3_tm = top_blob_tm.channel(p + 3);
675 
676                 output0_tm = output0_tm + r * 4;
677                 output1_tm = output1_tm + r * 4;
678                 output2_tm = output2_tm + r * 4;
679                 output3_tm = output3_tm + r * 4;
680 
681                 for (int i = 0; i < tiles; i++)
682                 {
683                     const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
684                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
685 #if __ARM_NEON
686 #if __aarch64__
687                     asm volatile(
688                         // inch loop
689                         "eor    v0.16b, v0.16b, v0.16b    \n"
690                         "eor    v1.16b, v1.16b, v1.16b    \n"
691                         "eor    v2.16b, v2.16b, v2.16b    \n"
692                         "eor    v3.16b, v3.16b, v3.16b    \n"
693                         "mov    w4, %w12                  \n"
694 
695                         "0:                               \n" // for (int q=0; q<inch; q++)
696                         "prfm    pldl1keep, [%5, #128]    \n" // _r0 = vld1_s16(r0);  // input inch0
697                         "ld1     {v8.4h}, [%4]            \n"
698                         "ld1     {v9.4h, v10.4h}, [%5]    \n" // _k0 = vld1q_s16(kptr);
699                         "add     %5, %5, #16              \n"
700                         "ld1     {v11.4h, v12.4h}, [%5]   \n" // _k0n = vld1q_s16(kptr+8);
701                         "add     %4, %4, #8               \n"
702                         "add     %5, %5, #16              \n"
703 
704                         "subs    w4, w4, #1               \n"
705 
706                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
707                         "smlal   v1.4s, v8.4h, v10.4h     \n" // sum1 += (a00-a03) * (k10-k13)
708                         "smlal   v2.4s, v8.4h, v11.4h     \n" // sum2 += (a00-a03) * (k20-k23)
709                         "smlal   v3.4s, v8.4h, v12.4h     \n" // sum3 += (a00-a03) * (k30-k33)
710 
711                         "bne     0b                       \n" // end for
712 
713                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
714                         "st1     {v1.4s}, [%1]            \n" //
715                         "st1     {v2.4s}, [%2]            \n" //
716                         "st1     {v3.4s}, [%3]            \n" //
717 
718                         : "=r"(output0_tm), // %0
719                         "=r"(output1_tm), // %1
720                         "=r"(output2_tm), // %2
721                         "=r"(output3_tm), // %3
722                         "=r"(r0),         // %4
723                         "=r"(kptr)        // %5
724                         : "0"(output0_tm),
725                         "1"(output1_tm),
726                         "2"(output2_tm),
727                         "3"(output3_tm),
728                         "4"(r0),
729                         "5"(kptr),
730                         "r"(inch) // %12
731                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
732 #else
733                     asm volatile(
734                         // inch loop
735                         "vmov.s32    q0, #0           \n"
736                         "vmov.s32    q1, #0           \n"
737                         "vmov.s32    q2, #0           \n"
738                         "vmov.s32    q3, #0           \n"
739                         "mov         r4, %12          \n"
740 
741                         "0:                           \n" // for (int q=0; q<inch; q++)
742                         "vld1.s16    {d16}, [%4]!     \n" // _r0 = vld1_s16(r0);  // input inch0
743                         "vld1.s16    {d18-d19}, [%5]  \n" // _k0 = vld1q_s16(kptr);
744                         "add         %5, #16          \n"
745                         "vld1.s16    {d20-d21}, [%5]  \n" // _k0n = vld1q_s16(kptr+8);
746                         "add         %5, #16          \n"
747 
748                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
749                         "vmlal.s16   q1, d16, d19     \n" // sum1 += (a00-a03) * (k10-k13)
750                         "vmlal.s16   q2, d16, d20     \n" // sum2 += (a00-a03) * (k20-k23)
751                         "vmlal.s16   q3, d16, d21     \n" // sum3 += (a00-a03) * (k30-k33)
752 
753                         "subs        r4, r4, #1       \n"
754                         "bne         0b               \n" // end for
755 
756                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
757                         "vst1.s32    {d2-d3}, [%1]    \n"
758                         "vst1.s32    {d4-d5}, [%2]    \n"
759                         "vst1.s32    {d6-d7}, [%3]    \n"
760 
761                         : "=r"(output0_tm), // %0
762                         "=r"(output1_tm), // %1
763                         "=r"(output2_tm), // %2
764                         "=r"(output3_tm), // %3
765                         "=r"(r0),         // %4
766                         "=r"(kptr)        // %5
767                         : "0"(output0_tm),
768                         "1"(output1_tm),
769                         "2"(output2_tm),
770                         "3"(output3_tm),
771                         "4"(r0),
772                         "5"(kptr),
773                         "r"(inch) // %12
774                         : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10");
775 #endif // __aarch64__
776 #else
777                     int sum0[4] = {0};
778                     int sum1[4] = {0};
779                     int sum2[4] = {0};
780                     int sum3[4] = {0};
781 
782                     for (int q = 0; q < inch; q++)
783                     {
784                         for (int n = 0; n < 4; n++)
785                         {
786                             sum0[n] += (int)r0[n] * kptr[n];
787                             sum1[n] += (int)r0[n] * kptr[n + 4];
788                             sum2[n] += (int)r0[n] * kptr[n + 8];
789                             sum3[n] += (int)r0[n] * kptr[n + 12];
790                         }
791                         kptr += 16;
792                         r0 += 4;
793                     }
794 
795                     for (int n = 0; n < 4; n++)
796                     {
797                         output0_tm[n] = sum0[n];
798                         output1_tm[n] = sum1[n];
799                         output2_tm[n] = sum2[n];
800                         output3_tm[n] = sum3[n];
801                     }
802 #endif // __ARM_NEON
803                     output0_tm += 16;
804                     output1_tm += 16;
805                     output2_tm += 16;
806                     output3_tm += 16;
807                 }
808             }
809 
810             remain_outch_start += nn_outch << 2;
811 
812             for (int p = remain_outch_start; p < outch; p++)
813             {
814                 int* output0_tm = top_blob_tm.channel(p);
815 
816                 output0_tm = output0_tm + r * 4;
817 
818                 for (int i = 0; i < tiles; i++)
819                 {
820                     const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
821                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
822 #if __ARM_NEON
823 #if __aarch64__
824                     asm volatile(
825                         // inch loop
826                         "eor    v0.16b, v0.16b, v0.16b    \n"
827                         "mov    w4, %w6                   \n"
828 
829                         "0:                               \n" // for (int q=0; q<inch; q++)
830                         //"prfm    pldl1keep, [%2, #128]    \n" // _r0 = vld1_s16(r0);  // input inch0
831                         "ld1     {v8.4h}, [%1]            \n"
832                         "ld1     {v9.4h}, [%2]            \n" // _k0 = vld1q_s16(kptr);
833                         "add     %1, %1, #8               \n"
834                         "add     %2, %2, #8               \n"
835 
836                         "subs    w4, w4, #1               \n"
837 
838                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
839 
840                         "bne     0b                       \n" // end for
841 
842                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
843 
844                         : "=r"(output0_tm), // %0
845                         "=r"(r0),         // %1
846                         "=r"(kptr)        // %2
847                         : "0"(output0_tm),
848                         "1"(r0),
849                         "2"(kptr),
850                         "r"(inch) // %6
851                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9");
852 #else
853                     asm volatile(
854                         // inch loop
855                         "vmov.s32    q0, #0           \n"
856                         "mov         r4, %6           \n"
857 
858                         "0:                           \n" // for (int q=0; q<inch; q++)
859                         "vld1.s16    {d16}, [%1]      \n" // _r0 = vld1_s16(r0);  // input inch0
860                         "add         %1, #8           \n"
861                         "vld1.s16    {d18}, [%2]      \n" // _k0 = vld1q_s16(kptr);
862                         "add         %2, #8           \n"
863                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
864 
865                         "subs        r4, r4, #1       \n"
866                         "bne         0b               \n" // end for
867 
868                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
869 
870                         : "=r"(output0_tm), // %0
871                         "=r"(r0),         // %1
872                         "=r"(kptr)        // %2
873                         : "0"(output0_tm),
874                         "1"(r0),
875                         "2"(kptr),
876                         "r"(inch) // %6
877                         : "cc", "memory", "r4", "q0", "q8", "q9");
878 #endif // __aarch64__
879 #else
880                     int sum0[4] = {0};
881 
882                     for (int q = 0; q < inch; q++)
883                     {
884                         for (int n = 0; n < 4; n++)
885                         {
886                             sum0[n] += (int)r0[n] * kptr[n];
887                         }
888                         kptr += 4;
889                         r0 += 4;
890                     }
891 
892                     for (int n = 0; n < 4; n++)
893                     {
894                         output0_tm[n] = sum0[n];
895                     }
896 #endif
897                     output0_tm += 16;
898                 }
899             }
900         }
901     }
902     bottom_blob_tm = Mat();
903     // END dot
904 
905     // BEGIN transform output
906     Mat top_blob_bordered;
907     top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
908     {
909         // AT
910         // const float itm[2][4] = {
911         //     {1.0f,  1.0f,  1.0f,  0.0f},
912         //     {0.0f,  1.0f, -1.0f,  1.0f}
913         // };
914 
915         int w_tm = outw / 2 * 4;
916         int h_tm = outh / 2 * 4;
917 
918         int nColBlocks = h_tm / 4; // may be the block num in FeatherCNN
919         int nRowBlocks = w_tm / 4;
920 
921 #if __ARM_NEON
922         int32x2_t _shift = vdup_n_s32(-2);
923 #endif
924 
925         #pragma omp parallel for num_threads(opt.num_threads)
926         for (int p = 0; p < outch; p++)
927         {
928             int* out_tile = top_blob_tm.channel(p);
929             int* outRow0 = top_blob_bordered.channel(p);
930             int* outRow1 = outRow0 + outw;
931 
932             for (int j = 0; j < nColBlocks; j++)
933             {
934                 for (int i = 0; i < nRowBlocks; i++)
935                 {
936 #if __ARM_NEON
937 #if __aarch64__
938                     asm volatile(
939                         "prfm   pldl1keep, [%0, #512]  \n"
940                         "ld1    {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64    \n"
941 
942                         "add    v0.4s, v0.4s, v1.4s    \n" // s0 = s0 + s1 + s2;
943                         "sub    v1.4s, v1.4s, v2.4s    \n"
944                         "add    v0.4s, v0.4s, v2.4s    \n" // s1 = s1 - s2 + s3;
945                         "add    v1.4s, v1.4s, v3.4s    \n"
946 
947                         "trn1   v4.4s, v0.4s, v1.4s    \n"
948                         "trn2   v5.4s, v0.4s, v1.4s    \n"
949 
950                         "dup    v6.2d, v4.d[1]         \n"
951                         "dup    v7.2d, v5.d[1]         \n"
952 
953                         "add    v0.2s, v4.2s, v5.2s    \n" // o0 = d0 + d1 + d2;
954                         "sub    v1.2s, v5.2s, v6.2s    \n"
955                         "add    v0.2s, v0.2s, v6.2s    \n" // o1 = d1 - d2 + d3;
956                         "add    v1.2s, v1.2s, v7.2s    \n"
957 
958                         "sshl    v0.2s, v0.2s, %6.2s   \n" // o0 = o0 >> 2
959                         "sshl    v1.2s, v1.2s, %6.2s   \n" // o1 = o1 >> 2
960 
961                         "st1     {v0.2s}, [%1], #8     \n"
962                         "st1     {v1.2s}, [%2], #8     \n"
963                         : "=r"(out_tile), // %0
964                         "=r"(outRow0),  // %1
965                         "=r"(outRow1)   // %2
966                         : "0"(out_tile),
967                         "1"(outRow0),
968                         "2"(outRow1),
969                         "w"(_shift) // %6
970                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
971 #else
972                     asm volatile(
973                         "pld        [%0, #512]      \n"
974                         "vldm        %0!, {d0-d7}   \n"
975 
976                         "vaddq.s32    q0, q0, q1    \n" // s0 = s0 + s1 + s2;
977                         "vsubq.s32    q1, q1, q2    \n"
978                         "vaddq.s32    q0, q0, q2    \n" // s1 = s1 - s2 + s3;
979                         "vaddq.s32    q1, q1, q3    \n"
980 
981                         "vtrn.s32    q0, q1         \n"
982 
983                         "vadd.s32    d8, d0, d2     \n" // o0 = d0 + d1 + d2;
984                         "vsub.s32    d9, d2, d1     \n"
985                         "vadd.s32    d8, d8, d1     \n" // o1 = d1 - d2 + d3;
986                         "vadd.s32    d9, d9, d3     \n"
987 
988                         "vshl.s32    d8, d8, %P6    \n" // o0 = o0 >> 2
989                         "vshl.s32    d9, d9, %P6    \n" // o1 = o1 >> 2
990 
991                         "vst1.s32    {d8}, [%1]!    \n"
992                         "vst1.s32    {d9}, [%2]!    \n"
993                         : "=r"(out_tile), // %0
994                         "=r"(outRow0),  // %1
995                         "=r"(outRow1)   // %2
996                         : "0"(out_tile),
997                         "1"(outRow0),
998                         "2"(outRow1),
999                         "w"(_shift) // %6
1000                         : "cc", "memory", "q0", "q1", "q2", "q3", "q4");
1001 #endif // __aarch64__
1002 #else
1003                     int s0[4], s1[4], s2[4], s3[4];
1004                     int w0[4], w1[4];
1005                     int d0[2], d1[2], d2[2], d3[2];
1006                     int o0[2], o1[2];
1007                     // load
1008                     for (int n = 0; n < 4; n++)
1009                     {
1010                         s0[n] = out_tile[n];
1011                         s1[n] = out_tile[n + 4];
1012                         s2[n] = out_tile[n + 8];
1013                         s3[n] = out_tile[n + 12];
1014                     }
1015                     // w = A_T * W
1016                     for (int n = 0; n < 4; n++)
1017                     {
1018                         w0[n] = s0[n] + s1[n] + s2[n];
1019                         w1[n] = s1[n] - s2[n] + s3[n];
1020                     }
1021                     // transpose w to w_t
1022                     {
1023                         d0[0] = w0[0];
1024                         d0[1] = w1[0];
1025                         d1[0] = w0[1];
1026                         d1[1] = w1[1];
1027                         d2[0] = w0[2];
1028                         d2[1] = w1[2];
1029                         d3[0] = w0[3];
1030                         d3[1] = w1[3];
1031                     }
1032                     // Y = A_T * w_t
1033                     for (int n = 0; n < 2; n++)
1034                     {
1035                         o0[n] = d0[n] + d1[n] + d2[n];
1036                         o1[n] = d1[n] - d2[n] + d3[n];
1037                     }
1038                     // save to top blob tm,why right 2,because the G' = G*2
1039                     outRow0[0] = o0[0] >> 2;
1040                     outRow0[1] = o0[1] >> 2;
1041                     outRow1[0] = o1[0] >> 2;
1042                     outRow1[1] = o1[1] >> 2;
1043 
1044                     out_tile += 16;
1045 
1046                     outRow0 += 2;
1047                     outRow1 += 2;
1048 #endif // __ARM_NEON
1049                 }
1050 
1051                 outRow0 += outw;
1052                 outRow1 += outw;
1053             }
1054         }
1055     }
1056     // END transform output
1057 
1058     // cut result pad
1059     copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1060 }
1061 
conv3x3s1_winograd43_transform_kernel_int8_neon(const Mat & kernel,std::vector<Mat> & kernel_tm2,int inch,int outch)1062 static void conv3x3s1_winograd43_transform_kernel_int8_neon(const Mat& kernel, std::vector<Mat>& kernel_tm2, int inch, int outch)
1063 {
1064     Mat kernel_tm(6 * 6, inch, outch, 2ul);
1065 
1066     // G
1067     // const float ktm[6][3] = {
1068     //     {  1.0f/4,     0.0f,    0.0f},
1069     //     { -1.0f/6,  -1.0f/6, -1.0f/6},
1070     //     { -1.0f/6,   1.0f/6, -1.0f/6},
1071     //     { 1.0f/24,  1.0f/12,  1.0f/6},
1072     //     { 1.0f/24, -1.0f/12,  1.0f/6},
1073     //     {    0.0f,     0.0f,    1.0f}
1074     // };
1075     const short ktm[6][3] = {
1076         {6, 0, 0},
1077         {-4, -4, -4},
1078         {-4, 4, -4},
1079         {1, 2, 4},
1080         {1, -2, 4},
1081         {0, 0, 6}
1082     };
1083 
1084     #pragma omp parallel for
1085     for (int p = 0; p < outch; p++)
1086     {
1087         for (int q = 0; q < inch; q++)
1088         {
1089             const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9;
1090             short* kernel_tm0 = kernel_tm.channel(p).row<short>(q);
1091 
1092             // transform kernel
1093             const signed char* k0 = kernel0;
1094             const signed char* k1 = kernel0 + 3;
1095             const signed char* k2 = kernel0 + 6;
1096 
1097             // h
1098             short tmp[6][3];
1099             for (int i = 0; i < 6; i++)
1100             {
1101                 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
1102                 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
1103                 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
1104             }
1105 
1106             // U
1107             for (int j = 0; j < 6; j++)
1108             {
1109                 short* tmpp = &tmp[j][0];
1110 
1111                 for (int i = 0; i < 6; i++)
1112                 {
1113                     kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
1114                 }
1115             }
1116         }
1117     }
1118 
1119     for (int r = 0; r < 9; r++)
1120     {
1121         Mat kernel_tm_test(4 * 8, inch, outch / 8 + (outch % 8) / 4 + outch % 4, 2u);
1122 
1123         int p = 0;
1124         for (; p + 7 < outch; p += 8)
1125         {
1126             const short* kernel0 = (const short*)kernel_tm.channel(p);
1127             const short* kernel1 = (const short*)kernel_tm.channel(p + 1);
1128             const short* kernel2 = (const short*)kernel_tm.channel(p + 2);
1129             const short* kernel3 = (const short*)kernel_tm.channel(p + 3);
1130             const short* kernel4 = (const short*)kernel_tm.channel(p + 4);
1131             const short* kernel5 = (const short*)kernel_tm.channel(p + 5);
1132             const short* kernel6 = (const short*)kernel_tm.channel(p + 6);
1133             const short* kernel7 = (const short*)kernel_tm.channel(p + 7);
1134 
1135             short* ktmp = kernel_tm_test.channel(p / 8);
1136 
1137             for (int q = 0; q < inch; q++)
1138             {
1139                 ktmp[0] = kernel0[r * 4 + 0];
1140                 ktmp[1] = kernel0[r * 4 + 1];
1141                 ktmp[2] = kernel0[r * 4 + 2];
1142                 ktmp[3] = kernel0[r * 4 + 3];
1143 
1144                 ktmp[4] = kernel1[r * 4 + 0];
1145                 ktmp[5] = kernel1[r * 4 + 1];
1146                 ktmp[6] = kernel1[r * 4 + 2];
1147                 ktmp[7] = kernel1[r * 4 + 3];
1148 
1149                 ktmp[8] = kernel2[r * 4 + 0];
1150                 ktmp[9] = kernel2[r * 4 + 1];
1151                 ktmp[10] = kernel2[r * 4 + 2];
1152                 ktmp[11] = kernel2[r * 4 + 3];
1153 
1154                 ktmp[12] = kernel3[r * 4 + 0];
1155                 ktmp[13] = kernel3[r * 4 + 1];
1156                 ktmp[14] = kernel3[r * 4 + 2];
1157                 ktmp[15] = kernel3[r * 4 + 3];
1158 
1159                 ktmp[16] = kernel4[r * 4 + 0];
1160                 ktmp[17] = kernel4[r * 4 + 1];
1161                 ktmp[18] = kernel4[r * 4 + 2];
1162                 ktmp[19] = kernel4[r * 4 + 3];
1163 
1164                 ktmp[20] = kernel5[r * 4 + 0];
1165                 ktmp[21] = kernel5[r * 4 + 1];
1166                 ktmp[22] = kernel5[r * 4 + 2];
1167                 ktmp[23] = kernel5[r * 4 + 3];
1168 
1169                 ktmp[24] = kernel6[r * 4 + 0];
1170                 ktmp[25] = kernel6[r * 4 + 1];
1171                 ktmp[26] = kernel6[r * 4 + 2];
1172                 ktmp[27] = kernel6[r * 4 + 3];
1173 
1174                 ktmp[28] = kernel7[r * 4 + 0];
1175                 ktmp[29] = kernel7[r * 4 + 1];
1176                 ktmp[30] = kernel7[r * 4 + 2];
1177                 ktmp[31] = kernel7[r * 4 + 3];
1178 
1179                 ktmp += 32;
1180                 kernel0 += 36;
1181                 kernel1 += 36;
1182                 kernel2 += 36;
1183                 kernel3 += 36;
1184                 kernel4 += 36;
1185                 kernel5 += 36;
1186                 kernel6 += 36;
1187                 kernel7 += 36;
1188             }
1189         }
1190 
1191         for (; p + 3 < outch; p += 4)
1192         {
1193             const short* kernel0 = (const short*)kernel_tm.channel(p);
1194             const short* kernel1 = (const short*)kernel_tm.channel(p + 1);
1195             const short* kernel2 = (const short*)kernel_tm.channel(p + 2);
1196             const short* kernel3 = (const short*)kernel_tm.channel(p + 3);
1197 
1198             short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4);
1199 
1200             for (int q = 0; q < inch; q++)
1201             {
1202                 ktmp[0] = kernel0[r * 4 + 0];
1203                 ktmp[1] = kernel0[r * 4 + 1];
1204                 ktmp[2] = kernel0[r * 4 + 2];
1205                 ktmp[3] = kernel0[r * 4 + 3];
1206 
1207                 ktmp[4] = kernel1[r * 4 + 0];
1208                 ktmp[5] = kernel1[r * 4 + 1];
1209                 ktmp[6] = kernel1[r * 4 + 2];
1210                 ktmp[7] = kernel1[r * 4 + 3];
1211 
1212                 ktmp[8] = kernel2[r * 4 + 0];
1213                 ktmp[9] = kernel2[r * 4 + 1];
1214                 ktmp[10] = kernel2[r * 4 + 2];
1215                 ktmp[11] = kernel2[r * 4 + 3];
1216 
1217                 ktmp[12] = kernel3[r * 4 + 0];
1218                 ktmp[13] = kernel3[r * 4 + 1];
1219                 ktmp[14] = kernel3[r * 4 + 2];
1220                 ktmp[15] = kernel3[r * 4 + 3];
1221 
1222                 ktmp += 16;
1223                 kernel0 += 36;
1224                 kernel1 += 36;
1225                 kernel2 += 36;
1226                 kernel3 += 36;
1227             }
1228         }
1229 
1230         for (; p < outch; p++)
1231         {
1232             const short* kernel0 = (const short*)kernel_tm.channel(p);
1233 
1234             short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4 + p % 4);
1235 
1236             for (int q = 0; q < inch; q++)
1237             {
1238                 ktmp[0] = kernel0[r * 4 + 0];
1239                 ktmp[1] = kernel0[r * 4 + 1];
1240                 ktmp[2] = kernel0[r * 4 + 2];
1241                 ktmp[3] = kernel0[r * 4 + 3];
1242 
1243                 ktmp += 4;
1244                 kernel0 += 36;
1245             }
1246         }
1247         kernel_tm2.push_back(kernel_tm_test);
1248     }
1249 }
1250 
conv3x3s1_winograd43_int8_neon(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Option & opt)1251 static void conv3x3s1_winograd43_int8_neon(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Option& opt)
1252 {
1253     int w = bottom_blob.w;
1254     int h = bottom_blob.h;
1255     int inch = bottom_blob.c;
1256 
1257     int outw = top_blob.w;
1258     int outh = top_blob.h;
1259     int outch = top_blob.c;
1260 
1261     // pad to 4n+2, winograd F(4,3)
1262     Mat bottom_blob_bordered = bottom_blob;
1263 
1264     outw = (outw + 3) / 4 * 4;
1265     outh = (outh + 3) / 4 * 4;
1266 
1267     w = outw + 2;
1268     h = outh + 2;
1269 
1270     Option opt_b = opt;
1271     opt_b.blob_allocator = opt.workspace_allocator;
1272     copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
1273 
1274     // BEGIN transform input
1275     Mat bottom_blob_tm;
1276     {
1277         int w_tm = outw / 4 * 6;
1278         int h_tm = outh / 4 * 6;
1279 
1280         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1281         int nRowBlocks = w_tm / 6;
1282 
1283         const int tiles = nColBlocks * nRowBlocks;
1284 
1285         bottom_blob_tm.create(4, inch, tiles * 9, 2u, opt.workspace_allocator);
1286 
1287         // BT
1288         // const float itm[4][4] = {
1289         //     {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
1290         //     {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
1291         //     {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
1292         //     {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
1293         //     {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
1294         //     {0.0f, 4.0f,  0.0f,-5.0f, 0.0f, 1.0f}
1295         // };
1296 
1297         // 0 =	4 * r00  - 5 * r02	+ r04
1298         // 1 = -4 * (r01 + r02)  + r03 + r04
1299         // 2 =	4 * (r01 - r02)  - r03 + r04
1300         // 3 = -2 * r01 - r02 + 2 * r03 + r04
1301         // 4 =	2 * r01 - r02 - 2 * r03 + r04
1302         // 5 =	4 * r01 - 5 * r03 + r05
1303 
1304         #pragma omp parallel for num_threads(opt.num_threads)
1305         for (int q = 0; q < inch; q++)
1306         {
1307             const signed char* img = bottom_blob_bordered.channel(q);
1308 
1309             for (int j = 0; j < nColBlocks; j++)
1310             {
1311                 const signed char* r0 = img + w * j * 4;
1312                 const signed char* r1 = r0 + w;
1313                 const signed char* r2 = r1 + w;
1314                 const signed char* r3 = r2 + w;
1315                 const signed char* r4 = r3 + w;
1316                 const signed char* r5 = r4 + w;
1317 
1318                 for (int i = 0; i < nRowBlocks; i++)
1319                 {
1320                     short* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row<short>(q);
1321                     short* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row<short>(q);
1322                     short* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row<short>(q);
1323                     short* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row<short>(q);
1324                     short* out_tm4 = bottom_blob_tm.channel(tiles * 4 + j * nRowBlocks + i).row<short>(q);
1325                     short* out_tm5 = bottom_blob_tm.channel(tiles * 5 + j * nRowBlocks + i).row<short>(q);
1326                     short* out_tm6 = bottom_blob_tm.channel(tiles * 6 + j * nRowBlocks + i).row<short>(q);
1327                     short* out_tm7 = bottom_blob_tm.channel(tiles * 7 + j * nRowBlocks + i).row<short>(q);
1328                     short* out_tm8 = bottom_blob_tm.channel(tiles * 8 + j * nRowBlocks + i).row<short>(q);
1329 #if __ARM_NEON
1330                     int8x8_t _d0, _d1, _d2, _d3, _d4, _d5;
1331                     int16x8_t _w0, _w1, _w2, _w3, _w4, _w5;
1332                     int16x8_t _t0, _t1, _t2, _t3, _t4, _t5;
1333                     int16x8_t _n0, _n1, _n2, _n3, _n4, _n5;
1334                     // load
1335                     _d0 = vld1_s8(r0);
1336                     _d1 = vld1_s8(r1);
1337                     _d2 = vld1_s8(r2);
1338                     _d3 = vld1_s8(r3);
1339                     _d4 = vld1_s8(r4);
1340                     _d5 = vld1_s8(r5);
1341 
1342                     int8x8_t _1_n = vdup_n_s8(-1);
1343                     int8x8_t _2_p = vdup_n_s8(2);
1344                     int8x8_t _2_n = vdup_n_s8(-2);
1345                     int8x8_t _4_p = vdup_n_s8(4);
1346                     int8x8_t _4_n = vdup_n_s8(-4);
1347                     int8x8_t _5_n = vdup_n_s8(-5);
1348 
1349                     int16x8_t _1_n_s16 = vdupq_n_s16(-1);
1350                     int16x8_t _2_p_s16 = vdupq_n_s16(2);
1351                     int16x8_t _2_n_s16 = vdupq_n_s16(-2);
1352                     int16x8_t _4_p_s16 = vdupq_n_s16(4);
1353                     int16x8_t _4_n_s16 = vdupq_n_s16(-4);
1354                     int16x8_t _5_n_s16 = vdupq_n_s16(-5);
1355                     // w = B_t * d
1356                     _w0 = vmull_s8(_d0, _4_p);
1357                     _w0 = vmlal_s8(_w0, _d2, _5_n);
1358                     _w0 = vaddw_s8(_w0, _d4);
1359 
1360                     _w1 = vmull_s8(_d1, _4_n);
1361                     _w1 = vmlal_s8(_w1, _d2, _4_n);
1362                     _w1 = vaddw_s8(_w1, _d3);
1363                     _w1 = vaddw_s8(_w1, _d4);
1364 
1365                     _w2 = vmull_s8(_d1, _4_p);
1366                     _w2 = vmlal_s8(_w2, _d2, _4_n);
1367                     _w2 = vmlal_s8(_w2, _d3, _1_n);
1368                     _w2 = vaddw_s8(_w2, _d4);
1369 
1370                     _w3 = vmull_s8(_d1, _2_n);
1371                     _w3 = vmlal_s8(_w3, _d2, _1_n);
1372                     _w3 = vmlal_s8(_w3, _d3, _2_p);
1373                     _w3 = vaddw_s8(_w3, _d4);
1374 
1375                     _w4 = vmull_s8(_d1, _2_p);
1376                     _w4 = vmlal_s8(_w4, _d2, _1_n);
1377                     _w4 = vmlal_s8(_w4, _d3, _2_n);
1378                     _w4 = vaddw_s8(_w4, _d4);
1379 
1380                     _w5 = vmull_s8(_d1, _4_p);
1381                     _w5 = vmlal_s8(_w5, _d3, _5_n);
1382                     _w5 = vaddw_s8(_w5, _d5);
1383                     // transpose d to d_t
1384                     {
1385                         _t0[0] = _w0[0];
1386                         _t1[0] = _w0[1];
1387                         _t2[0] = _w0[2];
1388                         _t3[0] = _w0[3];
1389                         _t4[0] = _w0[4];
1390                         _t5[0] = _w0[5];
1391                         _t0[1] = _w1[0];
1392                         _t1[1] = _w1[1];
1393                         _t2[1] = _w1[2];
1394                         _t3[1] = _w1[3];
1395                         _t4[1] = _w1[4];
1396                         _t5[1] = _w1[5];
1397                         _t0[2] = _w2[0];
1398                         _t1[2] = _w2[1];
1399                         _t2[2] = _w2[2];
1400                         _t3[2] = _w2[3];
1401                         _t4[2] = _w2[4];
1402                         _t5[2] = _w2[5];
1403                         _t0[3] = _w3[0];
1404                         _t1[3] = _w3[1];
1405                         _t2[3] = _w3[2];
1406                         _t3[3] = _w3[3];
1407                         _t4[3] = _w3[4];
1408                         _t5[3] = _w3[5];
1409                         _t0[4] = _w4[0];
1410                         _t1[4] = _w4[1];
1411                         _t2[4] = _w4[2];
1412                         _t3[4] = _w4[3];
1413                         _t4[4] = _w4[4];
1414                         _t5[4] = _w4[5];
1415                         _t0[5] = _w5[0];
1416                         _t1[5] = _w5[1];
1417                         _t2[5] = _w5[2];
1418                         _t3[5] = _w5[3];
1419                         _t4[5] = _w5[4];
1420                         _t5[5] = _w5[5];
1421                     }
1422                     // d = B_t * d_t
1423                     _n0 = vmulq_s16(_t0, _4_p_s16);
1424                     _n0 = vmlaq_s16(_n0, _t2, _5_n_s16);
1425                     _n0 = vaddq_s16(_n0, _t4);
1426 
1427                     _n1 = vmulq_s16(_t1, _4_n_s16);
1428                     _n1 = vmlaq_s16(_n1, _t2, _4_n_s16);
1429                     _n1 = vaddq_s16(_n1, _t3);
1430                     _n1 = vaddq_s16(_n1, _t4);
1431 
1432                     _n2 = vmulq_s16(_t1, _4_p_s16);
1433                     _n2 = vmlaq_s16(_n2, _t2, _4_n_s16);
1434                     _n2 = vmlaq_s16(_n2, _t3, _1_n_s16);
1435                     _n2 = vaddq_s16(_n2, _t4);
1436 
1437                     _n3 = vmulq_s16(_t1, _2_n_s16);
1438                     _n3 = vmlaq_s16(_n3, _t2, _1_n_s16);
1439                     _n3 = vmlaq_s16(_n3, _t3, _2_p_s16);
1440                     _n3 = vaddq_s16(_n3, _t4);
1441 
1442                     _n4 = vmulq_s16(_t1, _2_p_s16);
1443                     _n4 = vmlaq_s16(_n4, _t2, _1_n_s16);
1444                     _n4 = vmlaq_s16(_n4, _t3, _2_n_s16);
1445                     _n4 = vaddq_s16(_n4, _t4);
1446 
1447                     _n5 = vmulq_s16(_t1, _4_p_s16);
1448                     _n5 = vmlaq_s16(_n5, _t3, _5_n_s16);
1449                     _n5 = vaddq_s16(_n5, _t5);
1450                     // save to out_tm
1451                     out_tm0[0] = _n0[0];
1452                     out_tm0[1] = _n0[1];
1453                     out_tm0[2] = _n0[2];
1454                     out_tm0[3] = _n0[3];
1455                     out_tm1[0] = _n0[4];
1456                     out_tm1[1] = _n0[5];
1457                     out_tm1[2] = _n1[0];
1458                     out_tm1[3] = _n1[1];
1459                     out_tm2[0] = _n1[2];
1460                     out_tm2[1] = _n1[3];
1461                     out_tm2[2] = _n1[4];
1462                     out_tm2[3] = _n1[5];
1463 
1464                     out_tm3[0] = _n2[0];
1465                     out_tm3[1] = _n2[1];
1466                     out_tm3[2] = _n2[2];
1467                     out_tm3[3] = _n2[3];
1468                     out_tm4[0] = _n2[4];
1469                     out_tm4[1] = _n2[5];
1470                     out_tm4[2] = _n3[0];
1471                     out_tm4[3] = _n3[1];
1472                     out_tm5[0] = _n3[2];
1473                     out_tm5[1] = _n3[3];
1474                     out_tm5[2] = _n3[4];
1475                     out_tm5[3] = _n3[5];
1476 
1477                     out_tm6[0] = _n4[0];
1478                     out_tm6[1] = _n4[1];
1479                     out_tm6[2] = _n4[2];
1480                     out_tm6[3] = _n4[3];
1481                     out_tm7[0] = _n4[4];
1482                     out_tm7[1] = _n4[5];
1483                     out_tm7[2] = _n5[0];
1484                     out_tm7[3] = _n5[1];
1485                     out_tm8[0] = _n5[2];
1486                     out_tm8[1] = _n5[3];
1487                     out_tm8[2] = _n5[4];
1488                     out_tm8[3] = _n5[5];
1489 #else
1490                     short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6];
1491                     short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6];
1492                     short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6];
1493 
1494                     // load
1495                     for (int n = 0; n < 6; n++)
1496                     {
1497                         d0[n] = r0[n];
1498                         d1[n] = r1[n];
1499                         d2[n] = r2[n];
1500                         d3[n] = r3[n];
1501                         d4[n] = r4[n];
1502                         d5[n] = r5[n];
1503                     }
1504                     // w = B_t * d
1505                     for (int n = 0; n < 6; n++)
1506                     {
1507                         w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n];
1508                         w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n];
1509                         w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n];
1510                         w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n];
1511                         w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n];
1512                         w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n];
1513                     }
1514                     // transpose d to d_t
1515                     {
1516                         t0[0] = w0[0];
1517                         t1[0] = w0[1];
1518                         t2[0] = w0[2];
1519                         t3[0] = w0[3];
1520                         t4[0] = w0[4];
1521                         t5[0] = w0[5];
1522                         t0[1] = w1[0];
1523                         t1[1] = w1[1];
1524                         t2[1] = w1[2];
1525                         t3[1] = w1[3];
1526                         t4[1] = w1[4];
1527                         t5[1] = w1[5];
1528                         t0[2] = w2[0];
1529                         t1[2] = w2[1];
1530                         t2[2] = w2[2];
1531                         t3[2] = w2[3];
1532                         t4[2] = w2[4];
1533                         t5[2] = w2[5];
1534                         t0[3] = w3[0];
1535                         t1[3] = w3[1];
1536                         t2[3] = w3[2];
1537                         t3[3] = w3[3];
1538                         t4[3] = w3[4];
1539                         t5[3] = w3[5];
1540                         t0[4] = w4[0];
1541                         t1[4] = w4[1];
1542                         t2[4] = w4[2];
1543                         t3[4] = w4[3];
1544                         t4[4] = w4[4];
1545                         t5[4] = w4[5];
1546                         t0[5] = w5[0];
1547                         t1[5] = w5[1];
1548                         t2[5] = w5[2];
1549                         t3[5] = w5[3];
1550                         t4[5] = w5[4];
1551                         t5[5] = w5[5];
1552                     }
1553                     // d = B_t * d_t
1554                     for (int n = 0; n < 6; n++)
1555                     {
1556                         d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n];
1557                         d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n];
1558                         d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n];
1559                         d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n];
1560                         d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n];
1561                         d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n];
1562                     }
1563                     // save to out_tm
1564                     {
1565                         out_tm0[0] = d0[0];
1566                         out_tm0[1] = d0[1];
1567                         out_tm0[2] = d0[2];
1568                         out_tm0[3] = d0[3];
1569                         out_tm1[0] = d0[4];
1570                         out_tm1[1] = d0[5];
1571                         out_tm1[2] = d1[0];
1572                         out_tm1[3] = d1[1];
1573                         out_tm2[0] = d1[2];
1574                         out_tm2[1] = d1[3];
1575                         out_tm2[2] = d1[4];
1576                         out_tm2[3] = d1[5];
1577 
1578                         out_tm3[0] = d2[0];
1579                         out_tm3[1] = d2[1];
1580                         out_tm3[2] = d2[2];
1581                         out_tm3[3] = d2[3];
1582                         out_tm4[0] = d2[4];
1583                         out_tm4[1] = d2[5];
1584                         out_tm4[2] = d3[0];
1585                         out_tm4[3] = d3[1];
1586                         out_tm5[0] = d3[2];
1587                         out_tm5[1] = d3[3];
1588                         out_tm5[2] = d3[4];
1589                         out_tm5[3] = d3[5];
1590 
1591                         out_tm6[0] = d4[0];
1592                         out_tm6[1] = d4[1];
1593                         out_tm6[2] = d4[2];
1594                         out_tm6[3] = d4[3];
1595                         out_tm7[0] = d4[4];
1596                         out_tm7[1] = d4[5];
1597                         out_tm7[2] = d5[0];
1598                         out_tm7[3] = d5[1];
1599                         out_tm8[0] = d5[2];
1600                         out_tm8[1] = d5[3];
1601                         out_tm8[2] = d5[4];
1602                         out_tm8[3] = d5[5];
1603                     }
1604 #endif // __ARM_NEON
1605                     r0 += 4;
1606                     r1 += 4;
1607                     r2 += 4;
1608                     r3 += 4;
1609                     r4 += 4;
1610                     r5 += 4;
1611                 }
1612             }
1613         }
1614     }
1615     bottom_blob_bordered = Mat();
1616 
1617     // BEGIN dot
1618     Mat top_blob_tm;
1619     {
1620         int w_tm = outw / 4 * 6;
1621         int h_tm = outh / 4 * 6;
1622 
1623         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1624         int nRowBlocks = w_tm / 6;
1625 
1626         const int tiles = nColBlocks * nRowBlocks;
1627 
1628         top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator);
1629 
1630         #pragma omp parallel for num_threads(opt.num_threads)
1631         for (int r = 0; r < 9; r++)
1632         {
1633             int nn_outch = 0;
1634             int remain_outch_start = 0;
1635 
1636             nn_outch = outch >> 3;
1637             remain_outch_start = nn_outch << 3;
1638 
1639             for (int pp = 0; pp < nn_outch; pp++)
1640             {
1641                 int p = pp * 8;
1642 
1643                 int* output0_tm = top_blob_tm.channel(p);
1644                 int* output1_tm = top_blob_tm.channel(p + 1);
1645                 int* output2_tm = top_blob_tm.channel(p + 2);
1646                 int* output3_tm = top_blob_tm.channel(p + 3);
1647                 int* output4_tm = top_blob_tm.channel(p + 4);
1648                 int* output5_tm = top_blob_tm.channel(p + 5);
1649                 int* output6_tm = top_blob_tm.channel(p + 6);
1650                 int* output7_tm = top_blob_tm.channel(p + 7);
1651 
1652                 output0_tm = output0_tm + r * 4;
1653                 output1_tm = output1_tm + r * 4;
1654                 output2_tm = output2_tm + r * 4;
1655                 output3_tm = output3_tm + r * 4;
1656                 output4_tm = output4_tm + r * 4;
1657                 output5_tm = output5_tm + r * 4;
1658                 output6_tm = output6_tm + r * 4;
1659                 output7_tm = output7_tm + r * 4;
1660 
1661                 for (int i = 0; i < tiles; i++)
1662                 {
1663                     const short* kptr = kernel_tm_test[r].channel(p / 8);
1664                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
1665 #if __ARM_NEON
1666 #if __aarch64__
1667                     asm volatile(
1668                         // inch loop
1669                         "eor    v0.16b, v0.16b, v0.16b    \n"
1670                         "eor    v1.16b, v1.16b, v1.16b    \n"
1671                         "eor    v2.16b, v2.16b, v2.16b    \n"
1672                         "eor    v3.16b, v3.16b, v3.16b    \n"
1673                         "eor    v4.16b, v4.16b, v4.16b    \n"
1674                         "eor    v5.16b, v5.16b, v5.16b    \n"
1675                         "eor    v6.16b, v6.16b, v6.16b    \n"
1676                         "eor    v7.16b, v7.16b, v7.16b    \n"
1677                         "mov    w4, %w20                  \n"
1678 
1679                         "0:                               \n" // for (int q=0; q<inch; q++)
1680                         "prfm    pldl1keep, [%9, #128]    \n" // _r0 = vld1_s16(r0);
1681                         "ld1     {v8.4h}, [%8]            \n"
1682                         "ld1     {v9.4h, v10.4h}, [%9]    \n" // _k01 = vld1q_s16(kptr);
1683                         "add     %9, %9, #16              \n"
1684                         "ld1     {v11.4h, v12.4h}, [%9]   \n" // _k23 = vld1q_s16(kptr+8);
1685                         "add     %9, %9, #16              \n"
1686                         "ld1     {v13.4h, v14.4h}, [%9]   \n" // _k45 = vld1q_s16(kptr+16);
1687                         "add     %9, %9, #16              \n"
1688                         "ld1     {v15.4h, v16.4h}, [%9]   \n" // _k67 = vld1q_s16(kptr+24);
1689                         "add     %8, %8, #8               \n"
1690                         "add     %9, %9, #16              \n"
1691 
1692                         "subs    w4, w4, #1               \n"
1693 
1694                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
1695                         "smlal   v1.4s, v8.4h, v10.4h     \n" // sum1 += (a00-a03) * (k10-k13)
1696                         "smlal   v2.4s, v8.4h, v11.4h     \n" // sum2 += (a00-a03) * (k20-k23)
1697                         "smlal   v3.4s, v8.4h, v12.4h     \n" // sum3 += (a00-a03) * (k30-k33)
1698                         "smlal   v4.4s, v8.4h, v13.4h     \n" // sum4 += (a00-a03) * (k40-k43)
1699                         "smlal   v5.4s, v8.4h, v14.4h     \n" // sum5 += (a00-a03) * (k50-k53)
1700                         "smlal   v6.4s, v8.4h, v15.4h     \n" // sum6 += (a00-a03) * (k60-k63)
1701                         "smlal   v7.4s, v8.4h, v16.4h     \n" // sum7 += (a00-a03) * (k70-k73)
1702 
1703                         "bne     0b                       \n" // end for
1704 
1705                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
1706                         "st1     {v1.4s}, [%1]            \n" //
1707                         "st1     {v2.4s}, [%2]            \n" //
1708                         "st1     {v3.4s}, [%3]            \n" //
1709                         "st1     {v4.4s}, [%4]            \n" //
1710                         "st1     {v5.4s}, [%5]            \n" //
1711                         "st1     {v6.4s}, [%6]            \n" //
1712                         "st1     {v7.4s}, [%7]            \n" //
1713 
1714                         : "=r"(output0_tm), // %0
1715                         "=r"(output1_tm), // %1
1716                         "=r"(output2_tm), // %2
1717                         "=r"(output3_tm), // %3
1718                         "=r"(output4_tm), // %4
1719                         "=r"(output5_tm), // %5
1720                         "=r"(output6_tm), // %6
1721                         "=r"(output7_tm), // %7
1722                         "=r"(r0),         // %8
1723                         "=r"(kptr)        // %9
1724                         : "0"(output0_tm),
1725                         "1"(output1_tm),
1726                         "2"(output2_tm),
1727                         "3"(output3_tm),
1728                         "4"(output4_tm),
1729                         "5"(output5_tm),
1730                         "6"(output6_tm),
1731                         "7"(output7_tm),
1732                         "8"(r0),
1733                         "9"(kptr),
1734                         "r"(inch) // %20
1735                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16");
1736 #else
1737                     asm volatile(
1738                         // inch loop
1739                         "vmov.s32    q0, #0           \n"
1740                         "vmov.s32    q1, #0           \n"
1741                         "vmov.s32    q2, #0           \n"
1742                         "vmov.s32    q3, #0           \n"
1743                         "vmov.s32    q4, #0           \n"
1744                         "vmov.s32    q5, #0           \n"
1745                         "vmov.s32    q6, #0           \n"
1746                         "vmov.s32    q7, #0           \n"
1747                         "mov         r4, %20          \n"
1748 
1749                         "0:                           \n" // for (int q=0; q<inch; q++)
1750                         "vld1.s16    {d16}, [%8]!     \n" // _r0 = vld1_s16(r0);  // input inch0
1751                         "vld1.s16    {d18-d19}, [%9]  \n" // _k01 = vld1q_s16(kptr);
1752                         "add         %9, #16          \n"
1753                         "vld1.s16    {d20-d21}, [%9]  \n" // _k23 = vld1q_s16(kptr+8);
1754                         "add         %9, #16          \n"
1755                         "vld1.s16    {d22-d23}, [%9]  \n" // _k45 = vld1q_s16(kptr+16);
1756                         "add         %9, #16          \n"
1757                         "vld1.s16    {d24-d25}, [%9]  \n" // _k67 = vld1q_s16(kptr+24);
1758                         "add         %9, #16          \n"
1759 
1760                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
1761                         "vmlal.s16   q1, d16, d19     \n" // sum1 += (a00-a03) * (k10-k13)
1762                         "vmlal.s16   q2, d16, d20     \n" // sum2 += (a00-a03) * (k20-k23)
1763                         "vmlal.s16   q3, d16, d21     \n" // sum3 += (a00-a03) * (k30-k33)
1764                         "vmlal.s16   q4, d16, d22     \n" // sum4 += (a00-a03) * (k40-k43)
1765                         "vmlal.s16   q5, d16, d23     \n" // sum5 += (a00-a03) * (k50-k53)
1766                         "vmlal.s16   q6, d16, d24     \n" // sum6 += (a00-a03) * (k60-k63)
1767                         "vmlal.s16   q7, d16, d25     \n" // sum7 += (a00-a03) * (k70-k73)
1768 
1769                         "subs        r4, r4, #1       \n"
1770                         "bne         0b               \n" // end for
1771 
1772                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
1773                         "vst1.s32    {d2-d3}, [%1]    \n"
1774                         "vst1.s32    {d4-d5}, [%2]    \n"
1775                         "vst1.s32    {d6-d7}, [%3]    \n"
1776                         "vst1.s32    {d8-d9}, [%4]    \n"
1777                         "vst1.s32    {d10-d11}, [%5]  \n"
1778                         "vst1.s32    {d12-d13}, [%6]  \n"
1779                         "vst1.s32    {d14-d15}, [%7]  \n"
1780 
1781                         : "=r"(output0_tm), // %0
1782                         "=r"(output1_tm), // %1
1783                         "=r"(output2_tm), // %2
1784                         "=r"(output3_tm), // %3
1785                         "=r"(output4_tm), // %4
1786                         "=r"(output5_tm), // %5
1787                         "=r"(output6_tm), // %6
1788                         "=r"(output7_tm), // %7
1789                         "=r"(r0),         // %8
1790                         "=r"(kptr)        // %9
1791                         : "0"(output0_tm),
1792                         "1"(output1_tm),
1793                         "2"(output2_tm),
1794                         "3"(output3_tm),
1795                         "4"(output4_tm),
1796                         "5"(output5_tm),
1797                         "6"(output6_tm),
1798                         "7"(output7_tm),
1799                         "8"(r0),
1800                         "9"(kptr),
1801                         "r"(inch) // %20
1802                         : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
1803 #endif // __aarch64__
1804 #else
1805                     int sum0[4] = {0};
1806                     int sum1[4] = {0};
1807                     int sum2[4] = {0};
1808                     int sum3[4] = {0};
1809                     int sum4[4] = {0};
1810                     int sum5[4] = {0};
1811                     int sum6[4] = {0};
1812                     int sum7[4] = {0};
1813 
1814                     for (int q = 0; q < inch; q++)
1815                     {
1816                         for (int n = 0; n < 4; n++)
1817                         {
1818                             sum0[n] += (int)r0[n] * kptr[n];
1819                             sum1[n] += (int)r0[n] * kptr[n + 4];
1820                             sum2[n] += (int)r0[n] * kptr[n + 8];
1821                             sum3[n] += (int)r0[n] * kptr[n + 12];
1822                             sum4[n] += (int)r0[n] * kptr[n + 16];
1823                             sum5[n] += (int)r0[n] * kptr[n + 20];
1824                             sum6[n] += (int)r0[n] * kptr[n + 24];
1825                             sum7[n] += (int)r0[n] * kptr[n + 28];
1826                         }
1827                         kptr += 32;
1828                         r0 += 4;
1829                     }
1830 
1831                     for (int n = 0; n < 4; n++)
1832                     {
1833                         output0_tm[n] = sum0[n];
1834                         output1_tm[n] = sum1[n];
1835                         output2_tm[n] = sum2[n];
1836                         output3_tm[n] = sum3[n];
1837                         output4_tm[n] = sum4[n];
1838                         output5_tm[n] = sum5[n];
1839                         output6_tm[n] = sum6[n];
1840                         output7_tm[n] = sum7[n];
1841                     }
1842 #endif // __ARM_NEON
1843                     output0_tm += 36;
1844                     output1_tm += 36;
1845                     output2_tm += 36;
1846                     output3_tm += 36;
1847                     output4_tm += 36;
1848                     output5_tm += 36;
1849                     output6_tm += 36;
1850                     output7_tm += 36;
1851                 }
1852             }
1853 
1854             nn_outch = (outch - remain_outch_start) >> 2;
1855 
1856             for (int pp = 0; pp < nn_outch; pp++)
1857             {
1858                 int p = remain_outch_start + pp * 4;
1859 
1860                 int* output0_tm = top_blob_tm.channel(p);
1861                 int* output1_tm = top_blob_tm.channel(p + 1);
1862                 int* output2_tm = top_blob_tm.channel(p + 2);
1863                 int* output3_tm = top_blob_tm.channel(p + 3);
1864 
1865                 output0_tm = output0_tm + r * 4;
1866                 output1_tm = output1_tm + r * 4;
1867                 output2_tm = output2_tm + r * 4;
1868                 output3_tm = output3_tm + r * 4;
1869 
1870                 for (int i = 0; i < tiles; i++)
1871                 {
1872                     const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
1873                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
1874 #if __ARM_NEON
1875 #if __aarch64__
1876                     asm volatile(
1877                         // inch loop
1878                         "eor    v0.16b, v0.16b, v0.16b    \n"
1879                         "eor    v1.16b, v1.16b, v1.16b    \n"
1880                         "eor    v2.16b, v2.16b, v2.16b    \n"
1881                         "eor    v3.16b, v3.16b, v3.16b    \n"
1882                         "mov    w4, %w12                  \n"
1883 
1884                         "0:                               \n" // for (int q=0; q<inch; q++)
1885                         "prfm    pldl1keep, [%5, #128]    \n" // _r0 = vld1_s16(r0);  // input inch0
1886                         "ld1     {v8.4h}, [%4]            \n"
1887                         "ld1     {v9.4h, v10.4h}, [%5]    \n" // _k01 = vld1q_s16(kptr);
1888                         "add     %5, %5, #16              \n"
1889                         "ld1     {v11.4h, v12.4h}, [%5]   \n" // _k23 = vld1q_s16(kptr+8);
1890                         "add     %4, %4, #8               \n"
1891                         "add     %5, %5, #16              \n"
1892 
1893                         "subs    w4, w4, #1               \n"
1894 
1895                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
1896                         "smlal   v1.4s, v8.4h, v10.4h     \n" // sum1 += (a00-a03) * (k10-k13)
1897                         "smlal   v2.4s, v8.4h, v11.4h     \n" // sum2 += (a00-a03) * (k20-k23)
1898                         "smlal   v3.4s, v8.4h, v12.4h     \n" // sum3 += (a00-a03) * (k30-k33)
1899 
1900                         "bne     0b                       \n" // end for
1901 
1902                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
1903                         "st1     {v1.4s}, [%1]            \n" //
1904                         "st1     {v2.4s}, [%2]            \n" //
1905                         "st1     {v3.4s}, [%3]            \n" //
1906 
1907                         : "=r"(output0_tm), // %0
1908                         "=r"(output1_tm), // %1
1909                         "=r"(output2_tm), // %2
1910                         "=r"(output3_tm), // %3
1911                         "=r"(r0),         // %4
1912                         "=r"(kptr)        // %5
1913                         : "0"(output0_tm),
1914                         "1"(output1_tm),
1915                         "2"(output2_tm),
1916                         "3"(output3_tm),
1917                         "4"(r0),
1918                         "5"(kptr),
1919                         "r"(inch) // %12
1920                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
1921 #else
1922                     asm volatile(
1923                         // inch loop
1924                         "vmov.s32    q0, #0           \n"
1925                         "vmov.s32    q1, #0           \n"
1926                         "vmov.s32    q2, #0           \n"
1927                         "vmov.s32    q3, #0           \n"
1928                         "mov         r4, %12          \n"
1929 
1930                         "0:                           \n" // for (int q=0; q<inch; q++)
1931                         "vld1.s16    {d16}, [%4]!     \n" // _r0 = vld1_s16(r0);  // input inch0
1932                         "vld1.s16    {d18-d19}, [%5]  \n" // _k01 = vld1q_s16(kptr);
1933                         "add         %5, #16          \n"
1934                         "vld1.s16    {d20-d21}, [%5]  \n" // _k23 = vld1q_s16(kptr+8);
1935                         "add         %5, #16          \n"
1936 
1937                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
1938                         "vmlal.s16   q1, d16, d19     \n" // sum1 += (a00-a03) * (k10-k13)
1939                         "vmlal.s16   q2, d16, d20     \n" // sum2 += (a00-a03) * (k20-k23)
1940                         "vmlal.s16   q3, d16, d21     \n" // sum3 += (a00-a03) * (k30-k33)
1941 
1942                         "subs        r4, r4, #1       \n"
1943                         "bne         0b               \n" // end for
1944 
1945                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
1946                         "vst1.s32    {d2-d3}, [%1]    \n"
1947                         "vst1.s32    {d4-d5}, [%2]    \n"
1948                         "vst1.s32    {d6-d7}, [%3]    \n"
1949 
1950                         : "=r"(output0_tm), // %0
1951                         "=r"(output1_tm), // %1
1952                         "=r"(output2_tm), // %2
1953                         "=r"(output3_tm), // %3
1954                         "=r"(r0),         // %4
1955                         "=r"(kptr)        // %5
1956                         : "0"(output0_tm),
1957                         "1"(output1_tm),
1958                         "2"(output2_tm),
1959                         "3"(output3_tm),
1960                         "4"(r0),
1961                         "5"(kptr),
1962                         "r"(inch) // %12
1963                         : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10");
1964 #endif // __aarch64__
1965 #else
1966                     int sum0[4] = {0};
1967                     int sum1[4] = {0};
1968                     int sum2[4] = {0};
1969                     int sum3[4] = {0};
1970 
1971                     for (int q = 0; q < inch; q++)
1972                     {
1973                         for (int n = 0; n < 4; n++)
1974                         {
1975                             sum0[n] += (int)r0[n] * kptr[n];
1976                             sum1[n] += (int)r0[n] * kptr[n + 4];
1977                             sum2[n] += (int)r0[n] * kptr[n + 8];
1978                             sum3[n] += (int)r0[n] * kptr[n + 12];
1979                         }
1980                         kptr += 16;
1981                         r0 += 4;
1982                     }
1983 
1984                     for (int n = 0; n < 4; n++)
1985                     {
1986                         output0_tm[n] = sum0[n];
1987                         output1_tm[n] = sum1[n];
1988                         output2_tm[n] = sum2[n];
1989                         output3_tm[n] = sum3[n];
1990                     }
1991 #endif // __ARM_NEON
1992                     output0_tm += 36;
1993                     output1_tm += 36;
1994                     output2_tm += 36;
1995                     output3_tm += 36;
1996                 }
1997             }
1998 
1999             remain_outch_start += nn_outch << 2;
2000 
2001             for (int p = remain_outch_start; p < outch; p++)
2002             {
2003                 int* output0_tm = top_blob_tm.channel(p);
2004 
2005                 output0_tm = output0_tm + r * 4;
2006 
2007                 for (int i = 0; i < tiles; i++)
2008                 {
2009                     const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
2010                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
2011 #if __ARM_NEON
2012 #if __aarch64__
2013                     asm volatile(
2014                         // inch loop
2015                         "eor    v0.16b, v0.16b, v0.16b    \n"
2016                         "mov    w4, %w6                   \n"
2017 
2018                         "0:                               \n" // for (int q=0; q<inch; q++)
2019                         "ld1     {v8.4h}, [%1]            \n" // _r0 = vld1_s16(r0);  // input inch0
2020                         "ld1     {v9.4h}, [%2]            \n" // _k0 = vld1q_s16(kptr);
2021                         "add     %1, %1, #8               \n"
2022                         "add     %2, %2, #8               \n"
2023 
2024                         "subs    w4, w4, #1               \n"
2025 
2026                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
2027 
2028                         "bne     0b                       \n" // end for
2029 
2030                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
2031 
2032                         : "=r"(output0_tm), // %0
2033                         "=r"(r0),         // %1
2034                         "=r"(kptr)        // %2
2035                         : "0"(output0_tm),
2036                         "1"(r0),
2037                         "2"(kptr),
2038                         "r"(inch) // %6
2039                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9");
2040 #else
2041                     asm volatile(
2042                         // inch loop
2043                         "vmov.s32    q0, #0           \n"
2044                         "mov         r4, %6           \n"
2045 
2046                         "0:                           \n" // for (int q=0; q<inch; q++)
2047                         "vld1.s16    {d16}, [%1]      \n" // _r0 = vld1_s16(r0);  // input inch0
2048                         "add         %1, #8           \n"
2049                         "vld1.s16    {d18}, [%2]      \n" // _k0 = vld1q_s16(kptr);
2050                         "add         %2, #8           \n"
2051                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
2052 
2053                         "subs        r4, r4, #1       \n"
2054                         "bne         0b               \n" // end for
2055 
2056                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
2057 
2058                         : "=r"(output0_tm), // %0
2059                         "=r"(r0),         // %1
2060                         "=r"(kptr)        // %2
2061                         : "0"(output0_tm),
2062                         "1"(r0),
2063                         "2"(kptr),
2064                         "r"(inch) // %6
2065                         : "cc", "memory", "r4", "q0", "q8", "q9");
2066 #endif // __aarch64__
2067 #else  // __ARM_NEON
2068                     int sum0[4] = {0};
2069 
2070                     for (int q = 0; q < inch; q++)
2071                     {
2072                         for (int n = 0; n < 4; n++)
2073                         {
2074                             sum0[n] += (int)r0[n] * kptr[n];
2075                         }
2076                         kptr += 4;
2077                         r0 += 4;
2078                     }
2079 
2080                     for (int n = 0; n < 4; n++)
2081                     {
2082                         output0_tm[n] = sum0[n];
2083                     }
2084 #endif // __ARM_NEON
2085                     output0_tm += 36;
2086                 }
2087             }
2088 
2089             // for (int p=0; p<outch; p++)
2090             // {
2091             //     Mat out0_tm = top_blob_tm.channel(p);
2092             //     const Mat kernel0_tm = kernel_tm.channel(p);
2093 
2094             //     for (int i=0; i<tiles; i++)
2095             //     {
2096             //         int* output0_tm = out0_tm.row<int>(i);
2097 
2098             //         int sum0[36] = {0};
2099 
2100             //         for (int q=0; q<inch; q++)
2101             //         {
2102             //             const short* r0 = bottom_blob_tm.channel(q).row<short>(i);
2103             //             const short* k0 = kernel0_tm.row<short>(q);
2104 
2105             //             for (int n=0; n<36; n++)
2106             //             {
2107             //                 sum0[n] += (int)r0[n] * k0[n];
2108             //             }
2109             //         }
2110 
2111             //         for (int n=0; n<36; n++)
2112             //         {
2113             //             output0_tm[n] = sum0[n];
2114             //         }
2115             //     }
2116             // }
2117         }
2118     }
2119     bottom_blob_tm = Mat();
2120     // END dot
2121 
2122     // BEGIN transform output
2123     Mat top_blob_bordered;
2124     top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
2125     {
2126         // AT
2127         // const float itm[4][6] = {
2128         //     {1.0f, 1.0f,  1.0f, 1.0f,  1.0f, 0.0f},
2129         //     {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
2130         //     {0.0f, 1.0f,  1.0f, 4.0f,  4.0f, 0.0f},
2131         //     {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
2132         // };
2133 
2134         // 0 =	r00 + r01 + r02 + r03 +	r04
2135         // 1 =		  r01 - r02 + 2 * (r03 - r04)
2136         // 2 =		  r01 + r02 + 4 * (r03 + r04)
2137         // 3 =		  r01 - r02 + 8 * (r03 - r04)  + r05
2138 
2139         int w_tm = outw / 4 * 6;
2140         int h_tm = outh / 4 * 6;
2141 
2142         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
2143         int nRowBlocks = w_tm / 6;
2144 
2145         #pragma omp parallel for num_threads(opt.num_threads)
2146         for (int p = 0; p < outch; p++)
2147         {
2148             int* out_tile = top_blob_tm.channel(p);
2149             int* outRow0 = top_blob_bordered.channel(p);
2150             int* outRow1 = outRow0 + outw;
2151             int* outRow2 = outRow0 + outw * 2;
2152             int* outRow3 = outRow0 + outw * 3;
2153 
2154             for (int j = 0; j < nColBlocks; j++)
2155             {
2156                 for (int i = 0; i < nRowBlocks; i++)
2157                 {
2158 #if __ARM_NEON
2159                     int32x4_t _s0, _s1, _s2, _s3, _s4, _s5;
2160                     int32x2_t _s0n, _s1n, _s2n, _s3n, _s4n, _s5n;
2161                     int32x4_t _w0, _w3;
2162                     int32x2_t _w0n, _w3n;
2163                     int32x4_t _d0, _d1, _d2, _d3, _d4, _d5;
2164                     int32x4_t _o0, _o1, _o2, _o3;
2165                     // load
2166                     _s0 = vld1q_s32(out_tile);
2167                     _s0n = vld1_s32(out_tile + 4);
2168                     _s1 = vld1q_s32(out_tile + 6);
2169                     _s1n = vld1_s32(out_tile + 10);
2170                     _s2 = vld1q_s32(out_tile + 12);
2171                     _s2n = vld1_s32(out_tile + 16);
2172                     _s3 = vld1q_s32(out_tile + 18);
2173                     _s3n = vld1_s32(out_tile + 22);
2174                     _s4 = vld1q_s32(out_tile + 24);
2175                     _s4n = vld1_s32(out_tile + 28);
2176                     _s5 = vld1q_s32(out_tile + 30);
2177                     _s5n = vld1_s32(out_tile + 34);
2178                     // w = A_T * W
2179                     int32x2_t _tp0 = {1, 4};
2180                     int32x2_t _tp1 = {2, 8};
2181 
2182                     // 4*s5[n]
2183                     int32x4_t _s5x4 = vshlq_n_s32(_s5, 2);
2184                     int32x2_t _s5x4n = vshl_n_s32(_s5n, 2);
2185 
2186                     int32x4_t _t1p2 = vaddq_s32(_s1, _s2);
2187                     int32x2_t _t1p2n = vadd_s32(_s1n, _s2n);
2188                     int32x4_t _t3p4 = vaddq_s32(_s3, _s4);
2189                     int32x2_t _t3p4n = vadd_s32(_s3n, _s4n);
2190                     int32x4_t _t1s2 = vsubq_s32(_s1, _s2);
2191                     int32x2_t _t1s2n = vsub_s32(_s1n, _s2n);
2192                     int32x4_t _t3s4 = vsubq_s32(_s3, _s4);
2193                     int32x2_t _t3s4n = vsub_s32(_s3n, _s4n);
2194 
2195                     _w0 = vaddq_s32(_s0, _t1p2);
2196                     _w0n = vadd_s32(_s0n, _t1p2n);
2197                     _w0 = vaddq_s32(_w0, _t3p4);
2198                     _w0n = vadd_s32(_w0n, _t3p4n);
2199                     _w0n = vmul_s32(_w0n, _tp0);
2200 
2201                     // _w2,_w2n
2202                     _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
2203                     _t1p2n = vmla_lane_s32(_t1p2n, _t3p4n, _tp0, 1);
2204                     _t1p2n = vmul_s32(_t1p2n, _tp0);
2205 
2206                     _w3 = vaddq_s32(_s5x4, _t1s2);
2207                     _w3n = vadd_s32(_s5x4n, _t1s2n);
2208                     _w3 = vmlaq_lane_s32(_w3, _t3s4, _tp1, 1);
2209                     _w3n = vmla_lane_s32(_w3n, _t3s4n, _tp1, 1);
2210                     _w3n = vmul_s32(_w3n, _tp0);
2211 
2212                     // _w1, _w1n
2213                     _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
2214                     _t1s2n = vmla_lane_s32(_t1s2n, _t3s4n, _tp1, 0);
2215                     _t1s2n = vmul_s32(_t1s2n, _tp0);
2216 
2217                     int32x4_t _w02n = vcombine_s32(_w0n, _t1p2n);
2218                     int32x4_t _w13n = vcombine_s32(_t1s2n, _w3n);
2219 
2220                     // transpose w to w_t
2221 #if __aarch64__
2222                     int32x4_t _wt0 = vtrn1q_s32(_w0, _t1s2);
2223                     int32x4_t _wt1 = vtrn2q_s32(_w0, _t1s2);
2224                     int32x4_t _wt2 = vtrn1q_s32(_t1p2, _w3);
2225                     int32x4_t _wt3 = vtrn2q_s32(_t1p2, _w3);
2226                     int64x2_t _dt0 = vtrn1q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
2227                     int64x2_t _dt2 = vtrn2q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
2228                     int64x2_t _dt1 = vtrn1q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
2229                     int64x2_t _dt3 = vtrn2q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
2230                     _d0 = vreinterpretq_s32_s64(_dt0);
2231                     _d1 = vreinterpretq_s32_s64(_dt1);
2232                     _d2 = vreinterpretq_s32_s64(_dt2);
2233                     _d3 = vreinterpretq_s32_s64(_dt3);
2234                     _d4 = vtrn1q_s32(_w02n, _w13n);
2235                     _d5 = vtrn2q_s32(_w02n, _w13n);
2236 #else
2237                     asm volatile(
2238                         "vtrn.32    %q[_w0], %q[_w1]        \n"
2239                         "vtrn.32    %q[_w2], %q[_w3]        \n"
2240                         "vswp       %f[_w0], %e[_w2]        \n"
2241                         "vswp       %f[_w1], %e[_w3]        \n"
2242                         "vtrn.32    %q[_w02n], %q[_w13n]    \n"
2243                         : [_w0] "+w"(_w0),
2244                         [_w1] "+w"(_t1s2),
2245                         [_w2] "+w"(_t1p2),
2246                         [_w3] "+w"(_w3),
2247                         [_w02n] "+w"(_w02n),
2248                         [_w13n] "+w"(_w13n)
2249                         :
2250                         : "cc", "memory");
2251                     _d0 = _w0;
2252                     _d1 = _t1s2;
2253                     _d2 = _t1p2;
2254                     _d3 = _w3;
2255                     _d4 = _w02n;
2256                     _d5 = _w13n;
2257 #endif
2258                     // Y = A_T * w_t
2259                     _t1p2 = vaddq_s32(_d1, _d2);
2260                     _t3p4 = vaddq_s32(_d3, _d4);
2261                     _t1s2 = vsubq_s32(_d1, _d2);
2262                     _t3s4 = vsubq_s32(_d3, _d4);
2263 
2264                     _o0 = vaddq_s32(_d0, _t1p2);
2265                     _o0 = vaddq_s32(_o0, _t3p4);
2266 
2267                     // _o2
2268                     _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
2269 
2270                     _o3 = vaddq_s32(_d5, _t1s2);
2271                     _o3 = vmlaq_lane_s32(_o3, _t3s4, _tp1, 1);
2272 
2273                     // _o1
2274                     _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
2275 
2276                     // save to top blob tm
2277                     float32x4_t _ot0 = vcvtq_f32_s32(_o0);
2278                     float32x4_t _ot1 = vcvtq_f32_s32(_t1s2);
2279                     float32x4_t _ot2 = vcvtq_f32_s32(_t1p2);
2280                     float32x4_t _ot3 = vcvtq_f32_s32(_o3);
2281 
2282                     _ot0 = vmulq_n_f32(_ot0, 0.0017361112);
2283                     _ot1 = vmulq_n_f32(_ot1, 0.0017361112);
2284                     _ot2 = vmulq_n_f32(_ot2, 0.0017361112);
2285                     _ot3 = vmulq_n_f32(_ot3, 0.0017361112);
2286 
2287                     _o0 = vcvtq_s32_f32(_ot0);
2288                     _o1 = vcvtq_s32_f32(_ot1);
2289                     _o2 = vcvtq_s32_f32(_ot2);
2290                     _o3 = vcvtq_s32_f32(_ot3);
2291 
2292                     vst1q_s32(outRow0, _o0);
2293                     vst1q_s32(outRow1, _o1);
2294                     vst1q_s32(outRow2, _o2);
2295                     vst1q_s32(outRow3, _o3);
2296 #else
2297                     int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6];
2298                     int w0[6], w1[6], w2[6], w3[6];
2299                     int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4];
2300                     int o0[4], o1[4], o2[4], o3[4];
2301 
2302                     // load
2303                     for (int n = 0; n < 6; n++)
2304                     {
2305                         s0[n] = out_tile[n];
2306                         s1[n] = out_tile[n + 6];
2307                         s2[n] = out_tile[n + 12];
2308                         s3[n] = out_tile[n + 18];
2309                         s4[n] = out_tile[n + 24];
2310                         s5[n] = out_tile[n + 30];
2311                     }
2312                     // w = A_T * W
2313                     for (int n = 0; n < 5; n++)
2314                     {
2315                         w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n];
2316                         w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n];
2317                         w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n];
2318                         w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n];
2319                     }
2320                     for (int n = 5; n < 6; n++)
2321                     {
2322                         w0[n] = 4 * (s0[n] + s1[n] + s2[n] + s3[n] + s4[n]);
2323                         w1[n] = 4 * (s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]);
2324                         w2[n] = 4 * (s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]);
2325                         w3[n] = 4 * (s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n]);
2326                     }
2327                     // transpose w to w_t
2328                     {
2329                         d0[0] = w0[0];
2330                         d0[1] = w1[0];
2331                         d0[2] = w2[0];
2332                         d0[3] = w3[0];
2333                         d1[0] = w0[1];
2334                         d1[1] = w1[1];
2335                         d1[2] = w2[1];
2336                         d1[3] = w3[1];
2337                         d2[0] = w0[2];
2338                         d2[1] = w1[2];
2339                         d2[2] = w2[2];
2340                         d2[3] = w3[2];
2341                         d3[0] = w0[3];
2342                         d3[1] = w1[3];
2343                         d3[2] = w2[3];
2344                         d3[3] = w3[3];
2345                         d4[0] = w0[4];
2346                         d4[1] = w1[4];
2347                         d4[2] = w2[4];
2348                         d4[3] = w3[4];
2349                         d5[0] = w0[5];
2350                         d5[1] = w1[5];
2351                         d5[2] = w2[5];
2352                         d5[3] = w3[5];
2353                     }
2354                     // Y = A_T * w_t
2355                     for (int n = 0; n < 4; n++)
2356                     {
2357                         o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n];
2358                         o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n];
2359                         o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n];
2360                         o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n];
2361                     }
2362                     // save to top blob tm
2363                     for (int n = 0; n < 4; n++)
2364                     {
2365                         outRow0[n] = o0[n] / 576;
2366                         outRow1[n] = o1[n] / 576;
2367                         outRow2[n] = o2[n] / 576;
2368                         outRow3[n] = o3[n] / 576;
2369                     }
2370 #endif // __ARM_NEON
2371                     out_tile += 36;
2372 
2373                     outRow0 += 4;
2374                     outRow1 += 4;
2375                     outRow2 += 4;
2376                     outRow3 += 4;
2377                 }
2378 
2379                 outRow0 += outw * 3;
2380                 outRow1 += outw * 3;
2381                 outRow2 += outw * 3;
2382                 outRow3 += outw * 3;
2383             }
2384         }
2385     }
2386     // END transform output
2387 
2388     // cut result pad
2389     copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
2390 }
2391 
conv3x3s1_winograd43_dequant_int8_neon(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Mat & _bias,std::vector<float> scales_dequant,const Option & opt)2392 static void conv3x3s1_winograd43_dequant_int8_neon(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Mat& _bias, std::vector<float> scales_dequant, const Option& opt)
2393 {
2394     int w = bottom_blob.w;
2395     int h = bottom_blob.h;
2396     int inch = bottom_blob.c;
2397 
2398     int outw = top_blob.w;
2399     int outh = top_blob.h;
2400     int outch = top_blob.c;
2401 
2402     const float* bias = _bias;
2403 
2404     // pad to 4n+2, winograd F(4,3)
2405     Mat bottom_blob_bordered = bottom_blob;
2406 
2407     outw = (outw + 3) / 4 * 4;
2408     outh = (outh + 3) / 4 * 4;
2409 
2410     w = outw + 2;
2411     h = outh + 2;
2412     Option opt_b = opt;
2413     opt_b.blob_allocator = opt.workspace_allocator;
2414     copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
2415 
2416     // BEGIN transform input
2417     Mat bottom_blob_tm;
2418     {
2419         int w_tm = outw / 4 * 6;
2420         int h_tm = outh / 4 * 6;
2421 
2422         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
2423         int nRowBlocks = w_tm / 6;
2424 
2425         const int tiles = nColBlocks * nRowBlocks;
2426 
2427         bottom_blob_tm.create(4, inch, tiles * 9, 2u, opt.workspace_allocator);
2428 
2429         // BT
2430         // const float itm[4][4] = {
2431         //     {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
2432         //     {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
2433         //     {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
2434         //     {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
2435         //     {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
2436         //     {0.0f, 4.0f,  0.0f,-5.0f, 0.0f, 1.0f}
2437         // };
2438 
2439         // 0 =	4 * r00  - 5 * r02	+ r04
2440         // 1 = -4 * (r01 + r02)  + r03 + r04
2441         // 2 =	4 * (r01 - r02)  - r03 + r04
2442         // 3 = -2 * r01 - r02 + 2 * r03 + r04
2443         // 4 =	2 * r01 - r02 - 2 * r03 + r04
2444         // 5 =	4 * r01 - 5 * r03 + r05
2445 
2446         #pragma omp parallel for num_threads(opt.num_threads)
2447         for (int q = 0; q < inch; q++)
2448         {
2449             const signed char* img = bottom_blob_bordered.channel(q);
2450 
2451             for (int j = 0; j < nColBlocks; j++)
2452             {
2453                 const signed char* r0 = img + w * j * 4;
2454                 const signed char* r1 = r0 + w;
2455                 const signed char* r2 = r1 + w;
2456                 const signed char* r3 = r2 + w;
2457                 const signed char* r4 = r3 + w;
2458                 const signed char* r5 = r4 + w;
2459 
2460                 for (int i = 0; i < nRowBlocks; i++)
2461                 {
2462                     short* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row<short>(q);
2463                     short* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row<short>(q);
2464                     short* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row<short>(q);
2465                     short* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row<short>(q);
2466                     short* out_tm4 = bottom_blob_tm.channel(tiles * 4 + j * nRowBlocks + i).row<short>(q);
2467                     short* out_tm5 = bottom_blob_tm.channel(tiles * 5 + j * nRowBlocks + i).row<short>(q);
2468                     short* out_tm6 = bottom_blob_tm.channel(tiles * 6 + j * nRowBlocks + i).row<short>(q);
2469                     short* out_tm7 = bottom_blob_tm.channel(tiles * 7 + j * nRowBlocks + i).row<short>(q);
2470                     short* out_tm8 = bottom_blob_tm.channel(tiles * 8 + j * nRowBlocks + i).row<short>(q);
2471 #if __ARM_NEON
2472                     int8x8_t _d0, _d1, _d2, _d3, _d4, _d5;
2473                     int16x8_t _w0, _w1, _w2, _w3, _w4, _w5;
2474                     int16x8_t _t0, _t1, _t2, _t3, _t4, _t5;
2475                     int16x8_t _n0, _n1, _n2, _n3, _n4, _n5;
2476                     // load
2477                     _d0 = vld1_s8(r0);
2478                     _d1 = vld1_s8(r1);
2479                     _d2 = vld1_s8(r2);
2480                     _d3 = vld1_s8(r3);
2481                     _d4 = vld1_s8(r4);
2482                     _d5 = vld1_s8(r5);
2483 
2484                     int8x8_t _1_n = vdup_n_s8(-1);
2485                     int8x8_t _2_p = vdup_n_s8(2);
2486                     int8x8_t _2_n = vdup_n_s8(-2);
2487                     int8x8_t _4_p = vdup_n_s8(4);
2488                     int8x8_t _4_n = vdup_n_s8(-4);
2489                     int8x8_t _5_n = vdup_n_s8(-5);
2490 
2491                     int16x8_t _1_n_s16 = vdupq_n_s16(-1);
2492                     int16x8_t _2_p_s16 = vdupq_n_s16(2);
2493                     int16x8_t _2_n_s16 = vdupq_n_s16(-2);
2494                     int16x8_t _4_p_s16 = vdupq_n_s16(4);
2495                     int16x8_t _4_n_s16 = vdupq_n_s16(-4);
2496                     int16x8_t _5_n_s16 = vdupq_n_s16(-5);
2497                     // w = B_t * d
2498                     _w0 = vmull_s8(_d0, _4_p);
2499                     _w0 = vmlal_s8(_w0, _d2, _5_n);
2500                     _w0 = vaddw_s8(_w0, _d4);
2501 
2502                     _w1 = vmull_s8(_d1, _4_n);
2503                     _w1 = vmlal_s8(_w1, _d2, _4_n);
2504                     _w1 = vaddw_s8(_w1, _d3);
2505                     _w1 = vaddw_s8(_w1, _d4);
2506 
2507                     _w2 = vmull_s8(_d1, _4_p);
2508                     _w2 = vmlal_s8(_w2, _d2, _4_n);
2509                     _w2 = vmlal_s8(_w2, _d3, _1_n);
2510                     _w2 = vaddw_s8(_w2, _d4);
2511 
2512                     _w3 = vmull_s8(_d1, _2_n);
2513                     _w3 = vmlal_s8(_w3, _d2, _1_n);
2514                     _w3 = vmlal_s8(_w3, _d3, _2_p);
2515                     _w3 = vaddw_s8(_w3, _d4);
2516 
2517                     _w4 = vmull_s8(_d1, _2_p);
2518                     _w4 = vmlal_s8(_w4, _d2, _1_n);
2519                     _w4 = vmlal_s8(_w4, _d3, _2_n);
2520                     _w4 = vaddw_s8(_w4, _d4);
2521 
2522                     _w5 = vmull_s8(_d1, _4_p);
2523                     _w5 = vmlal_s8(_w5, _d3, _5_n);
2524                     _w5 = vaddw_s8(_w5, _d5);
2525                     // transpose d to d_t
2526                     {
2527                         _t0[0] = _w0[0];
2528                         _t1[0] = _w0[1];
2529                         _t2[0] = _w0[2];
2530                         _t3[0] = _w0[3];
2531                         _t4[0] = _w0[4];
2532                         _t5[0] = _w0[5];
2533                         _t0[1] = _w1[0];
2534                         _t1[1] = _w1[1];
2535                         _t2[1] = _w1[2];
2536                         _t3[1] = _w1[3];
2537                         _t4[1] = _w1[4];
2538                         _t5[1] = _w1[5];
2539                         _t0[2] = _w2[0];
2540                         _t1[2] = _w2[1];
2541                         _t2[2] = _w2[2];
2542                         _t3[2] = _w2[3];
2543                         _t4[2] = _w2[4];
2544                         _t5[2] = _w2[5];
2545                         _t0[3] = _w3[0];
2546                         _t1[3] = _w3[1];
2547                         _t2[3] = _w3[2];
2548                         _t3[3] = _w3[3];
2549                         _t4[3] = _w3[4];
2550                         _t5[3] = _w3[5];
2551                         _t0[4] = _w4[0];
2552                         _t1[4] = _w4[1];
2553                         _t2[4] = _w4[2];
2554                         _t3[4] = _w4[3];
2555                         _t4[4] = _w4[4];
2556                         _t5[4] = _w4[5];
2557                         _t0[5] = _w5[0];
2558                         _t1[5] = _w5[1];
2559                         _t2[5] = _w5[2];
2560                         _t3[5] = _w5[3];
2561                         _t4[5] = _w5[4];
2562                         _t5[5] = _w5[5];
2563                     }
2564                     // d = B_t * d_t
2565                     _n0 = vmulq_s16(_t0, _4_p_s16);
2566                     _n0 = vmlaq_s16(_n0, _t2, _5_n_s16);
2567                     _n0 = vaddq_s16(_n0, _t4);
2568 
2569                     _n1 = vmulq_s16(_t1, _4_n_s16);
2570                     _n1 = vmlaq_s16(_n1, _t2, _4_n_s16);
2571                     _n1 = vaddq_s16(_n1, _t3);
2572                     _n1 = vaddq_s16(_n1, _t4);
2573 
2574                     _n2 = vmulq_s16(_t1, _4_p_s16);
2575                     _n2 = vmlaq_s16(_n2, _t2, _4_n_s16);
2576                     _n2 = vmlaq_s16(_n2, _t3, _1_n_s16);
2577                     _n2 = vaddq_s16(_n2, _t4);
2578 
2579                     _n3 = vmulq_s16(_t1, _2_n_s16);
2580                     _n3 = vmlaq_s16(_n3, _t2, _1_n_s16);
2581                     _n3 = vmlaq_s16(_n3, _t3, _2_p_s16);
2582                     _n3 = vaddq_s16(_n3, _t4);
2583 
2584                     _n4 = vmulq_s16(_t1, _2_p_s16);
2585                     _n4 = vmlaq_s16(_n4, _t2, _1_n_s16);
2586                     _n4 = vmlaq_s16(_n4, _t3, _2_n_s16);
2587                     _n4 = vaddq_s16(_n4, _t4);
2588 
2589                     _n5 = vmulq_s16(_t1, _4_p_s16);
2590                     _n5 = vmlaq_s16(_n5, _t3, _5_n_s16);
2591                     _n5 = vaddq_s16(_n5, _t5);
2592                     // save to out_tm
2593                     out_tm0[0] = _n0[0];
2594                     out_tm0[1] = _n0[1];
2595                     out_tm0[2] = _n0[2];
2596                     out_tm0[3] = _n0[3];
2597                     out_tm1[0] = _n0[4];
2598                     out_tm1[1] = _n0[5];
2599                     out_tm1[2] = _n1[0];
2600                     out_tm1[3] = _n1[1];
2601                     out_tm2[0] = _n1[2];
2602                     out_tm2[1] = _n1[3];
2603                     out_tm2[2] = _n1[4];
2604                     out_tm2[3] = _n1[5];
2605 
2606                     out_tm3[0] = _n2[0];
2607                     out_tm3[1] = _n2[1];
2608                     out_tm3[2] = _n2[2];
2609                     out_tm3[3] = _n2[3];
2610                     out_tm4[0] = _n2[4];
2611                     out_tm4[1] = _n2[5];
2612                     out_tm4[2] = _n3[0];
2613                     out_tm4[3] = _n3[1];
2614                     out_tm5[0] = _n3[2];
2615                     out_tm5[1] = _n3[3];
2616                     out_tm5[2] = _n3[4];
2617                     out_tm5[3] = _n3[5];
2618 
2619                     out_tm6[0] = _n4[0];
2620                     out_tm6[1] = _n4[1];
2621                     out_tm6[2] = _n4[2];
2622                     out_tm6[3] = _n4[3];
2623                     out_tm7[0] = _n4[4];
2624                     out_tm7[1] = _n4[5];
2625                     out_tm7[2] = _n5[0];
2626                     out_tm7[3] = _n5[1];
2627                     out_tm8[0] = _n5[2];
2628                     out_tm8[1] = _n5[3];
2629                     out_tm8[2] = _n5[4];
2630                     out_tm8[3] = _n5[5];
2631 #else
2632                     short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6];
2633                     short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6];
2634                     short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6];
2635 
2636                     // load
2637                     for (int n = 0; n < 6; n++)
2638                     {
2639                         d0[n] = r0[n];
2640                         d1[n] = r1[n];
2641                         d2[n] = r2[n];
2642                         d3[n] = r3[n];
2643                         d4[n] = r4[n];
2644                         d5[n] = r5[n];
2645                     }
2646                     // w = B_t * d
2647                     for (int n = 0; n < 6; n++)
2648                     {
2649                         w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n];
2650                         w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n];
2651                         w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n];
2652                         w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n];
2653                         w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n];
2654                         w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n];
2655                     }
2656                     // transpose d to d_t
2657                     {
2658                         t0[0] = w0[0];
2659                         t1[0] = w0[1];
2660                         t2[0] = w0[2];
2661                         t3[0] = w0[3];
2662                         t4[0] = w0[4];
2663                         t5[0] = w0[5];
2664                         t0[1] = w1[0];
2665                         t1[1] = w1[1];
2666                         t2[1] = w1[2];
2667                         t3[1] = w1[3];
2668                         t4[1] = w1[4];
2669                         t5[1] = w1[5];
2670                         t0[2] = w2[0];
2671                         t1[2] = w2[1];
2672                         t2[2] = w2[2];
2673                         t3[2] = w2[3];
2674                         t4[2] = w2[4];
2675                         t5[2] = w2[5];
2676                         t0[3] = w3[0];
2677                         t1[3] = w3[1];
2678                         t2[3] = w3[2];
2679                         t3[3] = w3[3];
2680                         t4[3] = w3[4];
2681                         t5[3] = w3[5];
2682                         t0[4] = w4[0];
2683                         t1[4] = w4[1];
2684                         t2[4] = w4[2];
2685                         t3[4] = w4[3];
2686                         t4[4] = w4[4];
2687                         t5[4] = w4[5];
2688                         t0[5] = w5[0];
2689                         t1[5] = w5[1];
2690                         t2[5] = w5[2];
2691                         t3[5] = w5[3];
2692                         t4[5] = w5[4];
2693                         t5[5] = w5[5];
2694                     }
2695                     // d = B_t * d_t
2696                     for (int n = 0; n < 6; n++)
2697                     {
2698                         d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n];
2699                         d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n];
2700                         d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n];
2701                         d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n];
2702                         d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n];
2703                         d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n];
2704                     }
2705                     // save to out_tm
2706                     {
2707                         out_tm0[0] = d0[0];
2708                         out_tm0[1] = d0[1];
2709                         out_tm0[2] = d0[2];
2710                         out_tm0[3] = d0[3];
2711                         out_tm1[0] = d0[4];
2712                         out_tm1[1] = d0[5];
2713                         out_tm1[2] = d1[0];
2714                         out_tm1[3] = d1[1];
2715                         out_tm2[0] = d1[2];
2716                         out_tm2[1] = d1[3];
2717                         out_tm2[2] = d1[4];
2718                         out_tm2[3] = d1[5];
2719 
2720                         out_tm3[0] = d2[0];
2721                         out_tm3[1] = d2[1];
2722                         out_tm3[2] = d2[2];
2723                         out_tm3[3] = d2[3];
2724                         out_tm4[0] = d2[4];
2725                         out_tm4[1] = d2[5];
2726                         out_tm4[2] = d3[0];
2727                         out_tm4[3] = d3[1];
2728                         out_tm5[0] = d3[2];
2729                         out_tm5[1] = d3[3];
2730                         out_tm5[2] = d3[4];
2731                         out_tm5[3] = d3[5];
2732 
2733                         out_tm6[0] = d4[0];
2734                         out_tm6[1] = d4[1];
2735                         out_tm6[2] = d4[2];
2736                         out_tm6[3] = d4[3];
2737                         out_tm7[0] = d4[4];
2738                         out_tm7[1] = d4[5];
2739                         out_tm7[2] = d5[0];
2740                         out_tm7[3] = d5[1];
2741                         out_tm8[0] = d5[2];
2742                         out_tm8[1] = d5[3];
2743                         out_tm8[2] = d5[4];
2744                         out_tm8[3] = d5[5];
2745                     }
2746 #endif // __ARM_NEON
2747                     r0 += 4;
2748                     r1 += 4;
2749                     r2 += 4;
2750                     r3 += 4;
2751                     r4 += 4;
2752                     r5 += 4;
2753                 }
2754             }
2755         }
2756     }
2757     bottom_blob_bordered = Mat();
2758 
2759     // BEGIN dot
2760     Mat top_blob_tm;
2761     {
2762         int w_tm = outw / 4 * 6;
2763         int h_tm = outh / 4 * 6;
2764 
2765         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
2766         int nRowBlocks = w_tm / 6;
2767 
2768         const int tiles = nColBlocks * nRowBlocks;
2769 
2770         top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator);
2771 
2772         #pragma omp parallel for num_threads(opt.num_threads)
2773         for (int r = 0; r < 9; r++)
2774         {
2775             int nn_outch = 0;
2776             int remain_outch_start = 0;
2777 
2778             nn_outch = outch >> 3;
2779             remain_outch_start = nn_outch << 3;
2780 
2781             for (int pp = 0; pp < nn_outch; pp++)
2782             {
2783                 int p = pp * 8;
2784 
2785                 int* output0_tm = top_blob_tm.channel(p);
2786                 int* output1_tm = top_blob_tm.channel(p + 1);
2787                 int* output2_tm = top_blob_tm.channel(p + 2);
2788                 int* output3_tm = top_blob_tm.channel(p + 3);
2789                 int* output4_tm = top_blob_tm.channel(p + 4);
2790                 int* output5_tm = top_blob_tm.channel(p + 5);
2791                 int* output6_tm = top_blob_tm.channel(p + 6);
2792                 int* output7_tm = top_blob_tm.channel(p + 7);
2793 
2794                 output0_tm = output0_tm + r * 4;
2795                 output1_tm = output1_tm + r * 4;
2796                 output2_tm = output2_tm + r * 4;
2797                 output3_tm = output3_tm + r * 4;
2798                 output4_tm = output4_tm + r * 4;
2799                 output5_tm = output5_tm + r * 4;
2800                 output6_tm = output6_tm + r * 4;
2801                 output7_tm = output7_tm + r * 4;
2802 
2803                 for (int i = 0; i < tiles; i++)
2804                 {
2805                     const short* kptr = kernel_tm_test[r].channel(p / 8);
2806                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
2807 #if __ARM_NEON
2808 #if __aarch64__
2809                     asm volatile(
2810                         // inch loop
2811                         "eor    v0.16b, v0.16b, v0.16b    \n"
2812                         "eor    v1.16b, v1.16b, v1.16b    \n"
2813                         "eor    v2.16b, v2.16b, v2.16b    \n"
2814                         "eor    v3.16b, v3.16b, v3.16b    \n"
2815                         "eor    v4.16b, v4.16b, v4.16b    \n"
2816                         "eor    v5.16b, v5.16b, v5.16b    \n"
2817                         "eor    v6.16b, v6.16b, v6.16b    \n"
2818                         "eor    v7.16b, v7.16b, v7.16b    \n"
2819                         "mov    w4, %w20                  \n"
2820 
2821                         "0:                               \n" // for (int q=0; q<inch; q++)
2822                         "prfm    pldl1keep, [%9, #128]    \n" // _r0 = vld1_s16(r0);
2823                         "ld1     {v8.4h}, [%8]            \n"
2824                         "ld1     {v9.4h, v10.4h}, [%9]    \n" // _k01 = vld1q_s16(kptr);
2825                         "add     %9, %9, #16              \n"
2826                         "ld1     {v11.4h, v12.4h}, [%9]   \n" // _k23 = vld1q_s16(kptr+8);
2827                         "add     %9, %9, #16              \n"
2828                         "ld1     {v13.4h, v14.4h}, [%9]   \n" // _k45 = vld1q_s16(kptr+16);
2829                         "add     %9, %9, #16              \n"
2830                         "ld1     {v15.4h, v16.4h}, [%9]   \n" // _k67 = vld1q_s16(kptr+24);
2831                         "add     %8, %8, #8               \n"
2832                         "add     %9, %9, #16              \n"
2833 
2834                         "subs    w4, w4, #1               \n"
2835 
2836                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
2837                         "smlal   v1.4s, v8.4h, v10.4h     \n" // sum1 += (a00-a03) * (k10-k13)
2838                         "smlal   v2.4s, v8.4h, v11.4h     \n" // sum2 += (a00-a03) * (k20-k23)
2839                         "smlal   v3.4s, v8.4h, v12.4h     \n" // sum3 += (a00-a03) * (k30-k33)
2840                         "smlal   v4.4s, v8.4h, v13.4h     \n" // sum4 += (a00-a03) * (k40-k43)
2841                         "smlal   v5.4s, v8.4h, v14.4h     \n" // sum5 += (a00-a03) * (k50-k53)
2842                         "smlal   v6.4s, v8.4h, v15.4h     \n" // sum6 += (a00-a03) * (k60-k63)
2843                         "smlal   v7.4s, v8.4h, v16.4h     \n" // sum7 += (a00-a03) * (k70-k73)
2844 
2845                         "bne     0b                       \n" // end for
2846 
2847                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
2848                         "st1     {v1.4s}, [%1]            \n" //
2849                         "st1     {v2.4s}, [%2]            \n" //
2850                         "st1     {v3.4s}, [%3]            \n" //
2851                         "st1     {v4.4s}, [%4]            \n" //
2852                         "st1     {v5.4s}, [%5]            \n" //
2853                         "st1     {v6.4s}, [%6]            \n" //
2854                         "st1     {v7.4s}, [%7]            \n" //
2855 
2856                         : "=r"(output0_tm), // %0
2857                         "=r"(output1_tm), // %1
2858                         "=r"(output2_tm), // %2
2859                         "=r"(output3_tm), // %3
2860                         "=r"(output4_tm), // %4
2861                         "=r"(output5_tm), // %5
2862                         "=r"(output6_tm), // %6
2863                         "=r"(output7_tm), // %7
2864                         "=r"(r0),         // %8
2865                         "=r"(kptr)        // %9
2866                         : "0"(output0_tm),
2867                         "1"(output1_tm),
2868                         "2"(output2_tm),
2869                         "3"(output3_tm),
2870                         "4"(output4_tm),
2871                         "5"(output5_tm),
2872                         "6"(output6_tm),
2873                         "7"(output7_tm),
2874                         "8"(r0),
2875                         "9"(kptr),
2876                         "r"(inch) // %20
2877                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16");
2878 #else
2879                     asm volatile(
2880                         // inch loop
2881                         "vmov.s32    q0, #0           \n"
2882                         "vmov.s32    q1, #0           \n"
2883                         "vmov.s32    q2, #0           \n"
2884                         "vmov.s32    q3, #0           \n"
2885                         "vmov.s32    q4, #0           \n"
2886                         "vmov.s32    q5, #0           \n"
2887                         "vmov.s32    q6, #0           \n"
2888                         "vmov.s32    q7, #0           \n"
2889                         "mov         r4, %20          \n"
2890 
2891                         "0:                           \n" // for (int q=0; q<inch; q++)
2892                         "vld1.s16    {d16}, [%8]!     \n" // _r0 = vld1_s16(r0);  // input inch0
2893                         "vld1.s16    {d18-d19}, [%9]  \n" // _k01 = vld1q_s16(kptr);
2894                         "add         %9, #16          \n"
2895                         "vld1.s16    {d20-d21}, [%9]  \n" // _k23 = vld1q_s16(kptr+8);
2896                         "add         %9, #16          \n"
2897                         "vld1.s16    {d22-d23}, [%9]  \n" // _k45 = vld1q_s16(kptr+16);
2898                         "add         %9, #16          \n"
2899                         "vld1.s16    {d24-d25}, [%9]  \n" // _k67 = vld1q_s16(kptr+24);
2900                         "add         %9, #16          \n"
2901 
2902                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
2903                         "vmlal.s16   q1, d16, d19     \n" // sum1 += (a00-a03) * (k10-k13)
2904                         "vmlal.s16   q2, d16, d20     \n" // sum2 += (a00-a03) * (k20-k23)
2905                         "vmlal.s16   q3, d16, d21     \n" // sum3 += (a00-a03) * (k30-k33)
2906                         "vmlal.s16   q4, d16, d22     \n" // sum4 += (a00-a03) * (k40-k43)
2907                         "vmlal.s16   q5, d16, d23     \n" // sum5 += (a00-a03) * (k50-k53)
2908                         "vmlal.s16   q6, d16, d24     \n" // sum6 += (a00-a03) * (k60-k63)
2909                         "vmlal.s16   q7, d16, d25     \n" // sum7 += (a00-a03) * (k70-k73)
2910 
2911                         "subs        r4, r4, #1       \n"
2912                         "bne         0b               \n" // end for
2913 
2914                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
2915                         "vst1.s32    {d2-d3}, [%1]    \n"
2916                         "vst1.s32    {d4-d5}, [%2]    \n"
2917                         "vst1.s32    {d6-d7}, [%3]    \n"
2918                         "vst1.s32    {d8-d9}, [%4]    \n"
2919                         "vst1.s32    {d10-d11}, [%5]  \n"
2920                         "vst1.s32    {d12-d13}, [%6]  \n"
2921                         "vst1.s32    {d14-d15}, [%7]  \n"
2922 
2923                         : "=r"(output0_tm), // %0
2924                         "=r"(output1_tm), // %1
2925                         "=r"(output2_tm), // %2
2926                         "=r"(output3_tm), // %3
2927                         "=r"(output4_tm), // %4
2928                         "=r"(output5_tm), // %5
2929                         "=r"(output6_tm), // %6
2930                         "=r"(output7_tm), // %7
2931                         "=r"(r0),         // %8
2932                         "=r"(kptr)        // %9
2933                         : "0"(output0_tm),
2934                         "1"(output1_tm),
2935                         "2"(output2_tm),
2936                         "3"(output3_tm),
2937                         "4"(output4_tm),
2938                         "5"(output5_tm),
2939                         "6"(output6_tm),
2940                         "7"(output7_tm),
2941                         "8"(r0),
2942                         "9"(kptr),
2943                         "r"(inch) // %20
2944                         : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
2945 #endif // __aarch64__
2946 #else
2947                     int sum0[4] = {0};
2948                     int sum1[4] = {0};
2949                     int sum2[4] = {0};
2950                     int sum3[4] = {0};
2951                     int sum4[4] = {0};
2952                     int sum5[4] = {0};
2953                     int sum6[4] = {0};
2954                     int sum7[4] = {0};
2955 
2956                     for (int q = 0; q < inch; q++)
2957                     {
2958                         for (int n = 0; n < 4; n++)
2959                         {
2960                             sum0[n] += (int)r0[n] * kptr[n];
2961                             sum1[n] += (int)r0[n] * kptr[n + 4];
2962                             sum2[n] += (int)r0[n] * kptr[n + 8];
2963                             sum3[n] += (int)r0[n] * kptr[n + 12];
2964                             sum4[n] += (int)r0[n] * kptr[n + 16];
2965                             sum5[n] += (int)r0[n] * kptr[n + 20];
2966                             sum6[n] += (int)r0[n] * kptr[n + 24];
2967                             sum7[n] += (int)r0[n] * kptr[n + 28];
2968                         }
2969                         kptr += 32;
2970                         r0 += 4;
2971                     }
2972 
2973                     for (int n = 0; n < 4; n++)
2974                     {
2975                         output0_tm[n] = sum0[n];
2976                         output1_tm[n] = sum1[n];
2977                         output2_tm[n] = sum2[n];
2978                         output3_tm[n] = sum3[n];
2979                         output4_tm[n] = sum4[n];
2980                         output5_tm[n] = sum5[n];
2981                         output6_tm[n] = sum6[n];
2982                         output7_tm[n] = sum7[n];
2983                     }
2984 #endif // __ARM_NEON
2985                     output0_tm += 36;
2986                     output1_tm += 36;
2987                     output2_tm += 36;
2988                     output3_tm += 36;
2989                     output4_tm += 36;
2990                     output5_tm += 36;
2991                     output6_tm += 36;
2992                     output7_tm += 36;
2993                 }
2994             }
2995 
2996             nn_outch = (outch - remain_outch_start) >> 2;
2997 
2998             for (int pp = 0; pp < nn_outch; pp++)
2999             {
3000                 int p = remain_outch_start + pp * 4;
3001 
3002                 int* output0_tm = top_blob_tm.channel(p);
3003                 int* output1_tm = top_blob_tm.channel(p + 1);
3004                 int* output2_tm = top_blob_tm.channel(p + 2);
3005                 int* output3_tm = top_blob_tm.channel(p + 3);
3006 
3007                 output0_tm = output0_tm + r * 4;
3008                 output1_tm = output1_tm + r * 4;
3009                 output2_tm = output2_tm + r * 4;
3010                 output3_tm = output3_tm + r * 4;
3011 
3012                 for (int i = 0; i < tiles; i++)
3013                 {
3014                     const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
3015                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
3016 #if __ARM_NEON
3017 #if __aarch64__
3018                     asm volatile(
3019                         // inch loop
3020                         "eor    v0.16b, v0.16b, v0.16b    \n"
3021                         "eor    v1.16b, v1.16b, v1.16b    \n"
3022                         "eor    v2.16b, v2.16b, v2.16b    \n"
3023                         "eor    v3.16b, v3.16b, v3.16b    \n"
3024                         "mov    w4, %w12                  \n"
3025 
3026                         "0:                               \n" // for (int q=0; q<inch; q++)
3027                         "prfm    pldl1keep, [%5, #128]    \n" // _r0 = vld1_s16(r0);  // input inch0
3028                         "ld1     {v8.4h}, [%4]            \n"
3029                         "ld1     {v9.4h, v10.4h}, [%5]    \n" // _k01 = vld1q_s16(kptr);
3030                         "add     %5, %5, #16              \n"
3031                         "ld1     {v11.4h, v12.4h}, [%5]   \n" // _k23 = vld1q_s16(kptr+8);
3032                         "add     %4, %4, #8               \n"
3033                         "add     %5, %5, #16              \n"
3034 
3035                         "subs    w4, w4, #1               \n"
3036 
3037                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
3038                         "smlal   v1.4s, v8.4h, v10.4h     \n" // sum1 += (a00-a03) * (k10-k13)
3039                         "smlal   v2.4s, v8.4h, v11.4h     \n" // sum2 += (a00-a03) * (k20-k23)
3040                         "smlal   v3.4s, v8.4h, v12.4h     \n" // sum3 += (a00-a03) * (k30-k33)
3041 
3042                         "bne     0b                       \n" // end for
3043 
3044                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
3045                         "st1     {v1.4s}, [%1]            \n" //
3046                         "st1     {v2.4s}, [%2]            \n" //
3047                         "st1     {v3.4s}, [%3]            \n" //
3048 
3049                         : "=r"(output0_tm), // %0
3050                         "=r"(output1_tm), // %1
3051                         "=r"(output2_tm), // %2
3052                         "=r"(output3_tm), // %3
3053                         "=r"(r0),         // %4
3054                         "=r"(kptr)        // %5
3055                         : "0"(output0_tm),
3056                         "1"(output1_tm),
3057                         "2"(output2_tm),
3058                         "3"(output3_tm),
3059                         "4"(r0),
3060                         "5"(kptr),
3061                         "r"(inch) // %12
3062                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
3063 #else
3064                     asm volatile(
3065                         // inch loop
3066                         "vmov.s32    q0, #0           \n"
3067                         "vmov.s32    q1, #0           \n"
3068                         "vmov.s32    q2, #0           \n"
3069                         "vmov.s32    q3, #0           \n"
3070                         "mov         r4, %12          \n"
3071 
3072                         "0:                           \n" // for (int q=0; q<inch; q++)
3073                         "vld1.s16    {d16}, [%4]!     \n" // _r0 = vld1_s16(r0);  // input inch0
3074                         "vld1.s16    {d18-d19}, [%5]  \n" // _k01 = vld1q_s16(kptr);
3075                         "add         %5, #16          \n"
3076                         "vld1.s16    {d20-d21}, [%5]  \n" // _k23 = vld1q_s16(kptr+8);
3077                         "add         %5, #16          \n"
3078 
3079                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
3080                         "vmlal.s16   q1, d16, d19     \n" // sum1 += (a00-a03) * (k10-k13)
3081                         "vmlal.s16   q2, d16, d20     \n" // sum2 += (a00-a03) * (k20-k23)
3082                         "vmlal.s16   q3, d16, d21     \n" // sum3 += (a00-a03) * (k30-k33)
3083 
3084                         "subs        r4, r4, #1       \n"
3085                         "bne         0b               \n" // end for
3086 
3087                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
3088                         "vst1.s32    {d2-d3}, [%1]    \n"
3089                         "vst1.s32    {d4-d5}, [%2]    \n"
3090                         "vst1.s32    {d6-d7}, [%3]    \n"
3091 
3092                         : "=r"(output0_tm), // %0
3093                         "=r"(output1_tm), // %1
3094                         "=r"(output2_tm), // %2
3095                         "=r"(output3_tm), // %3
3096                         "=r"(r0),         // %4
3097                         "=r"(kptr)        // %5
3098                         : "0"(output0_tm),
3099                         "1"(output1_tm),
3100                         "2"(output2_tm),
3101                         "3"(output3_tm),
3102                         "4"(r0),
3103                         "5"(kptr),
3104                         "r"(inch) // %12
3105                         : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10");
3106 #endif // __aarch64__
3107 #else
3108                     int sum0[4] = {0};
3109                     int sum1[4] = {0};
3110                     int sum2[4] = {0};
3111                     int sum3[4] = {0};
3112 
3113                     for (int q = 0; q < inch; q++)
3114                     {
3115                         for (int n = 0; n < 4; n++)
3116                         {
3117                             sum0[n] += (int)r0[n] * kptr[n];
3118                             sum1[n] += (int)r0[n] * kptr[n + 4];
3119                             sum2[n] += (int)r0[n] * kptr[n + 8];
3120                             sum3[n] += (int)r0[n] * kptr[n + 12];
3121                         }
3122                         kptr += 16;
3123                         r0 += 4;
3124                     }
3125 
3126                     for (int n = 0; n < 4; n++)
3127                     {
3128                         output0_tm[n] = sum0[n];
3129                         output1_tm[n] = sum1[n];
3130                         output2_tm[n] = sum2[n];
3131                         output3_tm[n] = sum3[n];
3132                     }
3133 #endif // __ARM_NEON
3134                     output0_tm += 36;
3135                     output1_tm += 36;
3136                     output2_tm += 36;
3137                     output3_tm += 36;
3138                 }
3139             }
3140 
3141             remain_outch_start += nn_outch << 2;
3142 
3143             for (int p = remain_outch_start; p < outch; p++)
3144             {
3145                 int* output0_tm = top_blob_tm.channel(p);
3146 
3147                 output0_tm = output0_tm + r * 4;
3148 
3149                 for (int i = 0; i < tiles; i++)
3150                 {
3151                     const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
3152                     const short* r0 = bottom_blob_tm.channel(tiles * r + i);
3153 #if __ARM_NEON
3154 #if __aarch64__
3155                     asm volatile(
3156                         // inch loop
3157                         "eor    v0.16b, v0.16b, v0.16b    \n"
3158                         "mov    w4, %w6                   \n"
3159 
3160                         "0:                               \n" // for (int q=0; q<inch; q++)
3161                         "ld1     {v8.4h}, [%1]            \n" // _r0 = vld1_s16(r0);  // input inch0
3162                         "ld1     {v9.4h}, [%2]            \n" // _k0 = vld1q_s16(kptr);
3163                         "add     %1, %1, #8               \n"
3164                         "add     %2, %2, #8               \n"
3165 
3166                         "subs    w4, w4, #1               \n"
3167 
3168                         "smlal   v0.4s, v8.4h, v9.4h      \n" // sum0 += (a00-a03) * (k00-k03)
3169 
3170                         "bne     0b                       \n" // end for
3171 
3172                         "st1     {v0.4s}, [%0]            \n" // store the result to memory
3173 
3174                         : "=r"(output0_tm), // %0
3175                         "=r"(r0),         // %1
3176                         "=r"(kptr)        // %2
3177                         : "0"(output0_tm),
3178                         "1"(r0),
3179                         "2"(kptr),
3180                         "r"(inch) // %6
3181                         : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9");
3182 #else
3183                     asm volatile(
3184                         // inch loop
3185                         "vmov.s32    q0, #0           \n"
3186                         "mov         r4, %6           \n"
3187 
3188                         "0:                           \n" // for (int q=0; q<inch; q++)
3189                         "vld1.s16    {d16}, [%1]      \n" // _r0 = vld1_s16(r0);  // input inch0
3190                         "add         %1, #8           \n"
3191                         "vld1.s16    {d18}, [%2]      \n" // _k0 = vld1q_s16(kptr);
3192                         "add         %2, #8           \n"
3193                         "vmlal.s16   q0, d16, d18     \n" // sum0 += (a00-a03) * (k00-k03)
3194 
3195                         "subs        r4, r4, #1       \n"
3196                         "bne         0b               \n" // end for
3197 
3198                         "vst1.s32    {d0-d1}, [%0]    \n" // store the result to memory
3199 
3200                         : "=r"(output0_tm), // %0
3201                         "=r"(r0),         // %1
3202                         "=r"(kptr)        // %2
3203                         : "0"(output0_tm),
3204                         "1"(r0),
3205                         "2"(kptr),
3206                         "r"(inch) // %6
3207                         : "cc", "memory", "r4", "q0", "q8", "q9");
3208 #endif // __aarch64__
3209 #else  // __ARM_NEON
3210                     int sum0[4] = {0};
3211 
3212                     for (int q = 0; q < inch; q++)
3213                     {
3214                         for (int n = 0; n < 4; n++)
3215                         {
3216                             sum0[n] += (int)r0[n] * kptr[n];
3217                         }
3218                         kptr += 4;
3219                         r0 += 4;
3220                     }
3221 
3222                     for (int n = 0; n < 4; n++)
3223                     {
3224                         output0_tm[n] = sum0[n];
3225                     }
3226 #endif // __ARM_NEON
3227                     output0_tm += 36;
3228                 }
3229             }
3230 
3231             // for (int p=0; p<outch; p++)
3232             // {
3233             //     Mat out0_tm = top_blob_tm.channel(p);
3234             //     const Mat kernel0_tm = kernel_tm.channel(p);
3235 
3236             //     for (int i=0; i<tiles; i++)
3237             //     {
3238             //         int* output0_tm = out0_tm.row<int>(i);
3239 
3240             //         int sum0[36] = {0};
3241 
3242             //         for (int q=0; q<inch; q++)
3243             //         {
3244             //             const short* r0 = bottom_blob_tm.channel(q).row<short>(i);
3245             //             const short* k0 = kernel0_tm.row<short>(q);
3246 
3247             //             for (int n=0; n<36; n++)
3248             //             {
3249             //                 sum0[n] += (int)r0[n] * k0[n];
3250             //             }
3251             //         }
3252 
3253             //         for (int n=0; n<36; n++)
3254             //         {
3255             //             output0_tm[n] = sum0[n];
3256             //         }
3257             //     }
3258             // }
3259         }
3260     }
3261     bottom_blob_tm = Mat();
3262     // END dot
3263 
3264     // BEGIN transform output
3265     Mat top_blob_bordered;
3266     top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
3267     {
3268         // AT
3269         // const float itm[4][6] = {
3270         //     {1.0f, 1.0f,  1.0f, 1.0f,  1.0f, 0.0f},
3271         //     {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
3272         //     {0.0f, 1.0f,  1.0f, 4.0f,  4.0f, 0.0f},
3273         //     {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
3274         // };
3275 
3276         // 0 =	r00 + r01 + r02 + r03 +	r04
3277         // 1 =		  r01 - r02 + 2 * (r03 - r04)
3278         // 2 =		  r01 + r02 + 4 * (r03 + r04)
3279         // 3 =		  r01 - r02 + 8 * (r03 - r04)  + r05
3280 
3281         int w_tm = outw / 4 * 6;
3282         int h_tm = outh / 4 * 6;
3283 
3284         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
3285         int nRowBlocks = w_tm / 6;
3286 
3287         #pragma omp parallel for num_threads(opt.num_threads)
3288         for (int p = 0; p < outch; p++)
3289         {
3290             int* out_tile = top_blob_tm.channel(p);
3291             float* outRow0 = top_blob_bordered.channel(p);
3292             float* outRow1 = outRow0 + outw;
3293             float* outRow2 = outRow0 + outw * 2;
3294             float* outRow3 = outRow0 + outw * 3;
3295 
3296             const float bias0 = bias ? bias[p] : 0.f;
3297 
3298             const float scale_dequant0 = scales_dequant[p];
3299 
3300             const float scale0 = scale_dequant0 / 576.0;
3301 
3302             for (int j = 0; j < nColBlocks; j++)
3303             {
3304                 for (int i = 0; i < nRowBlocks; i++)
3305                 {
3306 #if __ARM_NEON
3307                     int32x4_t _s0, _s1, _s2, _s3, _s4, _s5;
3308                     int32x2_t _s0n, _s1n, _s2n, _s3n, _s4n, _s5n;
3309                     int32x4_t _w0, _w3;
3310                     int32x2_t _w0n, _w3n;
3311                     int32x4_t _d0, _d1, _d2, _d3, _d4, _d5;
3312                     int32x4_t _o0, _o3;
3313                     // load
3314                     _s0 = vld1q_s32(out_tile);
3315                     _s0n = vld1_s32(out_tile + 4);
3316                     _s1 = vld1q_s32(out_tile + 6);
3317                     _s1n = vld1_s32(out_tile + 10);
3318                     _s2 = vld1q_s32(out_tile + 12);
3319                     _s2n = vld1_s32(out_tile + 16);
3320                     _s3 = vld1q_s32(out_tile + 18);
3321                     _s3n = vld1_s32(out_tile + 22);
3322                     _s4 = vld1q_s32(out_tile + 24);
3323                     _s4n = vld1_s32(out_tile + 28);
3324                     _s5 = vld1q_s32(out_tile + 30);
3325                     _s5n = vld1_s32(out_tile + 34);
3326                     // w = A_T * W
3327                     int32x2_t _tp0 = {1, 4};
3328                     int32x2_t _tp1 = {2, 8};
3329 
3330                     // 4*s5[n]
3331                     int32x4_t _s5x4 = vshlq_n_s32(_s5, 2);
3332                     int32x2_t _s5x4n = vshl_n_s32(_s5n, 2);
3333 
3334                     int32x4_t _t1p2 = vaddq_s32(_s1, _s2);
3335                     int32x2_t _t1p2n = vadd_s32(_s1n, _s2n);
3336                     int32x4_t _t3p4 = vaddq_s32(_s3, _s4);
3337                     int32x2_t _t3p4n = vadd_s32(_s3n, _s4n);
3338                     int32x4_t _t1s2 = vsubq_s32(_s1, _s2);
3339                     int32x2_t _t1s2n = vsub_s32(_s1n, _s2n);
3340                     int32x4_t _t3s4 = vsubq_s32(_s3, _s4);
3341                     int32x2_t _t3s4n = vsub_s32(_s3n, _s4n);
3342 
3343                     _w0 = vaddq_s32(_s0, _t1p2);
3344                     _w0n = vadd_s32(_s0n, _t1p2n);
3345                     _w0 = vaddq_s32(_w0, _t3p4);
3346                     _w0n = vadd_s32(_w0n, _t3p4n);
3347                     _w0n = vmul_s32(_w0n, _tp0);
3348 
3349                     // _w2,_w2n
3350                     _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
3351                     _t1p2n = vmla_lane_s32(_t1p2n, _t3p4n, _tp0, 1);
3352                     _t1p2n = vmul_s32(_t1p2n, _tp0);
3353 
3354                     _w3 = vaddq_s32(_s5x4, _t1s2);
3355                     _w3n = vadd_s32(_s5x4n, _t1s2n);
3356                     _w3 = vmlaq_lane_s32(_w3, _t3s4, _tp1, 1);
3357                     _w3n = vmla_lane_s32(_w3n, _t3s4n, _tp1, 1);
3358                     _w3n = vmul_s32(_w3n, _tp0);
3359 
3360                     // _w1, _w1n
3361                     _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
3362                     _t1s2n = vmla_lane_s32(_t1s2n, _t3s4n, _tp1, 0);
3363                     _t1s2n = vmul_s32(_t1s2n, _tp0);
3364 
3365                     int32x4_t _w02n = vcombine_s32(_w0n, _t1p2n);
3366                     int32x4_t _w13n = vcombine_s32(_t1s2n, _w3n);
3367 
3368                     // transpose w to w_t
3369 #if __aarch64__
3370                     int32x4_t _wt0 = vtrn1q_s32(_w0, _t1s2);
3371                     int32x4_t _wt1 = vtrn2q_s32(_w0, _t1s2);
3372                     int32x4_t _wt2 = vtrn1q_s32(_t1p2, _w3);
3373                     int32x4_t _wt3 = vtrn2q_s32(_t1p2, _w3);
3374                     int64x2_t _dt0 = vtrn1q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
3375                     int64x2_t _dt2 = vtrn2q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
3376                     int64x2_t _dt1 = vtrn1q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
3377                     int64x2_t _dt3 = vtrn2q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
3378                     _d0 = vreinterpretq_s32_s64(_dt0);
3379                     _d1 = vreinterpretq_s32_s64(_dt1);
3380                     _d2 = vreinterpretq_s32_s64(_dt2);
3381                     _d3 = vreinterpretq_s32_s64(_dt3);
3382                     _d4 = vtrn1q_s32(_w02n, _w13n);
3383                     _d5 = vtrn2q_s32(_w02n, _w13n);
3384 #else
3385                     asm volatile(
3386                         "vtrn.32    %q[_w0], %q[_w1]        \n"
3387                         "vtrn.32    %q[_w2], %q[_w3]        \n"
3388                         "vswp       %f[_w0], %e[_w2]        \n"
3389                         "vswp       %f[_w1], %e[_w3]        \n"
3390                         "vtrn.32    %q[_w02n], %q[_w13n]    \n"
3391                         : [_w0] "+w"(_w0),
3392                         [_w1] "+w"(_t1s2),
3393                         [_w2] "+w"(_t1p2),
3394                         [_w3] "+w"(_w3),
3395                         [_w02n] "+w"(_w02n),
3396                         [_w13n] "+w"(_w13n)
3397                         :
3398                         : "cc", "memory");
3399                     _d0 = _w0;
3400                     _d1 = _t1s2;
3401                     _d2 = _t1p2;
3402                     _d3 = _w3;
3403                     _d4 = _w02n;
3404                     _d5 = _w13n;
3405 #endif
3406                     // Y = A_T * w_t
3407                     _t1p2 = vaddq_s32(_d1, _d2);
3408                     _t3p4 = vaddq_s32(_d3, _d4);
3409                     _t1s2 = vsubq_s32(_d1, _d2);
3410                     _t3s4 = vsubq_s32(_d3, _d4);
3411 
3412                     _o0 = vaddq_s32(_d0, _t1p2);
3413                     _o0 = vaddq_s32(_o0, _t3p4);
3414 
3415                     // _o2
3416                     _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
3417 
3418                     _o3 = vaddq_s32(_d5, _t1s2);
3419                     _o3 = vmlaq_lane_s32(_o3, _t3s4, _tp1, 1);
3420 
3421                     // _o1
3422                     _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
3423 
3424                     // save to top blob tm
3425                     float32x4_t _scale0 = vdupq_n_f32(scale0);
3426                     float32x4_t _out0_f32 = vdupq_n_f32(bias0);
3427                     float32x4_t _out1_f32 = vdupq_n_f32(bias0);
3428                     float32x4_t _out2_f32 = vdupq_n_f32(bias0);
3429                     float32x4_t _out3_f32 = vdupq_n_f32(bias0);
3430 
3431                     _out0_f32 = vmlaq_f32(_out0_f32, vcvtq_f32_s32(_o0), _scale0);
3432                     _out1_f32 = vmlaq_f32(_out1_f32, vcvtq_f32_s32(_t1s2), _scale0);
3433                     _out2_f32 = vmlaq_f32(_out2_f32, vcvtq_f32_s32(_t1p2), _scale0);
3434                     _out3_f32 = vmlaq_f32(_out3_f32, vcvtq_f32_s32(_o3), _scale0);
3435 
3436                     vst1q_f32(outRow0, _out0_f32);
3437                     vst1q_f32(outRow1, _out1_f32);
3438                     vst1q_f32(outRow2, _out2_f32);
3439                     vst1q_f32(outRow3, _out3_f32);
3440 #else
3441                     int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6];
3442                     int w0[6], w1[6], w2[6], w3[6];
3443                     int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4];
3444                     int o0[4], o1[4], o2[4], o3[4];
3445 
3446                     // load
3447                     for (int n = 0; n < 6; n++)
3448                     {
3449                         s0[n] = out_tile[n];
3450                         s1[n] = out_tile[n + 6];
3451                         s2[n] = out_tile[n + 12];
3452                         s3[n] = out_tile[n + 18];
3453                         s4[n] = out_tile[n + 24];
3454                         s5[n] = out_tile[n + 30];
3455                     }
3456                     // w = A_T * W
3457                     for (int n = 0; n < 5; n++)
3458                     {
3459                         w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n];
3460                         w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n];
3461                         w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n];
3462                         w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n];
3463                     }
3464                     for (int n = 5; n < 6; n++)
3465                     {
3466                         w0[n] = 4 * (s0[n] + s1[n] + s2[n] + s3[n] + s4[n]);
3467                         w1[n] = 4 * (s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]);
3468                         w2[n] = 4 * (s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]);
3469                         w3[n] = 4 * (s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n]);
3470                     }
3471                     // transpose w to w_t
3472                     {
3473                         d0[0] = w0[0];
3474                         d0[1] = w1[0];
3475                         d0[2] = w2[0];
3476                         d0[3] = w3[0];
3477                         d1[0] = w0[1];
3478                         d1[1] = w1[1];
3479                         d1[2] = w2[1];
3480                         d1[3] = w3[1];
3481                         d2[0] = w0[2];
3482                         d2[1] = w1[2];
3483                         d2[2] = w2[2];
3484                         d2[3] = w3[2];
3485                         d3[0] = w0[3];
3486                         d3[1] = w1[3];
3487                         d3[2] = w2[3];
3488                         d3[3] = w3[3];
3489                         d4[0] = w0[4];
3490                         d4[1] = w1[4];
3491                         d4[2] = w2[4];
3492                         d4[3] = w3[4];
3493                         d5[0] = w0[5];
3494                         d5[1] = w1[5];
3495                         d5[2] = w2[5];
3496                         d5[3] = w3[5];
3497                     }
3498                     // Y = A_T * w_t
3499                     for (int n = 0; n < 4; n++)
3500                     {
3501                         o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n];
3502                         o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n];
3503                         o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n];
3504                         o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n];
3505                     }
3506                     // save to top blob tm
3507                     for (int n = 0; n < 4; n++)
3508                     {
3509                         outRow0[n] = (float)o0[n] * scale0 + bias0;
3510                         outRow1[n] = (float)o1[n] * scale0 + bias0;
3511                         outRow2[n] = (float)o2[n] * scale0 + bias0;
3512                         outRow3[n] = (float)o3[n] * scale0 + bias0;
3513                     }
3514 #endif // __ARM_NEON
3515                     out_tile += 36;
3516 
3517                     outRow0 += 4;
3518                     outRow1 += 4;
3519                     outRow2 += 4;
3520                     outRow3 += 4;
3521                 }
3522 
3523                 outRow0 += outw * 3;
3524                 outRow1 += outw * 3;
3525                 outRow2 += outw * 3;
3526                 outRow3 += outw * 3;
3527             }
3528         }
3529     }
3530     // END transform output
3531 
3532     // cut result pad
3533     copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
3534 }
3535 
conv3x3s2_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch)3536 static void conv3x3s2_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch)
3537 {
3538     kernel_tm.create(8 * 9, inch, outch / 8 + outch % 8, (size_t)1u);
3539 
3540     const signed char* kernel = _kernel;
3541 
3542     int p = 0;
3543     for (; p + 7 < outch; p += 8)
3544     {
3545         const signed char* k0 = kernel + (p + 0) * inch * 9;
3546         const signed char* k1 = kernel + (p + 1) * inch * 9;
3547         const signed char* k2 = kernel + (p + 2) * inch * 9;
3548         const signed char* k3 = kernel + (p + 3) * inch * 9;
3549         const signed char* k4 = kernel + (p + 4) * inch * 9;
3550         const signed char* k5 = kernel + (p + 5) * inch * 9;
3551         const signed char* k6 = kernel + (p + 6) * inch * 9;
3552         const signed char* k7 = kernel + (p + 7) * inch * 9;
3553 
3554         signed char* ktmp = kernel_tm.channel(p / 8);
3555 
3556         for (int q = 0; q < inch; q++)
3557         {
3558             for (int k = 0; k < 9; k++)
3559             {
3560                 ktmp[0] = k0[k];
3561                 ktmp[1] = k1[k];
3562                 ktmp[2] = k2[k];
3563                 ktmp[3] = k3[k];
3564                 ktmp[4] = k4[k];
3565                 ktmp[5] = k5[k];
3566                 ktmp[6] = k6[k];
3567                 ktmp[7] = k7[k];
3568                 ktmp += 8;
3569             }
3570 
3571             k0 += 9;
3572             k1 += 9;
3573             k2 += 9;
3574             k3 += 9;
3575             k4 += 9;
3576             k5 += 9;
3577             k6 += 9;
3578             k7 += 9;
3579         }
3580     }
3581     for (; p < outch; p++)
3582     {
3583         const signed char* k0 = kernel + (p + 0) * inch * 9;
3584 
3585         signed char* ktmp = kernel_tm.channel(p / 8 + p % 8);
3586 
3587         for (int q = 0; q < inch; q++)
3588         {
3589             for (int k = 0; k < 9; k++)
3590             {
3591                 ktmp[k] = k0[k];
3592             }
3593             ktmp += 9;
3594 
3595             k0 += 9;
3596         }
3597     }
3598 }
3599 
conv3x3s2_packed_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Option & opt)3600 static void conv3x3s2_packed_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt)
3601 {
3602     int w = bottom_blob.w;
3603     int inch = bottom_blob.c;
3604 
3605     int outw = top_blob.w;
3606     int outh = top_blob.h;
3607     int outch = top_blob.c;
3608 
3609     const int tailstep = w - 2 * outw + w;
3610 
3611     int nn_outch = outch >> 3;
3612     int remain_outch_start = nn_outch << 3;
3613 
3614     #pragma omp parallel for num_threads(opt.num_threads)
3615     for (int pp = 0; pp < nn_outch; pp++)
3616     {
3617         int p = pp * 8;
3618 
3619         Mat out0 = top_blob.channel(p + 0);
3620         Mat out1 = top_blob.channel(p + 1);
3621         Mat out2 = top_blob.channel(p + 2);
3622         Mat out3 = top_blob.channel(p + 3);
3623         Mat out4 = top_blob.channel(p + 4);
3624         Mat out5 = top_blob.channel(p + 5);
3625         Mat out6 = top_blob.channel(p + 6);
3626         Mat out7 = top_blob.channel(p + 7);
3627 
3628         out0.fill(0);
3629         out1.fill(0);
3630         out2.fill(0);
3631         out3.fill(0);
3632         out4.fill(0);
3633         out5.fill(0);
3634         out6.fill(0);
3635         out7.fill(0);
3636 
3637         const signed char* ktmp = _kernel.channel(p / 8);
3638 
3639         for (int q = 0; q < inch; q++)
3640         {
3641             int* outptr0 = out0;
3642             int* outptr1 = out1;
3643             int* outptr2 = out2;
3644             int* outptr3 = out3;
3645             int* outptr4 = out4;
3646             int* outptr5 = out5;
3647             int* outptr6 = out6;
3648             int* outptr7 = out7;
3649 
3650             const signed char* img0 = bottom_blob.channel(q);
3651 
3652             const signed char* r0 = img0;
3653             const signed char* r1 = img0 + w;
3654             const signed char* r2 = img0 + w * 2;
3655 
3656             int i = 0;
3657 
3658             for (; i < outh; i++)
3659             {
3660 #if __ARM_NEON
3661 #if __aarch64__
3662                 int nn = outw >> 3;
3663                 int remain = outw & 7;
3664 #else
3665                 int nn = outw >> 2;
3666                 int remain = outw & 3;
3667 #endif // __aarch64__
3668 #else
3669                 int remain = outw;
3670 #endif // __ARM_NEON
3671 
3672 #if __ARM_NEON
3673 #if __aarch64__
3674                 if (nn > 0)
3675                 {
3676                     asm volatile(
3677                         "0:                                   \n"
3678 
3679                         "ld1    {v0.8b, v1.8b, v2.8b}, [%12], #24  \n" //ktmp
3680                         "ld2    {v3.8b, v4.8b}, [%9], #16     \n"      //r0-r2
3681                         "ld2    {v5.8b, v6.8b}, [%9]          \n"
3682 
3683                         "ld1    {v8.4s, v9.4s}, [%1]          \n" //out0
3684                         "ld1    {v10.4s, v11.4s}, [%2]        \n" //out1
3685                         "ld1    {v12.4s, v13.4s}, [%3]        \n" //out2
3686                         "ld1    {v14.4s, v15.4s}, [%4]        \n" //out3
3687                         "ld1    {v16.4s, v17.4s}, [%5]        \n" //out4
3688                         "ld1    {v18.4s, v19.4s}, [%6]        \n" //out5
3689                         "ld1    {v20.4s, v21.4s}, [%7]        \n" //out6
3690                         "ld1    {v22.4s, v23.4s}, [%8]        \n" //out7
3691 
3692                         "ext    v7.8b, v3.8b, v5.8b, #1       \n"
3693 
3694                         "sshll  v0.8h, v0.8b, #0              \n" //(k00-k70)
3695                         "sshll  v1.8h, v1.8b, #0              \n" //(k01-k71)
3696                         "sshll  v2.8h, v2.8b, #0              \n" //(k02-k72)
3697                         "sshll  v3.8h, v3.8b, #0              \n" // r0
3698                         "sshll  v4.8h, v4.8b, #0              \n" // r1
3699                         "sshll  v7.8h, v7.8b, #0              \n" // r2
3700 
3701                         // r0
3702                         "smlal  v8.4s, v3.4h, v0.h[0]         \n" // out0 += (r00-r07)*k00
3703                         "smlal2  v9.4s, v3.8h, v0.h[0]        \n"
3704                         "smlal  v10.4s, v3.4h, v0.h[1]        \n" // out1 += (r00-r07)*k10
3705                         "smlal2  v11.4s, v3.8h, v0.h[1]       \n"
3706                         "smlal  v12.4s, v3.4h, v0.h[2]        \n" // out2 += (r00-r07)*k20
3707                         "smlal2  v13.4s, v3.8h, v0.h[2]       \n"
3708                         "smlal  v14.4s, v3.4h, v0.h[3]        \n" // out3 += (r00-r07)*k30
3709                         "smlal2  v15.4s, v3.8h, v0.h[3]       \n"
3710                         "smlal  v16.4s, v3.4h, v0.h[4]        \n" // out4 += (r00-r07)*k40
3711                         "smlal2  v17.4s, v3.8h, v0.h[4]       \n"
3712                         "smlal  v18.4s, v3.4h, v0.h[5]        \n" // out5 += (r00-r07)*k50
3713                         "smlal2  v19.4s, v3.8h, v0.h[5]       \n"
3714                         "smlal  v20.4s, v3.4h, v0.h[6]        \n" // out6 += (r00-r07)*k60
3715                         "smlal2  v21.4s, v3.8h, v0.h[6]       \n"
3716                         "smlal  v22.4s, v3.4h, v0.h[7]        \n" // out7 += (r00-r07)*k70
3717                         "smlal2  v23.4s, v3.8h, v0.h[7]       \n"
3718                         // r1
3719                         "smlal  v8.4s, v4.4h, v1.h[0]         \n" // out0 += (r10-r17)*k01
3720                         "smlal2  v9.4s, v4.8h, v1.h[0]        \n"
3721                         "smlal  v10.4s, v4.4h, v1.h[1]        \n" // out1 += (r10-r17)*k11
3722                         "smlal2  v11.4s, v4.8h, v1.h[1]       \n"
3723                         "smlal  v12.4s, v4.4h, v1.h[2]        \n" // out2 += (r10-r17)*k21
3724                         "smlal2  v13.4s, v4.8h, v1.h[2]       \n"
3725                         "smlal  v14.4s, v4.4h, v1.h[3]        \n" // out3 += (r10-r17)*k31
3726                         "smlal2  v15.4s, v4.8h, v1.h[3]       \n"
3727                         "smlal  v16.4s, v4.4h, v1.h[4]        \n" // out4 += (r10-r17)*k41
3728                         "smlal2  v17.4s, v4.8h, v1.h[4]       \n"
3729                         "smlal  v18.4s, v4.4h, v1.h[5]        \n" // out5 += (r10-r17)*k51
3730                         "smlal2  v19.4s, v4.8h, v1.h[5]       \n"
3731                         "smlal  v20.4s, v4.4h, v1.h[6]        \n" // out6 += (r10-r17)*k61
3732                         "smlal2  v21.4s, v4.8h, v1.h[6]       \n"
3733                         "smlal  v22.4s, v4.4h, v1.h[7]        \n" // out7 += (r10-r17)*k71
3734                         "smlal2  v23.4s, v4.8h, v1.h[7]       \n"
3735                         // r2
3736                         "smlal  v8.4s, v7.4h, v2.h[0]         \n" // out0 += (r20-r27)*k02
3737                         "smlal2  v9.4s, v7.8h, v2.h[0]        \n"
3738                         "smlal  v10.4s, v7.4h, v2.h[1]        \n" // out1 += (r20-r27)*k12
3739                         "smlal2  v11.4s, v7.8h, v2.h[1]       \n"
3740                         "smlal  v12.4s, v7.4h, v2.h[2]        \n" // out2 += (r20-r27)*k22
3741                         "smlal2  v13.4s, v7.8h, v2.h[2]       \n"
3742                         "smlal  v14.4s, v7.4h, v2.h[3]        \n" // out3 += (r20-r27)*k32
3743                         "smlal2  v15.4s, v7.8h, v2.h[3]       \n"
3744                         "smlal  v16.4s, v7.4h, v2.h[4]        \n" // out4 += (r20-r27)*k42
3745                         "smlal2  v17.4s, v7.8h, v2.h[4]       \n"
3746                         "smlal  v18.4s, v7.4h, v2.h[5]        \n" // out5 += (r20-r27)*k52
3747                         "smlal2  v19.4s, v7.8h, v2.h[5]       \n"
3748                         "smlal  v20.4s, v7.4h, v2.h[6]        \n" // out6 += (r20-r27)*k62
3749                         "smlal2  v21.4s, v7.8h, v2.h[6]       \n"
3750                         "smlal  v22.4s, v7.4h, v2.h[7]        \n" // out7 += (r20-r27)*k72
3751                         "smlal2  v23.4s, v7.8h, v2.h[7]       \n"
3752 
3753                         "ld1    {v0.8b, v1.8b, v2.8b}, [%12], #24  \n" //ktmp
3754                         "ld2    {v3.8b, v4.8b}, [%10], #16    \n"      //r3-r5
3755                         "ld2    {v5.8b, v6.8b}, [%10]         \n"
3756 
3757                         "ext    v7.8b, v3.8b, v5.8b, #1       \n"
3758 
3759                         "sshll  v0.8h, v0.8b, #0              \n" //(k03-k73)
3760                         "sshll  v1.8h, v1.8b, #0              \n" //(k04-k74)
3761                         "sshll  v2.8h, v2.8b, #0              \n" //(k05-k75)
3762                         "sshll  v3.8h, v3.8b, #0              \n" // r3
3763                         "sshll  v4.8h, v4.8b, #0              \n" // r4
3764                         "sshll  v7.8h, v7.8b, #0              \n" // r5
3765 
3766                         // r3
3767                         "smlal  v8.4s, v3.4h, v0.h[0]         \n" // out0 += (r30-r37)*k03
3768                         "smlal2  v9.4s, v3.8h, v0.h[0]        \n"
3769                         "smlal  v10.4s, v3.4h, v0.h[1]        \n" // out1 += (r30-r37)*k13
3770                         "smlal2  v11.4s, v3.8h, v0.h[1]       \n"
3771                         "smlal  v12.4s, v3.4h, v0.h[2]        \n" // out2 += (r30-r37)*k23
3772                         "smlal2  v13.4s, v3.8h, v0.h[2]       \n"
3773                         "smlal  v14.4s, v3.4h, v0.h[3]        \n" // out3 += (r30-r37)*k33
3774                         "smlal2  v15.4s, v3.8h, v0.h[3]       \n"
3775                         "smlal  v16.4s, v3.4h, v0.h[4]        \n" // out4 += (r30-r37)*k43
3776                         "smlal2  v17.4s, v3.8h, v0.h[4]       \n"
3777                         "smlal  v18.4s, v3.4h, v0.h[5]        \n" // out5 += (r30-r37)*k53
3778                         "smlal2  v19.4s, v3.8h, v0.h[5]       \n"
3779                         "smlal  v20.4s, v3.4h, v0.h[6]        \n" // out6 += (r30-r37)*k63
3780                         "smlal2  v21.4s, v3.8h, v0.h[6]       \n"
3781                         "smlal  v22.4s, v3.4h, v0.h[7]        \n" // out7 += (r30-r37)*k73
3782                         "smlal2  v23.4s, v3.8h, v0.h[7]       \n"
3783                         // r4
3784                         "smlal  v8.4s, v4.4h, v1.h[0]         \n" // out0 += (r40-r47)*k04
3785                         "smlal2  v9.4s, v4.8h, v1.h[0]        \n"
3786                         "smlal  v10.4s, v4.4h, v1.h[1]        \n" // out1 += (r40-r47)*k14
3787                         "smlal2  v11.4s, v4.8h, v1.h[1]       \n"
3788                         "smlal  v12.4s, v4.4h, v1.h[2]        \n" // out2 += (r40-r47)*k24
3789                         "smlal2  v13.4s, v4.8h, v1.h[2]       \n"
3790                         "smlal  v14.4s, v4.4h, v1.h[3]        \n" // out3 += (r40-r47)*k34
3791                         "smlal2  v15.4s, v4.8h, v1.h[3]       \n"
3792                         "smlal  v16.4s, v4.4h, v1.h[4]        \n" // out4 += (r40-r47)*k44
3793                         "smlal2  v17.4s, v4.8h, v1.h[4]       \n"
3794                         "smlal  v18.4s, v4.4h, v1.h[5]        \n" // out5 += (r40-r47)*k54
3795                         "smlal2  v19.4s, v4.8h, v1.h[5]       \n"
3796                         "smlal  v20.4s, v4.4h, v1.h[6]        \n" // out6 += (r40-r47)*k64
3797                         "smlal2  v21.4s, v4.8h, v1.h[6]       \n"
3798                         "smlal  v22.4s, v4.4h, v1.h[7]        \n" // out7 += (r40-r47)*k74
3799                         "smlal2  v23.4s, v4.8h, v1.h[7]       \n"
3800                         // r5
3801                         "smlal  v8.4s, v7.4h, v2.h[0]         \n" // out0 += (r50-r57)*k05
3802                         "smlal2  v9.4s, v7.8h, v2.h[0]        \n"
3803                         "smlal  v10.4s, v7.4h, v2.h[1]        \n" // out1 += (r50-r57)*k15
3804                         "smlal2  v11.4s, v7.8h, v2.h[1]       \n"
3805                         "smlal  v12.4s, v7.4h, v2.h[2]        \n" // out2 += (r50-r57)*k25
3806                         "smlal2  v13.4s, v7.8h, v2.h[2]       \n"
3807                         "smlal  v14.4s, v7.4h, v2.h[3]        \n" // out3 += (r50-r57)*k35
3808                         "smlal2  v15.4s, v7.8h, v2.h[3]       \n"
3809                         "smlal  v16.4s, v7.4h, v2.h[4]        \n" // out4 += (r50-r57)*k45
3810                         "smlal2  v17.4s, v7.8h, v2.h[4]       \n"
3811                         "smlal  v18.4s, v7.4h, v2.h[5]        \n" // out5 += (r50-r57)*k55
3812                         "smlal2  v19.4s, v7.8h, v2.h[5]       \n"
3813                         "smlal  v20.4s, v7.4h, v2.h[6]        \n" // out6 += (r50-r57)*k65
3814                         "smlal2  v21.4s, v7.8h, v2.h[6]       \n"
3815                         "smlal  v22.4s, v7.4h, v2.h[7]        \n" // out7 += (r50-r57)*k75
3816                         "smlal2  v23.4s, v7.8h, v2.h[7]       \n"
3817 
3818                         "ld1    {v0.8b, v1.8b, v2.8b}, [%12], #24  \n" //ktmp
3819                         "ld2    {v3.8b, v4.8b}, [%11], #16    \n"      //r6-r8
3820                         "ld2    {v5.8b, v6.8b}, [%11]         \n"
3821 
3822                         "ext    v7.8b, v3.8b, v5.8b, #1       \n"
3823 
3824                         "sshll  v0.8h, v0.8b, #0              \n" //(k06-k76)
3825                         "sshll  v1.8h, v1.8b, #0              \n" //(k07-k77)
3826                         "sshll  v2.8h, v2.8b, #0              \n" //(k08-k78)
3827                         "sshll  v3.8h, v3.8b, #0              \n" // r6
3828                         "sshll  v4.8h, v4.8b, #0              \n" // r7
3829                         "sshll  v7.8h, v7.8b, #0              \n" // r8
3830 
3831                         // r6
3832                         "smlal  v8.4s, v3.4h, v0.h[0]         \n" // out0 += (r60-r67)*k06
3833                         "smlal2  v9.4s, v3.8h, v0.h[0]        \n"
3834                         "smlal  v10.4s, v3.4h, v0.h[1]        \n" // out1 += (r60-r67)*k16
3835                         "smlal2  v11.4s, v3.8h, v0.h[1]       \n"
3836                         "smlal  v12.4s, v3.4h, v0.h[2]        \n" // out2 += (r60-r67)*k26
3837                         "smlal2  v13.4s, v3.8h, v0.h[2]       \n"
3838                         "smlal  v14.4s, v3.4h, v0.h[3]        \n" // out3 += (r60-r67)*k36
3839                         "smlal2  v15.4s, v3.8h, v0.h[3]       \n"
3840                         "smlal  v16.4s, v3.4h, v0.h[4]        \n" // out4 += (r60-r67)*k46
3841                         "smlal2  v17.4s, v3.8h, v0.h[4]       \n"
3842                         "smlal  v18.4s, v3.4h, v0.h[5]        \n" // out5 += (r60-r67)*k56
3843                         "smlal2  v19.4s, v3.8h, v0.h[5]       \n"
3844                         "smlal  v20.4s, v3.4h, v0.h[6]        \n" // out6 += (r60-r67)*k66
3845                         "smlal2  v21.4s, v3.8h, v0.h[6]       \n"
3846                         "smlal  v22.4s, v3.4h, v0.h[7]        \n" // out7 += (r60-r67)*k76
3847                         "smlal2  v23.4s, v3.8h, v0.h[7]       \n"
3848                         // r7
3849                         "smlal  v8.4s, v4.4h, v1.h[0]         \n" // out0 += (r70-r77)*k07
3850                         "smlal2  v9.4s, v4.8h, v1.h[0]        \n"
3851                         "smlal  v10.4s, v4.4h, v1.h[1]        \n" // out1 += (r70-r77)*k17
3852                         "smlal2  v11.4s, v4.8h, v1.h[1]       \n"
3853                         "smlal  v12.4s, v4.4h, v1.h[2]        \n" // out2 += (r70-r77)*k27
3854                         "smlal2  v13.4s, v4.8h, v1.h[2]       \n"
3855                         "smlal  v14.4s, v4.4h, v1.h[3]        \n" // out3 += (r70-r77)*k37
3856                         "smlal2  v15.4s, v4.8h, v1.h[3]       \n"
3857                         "smlal  v16.4s, v4.4h, v1.h[4]        \n" // out4 += (r70-r77)*k47
3858                         "smlal2  v17.4s, v4.8h, v1.h[4]       \n"
3859                         "smlal  v18.4s, v4.4h, v1.h[5]        \n" // out5 += (r70-r77)*k57
3860                         "smlal2  v19.4s, v4.8h, v1.h[5]       \n"
3861                         "smlal  v20.4s, v4.4h, v1.h[6]        \n" // out6 += (r70-r77)*k67
3862                         "smlal2  v21.4s, v4.8h, v1.h[6]       \n"
3863                         "smlal  v22.4s, v4.4h, v1.h[7]        \n" // out7 += (r70-r77)*k77
3864                         "smlal2  v23.4s, v4.8h, v1.h[7]       \n"
3865                         // r8
3866                         "smlal  v8.4s, v7.4h, v2.h[0]         \n" // out0 += (r80-r87)*k08
3867                         "smlal2  v9.4s, v7.8h, v2.h[0]        \n"
3868                         "smlal  v10.4s, v7.4h, v2.h[1]        \n" // out1 += (r80-r87)*k18
3869                         "smlal2  v11.4s, v7.8h, v2.h[1]       \n"
3870                         "smlal  v12.4s, v7.4h, v2.h[2]        \n" // out2 += (r80-r87)*k28
3871                         "smlal2  v13.4s, v7.8h, v2.h[2]       \n"
3872                         "smlal  v14.4s, v7.4h, v2.h[3]        \n" // out3 += (r80-r87)*k38
3873                         "smlal2  v15.4s, v7.8h, v2.h[3]       \n"
3874                         "smlal  v16.4s, v7.4h, v2.h[4]        \n" // out4 += (r80-r87)*k48
3875                         "smlal2  v17.4s, v7.8h, v2.h[4]       \n"
3876                         "smlal  v18.4s, v7.4h, v2.h[5]        \n" // out5 += (r80-r87)*k58
3877                         "smlal2  v19.4s, v7.8h, v2.h[5]       \n"
3878                         "smlal  v20.4s, v7.4h, v2.h[6]        \n" // out6 += (r80-r87)*k68
3879                         "smlal2  v21.4s, v7.8h, v2.h[6]       \n"
3880                         "smlal  v22.4s, v7.4h, v2.h[7]        \n" // out7 += (r80-r87)*k78
3881                         "smlal2  v23.4s, v7.8h, v2.h[7]       \n"
3882 
3883                         "st1    {v8.4s, v9.4s}, [%1], #32     \n"
3884                         "st1    {v10.4s, v11.4s}, [%2], #32   \n"
3885                         "st1    {v12.4s, v13.4s}, [%3], #32   \n"
3886                         "st1    {v14.4s, v15.4s}, [%4], #32   \n"
3887                         "st1    {v16.4s, v17.4s}, [%5], #32   \n"
3888                         "st1    {v18.4s, v19.4s}, [%6], #32   \n"
3889                         "st1    {v20.4s, v21.4s}, [%7], #32   \n"
3890                         "st1    {v22.4s, v23.4s}, [%8], #32   \n"
3891 
3892                         "subs   %w0, %w0, #1                  \n"
3893                         "sub    %12, %12, #72                 \n" // reset ktmp
3894 
3895                         "bne    0b                            \n"
3896 
3897                         : "=r"(nn),      // %0
3898                         "=r"(outptr0), // %1
3899                         "=r"(outptr1), // %2
3900                         "=r"(outptr2), // %3
3901                         "=r"(outptr3), // %4
3902                         "=r"(outptr4), // %5
3903                         "=r"(outptr5), // %6
3904                         "=r"(outptr6), // %7
3905                         "=r"(outptr7), // %8
3906                         "=r"(r0),      // %9
3907                         "=r"(r1),      // %10
3908                         "=r"(r2),      // %11
3909                         "=r"(ktmp)     // %12
3910                         : "0"(nn),
3911                         "1"(outptr0),
3912                         "2"(outptr1),
3913                         "3"(outptr2),
3914                         "4"(outptr3),
3915                         "5"(outptr4),
3916                         "6"(outptr5),
3917                         "7"(outptr6),
3918                         "8"(outptr7),
3919                         "9"(r0),
3920                         "10"(r1),
3921                         "11"(r2),
3922                         "12"(ktmp)
3923                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
3924                 }
3925 #else  // __aarch64__
3926                 if (nn > 0)
3927                 {
3928                     asm volatile(
3929                         "0:                             \n"
3930                         "pld        [%1, #128]          \n"
3931                         "vld1.s32   {d16-d17}, [%1]     \n" // out0
3932                         "pld        [%2, #128]          \n"
3933                         "vld1.s32   {d18-d19}, [%2]     \n" // out1
3934                         "pld        [%3, #128]          \n"
3935                         "vld1.s32   {d20-d21}, [%3]     \n" // out2
3936                         "pld        [%4, #128]          \n"
3937                         "vld1.s32   {d22-d23}, [%4]     \n" // out3
3938 
3939                         // r0
3940                         "pld        [%9, #64]          \n"
3941                         "vld2.s8    {d8-d9}, [%9]       \n" // d8(a00 a02 a04 a06 a08 a010 a012 a014), d9(a01 a03 a05 a07 a09 a011 a013 a015)
3942                         "add        %9, #8              \n"
3943                         "pld        [%12, #64]         \n"
3944                         "vld1.s8    {d0-d2}, [%12]!     \n" // d0(k00-k70) d1(k01-k71) d2(k02-k72)
3945 
3946                         "pld        [%5, #128]          \n"
3947                         "vld1.s32   {d24-d25}, [%5]     \n" // out4
3948                         "pld        [%6, #128]          \n"
3949                         "vld1.s32   {d26-d27}, [%6]     \n" // out5
3950 
3951                         "vmovl.s8   q2, d2              \n" // q2(k02-k72)
3952                         "vmovl.s8   q1, d1              \n" // q1(k01-k71)
3953                         "vmovl.s8   q0, d0              \n" // q0(k00-k70)
3954                         "vext.s8    d12, d8, d8, #1     \n" // d12(a02 a04 a06 a08 x x x x)
3955 
3956                         "pld        [%7, #128]          \n"
3957                         "vld1.s32   {d28-d29}, [%7]     \n" // out6
3958 
3959                         "vmovl.s8   q5, d9              \n" // q5(a01 a03 a05 a07 a09 a011 a013 a015) d11
3960                         "vmovl.s8   q4, d8              \n" // q4(a00 a02 a04 a06 a08 a010 a012 a014) d9
3961                         "vmovl.s8   q6, d12             \n" // q6(a02 a04 a06 a08 a010 a012 a014 a016) d13
3962 
3963                         "pld        [%8, #128]          \n"
3964                         "vld1.s32   {d30-d31}, [%8]     \n" // out7
3965 
3966                         "vmlal.s16  q8, d8, d0[0]       \n" // sum0 += (a00 a02 a04 a06) * k00
3967                         "vmlal.s16  q9, d8, d0[1]       \n" // sum1 += (a00 a02 a04 a06) * k10
3968                         "vmlal.s16  q10, d8, d0[2]      \n" // sum2 += (a00 a02 a04 a06) * k20
3969                         "vmlal.s16  q11, d8, d0[3]      \n" // sum3 += (a00 a02 a04 a06) * k30
3970                         "vmlal.s16  q12, d8, d1[0]      \n" // sum4 += (a00 a02 a04 a06) * k40
3971                         "vmlal.s16  q13, d8, d1[1]      \n" // sum5 += (a00 a02 a04 a06) * k50
3972                         "vmlal.s16  q14, d8, d1[2]      \n" // sum6 += (a00 a02 a04 a06) * k60
3973                         "vmlal.s16  q15, d8, d1[3]      \n" // sum7 += (a00 a02 a04 a06) * k70
3974 
3975                         "vmlal.s16  q8, d10, d2[0]      \n" // sum0 += (a01-a07) * k01
3976                         "vmlal.s16  q9, d10, d2[1]      \n" // sum1 += (a01-a07) * k11
3977                         "vmlal.s16  q10, d10, d2[2]     \n" // sum2 += (a01-a07) * k21
3978                         "vmlal.s16  q11, d10, d2[3]     \n" // sum3 += (a01-a07) * k31
3979                         "vmlal.s16  q12, d10, d3[0]     \n" // sum4 += (a01-a07) * k41
3980                         "vmlal.s16  q13, d10, d3[1]     \n" // sum5 += (a01-a07) * k51
3981                         "vmlal.s16  q14, d10, d3[2]     \n" // sum6 += (a01-a07) * k61
3982                         "vmlal.s16  q15, d10, d3[3]     \n" // sum7 += (a01-a07) * k71
3983 
3984                         "pld        [%10, #64]         \n"
3985                         "vld2.s8    {d8-d9}, [%10]      \n" // d8(a10 a12 a14 a16 a18 a110 a112 a114), d9(a11 a13 a15 a17 a19 a111 a113 a115)
3986                         "add        %10, #8             \n"
3987 
3988                         "vmlal.s16  q8, d12, d4[0]      \n" // sum0 += (a02-a08) * k02
3989                         "vmlal.s16  q9, d12, d4[1]      \n" // sum1 += (a02-a08) * k12
3990                         "vmlal.s16  q10, d12, d4[2]     \n" // sum2 += (a02-a08) * k22
3991                         "vmlal.s16  q11, d12, d4[3]     \n" // sum3 += (a02-a08) * k32
3992 
3993                         "pld        [%12, #64]         \n"
3994                         "vld1.s8    {d0-d2}, [%12]!     \n" // d0(k03-k73) d1(k04-k74) d2(k05-k75)
3995 
3996                         "vmlal.s16  q12, d12, d5[0]     \n" // sum4 += (a02-a08) * k42
3997                         "vmlal.s16  q13, d12, d5[1]     \n" // sum5 += (a02-a08) * k52
3998                         "vmlal.s16  q14, d12, d5[2]     \n" // sum6 += (a02-a08) * k62
3999                         "vmlal.s16  q15, d12, d5[3]     \n" // sum7 += (a02-a08) * k72
4000 
4001                         // r1
4002                         "vext.s8    d12, d8, d8, #1     \n" // d12(a12 a14 a16 a18 x x x x)
4003 
4004                         "vmovl.s8   q2, d2              \n" // q2(k05-k75)
4005                         "vmovl.s8   q1, d1              \n" // q1(k04-k74)
4006                         "vmovl.s8   q0, d0              \n" // q0(k03-k73)
4007                         "vmovl.s8   q5, d9              \n" // q5(a11-a115)
4008                         "vmovl.s8   q4, d8              \n" // q4(a10-a114)
4009                         "vmovl.s8   q6, d12             \n" // q6(a12-a116)
4010 
4011                         "vmlal.s16  q8, d8, d0[0]       \n" // sum0 += (a10-a16) * k03
4012                         "vmlal.s16  q9, d8, d0[1]       \n" // sum1 += (a10-a16) * k13
4013                         "vmlal.s16  q10, d8, d0[2]      \n" // sum2 += (a10-a16) * k23
4014                         "vmlal.s16  q11, d8, d0[3]      \n" // sum3 += (a10-a16) * k33
4015                         "vmlal.s16  q12, d8, d1[0]      \n" // sum4 += (a10-a16) * k43
4016                         "vmlal.s16  q13, d8, d1[1]      \n" // sum5 += (a10-a16) * k53
4017                         "vmlal.s16  q14, d8, d1[2]      \n" // sum6 += (a10-a16) * k63
4018                         "vmlal.s16  q15, d8, d1[3]      \n" // sum7 += (a10-a16) * k73
4019 
4020                         "vmlal.s16  q8, d10, d2[0]      \n" // sum0 += (a11-a17) * k04
4021                         "vmlal.s16  q9, d10, d2[1]      \n" // sum1 += (a11-a17) * k14
4022                         "vmlal.s16  q10, d10, d2[2]     \n" // sum2 += (a11-a17) * k24
4023                         "vmlal.s16  q11, d10, d2[3]     \n" // sum3 += (a11-a17) * k34
4024                         "vmlal.s16  q12, d10, d3[0]     \n" // sum4 += (a11-a17) * k44
4025                         "vmlal.s16  q13, d10, d3[1]     \n" // sum5 += (a11-a17) * k54
4026                         "vmlal.s16  q14, d10, d3[2]     \n" // sum6 += (a11-a17) * k64
4027                         "vmlal.s16  q15, d10, d3[3]     \n" // sum7 += (a11-a17) * k74
4028 
4029                         "pld        [%11, #64]         \n"
4030                         "vld2.s8    {d8-d9}, [%11]      \n" // d8(a20 a22 a24 a26 a28 a210 a212 a214), d9(a21 a23 a25 a27 a29 a211 a213 a215)
4031                         "add        %11, #8             \n"
4032 
4033                         "vmlal.s16  q8, d12, d4[0]      \n" // sum0 += (a12-a18) * k05
4034                         "vmlal.s16  q9, d12, d4[1]      \n" // sum1 += (a12-a18) * k15
4035                         "vmlal.s16  q10, d12, d4[2]     \n" // sum2 += (a12-a18) * k25
4036                         "vmlal.s16  q11, d12, d4[3]     \n" // sum3 += (a12-a18) * k35
4037 
4038                         "pld        [%12, #64]         \n"
4039                         "vld1.s8    {d0-d2}, [%12]!     \n" // d0(k06-k76) d1(k07-k77) d2(k08-k78)
4040 
4041                         "vmlal.s16  q12, d12, d5[0]     \n" // sum4 += (a12-a18) * k45
4042                         "vmlal.s16  q13, d12, d5[1]     \n" // sum5 += (a12-a18) * k55
4043                         "vmlal.s16  q14, d12, d5[2]     \n" // sum6 += (a12-a18) * k65
4044                         "vmlal.s16  q15, d12, d5[3]     \n" // sum7 += (a12-a18) * k75
4045 
4046                         // r2
4047                         "vext.s8    d12, d8, d8, #1     \n" // d12(a22 a24 a26 a28 x x x x)
4048 
4049                         "vmovl.s8   q2, d2              \n" // q2(k08-k78)
4050                         "vmovl.s8   q1, d1              \n" // q1(k07-k77)
4051                         "vmovl.s8   q0, d0              \n" // q0(k06-k76)
4052                         "vmovl.s8   q5, d9              \n" // q5(a21-a215)
4053                         "vmovl.s8   q4, d8              \n" // q4(a20-a214)
4054                         "vmovl.s8   q6, d12             \n" // q6(a22-a216)
4055 
4056                         "vmlal.s16  q8, d8, d0[0]       \n" // sum0 += (a20-a26) * k06
4057                         "vmlal.s16  q9, d8, d0[1]       \n" // sum1 += (a20-a26) * k16
4058                         "vmlal.s16  q10, d8, d0[2]      \n" // sum2 += (a20-a26) * k26
4059                         "vmlal.s16  q11, d8, d0[3]      \n" // sum3 += (a20-a26) * k36
4060                         "vmlal.s16  q12, d8, d1[0]      \n" // sum4 += (a20-a26) * k46
4061                         "vmlal.s16  q13, d8, d1[1]      \n" // sum5 += (a20-a26) * k56
4062                         "vmlal.s16  q14, d8, d1[2]      \n" // sum6 += (a20-a26) * k66
4063                         "vmlal.s16  q15, d8, d1[3]      \n" // sum7 += (a20-a26) * k76
4064 
4065                         "vmlal.s16  q8, d10, d2[0]      \n" // sum0 += (a21-a27) * k07
4066                         "vmlal.s16  q9, d10, d2[1]      \n" // sum1 += (a21-a27) * k17
4067                         "vmlal.s16  q10, d10, d2[2]     \n" // sum2 += (a21-a27) * k27
4068                         "vmlal.s16  q11, d10, d2[3]     \n" // sum3 += (a21-a27) * k37
4069                         "vmlal.s16  q12, d10, d3[0]     \n" // sum4 += (a21-a27) * k47
4070                         "vmlal.s16  q13, d10, d3[1]     \n" // sum5 += (a21-a27) * k57
4071                         "vmlal.s16  q14, d10, d3[2]     \n" // sum6 += (a21-a27) * k67
4072                         "vmlal.s16  q15, d10, d3[3]     \n" // sum7 += (a21-a27) * k77
4073 
4074                         "vmlal.s16  q8, d12, d4[0]      \n" // sum0 += (a22-a28) * k08
4075                         "vmlal.s16  q9, d12, d4[1]      \n" // sum1 += (a22-a28) * k18
4076                         "vmlal.s16  q10, d12, d4[2]     \n" // sum2 += (a22-a28) * k28
4077                         "vmlal.s16  q11, d12, d4[3]     \n" // sum3 += (a22-a28) * k38
4078                         "vmlal.s16  q12, d12, d5[0]     \n" // sum4 += (a22-a28) * k48
4079                         "vmlal.s16  q13, d12, d5[1]     \n" // sum5 += (a22-a28) * k58
4080                         "vmlal.s16  q14, d12, d5[2]     \n" // sum6 += (a22-a28) * k68
4081                         "vmlal.s16  q15, d12, d5[3]     \n" // sum7 += (a22-a28) * k78
4082 
4083                         // save s32 to memory
4084                         "sub        %12, %12, #72       \n"
4085                         "vst1.s32   {d16-d17}, [%1]!    \n" // out0
4086                         "vst1.s32   {d18-d19}, [%2]!    \n" // out1
4087                         "vst1.s32   {d20-d21}, [%3]!    \n" // out2
4088                         "vst1.s32   {d22-d23}, [%4]!    \n" // out3
4089                         "subs       %0, #1              \n"
4090                         "vst1.s32   {d24-d25}, [%5]!    \n" // out4
4091                         "vst1.s32   {d26-d27}, [%6]!    \n" // out5
4092                         "vst1.s32   {d28-d29}, [%7]!    \n" // out6
4093                         "vst1.s32   {d30-d31}, [%8]!    \n" // out7
4094 
4095                         "bne        0b                  \n"
4096                         : "=r"(nn),      // %0
4097                         "=r"(outptr0), // %1
4098                         "=r"(outptr1), // %2
4099                         "=r"(outptr2), // %3
4100                         "=r"(outptr3), // %4
4101                         "=r"(outptr4), // %5
4102                         "=r"(outptr5), // %6
4103                         "=r"(outptr6), // %7
4104                         "=r"(outptr7), // %8
4105                         "=r"(r0),      // %9
4106                         "=r"(r1),      // %10
4107                         "=r"(r2),      // %11
4108                         "=r"(ktmp)     // %12
4109                         : "0"(nn),
4110                         "1"(outptr0),
4111                         "2"(outptr1),
4112                         "3"(outptr2),
4113                         "4"(outptr3),
4114                         "5"(outptr4),
4115                         "6"(outptr5),
4116                         "7"(outptr6),
4117                         "8"(outptr7),
4118                         "9"(r0),
4119                         "10"(r1),
4120                         "11"(r2),
4121                         "12"(ktmp)
4122                         : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
4123                 }
4124 #endif // __aarch64__
4125 #endif // __ARM_NEON
4126                 for (; remain > 0; remain--)
4127                 {
4128 #if __ARM_NEON
4129 #if __aarch64__
4130                     int8x8_t _r0_s8 = vld1_s8(r0); // (a00 a01 a02 ....)
4131                     int8x8_t _r1_s8 = vld1_s8(r1); // (a10 a11 a12 ....)
4132                     int8x8_t _r2_s8 = vld1_s8(r2); // (a20 a21 a22 ....)
4133 
4134                     int16x8_t _r0 = vmovl_s8(_r0_s8);
4135                     int16x8_t _r1 = vmovl_s8(_r1_s8);
4136                     int16x8_t _r2 = vmovl_s8(_r2_s8);
4137 
4138                     int32x4_t _sum03 = {};
4139                     int32x4_t _sum47 = {};
4140 
4141                     _sum03 = vld1q_lane_s32(outptr0, _sum03, 0); // out0
4142                     _sum03 = vld1q_lane_s32(outptr1, _sum03, 1); // out1
4143                     _sum03 = vld1q_lane_s32(outptr2, _sum03, 2); // out2
4144                     _sum03 = vld1q_lane_s32(outptr3, _sum03, 3); // out3
4145                     _sum47 = vld1q_lane_s32(outptr4, _sum47, 0); // out4
4146                     _sum47 = vld1q_lane_s32(outptr5, _sum47, 1); // out5
4147                     _sum47 = vld1q_lane_s32(outptr6, _sum47, 2); // out6
4148                     _sum47 = vld1q_lane_s32(outptr7, _sum47, 3); // out7
4149 
4150                     // k0 - k2
4151                     int8x8_t _k0_8 = vld1_s8(ktmp);      //(k00-k70)
4152                     int8x8_t _k1_8 = vld1_s8(ktmp + 8);  //(k01-k71)
4153                     int8x8_t _k2_8 = vld1_s8(ktmp + 16); //(k02-k72)
4154 
4155                     int16x8_t _k0 = vmovl_s8(_k0_8);
4156                     int16x8_t _k1 = vmovl_s8(_k1_8);
4157                     int16x8_t _k2 = vmovl_s8(_k2_8);
4158 
4159                     int32x4_t _sum0 = vmull_laneq_s16(vget_low_s16(_k0), _r0, 0);
4160                     int32x4_t _sum0n = vmull_laneq_s16(vget_high_s16(_k0), _r0, 0);
4161                     int32x4_t _sum1 = vmull_laneq_s16(vget_low_s16(_k1), _r0, 1);
4162                     int32x4_t _sum1n = vmull_laneq_s16(vget_high_s16(_k1), _r0, 1);
4163                     _sum03 = vmlal_laneq_s16(_sum03, vget_low_s16(_k2), _r0, 2);
4164                     _sum47 = vmlal_laneq_s16(_sum47, vget_high_s16(_k2), _r0, 2);
4165 
4166                     // k3 - k5
4167                     _k0_8 = vld1_s8(ktmp + 24); //(k03-k73)
4168                     _k1_8 = vld1_s8(ktmp + 32); //(k04-k74)
4169                     _k2_8 = vld1_s8(ktmp + 40); //(k05-k75)
4170 
4171                     _k0 = vmovl_s8(_k0_8);
4172                     _k1 = vmovl_s8(_k1_8);
4173                     _k2 = vmovl_s8(_k2_8);
4174 
4175                     _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_k0), _r1, 0);
4176                     _sum0n = vmlal_laneq_s16(_sum0n, vget_high_s16(_k0), _r1, 0);
4177                     _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_k1), _r1, 1);
4178                     _sum1n = vmlal_laneq_s16(_sum1n, vget_high_s16(_k1), _r1, 1);
4179                     _sum03 = vmlal_laneq_s16(_sum03, vget_low_s16(_k2), _r1, 2);
4180                     _sum47 = vmlal_laneq_s16(_sum47, vget_high_s16(_k2), _r1, 2);
4181 
4182                     // k6 - k8
4183                     _k0_8 = vld1_s8(ktmp + 48); //(k06-k76)
4184                     _k1_8 = vld1_s8(ktmp + 56); //(k07-k77)
4185                     _k2_8 = vld1_s8(ktmp + 64); //(k08-k78)
4186 
4187                     _k0 = vmovl_s8(_k0_8);
4188                     _k1 = vmovl_s8(_k1_8);
4189                     _k2 = vmovl_s8(_k2_8);
4190 
4191                     _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_k0), _r2, 0);
4192                     _sum0n = vmlal_laneq_s16(_sum0n, vget_high_s16(_k0), _r2, 0);
4193                     _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_k1), _r2, 1);
4194                     _sum1n = vmlal_laneq_s16(_sum1n, vget_high_s16(_k1), _r2, 1);
4195                     _sum03 = vmlal_laneq_s16(_sum03, vget_low_s16(_k2), _r2, 2);
4196                     _sum47 = vmlal_laneq_s16(_sum47, vget_high_s16(_k2), _r2, 2);
4197 
4198                     _sum0 = vaddq_s32(_sum0, _sum1);
4199                     _sum0n = vaddq_s32(_sum0n, _sum1n);
4200                     _sum03 = vaddq_s32(_sum03, _sum0);
4201                     _sum47 = vaddq_s32(_sum47, _sum0n);
4202 
4203                     vst1q_lane_s32(outptr0, _sum03, 0);
4204                     vst1q_lane_s32(outptr1, _sum03, 1);
4205                     vst1q_lane_s32(outptr2, _sum03, 2);
4206                     vst1q_lane_s32(outptr3, _sum03, 3);
4207                     vst1q_lane_s32(outptr4, _sum47, 0);
4208                     vst1q_lane_s32(outptr5, _sum47, 1);
4209                     vst1q_lane_s32(outptr6, _sum47, 2);
4210                     vst1q_lane_s32(outptr7, _sum47, 3);
4211 
4212                     outptr0++;
4213                     outptr1++;
4214                     outptr2++;
4215                     outptr3++;
4216                     outptr4++;
4217                     outptr5++;
4218                     outptr6++;
4219                     outptr7++;
4220 #else  // __aarch64__
4221                     asm volatile(
4222                         "pld        [%8, #64]          \n"
4223                         "vld1.s8    {d0}, [%8]         \n" // d0(a00 a01 a02 ....)
4224                         "pld        [%9, #64]          \n"
4225                         "vld1.s8    {d2}, [%9]         \n" // d2(a10 a11 a12 ....)
4226                         "pld        [%10, #64]         \n"
4227                         "vld1.s8    {d4}, [%10]        \n" // d4(a20 a21 a22 ....)
4228 
4229                         "pld        [%11, #64]         \n"
4230                         "vld1.s8    {d6-d8}, [%11]!    \n" // d6(k00-k70) d7(k01-k71) d8(k02-k72)
4231 
4232                         "vmovl.s8   q0, d0             \n" // d0(a00 a01 a02 x)
4233                         "vmovl.s8   q1, d2             \n" // d2(a10 a11 a12 x)
4234                         "vmovl.s8   q2, d4             \n" // d4(a20 a21 a22 x)
4235 
4236                         "vmovl.s8   q5, d8             \n" // d10(k02-k32) d11(k42-k72)
4237                         "vmovl.s8   q4, d7             \n" // d8(k01-k31) d9(k41-k71)
4238                         "vmovl.s8   q3, d6             \n" // d6(k00-k30) d7(k40-k70)
4239 
4240                         "vld1.s32   {d20[0]}, [%0]     \n" // out0 q10
4241                         "vld1.s32   {d20[1]}, [%1]     \n" // out1
4242                         "vld1.s32   {d21[0]}, [%2]     \n" // out2
4243                         "vld1.s32   {d21[1]}, [%3]     \n" // out3
4244 
4245                         "pld        [%11, #64]         \n"
4246                         "vld1.s8    {d24-d26}, [%11]!  \n"
4247                         "vmovl.s8   q14, d26           \n" // d28(k05-k35) d29(k45-k75)
4248                         "vmovl.s8   q13, d25           \n" // d26(k04-k34) d27(k44-k74)
4249                         "vmovl.s8   q12, d24           \n" // d24(k03-k33) d25(k43-k73)
4250 
4251                         "vld1.s32   {d22[0]}, [%4]     \n" // out4 q11
4252                         "vld1.s32   {d22[1]}, [%5]     \n" // out5
4253                         "vld1.s32   {d23[0]}, [%6]     \n" // out6
4254                         "vld1.s32   {d23[1]}, [%7]     \n" // out7
4255 
4256                         "vmull.s16  q6, d6, d0[0]      \n" // a00 x (k00-k30)
4257                         "vmull.s16  q7, d7, d0[0]      \n" // a00 x (k40-k70)
4258                         "vmull.s16  q8, d8, d0[1]      \n" // a01 x (k01-k31)
4259                         "vmull.s16  q9, d9, d0[1]      \n" // a01 x (k41-k71)
4260                         "vmlal.s16  q10, d10, d0[2]    \n" // a02 x (k02-k32)
4261                         "vmlal.s16  q11, d11, d0[2]    \n" // a02 x (k42-k72)
4262 
4263                         "pld        [%11, #64]         \n"
4264                         "vld1.s8    {d6-d8}, [%11]!    \n"
4265                         "vmovl.s8   q5, d8             \n" // d10(k08-k38) d11(k48-k78)
4266                         "vmovl.s8   q4, d7             \n" // d8(k07-k37) d9(k47-k77)
4267                         "vmovl.s8   q3, d6             \n" // d6(k06-k36) d7(k46-k76)
4268 
4269                         "vmlal.s16  q6, d24, d2[0]     \n" // a10 x (k03-k33)
4270                         "vmlal.s16  q7, d25, d2[0]     \n" // a10 x (k43-k73)
4271                         "vmlal.s16  q8, d26, d2[1]     \n" // a11 x (k04-k34)
4272                         "vmlal.s16  q9, d27, d2[1]     \n" // a11 x (k44-k74)
4273                         "vmlal.s16  q10, d28, d2[2]    \n" // a12 x (k05-k35)
4274                         "vmlal.s16  q11, d29, d2[2]    \n" // a12 x (k45-k75)
4275 
4276                         "vmlal.s16  q6, d6, d4[0]      \n" // a20 x (k06-k36)
4277                         "vmlal.s16  q7, d7, d4[0]      \n" // a20 x (k46-k76)
4278                         "vmlal.s16  q8, d8, d4[1]      \n" // a21 x (k07-k37)
4279                         "vmlal.s16  q9, d9, d4[1]      \n" // a21 x (k47-k77)
4280                         "vmlal.s16  q10, d10, d4[2]    \n" // a22 x (k08-k38)
4281                         "vmlal.s16  q11, d11, d4[2]    \n" // a22 x (k48-k78)
4282 
4283                         "vadd.s32   q8, q8, q6         \n"
4284                         "vadd.s32   q9, q9, q7         \n"
4285 
4286                         "sub        %11, %11, #72      \n"
4287 
4288                         "vadd.s32   q10, q10, q8       \n"
4289                         "vadd.s32   q11, q11, q9       \n"
4290 
4291                         "vst1.s32   {d20[0]}, [%0]!    \n" // out0
4292                         "vst1.s32   {d20[1]}, [%1]!    \n" // out1
4293                         "vst1.s32   {d21[0]}, [%2]!    \n" // out2
4294                         "vst1.s32   {d21[1]}, [%3]!    \n" // out3
4295                         "vst1.s32   {d22[0]}, [%4]!    \n" // out4
4296                         "vst1.s32   {d22[1]}, [%5]!    \n" // out5
4297                         "vst1.s32   {d23[0]}, [%6]!    \n" // out6
4298                         "vst1.s32   {d23[1]}, [%7]!    \n" // out7
4299 
4300                         : "=r"(outptr0), // %0
4301                         "=r"(outptr1), // %1
4302                         "=r"(outptr2), // %2
4303                         "=r"(outptr3), // %3
4304                         "=r"(outptr4), // %4
4305                         "=r"(outptr5), // %5
4306                         "=r"(outptr6), // %6
4307                         "=r"(outptr7), // %7
4308                         "=r"(r0),      // %8
4309                         "=r"(r1),      // %9
4310                         "=r"(r2),      // %10
4311                         "=r"(ktmp)     // %11
4312                         : "0"(outptr0),
4313                         "1"(outptr1),
4314                         "2"(outptr2),
4315                         "3"(outptr3),
4316                         "4"(outptr4),
4317                         "5"(outptr5),
4318                         "6"(outptr6),
4319                         "7"(outptr7),
4320                         "8"(r0),
4321                         "9"(r1),
4322                         "10"(r2),
4323                         "11"(ktmp)
4324                         : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
4325 #endif // __aarch64__
4326 #else  // __ARM_NEON
4327                     int sum0 = 0;
4328                     int sum1 = 0;
4329                     int sum2 = 0;
4330                     int sum3 = 0;
4331                     int sum4 = 0;
4332                     int sum5 = 0;
4333                     int sum6 = 0;
4334                     int sum7 = 0;
4335 
4336                     sum0 += (int)r0[0] * ktmp[0];
4337                     sum1 += (int)r0[0] * ktmp[1];
4338                     sum2 += (int)r0[0] * ktmp[2];
4339                     sum3 += (int)r0[0] * ktmp[3];
4340                     sum4 += (int)r0[0] * ktmp[4];
4341                     sum5 += (int)r0[0] * ktmp[5];
4342                     sum6 += (int)r0[0] * ktmp[6];
4343                     sum7 += (int)r0[0] * ktmp[7];
4344                     ktmp += 8;
4345 
4346                     sum0 += (int)r0[1] * ktmp[0];
4347                     sum1 += (int)r0[1] * ktmp[1];
4348                     sum2 += (int)r0[1] * ktmp[2];
4349                     sum3 += (int)r0[1] * ktmp[3];
4350                     sum4 += (int)r0[1] * ktmp[4];
4351                     sum5 += (int)r0[1] * ktmp[5];
4352                     sum6 += (int)r0[1] * ktmp[6];
4353                     sum7 += (int)r0[1] * ktmp[7];
4354                     ktmp += 8;
4355 
4356                     sum0 += (int)r0[2] * ktmp[0];
4357                     sum1 += (int)r0[2] * ktmp[1];
4358                     sum2 += (int)r0[2] * ktmp[2];
4359                     sum3 += (int)r0[2] * ktmp[3];
4360                     sum4 += (int)r0[2] * ktmp[4];
4361                     sum5 += (int)r0[2] * ktmp[5];
4362                     sum6 += (int)r0[2] * ktmp[6];
4363                     sum7 += (int)r0[2] * ktmp[7];
4364                     ktmp += 8;
4365 
4366                     sum0 += (int)r1[0] * ktmp[0];
4367                     sum1 += (int)r1[0] * ktmp[1];
4368                     sum2 += (int)r1[0] * ktmp[2];
4369                     sum3 += (int)r1[0] * ktmp[3];
4370                     sum4 += (int)r1[0] * ktmp[4];
4371                     sum5 += (int)r1[0] * ktmp[5];
4372                     sum6 += (int)r1[0] * ktmp[6];
4373                     sum7 += (int)r1[0] * ktmp[7];
4374                     ktmp += 8;
4375 
4376                     sum0 += (int)r1[1] * ktmp[0];
4377                     sum1 += (int)r1[1] * ktmp[1];
4378                     sum2 += (int)r1[1] * ktmp[2];
4379                     sum3 += (int)r1[1] * ktmp[3];
4380                     sum4 += (int)r1[1] * ktmp[4];
4381                     sum5 += (int)r1[1] * ktmp[5];
4382                     sum6 += (int)r1[1] * ktmp[6];
4383                     sum7 += (int)r1[1] * ktmp[7];
4384                     ktmp += 8;
4385 
4386                     sum0 += (int)r1[2] * ktmp[0];
4387                     sum1 += (int)r1[2] * ktmp[1];
4388                     sum2 += (int)r1[2] * ktmp[2];
4389                     sum3 += (int)r1[2] * ktmp[3];
4390                     sum4 += (int)r1[2] * ktmp[4];
4391                     sum5 += (int)r1[2] * ktmp[5];
4392                     sum6 += (int)r1[2] * ktmp[6];
4393                     sum7 += (int)r1[2] * ktmp[7];
4394                     ktmp += 8;
4395 
4396                     sum0 += (int)r2[0] * ktmp[0];
4397                     sum1 += (int)r2[0] * ktmp[1];
4398                     sum2 += (int)r2[0] * ktmp[2];
4399                     sum3 += (int)r2[0] * ktmp[3];
4400                     sum4 += (int)r2[0] * ktmp[4];
4401                     sum5 += (int)r2[0] * ktmp[5];
4402                     sum6 += (int)r2[0] * ktmp[6];
4403                     sum7 += (int)r2[0] * ktmp[7];
4404                     ktmp += 8;
4405 
4406                     sum0 += (int)r2[1] * ktmp[0];
4407                     sum1 += (int)r2[1] * ktmp[1];
4408                     sum2 += (int)r2[1] * ktmp[2];
4409                     sum3 += (int)r2[1] * ktmp[3];
4410                     sum4 += (int)r2[1] * ktmp[4];
4411                     sum5 += (int)r2[1] * ktmp[5];
4412                     sum6 += (int)r2[1] * ktmp[6];
4413                     sum7 += (int)r2[1] * ktmp[7];
4414                     ktmp += 8;
4415 
4416                     sum0 += (int)r2[2] * ktmp[0];
4417                     sum1 += (int)r2[2] * ktmp[1];
4418                     sum2 += (int)r2[2] * ktmp[2];
4419                     sum3 += (int)r2[2] * ktmp[3];
4420                     sum4 += (int)r2[2] * ktmp[4];
4421                     sum5 += (int)r2[2] * ktmp[5];
4422                     sum6 += (int)r2[2] * ktmp[6];
4423                     sum7 += (int)r2[2] * ktmp[7];
4424                     ktmp += 8;
4425 
4426                     *outptr0 += sum0;
4427                     *outptr1 += sum1;
4428                     *outptr2 += sum2;
4429                     *outptr3 += sum3;
4430                     *outptr4 += sum4;
4431                     *outptr5 += sum5;
4432                     *outptr6 += sum6;
4433                     *outptr7 += sum7;
4434 
4435                     ktmp -= 8 * 9;
4436 
4437                     outptr0++;
4438                     outptr1++;
4439                     outptr2++;
4440                     outptr3++;
4441                     outptr4++;
4442                     outptr5++;
4443                     outptr6++;
4444                     outptr7++;
4445 #endif // __ARM_NEON
4446                     r0 += 2;
4447                     r1 += 2;
4448                     r2 += 2;
4449                 }
4450 
4451                 r0 += tailstep;
4452                 r1 += tailstep;
4453                 r2 += tailstep;
4454             }
4455 
4456             ktmp += 8 * 9;
4457         }
4458     }
4459 
4460     #pragma omp parallel for num_threads(opt.num_threads)
4461     for (int p = remain_outch_start; p < outch; p++)
4462     {
4463         Mat out = top_blob.channel(p);
4464 
4465         out.fill(0);
4466 
4467         const signed char* ktmp = _kernel.channel(p / 8 + p % 8);
4468 
4469         for (int q = 0; q < inch; q++)
4470         {
4471             int* outptr = out;
4472 
4473             const signed char* img0 = bottom_blob.channel(q);
4474 
4475             const signed char* r0 = img0;
4476             const signed char* r1 = img0 + w;
4477             const signed char* r2 = img0 + w * 2;
4478 
4479             int i = 0;
4480 
4481             for (; i < outh; i++)
4482             {
4483 #if __ARM_NEON
4484                 int nn = outw >> 3;
4485                 int remain = outw & 7;
4486 #else
4487                 int remain = outw;
4488 #endif // __ARM_NEON
4489 
4490 #if __ARM_NEON
4491 #if __aarch64__
4492                 if (nn > 0)
4493                 {
4494                     asm volatile(
4495                         "0:                                   \n"
4496 
4497                         "ld1    {v0.8b, v1.8b}, [%5]          \n" //ktmp
4498                         "ld2    {v2.8b, v3.8b}, [%2], #16     \n" //r0-r2
4499                         "ld2    {v4.8b, v5.8b}, [%2]          \n"
4500 
4501                         "ld2    {v6.8b, v7.8b}, [%3], #16     \n" //r3-r5
4502                         "ld2    {v8.8b, v9.8b}, [%3]          \n"
4503 
4504                         "ld2    {v10.8b, v11.8b}, [%4], #16   \n" //r6-r8
4505                         "ld2    {v12.8b, v13.8b}, [%4]        \n"
4506 
4507                         "ld1    {v14.4s, v15.4s}, [%1]        \n" //out0
4508 
4509                         "ext    v4.8b, v2.8b, v4.8b, #1       \n"
4510                         "ext    v8.8b, v6.8b, v8.8b, #1       \n"
4511                         "ext    v12.8b, v10.8b, v12.8b, #1    \n"
4512 
4513                         "sshll  v0.8h, v0.8b, #0              \n" //(k0-k7)
4514                         "sshll  v1.8h, v1.8b, #0              \n" //(k8)
4515                         "sshll  v2.8h, v2.8b, #0              \n" // r0
4516                         "sshll  v3.8h, v3.8b, #0              \n" // r1
4517                         "sshll  v4.8h, v4.8b, #0              \n" // r2
4518                         "sshll  v6.8h, v6.8b, #0              \n" // r3
4519                         "sshll  v7.8h, v7.8b, #0              \n" // r4
4520                         "sshll  v8.8h, v8.8b, #0              \n" // r5
4521                         "sshll  v10.8h, v10.8b, #0            \n" // r6
4522                         "sshll  v11.8h, v11.8b, #0            \n" // r7
4523                         "sshll  v12.8h, v12.8b, #0            \n" // r8
4524 
4525                         // r0
4526                         "smull  v16.4s, v2.4h, v0.h[0]        \n" // out = r0*k0
4527                         "smull2  v17.4s, v2.8h, v0.h[0]       \n"
4528                         "smull  v18.4s, v3.4h, v0.h[1]        \n" // outn = r1*k1
4529                         "smull2  v19.4s, v3.8h, v0.h[1]       \n"
4530                         "smlal  v16.4s, v4.4h, v0.h[2]        \n" // out = r2*k2
4531                         "smlal2  v17.4s, v4.8h, v0.h[2]       \n"
4532                         "smlal  v18.4s, v6.4h, v0.h[3]        \n" // outn = r3*k3
4533                         "smlal2  v19.4s, v6.8h, v0.h[3]       \n"
4534                         "smlal  v16.4s, v7.4h, v0.h[4]        \n" // out = r4*k4
4535                         "smlal2  v17.4s, v7.8h, v0.h[4]       \n"
4536                         "smlal  v18.4s, v8.4h, v0.h[5]        \n" // outn = r5*k5
4537                         "smlal2  v19.4s, v8.8h, v0.h[5]       \n"
4538                         "smlal  v16.4s, v10.4h, v0.h[6]       \n" // out = r6*k6
4539                         "smlal2  v17.4s, v10.8h, v0.h[6]      \n"
4540                         "smlal  v18.4s, v11.4h, v0.h[7]       \n" // outn = r7*k7
4541                         "smlal2  v19.4s, v11.8h, v0.h[7]      \n"
4542                         "smlal  v16.4s, v12.4h, v1.h[0]       \n" // out = r8*k8
4543                         "smlal2  v17.4s, v12.8h, v1.h[0]      \n"
4544 
4545                         "add    v8.4s, v16.4s, v18.4s         \n"
4546                         "add    v9.4s, v17.4s, v19.4s         \n"
4547 
4548                         "st1    {v8.4s, v9.4s}, [%1], #32     \n"
4549 
4550                         "subs   %w0, %w0, #1                  \n"
4551 
4552                         "bne    0b                            \n"
4553 
4554                         : "=r"(nn),     // %0
4555                         "=r"(outptr), // %1
4556                         "=r"(r0),     // %2
4557                         "=r"(r1),     // %3
4558                         "=r"(r2),     // %4
4559                         "=r"(ktmp)    // %5
4560                         : "0"(nn),
4561                         "1"(outptr),
4562                         "2"(r0),
4563                         "3"(r1),
4564                         "4"(r2),
4565                         "5"(ktmp)
4566                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
4567                 }
4568 #else
4569                 if (nn > 0)
4570                 {
4571                     asm volatile(
4572                         "vld1.s8    {d0-d1}, [%5]       \n" // d0(k0 - k7) d1(k8 ...)
4573                         "vmovl.s8   q1, d1              \n" // d2(k8 ...)
4574                         "vmovl.s8   q0, d0              \n" // d0(k0 - k3) d1(k4 - k7)
4575                         "0:                             \n"
4576                         "pld        [%2, #192]          \n"
4577                         "vld2.s8    {d4-d5}, [%2]!      \n" // r0 d4(a00 a02 ... a014) d5(a01 a03 ... a015)
4578                         "vld2.s8    {d8-d9}, [%2]       \n" //    d8(a016 ....)
4579                         "vld2.s8    {d10-d11}, [%3]!    \n" // r1 d10(a10 a12 ... a114) d11(a11 a13 ... a115)
4580                         "vld2.s8    {d14-d15}, [%3]     \n" //    d14(a116 ....)
4581                         "vld2.s8    {d16-d17}, [%4]!    \n" // r2 d16(a20 a22 ... a214) d17(a21 a23 ... a215)
4582                         "vld2.s8    {d20-d21}, [%4]     \n" //    d20(a216 ....)
4583                         "vld1.s32   {d22-d25}, [%1]     \n" // q11(out0 - out3) q12(out4 - out7)
4584 
4585                         "vext.s8    d8, d4, d8, #1      \n" //  d8(a02 a04 ... a016)
4586                         "vext.s8    d14, d10, d14, #1   \n" // d14(a12 a14 ... a116)
4587                         "vext.s8    d20, d16, d20, #1   \n" // d20(a22 a24 ... a216)
4588 
4589                         "vmovl.s8   q3, d5              \n" // q3(a01 a03 ... a015)
4590                         "vmovl.s8   q2, d4              \n" // q2(a00 a02 ... a014)
4591                         "vmovl.s8   q4, d8              \n" // q4(a02 a04 ... a016)
4592 
4593                         "vmovl.s8   q6, d11             \n" // q6(a11 a13 ... a115)
4594                         "vmovl.s8   q5, d10             \n" // q5(a10 a12 ... a114)
4595                         "vmovl.s8   q7, d14             \n" // q7(a12 a14 ... a116)
4596 
4597                         "vmovl.s8   q9, d17             \n" // q9(a21 a23 ... a215)
4598                         "vmovl.s8   q8, d16             \n" // q8(a20 a22 ... a214)
4599                         "vmovl.s8   q10, d20            \n" // q10(a22 a24 ... a216)
4600 
4601                         "vmlal.s16  q11, d4, d0[0]      \n" // k0
4602                         "vmlal.s16  q12, d5, d0[0]      \n"
4603                         "vmull.s16  q13, d6, d0[1]      \n" // k1
4604                         "vmull.s16  q14, d7, d0[1]      \n"
4605                         "vmlal.s16  q11, d8, d0[2]      \n" // k2
4606                         "vmlal.s16  q12, d9, d0[2]      \n"
4607 
4608                         "vmlal.s16  q13, d12, d1[0]     \n" // k4
4609                         "vmlal.s16  q14, d13, d1[0]     \n"
4610                         "vmlal.s16  q11, d10, d0[3]     \n" // k3
4611                         "vmlal.s16  q12, d11, d0[3]     \n"
4612                         "vmlal.s16  q13, d14, d1[1]     \n" // k5
4613                         "vmlal.s16  q14, d15, d1[1]     \n"
4614 
4615                         "vmlal.s16  q11, d16, d1[2]     \n" // k6
4616                         "vmlal.s16  q12, d17, d1[2]     \n"
4617                         "vmlal.s16  q13, d18, d1[3]     \n" // k7
4618                         "vmlal.s16  q14, d19, d1[3]     \n"
4619                         "vmlal.s16  q11, d20, d2[0]     \n" // k8
4620                         "vmlal.s16  q12, d21, d2[0]     \n"
4621 
4622                         "vadd.s32   q11, q11, q13       \n"
4623                         "vadd.s32   q12, q12, q14       \n"
4624 
4625                         "vst1.32    {d22-d25}, [%1]!    \n"
4626 
4627                         "subs       %0, #1              \n"
4628                         "bne        0b                  \n"
4629                         : "=r"(nn),     // %0
4630                         "=r"(outptr), // %1
4631                         "=r"(r0),     // %2
4632                         "=r"(r1),     // %3
4633                         "=r"(r2),     // %4
4634                         "=r"(ktmp)    // %5
4635                         : "0"(nn),
4636                         "1"(outptr),
4637                         "2"(r0),
4638                         "3"(r1),
4639                         "4"(r2),
4640                         "5"(ktmp)
4641                         : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
4642                 }
4643 #endif // __aarch64__
4644 #endif // __ARM_NEON
4645                 if (remain > 0)
4646                 {
4647 #if __ARM_NEON
4648                     int8x8_t _k01234567s8 = vld1_s8(ktmp);
4649                     int8x8_t _k8xxxxxxxs8 = vld1_s8(ktmp + 8);
4650                     int8x8_t _k34567xxxs8 = vext_s8(_k01234567s8, _k01234567s8, 3);
4651                     int8x8_t _k678xxxxxs8 = vext_s8(_k01234567s8, _k8xxxxxxxs8, 6);
4652                     int16x8_t _k0123_s16 = vmovl_s8(_k01234567s8);
4653                     int16x8_t _k3456_s16 = vmovl_s8(_k34567xxxs8);
4654                     int16x8_t _k678x_s16 = vmovl_s8(_k678xxxxxs8);
4655 #endif
4656                     for (; remain > 0; remain--)
4657                     {
4658 #if __ARM_NEON
4659                         int8x8_t _r00s8 = vld1_s8(r0);
4660                         int8x8_t _r10s8 = vld1_s8(r1);
4661                         int8x8_t _r20s8 = vld1_s8(r2);
4662 
4663                         int16x8_t _r00s16 = vmovl_s8(_r00s8);
4664                         int16x8_t _r10s16 = vmovl_s8(_r10s8);
4665                         int16x8_t _r20s16 = vmovl_s8(_r20s8);
4666 
4667                         int32x4_t _sum = vmull_s16(vget_low_s16(_r00s16), vget_low_s16(_k0123_s16));
4668                         _sum = vmlal_s16(_sum, vget_low_s16(_r10s16), vget_low_s16(_k3456_s16));
4669                         _sum = vmlal_s16(_sum, vget_low_s16(_r20s16), vget_low_s16(_k678x_s16));
4670 
4671                         _sum = vsetq_lane_s32(*outptr, _sum, 3);
4672 
4673 #if __aarch64__
4674                         *outptr = vaddvq_s32(_sum);
4675 #else
4676                         int32x2_t _ss = vadd_s32(vget_low_s32(_sum), vget_high_s32(_sum));
4677                         _ss = vpadd_s32(_ss, _ss);
4678 
4679                         *outptr = vget_lane_s32(_ss, 0);
4680 #endif // __aarch64__
4681 #else
4682                         int sum = 0;
4683 
4684                         sum += (int)r0[0] * ktmp[0];
4685                         sum += (int)r0[1] * ktmp[1];
4686                         sum += (int)r0[2] * ktmp[2];
4687                         sum += (int)r1[0] * ktmp[3];
4688                         sum += (int)r1[1] * ktmp[4];
4689                         sum += (int)r1[2] * ktmp[5];
4690                         sum += (int)r2[0] * ktmp[6];
4691                         sum += (int)r2[1] * ktmp[7];
4692                         sum += (int)r2[2] * ktmp[8];
4693 
4694                         *outptr += sum;
4695 #endif // __ARM_NEON
4696                         r0 += 2;
4697                         r1 += 2;
4698                         r2 += 2;
4699                         outptr++;
4700                     }
4701                 }
4702 
4703                 r0 += tailstep;
4704                 r1 += tailstep;
4705                 r2 += tailstep;
4706             }
4707 
4708             ktmp += 9;
4709         }
4710     }
4711 }
4712