1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // author:BUG1989 (https://github.com/BUG1989/) Long-term support.
4 // author:FuGuangping (https://github.com/fu1899) Implemented the first version of INT8 quantization on ARMv7.
5 //
6 // Copyright (C) 2019 BUG1989. All rights reserved.
7 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
8 //
9 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
10 // in compliance with the License. You may obtain a copy of the License at
11 //
12 // https://opensource.org/licenses/BSD-3-Clause
13 //
14 // Unless required by applicable law or agreed to in writing, software distributed
15 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
16 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
17 // specific language governing permissions and limitations under the License.
18
conv3x3s1_winograd23_transform_kernel_int8_neon(const Mat & kernel,std::vector<Mat> & kernel_tm2,int inch,int outch)19 static void conv3x3s1_winograd23_transform_kernel_int8_neon(const Mat& kernel, std::vector<Mat>& kernel_tm2, int inch, int outch)
20 {
21 Mat kernel_tm(4 * 4, inch, outch, 2ul);
22
23 // G
24 const short ktm[4][3] = {
25 {2, 0, 0},
26 {1, 1, 1},
27 {1, -1, 1},
28 {0, 0, 2}
29 };
30
31 #pragma omp parallel for
32 for (int p = 0; p < outch; p++)
33 {
34 for (int q = 0; q < inch; q++)
35 {
36 const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9;
37 short* kernel_tm0 = kernel_tm.channel(p).row<short>(q);
38
39 // transform kernel
40 const signed char* k0 = kernel0;
41 const signed char* k1 = kernel0 + 3;
42 const signed char* k2 = kernel0 + 6;
43
44 // h
45 short tmp[4][3];
46 for (int i = 0; i < 4; i++)
47 {
48 tmp[i][0] = (short)k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
49 tmp[i][1] = (short)k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
50 tmp[i][2] = (short)k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
51 }
52
53 // U
54 for (int j = 0; j < 4; j++)
55 {
56 short* tmpp = &tmp[j][0];
57
58 for (int i = 0; i < 4; i++)
59 {
60 kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
61 }
62 }
63 }
64 }
65
66 for (int r = 0; r < 4; r++)
67 {
68 Mat kernel_tm_test(4 * 8, inch, outch / 8 + (outch % 8) / 4 + outch % 4, 2u);
69
70 int p = 0;
71 for (; p + 7 < outch; p += 8)
72 {
73 const short* kernel0 = (const short*)kernel_tm + (p + 0) * inch * 16;
74 const short* kernel1 = (const short*)kernel_tm + (p + 1) * inch * 16;
75 const short* kernel2 = (const short*)kernel_tm + (p + 2) * inch * 16;
76 const short* kernel3 = (const short*)kernel_tm + (p + 3) * inch * 16;
77 const short* kernel4 = (const short*)kernel_tm + (p + 4) * inch * 16;
78 const short* kernel5 = (const short*)kernel_tm + (p + 5) * inch * 16;
79 const short* kernel6 = (const short*)kernel_tm + (p + 6) * inch * 16;
80 const short* kernel7 = (const short*)kernel_tm + (p + 7) * inch * 16;
81
82 short* ktmp = kernel_tm_test.channel(p / 8);
83
84 for (int q = 0; q < inch; q++)
85 {
86 ktmp[0] = kernel0[r * 4 + 0];
87 ktmp[1] = kernel0[r * 4 + 1];
88 ktmp[2] = kernel0[r * 4 + 2];
89 ktmp[3] = kernel0[r * 4 + 3];
90
91 ktmp[4] = kernel1[r * 4 + 0];
92 ktmp[5] = kernel1[r * 4 + 1];
93 ktmp[6] = kernel1[r * 4 + 2];
94 ktmp[7] = kernel1[r * 4 + 3];
95
96 ktmp[8] = kernel2[r * 4 + 0];
97 ktmp[9] = kernel2[r * 4 + 1];
98 ktmp[10] = kernel2[r * 4 + 2];
99 ktmp[11] = kernel2[r * 4 + 3];
100
101 ktmp[12] = kernel3[r * 4 + 0];
102 ktmp[13] = kernel3[r * 4 + 1];
103 ktmp[14] = kernel3[r * 4 + 2];
104 ktmp[15] = kernel3[r * 4 + 3];
105
106 ktmp[16] = kernel4[r * 4 + 0];
107 ktmp[17] = kernel4[r * 4 + 1];
108 ktmp[18] = kernel4[r * 4 + 2];
109 ktmp[19] = kernel4[r * 4 + 3];
110
111 ktmp[20] = kernel5[r * 4 + 0];
112 ktmp[21] = kernel5[r * 4 + 1];
113 ktmp[22] = kernel5[r * 4 + 2];
114 ktmp[23] = kernel5[r * 4 + 3];
115
116 ktmp[24] = kernel6[r * 4 + 0];
117 ktmp[25] = kernel6[r * 4 + 1];
118 ktmp[26] = kernel6[r * 4 + 2];
119 ktmp[27] = kernel6[r * 4 + 3];
120
121 ktmp[28] = kernel7[r * 4 + 0];
122 ktmp[29] = kernel7[r * 4 + 1];
123 ktmp[30] = kernel7[r * 4 + 2];
124 ktmp[31] = kernel7[r * 4 + 3];
125
126 ktmp += 32;
127 kernel0 += 16;
128 kernel1 += 16;
129 kernel2 += 16;
130 kernel3 += 16;
131 kernel4 += 16;
132 kernel5 += 16;
133 kernel6 += 16;
134 kernel7 += 16;
135 }
136 }
137
138 for (; p + 3 < outch; p += 4)
139 {
140 const short* kernel0 = (const short*)kernel_tm + (p + 0) * inch * 16;
141 const short* kernel1 = (const short*)kernel_tm + (p + 1) * inch * 16;
142 const short* kernel2 = (const short*)kernel_tm + (p + 2) * inch * 16;
143 const short* kernel3 = (const short*)kernel_tm + (p + 3) * inch * 16;
144
145 short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4);
146
147 for (int q = 0; q < inch; q++)
148 {
149 ktmp[0] = kernel0[r * 4 + 0];
150 ktmp[1] = kernel0[r * 4 + 1];
151 ktmp[2] = kernel0[r * 4 + 2];
152 ktmp[3] = kernel0[r * 4 + 3];
153
154 ktmp[4] = kernel1[r * 4 + 0];
155 ktmp[5] = kernel1[r * 4 + 1];
156 ktmp[6] = kernel1[r * 4 + 2];
157 ktmp[7] = kernel1[r * 4 + 3];
158
159 ktmp[8] = kernel2[r * 4 + 0];
160 ktmp[9] = kernel2[r * 4 + 1];
161 ktmp[10] = kernel2[r * 4 + 2];
162 ktmp[11] = kernel2[r * 4 + 3];
163
164 ktmp[12] = kernel3[r * 4 + 0];
165 ktmp[13] = kernel3[r * 4 + 1];
166 ktmp[14] = kernel3[r * 4 + 2];
167 ktmp[15] = kernel3[r * 4 + 3];
168
169 ktmp += 16;
170 kernel0 += 16;
171 kernel1 += 16;
172 kernel2 += 16;
173 kernel3 += 16;
174 }
175 }
176
177 for (; p < outch; p++)
178 {
179 const short* kernel0 = (const short*)kernel_tm + p * inch * 16;
180
181 short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4 + p % 4);
182
183 for (int q = 0; q < inch; q++)
184 {
185 ktmp[0] = kernel0[r * 4 + 0];
186 ktmp[1] = kernel0[r * 4 + 1];
187 ktmp[2] = kernel0[r * 4 + 2];
188 ktmp[3] = kernel0[r * 4 + 3];
189
190 ktmp += 4;
191 kernel0 += 16;
192 }
193 }
194 kernel_tm2.push_back(kernel_tm_test);
195 }
196 }
197
conv3x3s1_winograd23_int8_neon(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Option & opt)198 static void conv3x3s1_winograd23_int8_neon(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Option& opt)
199 {
200 int w = bottom_blob.w;
201 int h = bottom_blob.h;
202 int inch = bottom_blob.c;
203
204 int outw = top_blob.w;
205 int outh = top_blob.h;
206 int outch = top_blob.c;
207
208 // pad to 2n+2, winograd F(2,3)
209 Mat bottom_blob_bordered = bottom_blob;
210
211 outw = (outw + 1) / 2 * 2;
212 outh = (outh + 1) / 2 * 2;
213
214 w = outw + 2;
215 h = outh + 2;
216 Option opt_b = opt;
217 opt_b.blob_allocator = opt.workspace_allocator;
218 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
219
220 // BEGIN transform input
221 Mat bottom_blob_tm;
222 {
223 int w_tm = outw / 2 * 4;
224 int h_tm = outh / 2 * 4;
225
226 int nColBlocks = h_tm / 4; // may be the block num in FeatherCNN
227 int nRowBlocks = w_tm / 4;
228
229 const int tiles = nColBlocks * nRowBlocks;
230
231 bottom_blob_tm.create(4, inch, tiles * 4, 2u, opt.workspace_allocator);
232
233 // BT
234 // const float itm[4][4] = {
235 // {1.0f, 0.0f, -1.0f, 0.0f},
236 // {0.0f, 1.0f, 1.00f, 0.0f},
237 // {0.0f, -1.0f, 1.00f, 0.0f},
238 // {0.0f, -1.0f, 0.00f, 1.0f}
239 // };
240
241 #pragma omp parallel for num_threads(opt.num_threads)
242 for (int q = 0; q < inch; q++)
243 {
244 const signed char* img = bottom_blob_bordered.channel(q);
245
246 for (int j = 0; j < nColBlocks; j++)
247 {
248 const signed char* r0 = img + w * j * 2;
249 const signed char* r1 = r0 + w;
250 const signed char* r2 = r1 + w;
251 const signed char* r3 = r2 + w;
252
253 for (int i = 0; i < nRowBlocks; i++)
254 {
255 short* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row<short>(q);
256 short* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row<short>(q);
257 short* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row<short>(q);
258 short* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row<short>(q);
259 #if __ARM_NEON
260 #if __aarch64__
261 asm volatile(
262 // load
263 "prfm pldl1keep, [%0, #64] \n"
264 "ld1 {v0.8b}, [%0] \n"
265 "prfm pldl1keep, [%1, #64] \n"
266 "ld1 {v1.8b}, [%1] \n"
267 "prfm pldl1keep, [%2, #64] \n"
268 "ld1 {v2.8b}, [%2] \n"
269 "prfm pldl1keep, [%3, #64] \n"
270 "ld1 {v3.8b}, [%3] \n"
271 // w = B_t * d, trans int8 to int16
272 "ssubl v4.8h, v0.8b, v2.8b \n" // d4
273 "saddl v5.8h, v1.8b, v2.8b \n" // d6
274 "ssubl v6.8h, v2.8b, v1.8b \n" // d8
275 "ssubl v7.8h, v3.8b, v1.8b \n" // d10
276 // transpose w to w_t
277 "trn1 v8.4h, v4.4h, v5.4h \n"
278 "trn2 v9.4h, v4.4h, v5.4h \n"
279 "trn1 v10.4h, v6.4h, v7.4h \n"
280 "trn2 v11.4h, v6.4h, v7.4h \n"
281
282 "trn1 v0.2s, v8.2s, v10.2s \n"
283 "trn2 v2.2s, v8.2s, v10.2s \n"
284 "trn1 v1.2s, v9.2s, v11.2s \n"
285 "trn2 v3.2s, v9.2s, v11.2s \n"
286 // U = B_t * d_t
287 "sub v4.4h, v0.4h, v2.4h \n"
288 "add v5.4h, v1.4h, v2.4h \n"
289 "sub v6.4h, v2.4h, v1.4h \n"
290 "sub v7.4h, v3.4h, v1.4h \n"
291 // save
292 "st1 {v4.4h}, [%4] \n"
293 "st1 {v5.4h}, [%5] \n"
294 "st1 {v6.4h}, [%6] \n"
295 "st1 {v7.4h}, [%7] \n"
296 : "=r"(r0), // %0
297 "=r"(r1), // %1
298 "=r"(r2), // %2
299 "=r"(r3), // %3
300 "=r"(out_tm0), // %4
301 "=r"(out_tm1), // %5
302 "=r"(out_tm2), // %6
303 "=r"(out_tm3) // %7
304 : "0"(r0),
305 "1"(r1),
306 "2"(r2),
307 "3"(r3),
308 "4"(out_tm0),
309 "5"(out_tm1),
310 "6"(out_tm2),
311 "7"(out_tm3)
312 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
313 #else
314 asm volatile(
315 // load
316 "pld [%0, #64] \n"
317 "vld1.s8 {d0}, [%0] \n"
318 "pld [%1, #64] \n"
319 "vld1.s8 {d1}, [%1] \n"
320 "pld [%2, #64] \n"
321 "vld1.s8 {d2}, [%2] \n"
322 "pld [%3, #64] \n"
323 "vld1.s8 {d3}, [%3] \n"
324 // w = B_t * d, trans int8 to int16
325 "vsubl.s8 q2, d0, d2 \n" // d4
326 "vaddl.s8 q3, d1, d2 \n" // d6
327 "vsubl.s8 q4, d2, d1 \n" // d8
328 "vsubl.s8 q5, d3, d1 \n" // d10
329 // transpose w to w_t
330 "vtrn.s16 d4, d6 \n"
331 "vtrn.s16 d8, d10 \n"
332 "vtrn.s32 d4, d8 \n"
333 "vtrn.s32 d6, d10 \n"
334 // U = B_t * d_t
335 "vsub.s16 d11, d4, d8 \n"
336 "vadd.s16 d12, d6, d8 \n"
337 "vsub.s16 d13, d8, d6 \n"
338 "vsub.s16 d14, d10, d6 \n"
339 // save
340 "vst1.s32 {d11}, [%4] \n"
341 "vst1.s32 {d12}, [%5] \n"
342 "vst1.s32 {d13}, [%6] \n"
343 "vst1.s32 {d14}, [%7] \n"
344 : "=r"(r0), // %0
345 "=r"(r1), // %1
346 "=r"(r2), // %2
347 "=r"(r3), // %3
348 "=r"(out_tm0), // %4
349 "=r"(out_tm1), // %5
350 "=r"(out_tm2), // %6
351 "=r"(out_tm3) // %7
352 : "0"(r0),
353 "1"(r1),
354 "2"(r2),
355 "3"(r3),
356 "4"(out_tm0),
357 "5"(out_tm1),
358 "6"(out_tm2),
359 "7"(out_tm3)
360 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
361 #endif // __aarch64__
362 #else
363 short d0[4], d1[4], d2[4], d3[4];
364 short w0[4], w1[4], w2[4], w3[4];
365 short t0[4], t1[4], t2[4], t3[4];
366 // load
367 for (int n = 0; n < 4; n++)
368 {
369 d0[n] = r0[n];
370 d1[n] = r1[n];
371 d2[n] = r2[n];
372 d3[n] = r3[n];
373 }
374 // w = B_t * d
375 for (int n = 0; n < 4; n++)
376 {
377 w0[n] = d0[n] - d2[n];
378 w1[n] = d1[n] + d2[n];
379 w2[n] = d2[n] - d1[n];
380 w3[n] = d3[n] - d1[n];
381 }
382 // transpose d to d_t
383 {
384 t0[0] = w0[0];
385 t1[0] = w0[1];
386 t2[0] = w0[2];
387 t3[0] = w0[3];
388 t0[1] = w1[0];
389 t1[1] = w1[1];
390 t2[1] = w1[2];
391 t3[1] = w1[3];
392 t0[2] = w2[0];
393 t1[2] = w2[1];
394 t2[2] = w2[2];
395 t3[2] = w2[3];
396 t0[3] = w3[0];
397 t1[3] = w3[1];
398 t2[3] = w3[2];
399 t3[3] = w3[3];
400 }
401 // U = B_t * d_t
402 for (int n = 0; n < 4; n++)
403 {
404 d0[n] = t0[n] - t2[n];
405 d1[n] = t1[n] + t2[n];
406 d2[n] = t2[n] - t1[n];
407 d3[n] = t3[n] - t1[n];
408 }
409 // save to out_tm
410 for (int n = 0; n < 4; n++)
411 {
412 out_tm0[n] = d0[n];
413 out_tm1[n] = d1[n];
414 out_tm2[n] = d2[n];
415 out_tm3[n] = d3[n];
416 }
417 #endif
418 r0 += 2;
419 r1 += 2;
420 r2 += 2;
421 r3 += 2;
422 }
423 }
424 }
425 }
426 bottom_blob_bordered = Mat();
427
428 // BEGIN dot
429 Mat top_blob_tm;
430 {
431 int w_tm = outw / 2 * 4;
432 int h_tm = outh / 2 * 4;
433
434 int nColBlocks = h_tm / 4; // may be the block num in FeatherCNN
435 int nRowBlocks = w_tm / 4;
436
437 const int tiles = nColBlocks * nRowBlocks;
438
439 top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator);
440
441 #pragma omp parallel for num_threads(opt.num_threads)
442 for (int r = 0; r < 4; r++)
443 {
444 int nn_outch = 0;
445 int remain_outch_start = 0;
446
447 nn_outch = outch >> 3;
448 remain_outch_start = nn_outch << 3;
449
450 for (int pp = 0; pp < nn_outch; pp++)
451 {
452 int p = pp * 8;
453
454 int* output0_tm = top_blob_tm.channel(p);
455 int* output1_tm = top_blob_tm.channel(p + 1);
456 int* output2_tm = top_blob_tm.channel(p + 2);
457 int* output3_tm = top_blob_tm.channel(p + 3);
458 int* output4_tm = top_blob_tm.channel(p + 4);
459 int* output5_tm = top_blob_tm.channel(p + 5);
460 int* output6_tm = top_blob_tm.channel(p + 6);
461 int* output7_tm = top_blob_tm.channel(p + 7);
462
463 output0_tm = output0_tm + r * 4;
464 output1_tm = output1_tm + r * 4;
465 output2_tm = output2_tm + r * 4;
466 output3_tm = output3_tm + r * 4;
467 output4_tm = output4_tm + r * 4;
468 output5_tm = output5_tm + r * 4;
469 output6_tm = output6_tm + r * 4;
470 output7_tm = output7_tm + r * 4;
471
472 for (int i = 0; i < tiles; i++)
473 {
474 const short* kptr = kernel_tm_test[r].channel(p / 8);
475 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
476 #if __ARM_NEON
477 #if __aarch64__
478 asm volatile(
479 // inch loop
480 "eor v0.16b, v0.16b, v0.16b \n"
481 "eor v1.16b, v1.16b, v1.16b \n"
482 "eor v2.16b, v2.16b, v2.16b \n"
483 "eor v3.16b, v3.16b, v3.16b \n"
484 "eor v4.16b, v4.16b, v4.16b \n"
485 "eor v5.16b, v5.16b, v5.16b \n"
486 "eor v6.16b, v6.16b, v6.16b \n"
487 "eor v7.16b, v7.16b, v7.16b \n"
488 "mov w4, %w20 \n"
489
490 "0: \n" // for (int q=0; q<inch; q++)
491 "prfm pldl1keep, [%9, #128] \n" // _r0 = vld1_s16(r0); // input inch0
492 "ld1 {v8.4h}, [%8] \n"
493 "ld1 {v9.4h, v10.4h}, [%9] \n" // _k0 = vld1q_s16(kptr);
494 "add %9, %9, #16 \n"
495 "ld1 {v11.4h, v12.4h}, [%9] \n" // _k0n = vld1q_s16(kptr+8);
496 "add %9, %9, #16 \n"
497 "ld1 {v13.4h, v14.4h}, [%9] \n" // _k1 = vld1q_s16(kptr+16);
498 "add %9, %9, #16 \n"
499 "ld1 {v15.4h, v16.4h}, [%9] \n" // _k1n = vld1q_s16(kptr+24);
500 "add %8, %8, #8 \n"
501 "add %9, %9, #16 \n"
502
503 "subs w4, w4, #1 \n"
504
505 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
506 "smlal v1.4s, v8.4h, v10.4h \n" // sum1 += (a00-a03) * (k10-k13)
507 "smlal v2.4s, v8.4h, v11.4h \n" // sum2 += (a00-a03) * (k20-k23)
508 "smlal v3.4s, v8.4h, v12.4h \n" // sum3 += (a00-a03) * (k30-k33)
509 "smlal v4.4s, v8.4h, v13.4h \n" // sum4 += (a00-a03) * (k40-k43)
510 "smlal v5.4s, v8.4h, v14.4h \n" // sum5 += (a00-a03) * (k50-k53)
511 "smlal v6.4s, v8.4h, v15.4h \n" // sum6 += (a00-a03) * (k60-k63)
512 "smlal v7.4s, v8.4h, v16.4h \n" // sum7 += (a00-a03) * (k70-k73)
513
514 "bne 0b \n" // end for
515
516 "st1 {v0.4s}, [%0] \n" // store the result to memory
517 "st1 {v1.4s}, [%1] \n" //
518 "st1 {v2.4s}, [%2] \n" //
519 "st1 {v3.4s}, [%3] \n" //
520 "st1 {v4.4s}, [%4] \n" //
521 "st1 {v5.4s}, [%5] \n" //
522 "st1 {v6.4s}, [%6] \n" //
523 "st1 {v7.4s}, [%7] \n" //
524
525 : "=r"(output0_tm), // %0
526 "=r"(output1_tm), // %1
527 "=r"(output2_tm), // %2
528 "=r"(output3_tm), // %3
529 "=r"(output4_tm), // %4
530 "=r"(output5_tm), // %5
531 "=r"(output6_tm), // %6
532 "=r"(output7_tm), // %7
533 "=r"(r0), // %8
534 "=r"(kptr) // %9
535 : "0"(output0_tm),
536 "1"(output1_tm),
537 "2"(output2_tm),
538 "3"(output3_tm),
539 "4"(output4_tm),
540 "5"(output5_tm),
541 "6"(output6_tm),
542 "7"(output7_tm),
543 "8"(r0),
544 "9"(kptr),
545 "r"(inch) // %20
546 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16");
547 #else
548 asm volatile(
549 // inch loop
550 "vmov.s32 q0, #0 \n"
551 "vmov.s32 q1, #0 \n"
552 "vmov.s32 q2, #0 \n"
553 "vmov.s32 q3, #0 \n"
554 "vmov.s32 q4, #0 \n"
555 "vmov.s32 q5, #0 \n"
556 "vmov.s32 q6, #0 \n"
557 "vmov.s32 q7, #0 \n"
558 "mov r4, %20 \n"
559
560 "0: \n" // for (int q=0; q<inch; q++)
561 "vld1.s16 {d16}, [%8]! \n" // _r0 = vld1_s16(r0); // input inch0
562 "vld1.s16 {d18-d19}, [%9] \n" // _k0 = vld1q_s16(kptr);
563 "add %9, #16 \n"
564 "vld1.s16 {d20-d21}, [%9] \n" // _k0n = vld1q_s16(kptr+8);
565 "add %9, #16 \n"
566 "vld1.s16 {d22-d23}, [%9] \n" // _k1 = vld1q_s16(kptr+16);
567 "add %9, #16 \n"
568 "vld1.s16 {d24-d25}, [%9] \n" // _k1n = vld1q_s16(kptr+24);
569 "add %9, #16 \n"
570
571 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
572 "vmlal.s16 q1, d16, d19 \n" // sum1 += (a00-a03) * (k10-k13)
573 "vmlal.s16 q2, d16, d20 \n" // sum2 += (a00-a03) * (k20-k23)
574 "vmlal.s16 q3, d16, d21 \n" // sum3 += (a00-a03) * (k30-k33)
575 "vmlal.s16 q4, d16, d22 \n" // sum4 += (a00-a03) * (k40-k43)
576 "vmlal.s16 q5, d16, d23 \n" // sum5 += (a00-a03) * (k50-k53)
577 "vmlal.s16 q6, d16, d24 \n" // sum6 += (a00-a03) * (k60-k63)
578 "vmlal.s16 q7, d16, d25 \n" // sum7 += (a00-a03) * (k70-k73)
579
580 "subs r4, r4, #1 \n"
581 "bne 0b \n" // end for
582
583 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
584 "vst1.s32 {d2-d3}, [%1] \n"
585 "vst1.s32 {d4-d5}, [%2] \n"
586 "vst1.s32 {d6-d7}, [%3] \n"
587 "vst1.s32 {d8-d9}, [%4] \n"
588 "vst1.s32 {d10-d11}, [%5] \n"
589 "vst1.s32 {d12-d13}, [%6] \n"
590 "vst1.s32 {d14-d15}, [%7] \n"
591
592 : "=r"(output0_tm), // %0
593 "=r"(output1_tm), // %1
594 "=r"(output2_tm), // %2
595 "=r"(output3_tm), // %3
596 "=r"(output4_tm), // %4
597 "=r"(output5_tm), // %5
598 "=r"(output6_tm), // %6
599 "=r"(output7_tm), // %7
600 "=r"(r0), // %8
601 "=r"(kptr) // %9
602 : "0"(output0_tm),
603 "1"(output1_tm),
604 "2"(output2_tm),
605 "3"(output3_tm),
606 "4"(output4_tm),
607 "5"(output5_tm),
608 "6"(output6_tm),
609 "7"(output7_tm),
610 "8"(r0),
611 "9"(kptr),
612 "r"(inch) // %20
613 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
614 #endif // __aarch64__
615 #else
616 int sum0[4] = {0};
617 int sum1[4] = {0};
618 int sum2[4] = {0};
619 int sum3[4] = {0};
620 int sum4[4] = {0};
621 int sum5[4] = {0};
622 int sum6[4] = {0};
623 int sum7[4] = {0};
624
625 for (int q = 0; q < inch; q++)
626 {
627 for (int n = 0; n < 4; n++)
628 {
629 sum0[n] += (int)r0[n] * kptr[n];
630 sum1[n] += (int)r0[n] * kptr[n + 4];
631 sum2[n] += (int)r0[n] * kptr[n + 8];
632 sum3[n] += (int)r0[n] * kptr[n + 12];
633 sum4[n] += (int)r0[n] * kptr[n + 16];
634 sum5[n] += (int)r0[n] * kptr[n + 20];
635 sum6[n] += (int)r0[n] * kptr[n + 24];
636 sum7[n] += (int)r0[n] * kptr[n + 28];
637 }
638 kptr += 32;
639 r0 += 4;
640 }
641
642 for (int n = 0; n < 4; n++)
643 {
644 output0_tm[n] = sum0[n];
645 output1_tm[n] = sum1[n];
646 output2_tm[n] = sum2[n];
647 output3_tm[n] = sum3[n];
648 output4_tm[n] = sum4[n];
649 output5_tm[n] = sum5[n];
650 output6_tm[n] = sum6[n];
651 output7_tm[n] = sum7[n];
652 }
653 #endif // __ARM_NEON
654 output0_tm += 16;
655 output1_tm += 16;
656 output2_tm += 16;
657 output3_tm += 16;
658 output4_tm += 16;
659 output5_tm += 16;
660 output6_tm += 16;
661 output7_tm += 16;
662 }
663 }
664
665 nn_outch = (outch - remain_outch_start) >> 2;
666
667 for (int pp = 0; pp < nn_outch; pp++)
668 {
669 int p = remain_outch_start + pp * 4;
670
671 int* output0_tm = top_blob_tm.channel(p);
672 int* output1_tm = top_blob_tm.channel(p + 1);
673 int* output2_tm = top_blob_tm.channel(p + 2);
674 int* output3_tm = top_blob_tm.channel(p + 3);
675
676 output0_tm = output0_tm + r * 4;
677 output1_tm = output1_tm + r * 4;
678 output2_tm = output2_tm + r * 4;
679 output3_tm = output3_tm + r * 4;
680
681 for (int i = 0; i < tiles; i++)
682 {
683 const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
684 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
685 #if __ARM_NEON
686 #if __aarch64__
687 asm volatile(
688 // inch loop
689 "eor v0.16b, v0.16b, v0.16b \n"
690 "eor v1.16b, v1.16b, v1.16b \n"
691 "eor v2.16b, v2.16b, v2.16b \n"
692 "eor v3.16b, v3.16b, v3.16b \n"
693 "mov w4, %w12 \n"
694
695 "0: \n" // for (int q=0; q<inch; q++)
696 "prfm pldl1keep, [%5, #128] \n" // _r0 = vld1_s16(r0); // input inch0
697 "ld1 {v8.4h}, [%4] \n"
698 "ld1 {v9.4h, v10.4h}, [%5] \n" // _k0 = vld1q_s16(kptr);
699 "add %5, %5, #16 \n"
700 "ld1 {v11.4h, v12.4h}, [%5] \n" // _k0n = vld1q_s16(kptr+8);
701 "add %4, %4, #8 \n"
702 "add %5, %5, #16 \n"
703
704 "subs w4, w4, #1 \n"
705
706 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
707 "smlal v1.4s, v8.4h, v10.4h \n" // sum1 += (a00-a03) * (k10-k13)
708 "smlal v2.4s, v8.4h, v11.4h \n" // sum2 += (a00-a03) * (k20-k23)
709 "smlal v3.4s, v8.4h, v12.4h \n" // sum3 += (a00-a03) * (k30-k33)
710
711 "bne 0b \n" // end for
712
713 "st1 {v0.4s}, [%0] \n" // store the result to memory
714 "st1 {v1.4s}, [%1] \n" //
715 "st1 {v2.4s}, [%2] \n" //
716 "st1 {v3.4s}, [%3] \n" //
717
718 : "=r"(output0_tm), // %0
719 "=r"(output1_tm), // %1
720 "=r"(output2_tm), // %2
721 "=r"(output3_tm), // %3
722 "=r"(r0), // %4
723 "=r"(kptr) // %5
724 : "0"(output0_tm),
725 "1"(output1_tm),
726 "2"(output2_tm),
727 "3"(output3_tm),
728 "4"(r0),
729 "5"(kptr),
730 "r"(inch) // %12
731 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
732 #else
733 asm volatile(
734 // inch loop
735 "vmov.s32 q0, #0 \n"
736 "vmov.s32 q1, #0 \n"
737 "vmov.s32 q2, #0 \n"
738 "vmov.s32 q3, #0 \n"
739 "mov r4, %12 \n"
740
741 "0: \n" // for (int q=0; q<inch; q++)
742 "vld1.s16 {d16}, [%4]! \n" // _r0 = vld1_s16(r0); // input inch0
743 "vld1.s16 {d18-d19}, [%5] \n" // _k0 = vld1q_s16(kptr);
744 "add %5, #16 \n"
745 "vld1.s16 {d20-d21}, [%5] \n" // _k0n = vld1q_s16(kptr+8);
746 "add %5, #16 \n"
747
748 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
749 "vmlal.s16 q1, d16, d19 \n" // sum1 += (a00-a03) * (k10-k13)
750 "vmlal.s16 q2, d16, d20 \n" // sum2 += (a00-a03) * (k20-k23)
751 "vmlal.s16 q3, d16, d21 \n" // sum3 += (a00-a03) * (k30-k33)
752
753 "subs r4, r4, #1 \n"
754 "bne 0b \n" // end for
755
756 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
757 "vst1.s32 {d2-d3}, [%1] \n"
758 "vst1.s32 {d4-d5}, [%2] \n"
759 "vst1.s32 {d6-d7}, [%3] \n"
760
761 : "=r"(output0_tm), // %0
762 "=r"(output1_tm), // %1
763 "=r"(output2_tm), // %2
764 "=r"(output3_tm), // %3
765 "=r"(r0), // %4
766 "=r"(kptr) // %5
767 : "0"(output0_tm),
768 "1"(output1_tm),
769 "2"(output2_tm),
770 "3"(output3_tm),
771 "4"(r0),
772 "5"(kptr),
773 "r"(inch) // %12
774 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10");
775 #endif // __aarch64__
776 #else
777 int sum0[4] = {0};
778 int sum1[4] = {0};
779 int sum2[4] = {0};
780 int sum3[4] = {0};
781
782 for (int q = 0; q < inch; q++)
783 {
784 for (int n = 0; n < 4; n++)
785 {
786 sum0[n] += (int)r0[n] * kptr[n];
787 sum1[n] += (int)r0[n] * kptr[n + 4];
788 sum2[n] += (int)r0[n] * kptr[n + 8];
789 sum3[n] += (int)r0[n] * kptr[n + 12];
790 }
791 kptr += 16;
792 r0 += 4;
793 }
794
795 for (int n = 0; n < 4; n++)
796 {
797 output0_tm[n] = sum0[n];
798 output1_tm[n] = sum1[n];
799 output2_tm[n] = sum2[n];
800 output3_tm[n] = sum3[n];
801 }
802 #endif // __ARM_NEON
803 output0_tm += 16;
804 output1_tm += 16;
805 output2_tm += 16;
806 output3_tm += 16;
807 }
808 }
809
810 remain_outch_start += nn_outch << 2;
811
812 for (int p = remain_outch_start; p < outch; p++)
813 {
814 int* output0_tm = top_blob_tm.channel(p);
815
816 output0_tm = output0_tm + r * 4;
817
818 for (int i = 0; i < tiles; i++)
819 {
820 const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
821 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
822 #if __ARM_NEON
823 #if __aarch64__
824 asm volatile(
825 // inch loop
826 "eor v0.16b, v0.16b, v0.16b \n"
827 "mov w4, %w6 \n"
828
829 "0: \n" // for (int q=0; q<inch; q++)
830 //"prfm pldl1keep, [%2, #128] \n" // _r0 = vld1_s16(r0); // input inch0
831 "ld1 {v8.4h}, [%1] \n"
832 "ld1 {v9.4h}, [%2] \n" // _k0 = vld1q_s16(kptr);
833 "add %1, %1, #8 \n"
834 "add %2, %2, #8 \n"
835
836 "subs w4, w4, #1 \n"
837
838 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
839
840 "bne 0b \n" // end for
841
842 "st1 {v0.4s}, [%0] \n" // store the result to memory
843
844 : "=r"(output0_tm), // %0
845 "=r"(r0), // %1
846 "=r"(kptr) // %2
847 : "0"(output0_tm),
848 "1"(r0),
849 "2"(kptr),
850 "r"(inch) // %6
851 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9");
852 #else
853 asm volatile(
854 // inch loop
855 "vmov.s32 q0, #0 \n"
856 "mov r4, %6 \n"
857
858 "0: \n" // for (int q=0; q<inch; q++)
859 "vld1.s16 {d16}, [%1] \n" // _r0 = vld1_s16(r0); // input inch0
860 "add %1, #8 \n"
861 "vld1.s16 {d18}, [%2] \n" // _k0 = vld1q_s16(kptr);
862 "add %2, #8 \n"
863 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
864
865 "subs r4, r4, #1 \n"
866 "bne 0b \n" // end for
867
868 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
869
870 : "=r"(output0_tm), // %0
871 "=r"(r0), // %1
872 "=r"(kptr) // %2
873 : "0"(output0_tm),
874 "1"(r0),
875 "2"(kptr),
876 "r"(inch) // %6
877 : "cc", "memory", "r4", "q0", "q8", "q9");
878 #endif // __aarch64__
879 #else
880 int sum0[4] = {0};
881
882 for (int q = 0; q < inch; q++)
883 {
884 for (int n = 0; n < 4; n++)
885 {
886 sum0[n] += (int)r0[n] * kptr[n];
887 }
888 kptr += 4;
889 r0 += 4;
890 }
891
892 for (int n = 0; n < 4; n++)
893 {
894 output0_tm[n] = sum0[n];
895 }
896 #endif
897 output0_tm += 16;
898 }
899 }
900 }
901 }
902 bottom_blob_tm = Mat();
903 // END dot
904
905 // BEGIN transform output
906 Mat top_blob_bordered;
907 top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
908 {
909 // AT
910 // const float itm[2][4] = {
911 // {1.0f, 1.0f, 1.0f, 0.0f},
912 // {0.0f, 1.0f, -1.0f, 1.0f}
913 // };
914
915 int w_tm = outw / 2 * 4;
916 int h_tm = outh / 2 * 4;
917
918 int nColBlocks = h_tm / 4; // may be the block num in FeatherCNN
919 int nRowBlocks = w_tm / 4;
920
921 #if __ARM_NEON
922 int32x2_t _shift = vdup_n_s32(-2);
923 #endif
924
925 #pragma omp parallel for num_threads(opt.num_threads)
926 for (int p = 0; p < outch; p++)
927 {
928 int* out_tile = top_blob_tm.channel(p);
929 int* outRow0 = top_blob_bordered.channel(p);
930 int* outRow1 = outRow0 + outw;
931
932 for (int j = 0; j < nColBlocks; j++)
933 {
934 for (int i = 0; i < nRowBlocks; i++)
935 {
936 #if __ARM_NEON
937 #if __aarch64__
938 asm volatile(
939 "prfm pldl1keep, [%0, #512] \n"
940 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
941
942 "add v0.4s, v0.4s, v1.4s \n" // s0 = s0 + s1 + s2;
943 "sub v1.4s, v1.4s, v2.4s \n"
944 "add v0.4s, v0.4s, v2.4s \n" // s1 = s1 - s2 + s3;
945 "add v1.4s, v1.4s, v3.4s \n"
946
947 "trn1 v4.4s, v0.4s, v1.4s \n"
948 "trn2 v5.4s, v0.4s, v1.4s \n"
949
950 "dup v6.2d, v4.d[1] \n"
951 "dup v7.2d, v5.d[1] \n"
952
953 "add v0.2s, v4.2s, v5.2s \n" // o0 = d0 + d1 + d2;
954 "sub v1.2s, v5.2s, v6.2s \n"
955 "add v0.2s, v0.2s, v6.2s \n" // o1 = d1 - d2 + d3;
956 "add v1.2s, v1.2s, v7.2s \n"
957
958 "sshl v0.2s, v0.2s, %6.2s \n" // o0 = o0 >> 2
959 "sshl v1.2s, v1.2s, %6.2s \n" // o1 = o1 >> 2
960
961 "st1 {v0.2s}, [%1], #8 \n"
962 "st1 {v1.2s}, [%2], #8 \n"
963 : "=r"(out_tile), // %0
964 "=r"(outRow0), // %1
965 "=r"(outRow1) // %2
966 : "0"(out_tile),
967 "1"(outRow0),
968 "2"(outRow1),
969 "w"(_shift) // %6
970 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
971 #else
972 asm volatile(
973 "pld [%0, #512] \n"
974 "vldm %0!, {d0-d7} \n"
975
976 "vaddq.s32 q0, q0, q1 \n" // s0 = s0 + s1 + s2;
977 "vsubq.s32 q1, q1, q2 \n"
978 "vaddq.s32 q0, q0, q2 \n" // s1 = s1 - s2 + s3;
979 "vaddq.s32 q1, q1, q3 \n"
980
981 "vtrn.s32 q0, q1 \n"
982
983 "vadd.s32 d8, d0, d2 \n" // o0 = d0 + d1 + d2;
984 "vsub.s32 d9, d2, d1 \n"
985 "vadd.s32 d8, d8, d1 \n" // o1 = d1 - d2 + d3;
986 "vadd.s32 d9, d9, d3 \n"
987
988 "vshl.s32 d8, d8, %P6 \n" // o0 = o0 >> 2
989 "vshl.s32 d9, d9, %P6 \n" // o1 = o1 >> 2
990
991 "vst1.s32 {d8}, [%1]! \n"
992 "vst1.s32 {d9}, [%2]! \n"
993 : "=r"(out_tile), // %0
994 "=r"(outRow0), // %1
995 "=r"(outRow1) // %2
996 : "0"(out_tile),
997 "1"(outRow0),
998 "2"(outRow1),
999 "w"(_shift) // %6
1000 : "cc", "memory", "q0", "q1", "q2", "q3", "q4");
1001 #endif // __aarch64__
1002 #else
1003 int s0[4], s1[4], s2[4], s3[4];
1004 int w0[4], w1[4];
1005 int d0[2], d1[2], d2[2], d3[2];
1006 int o0[2], o1[2];
1007 // load
1008 for (int n = 0; n < 4; n++)
1009 {
1010 s0[n] = out_tile[n];
1011 s1[n] = out_tile[n + 4];
1012 s2[n] = out_tile[n + 8];
1013 s3[n] = out_tile[n + 12];
1014 }
1015 // w = A_T * W
1016 for (int n = 0; n < 4; n++)
1017 {
1018 w0[n] = s0[n] + s1[n] + s2[n];
1019 w1[n] = s1[n] - s2[n] + s3[n];
1020 }
1021 // transpose w to w_t
1022 {
1023 d0[0] = w0[0];
1024 d0[1] = w1[0];
1025 d1[0] = w0[1];
1026 d1[1] = w1[1];
1027 d2[0] = w0[2];
1028 d2[1] = w1[2];
1029 d3[0] = w0[3];
1030 d3[1] = w1[3];
1031 }
1032 // Y = A_T * w_t
1033 for (int n = 0; n < 2; n++)
1034 {
1035 o0[n] = d0[n] + d1[n] + d2[n];
1036 o1[n] = d1[n] - d2[n] + d3[n];
1037 }
1038 // save to top blob tm,why right 2,because the G' = G*2
1039 outRow0[0] = o0[0] >> 2;
1040 outRow0[1] = o0[1] >> 2;
1041 outRow1[0] = o1[0] >> 2;
1042 outRow1[1] = o1[1] >> 2;
1043
1044 out_tile += 16;
1045
1046 outRow0 += 2;
1047 outRow1 += 2;
1048 #endif // __ARM_NEON
1049 }
1050
1051 outRow0 += outw;
1052 outRow1 += outw;
1053 }
1054 }
1055 }
1056 // END transform output
1057
1058 // cut result pad
1059 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1060 }
1061
conv3x3s1_winograd43_transform_kernel_int8_neon(const Mat & kernel,std::vector<Mat> & kernel_tm2,int inch,int outch)1062 static void conv3x3s1_winograd43_transform_kernel_int8_neon(const Mat& kernel, std::vector<Mat>& kernel_tm2, int inch, int outch)
1063 {
1064 Mat kernel_tm(6 * 6, inch, outch, 2ul);
1065
1066 // G
1067 // const float ktm[6][3] = {
1068 // { 1.0f/4, 0.0f, 0.0f},
1069 // { -1.0f/6, -1.0f/6, -1.0f/6},
1070 // { -1.0f/6, 1.0f/6, -1.0f/6},
1071 // { 1.0f/24, 1.0f/12, 1.0f/6},
1072 // { 1.0f/24, -1.0f/12, 1.0f/6},
1073 // { 0.0f, 0.0f, 1.0f}
1074 // };
1075 const short ktm[6][3] = {
1076 {6, 0, 0},
1077 {-4, -4, -4},
1078 {-4, 4, -4},
1079 {1, 2, 4},
1080 {1, -2, 4},
1081 {0, 0, 6}
1082 };
1083
1084 #pragma omp parallel for
1085 for (int p = 0; p < outch; p++)
1086 {
1087 for (int q = 0; q < inch; q++)
1088 {
1089 const signed char* kernel0 = (const signed char*)kernel + p * inch * 9 + q * 9;
1090 short* kernel_tm0 = kernel_tm.channel(p).row<short>(q);
1091
1092 // transform kernel
1093 const signed char* k0 = kernel0;
1094 const signed char* k1 = kernel0 + 3;
1095 const signed char* k2 = kernel0 + 6;
1096
1097 // h
1098 short tmp[6][3];
1099 for (int i = 0; i < 6; i++)
1100 {
1101 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
1102 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
1103 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
1104 }
1105
1106 // U
1107 for (int j = 0; j < 6; j++)
1108 {
1109 short* tmpp = &tmp[j][0];
1110
1111 for (int i = 0; i < 6; i++)
1112 {
1113 kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
1114 }
1115 }
1116 }
1117 }
1118
1119 for (int r = 0; r < 9; r++)
1120 {
1121 Mat kernel_tm_test(4 * 8, inch, outch / 8 + (outch % 8) / 4 + outch % 4, 2u);
1122
1123 int p = 0;
1124 for (; p + 7 < outch; p += 8)
1125 {
1126 const short* kernel0 = (const short*)kernel_tm.channel(p);
1127 const short* kernel1 = (const short*)kernel_tm.channel(p + 1);
1128 const short* kernel2 = (const short*)kernel_tm.channel(p + 2);
1129 const short* kernel3 = (const short*)kernel_tm.channel(p + 3);
1130 const short* kernel4 = (const short*)kernel_tm.channel(p + 4);
1131 const short* kernel5 = (const short*)kernel_tm.channel(p + 5);
1132 const short* kernel6 = (const short*)kernel_tm.channel(p + 6);
1133 const short* kernel7 = (const short*)kernel_tm.channel(p + 7);
1134
1135 short* ktmp = kernel_tm_test.channel(p / 8);
1136
1137 for (int q = 0; q < inch; q++)
1138 {
1139 ktmp[0] = kernel0[r * 4 + 0];
1140 ktmp[1] = kernel0[r * 4 + 1];
1141 ktmp[2] = kernel0[r * 4 + 2];
1142 ktmp[3] = kernel0[r * 4 + 3];
1143
1144 ktmp[4] = kernel1[r * 4 + 0];
1145 ktmp[5] = kernel1[r * 4 + 1];
1146 ktmp[6] = kernel1[r * 4 + 2];
1147 ktmp[7] = kernel1[r * 4 + 3];
1148
1149 ktmp[8] = kernel2[r * 4 + 0];
1150 ktmp[9] = kernel2[r * 4 + 1];
1151 ktmp[10] = kernel2[r * 4 + 2];
1152 ktmp[11] = kernel2[r * 4 + 3];
1153
1154 ktmp[12] = kernel3[r * 4 + 0];
1155 ktmp[13] = kernel3[r * 4 + 1];
1156 ktmp[14] = kernel3[r * 4 + 2];
1157 ktmp[15] = kernel3[r * 4 + 3];
1158
1159 ktmp[16] = kernel4[r * 4 + 0];
1160 ktmp[17] = kernel4[r * 4 + 1];
1161 ktmp[18] = kernel4[r * 4 + 2];
1162 ktmp[19] = kernel4[r * 4 + 3];
1163
1164 ktmp[20] = kernel5[r * 4 + 0];
1165 ktmp[21] = kernel5[r * 4 + 1];
1166 ktmp[22] = kernel5[r * 4 + 2];
1167 ktmp[23] = kernel5[r * 4 + 3];
1168
1169 ktmp[24] = kernel6[r * 4 + 0];
1170 ktmp[25] = kernel6[r * 4 + 1];
1171 ktmp[26] = kernel6[r * 4 + 2];
1172 ktmp[27] = kernel6[r * 4 + 3];
1173
1174 ktmp[28] = kernel7[r * 4 + 0];
1175 ktmp[29] = kernel7[r * 4 + 1];
1176 ktmp[30] = kernel7[r * 4 + 2];
1177 ktmp[31] = kernel7[r * 4 + 3];
1178
1179 ktmp += 32;
1180 kernel0 += 36;
1181 kernel1 += 36;
1182 kernel2 += 36;
1183 kernel3 += 36;
1184 kernel4 += 36;
1185 kernel5 += 36;
1186 kernel6 += 36;
1187 kernel7 += 36;
1188 }
1189 }
1190
1191 for (; p + 3 < outch; p += 4)
1192 {
1193 const short* kernel0 = (const short*)kernel_tm.channel(p);
1194 const short* kernel1 = (const short*)kernel_tm.channel(p + 1);
1195 const short* kernel2 = (const short*)kernel_tm.channel(p + 2);
1196 const short* kernel3 = (const short*)kernel_tm.channel(p + 3);
1197
1198 short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4);
1199
1200 for (int q = 0; q < inch; q++)
1201 {
1202 ktmp[0] = kernel0[r * 4 + 0];
1203 ktmp[1] = kernel0[r * 4 + 1];
1204 ktmp[2] = kernel0[r * 4 + 2];
1205 ktmp[3] = kernel0[r * 4 + 3];
1206
1207 ktmp[4] = kernel1[r * 4 + 0];
1208 ktmp[5] = kernel1[r * 4 + 1];
1209 ktmp[6] = kernel1[r * 4 + 2];
1210 ktmp[7] = kernel1[r * 4 + 3];
1211
1212 ktmp[8] = kernel2[r * 4 + 0];
1213 ktmp[9] = kernel2[r * 4 + 1];
1214 ktmp[10] = kernel2[r * 4 + 2];
1215 ktmp[11] = kernel2[r * 4 + 3];
1216
1217 ktmp[12] = kernel3[r * 4 + 0];
1218 ktmp[13] = kernel3[r * 4 + 1];
1219 ktmp[14] = kernel3[r * 4 + 2];
1220 ktmp[15] = kernel3[r * 4 + 3];
1221
1222 ktmp += 16;
1223 kernel0 += 36;
1224 kernel1 += 36;
1225 kernel2 += 36;
1226 kernel3 += 36;
1227 }
1228 }
1229
1230 for (; p < outch; p++)
1231 {
1232 const short* kernel0 = (const short*)kernel_tm.channel(p);
1233
1234 short* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4 + p % 4);
1235
1236 for (int q = 0; q < inch; q++)
1237 {
1238 ktmp[0] = kernel0[r * 4 + 0];
1239 ktmp[1] = kernel0[r * 4 + 1];
1240 ktmp[2] = kernel0[r * 4 + 2];
1241 ktmp[3] = kernel0[r * 4 + 3];
1242
1243 ktmp += 4;
1244 kernel0 += 36;
1245 }
1246 }
1247 kernel_tm2.push_back(kernel_tm_test);
1248 }
1249 }
1250
conv3x3s1_winograd43_int8_neon(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Option & opt)1251 static void conv3x3s1_winograd43_int8_neon(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Option& opt)
1252 {
1253 int w = bottom_blob.w;
1254 int h = bottom_blob.h;
1255 int inch = bottom_blob.c;
1256
1257 int outw = top_blob.w;
1258 int outh = top_blob.h;
1259 int outch = top_blob.c;
1260
1261 // pad to 4n+2, winograd F(4,3)
1262 Mat bottom_blob_bordered = bottom_blob;
1263
1264 outw = (outw + 3) / 4 * 4;
1265 outh = (outh + 3) / 4 * 4;
1266
1267 w = outw + 2;
1268 h = outh + 2;
1269
1270 Option opt_b = opt;
1271 opt_b.blob_allocator = opt.workspace_allocator;
1272 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
1273
1274 // BEGIN transform input
1275 Mat bottom_blob_tm;
1276 {
1277 int w_tm = outw / 4 * 6;
1278 int h_tm = outh / 4 * 6;
1279
1280 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1281 int nRowBlocks = w_tm / 6;
1282
1283 const int tiles = nColBlocks * nRowBlocks;
1284
1285 bottom_blob_tm.create(4, inch, tiles * 9, 2u, opt.workspace_allocator);
1286
1287 // BT
1288 // const float itm[4][4] = {
1289 // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
1290 // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
1291 // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
1292 // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
1293 // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
1294 // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f}
1295 // };
1296
1297 // 0 = 4 * r00 - 5 * r02 + r04
1298 // 1 = -4 * (r01 + r02) + r03 + r04
1299 // 2 = 4 * (r01 - r02) - r03 + r04
1300 // 3 = -2 * r01 - r02 + 2 * r03 + r04
1301 // 4 = 2 * r01 - r02 - 2 * r03 + r04
1302 // 5 = 4 * r01 - 5 * r03 + r05
1303
1304 #pragma omp parallel for num_threads(opt.num_threads)
1305 for (int q = 0; q < inch; q++)
1306 {
1307 const signed char* img = bottom_blob_bordered.channel(q);
1308
1309 for (int j = 0; j < nColBlocks; j++)
1310 {
1311 const signed char* r0 = img + w * j * 4;
1312 const signed char* r1 = r0 + w;
1313 const signed char* r2 = r1 + w;
1314 const signed char* r3 = r2 + w;
1315 const signed char* r4 = r3 + w;
1316 const signed char* r5 = r4 + w;
1317
1318 for (int i = 0; i < nRowBlocks; i++)
1319 {
1320 short* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row<short>(q);
1321 short* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row<short>(q);
1322 short* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row<short>(q);
1323 short* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row<short>(q);
1324 short* out_tm4 = bottom_blob_tm.channel(tiles * 4 + j * nRowBlocks + i).row<short>(q);
1325 short* out_tm5 = bottom_blob_tm.channel(tiles * 5 + j * nRowBlocks + i).row<short>(q);
1326 short* out_tm6 = bottom_blob_tm.channel(tiles * 6 + j * nRowBlocks + i).row<short>(q);
1327 short* out_tm7 = bottom_blob_tm.channel(tiles * 7 + j * nRowBlocks + i).row<short>(q);
1328 short* out_tm8 = bottom_blob_tm.channel(tiles * 8 + j * nRowBlocks + i).row<short>(q);
1329 #if __ARM_NEON
1330 int8x8_t _d0, _d1, _d2, _d3, _d4, _d5;
1331 int16x8_t _w0, _w1, _w2, _w3, _w4, _w5;
1332 int16x8_t _t0, _t1, _t2, _t3, _t4, _t5;
1333 int16x8_t _n0, _n1, _n2, _n3, _n4, _n5;
1334 // load
1335 _d0 = vld1_s8(r0);
1336 _d1 = vld1_s8(r1);
1337 _d2 = vld1_s8(r2);
1338 _d3 = vld1_s8(r3);
1339 _d4 = vld1_s8(r4);
1340 _d5 = vld1_s8(r5);
1341
1342 int8x8_t _1_n = vdup_n_s8(-1);
1343 int8x8_t _2_p = vdup_n_s8(2);
1344 int8x8_t _2_n = vdup_n_s8(-2);
1345 int8x8_t _4_p = vdup_n_s8(4);
1346 int8x8_t _4_n = vdup_n_s8(-4);
1347 int8x8_t _5_n = vdup_n_s8(-5);
1348
1349 int16x8_t _1_n_s16 = vdupq_n_s16(-1);
1350 int16x8_t _2_p_s16 = vdupq_n_s16(2);
1351 int16x8_t _2_n_s16 = vdupq_n_s16(-2);
1352 int16x8_t _4_p_s16 = vdupq_n_s16(4);
1353 int16x8_t _4_n_s16 = vdupq_n_s16(-4);
1354 int16x8_t _5_n_s16 = vdupq_n_s16(-5);
1355 // w = B_t * d
1356 _w0 = vmull_s8(_d0, _4_p);
1357 _w0 = vmlal_s8(_w0, _d2, _5_n);
1358 _w0 = vaddw_s8(_w0, _d4);
1359
1360 _w1 = vmull_s8(_d1, _4_n);
1361 _w1 = vmlal_s8(_w1, _d2, _4_n);
1362 _w1 = vaddw_s8(_w1, _d3);
1363 _w1 = vaddw_s8(_w1, _d4);
1364
1365 _w2 = vmull_s8(_d1, _4_p);
1366 _w2 = vmlal_s8(_w2, _d2, _4_n);
1367 _w2 = vmlal_s8(_w2, _d3, _1_n);
1368 _w2 = vaddw_s8(_w2, _d4);
1369
1370 _w3 = vmull_s8(_d1, _2_n);
1371 _w3 = vmlal_s8(_w3, _d2, _1_n);
1372 _w3 = vmlal_s8(_w3, _d3, _2_p);
1373 _w3 = vaddw_s8(_w3, _d4);
1374
1375 _w4 = vmull_s8(_d1, _2_p);
1376 _w4 = vmlal_s8(_w4, _d2, _1_n);
1377 _w4 = vmlal_s8(_w4, _d3, _2_n);
1378 _w4 = vaddw_s8(_w4, _d4);
1379
1380 _w5 = vmull_s8(_d1, _4_p);
1381 _w5 = vmlal_s8(_w5, _d3, _5_n);
1382 _w5 = vaddw_s8(_w5, _d5);
1383 // transpose d to d_t
1384 {
1385 _t0[0] = _w0[0];
1386 _t1[0] = _w0[1];
1387 _t2[0] = _w0[2];
1388 _t3[0] = _w0[3];
1389 _t4[0] = _w0[4];
1390 _t5[0] = _w0[5];
1391 _t0[1] = _w1[0];
1392 _t1[1] = _w1[1];
1393 _t2[1] = _w1[2];
1394 _t3[1] = _w1[3];
1395 _t4[1] = _w1[4];
1396 _t5[1] = _w1[5];
1397 _t0[2] = _w2[0];
1398 _t1[2] = _w2[1];
1399 _t2[2] = _w2[2];
1400 _t3[2] = _w2[3];
1401 _t4[2] = _w2[4];
1402 _t5[2] = _w2[5];
1403 _t0[3] = _w3[0];
1404 _t1[3] = _w3[1];
1405 _t2[3] = _w3[2];
1406 _t3[3] = _w3[3];
1407 _t4[3] = _w3[4];
1408 _t5[3] = _w3[5];
1409 _t0[4] = _w4[0];
1410 _t1[4] = _w4[1];
1411 _t2[4] = _w4[2];
1412 _t3[4] = _w4[3];
1413 _t4[4] = _w4[4];
1414 _t5[4] = _w4[5];
1415 _t0[5] = _w5[0];
1416 _t1[5] = _w5[1];
1417 _t2[5] = _w5[2];
1418 _t3[5] = _w5[3];
1419 _t4[5] = _w5[4];
1420 _t5[5] = _w5[5];
1421 }
1422 // d = B_t * d_t
1423 _n0 = vmulq_s16(_t0, _4_p_s16);
1424 _n0 = vmlaq_s16(_n0, _t2, _5_n_s16);
1425 _n0 = vaddq_s16(_n0, _t4);
1426
1427 _n1 = vmulq_s16(_t1, _4_n_s16);
1428 _n1 = vmlaq_s16(_n1, _t2, _4_n_s16);
1429 _n1 = vaddq_s16(_n1, _t3);
1430 _n1 = vaddq_s16(_n1, _t4);
1431
1432 _n2 = vmulq_s16(_t1, _4_p_s16);
1433 _n2 = vmlaq_s16(_n2, _t2, _4_n_s16);
1434 _n2 = vmlaq_s16(_n2, _t3, _1_n_s16);
1435 _n2 = vaddq_s16(_n2, _t4);
1436
1437 _n3 = vmulq_s16(_t1, _2_n_s16);
1438 _n3 = vmlaq_s16(_n3, _t2, _1_n_s16);
1439 _n3 = vmlaq_s16(_n3, _t3, _2_p_s16);
1440 _n3 = vaddq_s16(_n3, _t4);
1441
1442 _n4 = vmulq_s16(_t1, _2_p_s16);
1443 _n4 = vmlaq_s16(_n4, _t2, _1_n_s16);
1444 _n4 = vmlaq_s16(_n4, _t3, _2_n_s16);
1445 _n4 = vaddq_s16(_n4, _t4);
1446
1447 _n5 = vmulq_s16(_t1, _4_p_s16);
1448 _n5 = vmlaq_s16(_n5, _t3, _5_n_s16);
1449 _n5 = vaddq_s16(_n5, _t5);
1450 // save to out_tm
1451 out_tm0[0] = _n0[0];
1452 out_tm0[1] = _n0[1];
1453 out_tm0[2] = _n0[2];
1454 out_tm0[3] = _n0[3];
1455 out_tm1[0] = _n0[4];
1456 out_tm1[1] = _n0[5];
1457 out_tm1[2] = _n1[0];
1458 out_tm1[3] = _n1[1];
1459 out_tm2[0] = _n1[2];
1460 out_tm2[1] = _n1[3];
1461 out_tm2[2] = _n1[4];
1462 out_tm2[3] = _n1[5];
1463
1464 out_tm3[0] = _n2[0];
1465 out_tm3[1] = _n2[1];
1466 out_tm3[2] = _n2[2];
1467 out_tm3[3] = _n2[3];
1468 out_tm4[0] = _n2[4];
1469 out_tm4[1] = _n2[5];
1470 out_tm4[2] = _n3[0];
1471 out_tm4[3] = _n3[1];
1472 out_tm5[0] = _n3[2];
1473 out_tm5[1] = _n3[3];
1474 out_tm5[2] = _n3[4];
1475 out_tm5[3] = _n3[5];
1476
1477 out_tm6[0] = _n4[0];
1478 out_tm6[1] = _n4[1];
1479 out_tm6[2] = _n4[2];
1480 out_tm6[3] = _n4[3];
1481 out_tm7[0] = _n4[4];
1482 out_tm7[1] = _n4[5];
1483 out_tm7[2] = _n5[0];
1484 out_tm7[3] = _n5[1];
1485 out_tm8[0] = _n5[2];
1486 out_tm8[1] = _n5[3];
1487 out_tm8[2] = _n5[4];
1488 out_tm8[3] = _n5[5];
1489 #else
1490 short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6];
1491 short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6];
1492 short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6];
1493
1494 // load
1495 for (int n = 0; n < 6; n++)
1496 {
1497 d0[n] = r0[n];
1498 d1[n] = r1[n];
1499 d2[n] = r2[n];
1500 d3[n] = r3[n];
1501 d4[n] = r4[n];
1502 d5[n] = r5[n];
1503 }
1504 // w = B_t * d
1505 for (int n = 0; n < 6; n++)
1506 {
1507 w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n];
1508 w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n];
1509 w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n];
1510 w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n];
1511 w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n];
1512 w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n];
1513 }
1514 // transpose d to d_t
1515 {
1516 t0[0] = w0[0];
1517 t1[0] = w0[1];
1518 t2[0] = w0[2];
1519 t3[0] = w0[3];
1520 t4[0] = w0[4];
1521 t5[0] = w0[5];
1522 t0[1] = w1[0];
1523 t1[1] = w1[1];
1524 t2[1] = w1[2];
1525 t3[1] = w1[3];
1526 t4[1] = w1[4];
1527 t5[1] = w1[5];
1528 t0[2] = w2[0];
1529 t1[2] = w2[1];
1530 t2[2] = w2[2];
1531 t3[2] = w2[3];
1532 t4[2] = w2[4];
1533 t5[2] = w2[5];
1534 t0[3] = w3[0];
1535 t1[3] = w3[1];
1536 t2[3] = w3[2];
1537 t3[3] = w3[3];
1538 t4[3] = w3[4];
1539 t5[3] = w3[5];
1540 t0[4] = w4[0];
1541 t1[4] = w4[1];
1542 t2[4] = w4[2];
1543 t3[4] = w4[3];
1544 t4[4] = w4[4];
1545 t5[4] = w4[5];
1546 t0[5] = w5[0];
1547 t1[5] = w5[1];
1548 t2[5] = w5[2];
1549 t3[5] = w5[3];
1550 t4[5] = w5[4];
1551 t5[5] = w5[5];
1552 }
1553 // d = B_t * d_t
1554 for (int n = 0; n < 6; n++)
1555 {
1556 d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n];
1557 d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n];
1558 d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n];
1559 d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n];
1560 d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n];
1561 d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n];
1562 }
1563 // save to out_tm
1564 {
1565 out_tm0[0] = d0[0];
1566 out_tm0[1] = d0[1];
1567 out_tm0[2] = d0[2];
1568 out_tm0[3] = d0[3];
1569 out_tm1[0] = d0[4];
1570 out_tm1[1] = d0[5];
1571 out_tm1[2] = d1[0];
1572 out_tm1[3] = d1[1];
1573 out_tm2[0] = d1[2];
1574 out_tm2[1] = d1[3];
1575 out_tm2[2] = d1[4];
1576 out_tm2[3] = d1[5];
1577
1578 out_tm3[0] = d2[0];
1579 out_tm3[1] = d2[1];
1580 out_tm3[2] = d2[2];
1581 out_tm3[3] = d2[3];
1582 out_tm4[0] = d2[4];
1583 out_tm4[1] = d2[5];
1584 out_tm4[2] = d3[0];
1585 out_tm4[3] = d3[1];
1586 out_tm5[0] = d3[2];
1587 out_tm5[1] = d3[3];
1588 out_tm5[2] = d3[4];
1589 out_tm5[3] = d3[5];
1590
1591 out_tm6[0] = d4[0];
1592 out_tm6[1] = d4[1];
1593 out_tm6[2] = d4[2];
1594 out_tm6[3] = d4[3];
1595 out_tm7[0] = d4[4];
1596 out_tm7[1] = d4[5];
1597 out_tm7[2] = d5[0];
1598 out_tm7[3] = d5[1];
1599 out_tm8[0] = d5[2];
1600 out_tm8[1] = d5[3];
1601 out_tm8[2] = d5[4];
1602 out_tm8[3] = d5[5];
1603 }
1604 #endif // __ARM_NEON
1605 r0 += 4;
1606 r1 += 4;
1607 r2 += 4;
1608 r3 += 4;
1609 r4 += 4;
1610 r5 += 4;
1611 }
1612 }
1613 }
1614 }
1615 bottom_blob_bordered = Mat();
1616
1617 // BEGIN dot
1618 Mat top_blob_tm;
1619 {
1620 int w_tm = outw / 4 * 6;
1621 int h_tm = outh / 4 * 6;
1622
1623 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1624 int nRowBlocks = w_tm / 6;
1625
1626 const int tiles = nColBlocks * nRowBlocks;
1627
1628 top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator);
1629
1630 #pragma omp parallel for num_threads(opt.num_threads)
1631 for (int r = 0; r < 9; r++)
1632 {
1633 int nn_outch = 0;
1634 int remain_outch_start = 0;
1635
1636 nn_outch = outch >> 3;
1637 remain_outch_start = nn_outch << 3;
1638
1639 for (int pp = 0; pp < nn_outch; pp++)
1640 {
1641 int p = pp * 8;
1642
1643 int* output0_tm = top_blob_tm.channel(p);
1644 int* output1_tm = top_blob_tm.channel(p + 1);
1645 int* output2_tm = top_blob_tm.channel(p + 2);
1646 int* output3_tm = top_blob_tm.channel(p + 3);
1647 int* output4_tm = top_blob_tm.channel(p + 4);
1648 int* output5_tm = top_blob_tm.channel(p + 5);
1649 int* output6_tm = top_blob_tm.channel(p + 6);
1650 int* output7_tm = top_blob_tm.channel(p + 7);
1651
1652 output0_tm = output0_tm + r * 4;
1653 output1_tm = output1_tm + r * 4;
1654 output2_tm = output2_tm + r * 4;
1655 output3_tm = output3_tm + r * 4;
1656 output4_tm = output4_tm + r * 4;
1657 output5_tm = output5_tm + r * 4;
1658 output6_tm = output6_tm + r * 4;
1659 output7_tm = output7_tm + r * 4;
1660
1661 for (int i = 0; i < tiles; i++)
1662 {
1663 const short* kptr = kernel_tm_test[r].channel(p / 8);
1664 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
1665 #if __ARM_NEON
1666 #if __aarch64__
1667 asm volatile(
1668 // inch loop
1669 "eor v0.16b, v0.16b, v0.16b \n"
1670 "eor v1.16b, v1.16b, v1.16b \n"
1671 "eor v2.16b, v2.16b, v2.16b \n"
1672 "eor v3.16b, v3.16b, v3.16b \n"
1673 "eor v4.16b, v4.16b, v4.16b \n"
1674 "eor v5.16b, v5.16b, v5.16b \n"
1675 "eor v6.16b, v6.16b, v6.16b \n"
1676 "eor v7.16b, v7.16b, v7.16b \n"
1677 "mov w4, %w20 \n"
1678
1679 "0: \n" // for (int q=0; q<inch; q++)
1680 "prfm pldl1keep, [%9, #128] \n" // _r0 = vld1_s16(r0);
1681 "ld1 {v8.4h}, [%8] \n"
1682 "ld1 {v9.4h, v10.4h}, [%9] \n" // _k01 = vld1q_s16(kptr);
1683 "add %9, %9, #16 \n"
1684 "ld1 {v11.4h, v12.4h}, [%9] \n" // _k23 = vld1q_s16(kptr+8);
1685 "add %9, %9, #16 \n"
1686 "ld1 {v13.4h, v14.4h}, [%9] \n" // _k45 = vld1q_s16(kptr+16);
1687 "add %9, %9, #16 \n"
1688 "ld1 {v15.4h, v16.4h}, [%9] \n" // _k67 = vld1q_s16(kptr+24);
1689 "add %8, %8, #8 \n"
1690 "add %9, %9, #16 \n"
1691
1692 "subs w4, w4, #1 \n"
1693
1694 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
1695 "smlal v1.4s, v8.4h, v10.4h \n" // sum1 += (a00-a03) * (k10-k13)
1696 "smlal v2.4s, v8.4h, v11.4h \n" // sum2 += (a00-a03) * (k20-k23)
1697 "smlal v3.4s, v8.4h, v12.4h \n" // sum3 += (a00-a03) * (k30-k33)
1698 "smlal v4.4s, v8.4h, v13.4h \n" // sum4 += (a00-a03) * (k40-k43)
1699 "smlal v5.4s, v8.4h, v14.4h \n" // sum5 += (a00-a03) * (k50-k53)
1700 "smlal v6.4s, v8.4h, v15.4h \n" // sum6 += (a00-a03) * (k60-k63)
1701 "smlal v7.4s, v8.4h, v16.4h \n" // sum7 += (a00-a03) * (k70-k73)
1702
1703 "bne 0b \n" // end for
1704
1705 "st1 {v0.4s}, [%0] \n" // store the result to memory
1706 "st1 {v1.4s}, [%1] \n" //
1707 "st1 {v2.4s}, [%2] \n" //
1708 "st1 {v3.4s}, [%3] \n" //
1709 "st1 {v4.4s}, [%4] \n" //
1710 "st1 {v5.4s}, [%5] \n" //
1711 "st1 {v6.4s}, [%6] \n" //
1712 "st1 {v7.4s}, [%7] \n" //
1713
1714 : "=r"(output0_tm), // %0
1715 "=r"(output1_tm), // %1
1716 "=r"(output2_tm), // %2
1717 "=r"(output3_tm), // %3
1718 "=r"(output4_tm), // %4
1719 "=r"(output5_tm), // %5
1720 "=r"(output6_tm), // %6
1721 "=r"(output7_tm), // %7
1722 "=r"(r0), // %8
1723 "=r"(kptr) // %9
1724 : "0"(output0_tm),
1725 "1"(output1_tm),
1726 "2"(output2_tm),
1727 "3"(output3_tm),
1728 "4"(output4_tm),
1729 "5"(output5_tm),
1730 "6"(output6_tm),
1731 "7"(output7_tm),
1732 "8"(r0),
1733 "9"(kptr),
1734 "r"(inch) // %20
1735 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16");
1736 #else
1737 asm volatile(
1738 // inch loop
1739 "vmov.s32 q0, #0 \n"
1740 "vmov.s32 q1, #0 \n"
1741 "vmov.s32 q2, #0 \n"
1742 "vmov.s32 q3, #0 \n"
1743 "vmov.s32 q4, #0 \n"
1744 "vmov.s32 q5, #0 \n"
1745 "vmov.s32 q6, #0 \n"
1746 "vmov.s32 q7, #0 \n"
1747 "mov r4, %20 \n"
1748
1749 "0: \n" // for (int q=0; q<inch; q++)
1750 "vld1.s16 {d16}, [%8]! \n" // _r0 = vld1_s16(r0); // input inch0
1751 "vld1.s16 {d18-d19}, [%9] \n" // _k01 = vld1q_s16(kptr);
1752 "add %9, #16 \n"
1753 "vld1.s16 {d20-d21}, [%9] \n" // _k23 = vld1q_s16(kptr+8);
1754 "add %9, #16 \n"
1755 "vld1.s16 {d22-d23}, [%9] \n" // _k45 = vld1q_s16(kptr+16);
1756 "add %9, #16 \n"
1757 "vld1.s16 {d24-d25}, [%9] \n" // _k67 = vld1q_s16(kptr+24);
1758 "add %9, #16 \n"
1759
1760 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
1761 "vmlal.s16 q1, d16, d19 \n" // sum1 += (a00-a03) * (k10-k13)
1762 "vmlal.s16 q2, d16, d20 \n" // sum2 += (a00-a03) * (k20-k23)
1763 "vmlal.s16 q3, d16, d21 \n" // sum3 += (a00-a03) * (k30-k33)
1764 "vmlal.s16 q4, d16, d22 \n" // sum4 += (a00-a03) * (k40-k43)
1765 "vmlal.s16 q5, d16, d23 \n" // sum5 += (a00-a03) * (k50-k53)
1766 "vmlal.s16 q6, d16, d24 \n" // sum6 += (a00-a03) * (k60-k63)
1767 "vmlal.s16 q7, d16, d25 \n" // sum7 += (a00-a03) * (k70-k73)
1768
1769 "subs r4, r4, #1 \n"
1770 "bne 0b \n" // end for
1771
1772 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
1773 "vst1.s32 {d2-d3}, [%1] \n"
1774 "vst1.s32 {d4-d5}, [%2] \n"
1775 "vst1.s32 {d6-d7}, [%3] \n"
1776 "vst1.s32 {d8-d9}, [%4] \n"
1777 "vst1.s32 {d10-d11}, [%5] \n"
1778 "vst1.s32 {d12-d13}, [%6] \n"
1779 "vst1.s32 {d14-d15}, [%7] \n"
1780
1781 : "=r"(output0_tm), // %0
1782 "=r"(output1_tm), // %1
1783 "=r"(output2_tm), // %2
1784 "=r"(output3_tm), // %3
1785 "=r"(output4_tm), // %4
1786 "=r"(output5_tm), // %5
1787 "=r"(output6_tm), // %6
1788 "=r"(output7_tm), // %7
1789 "=r"(r0), // %8
1790 "=r"(kptr) // %9
1791 : "0"(output0_tm),
1792 "1"(output1_tm),
1793 "2"(output2_tm),
1794 "3"(output3_tm),
1795 "4"(output4_tm),
1796 "5"(output5_tm),
1797 "6"(output6_tm),
1798 "7"(output7_tm),
1799 "8"(r0),
1800 "9"(kptr),
1801 "r"(inch) // %20
1802 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
1803 #endif // __aarch64__
1804 #else
1805 int sum0[4] = {0};
1806 int sum1[4] = {0};
1807 int sum2[4] = {0};
1808 int sum3[4] = {0};
1809 int sum4[4] = {0};
1810 int sum5[4] = {0};
1811 int sum6[4] = {0};
1812 int sum7[4] = {0};
1813
1814 for (int q = 0; q < inch; q++)
1815 {
1816 for (int n = 0; n < 4; n++)
1817 {
1818 sum0[n] += (int)r0[n] * kptr[n];
1819 sum1[n] += (int)r0[n] * kptr[n + 4];
1820 sum2[n] += (int)r0[n] * kptr[n + 8];
1821 sum3[n] += (int)r0[n] * kptr[n + 12];
1822 sum4[n] += (int)r0[n] * kptr[n + 16];
1823 sum5[n] += (int)r0[n] * kptr[n + 20];
1824 sum6[n] += (int)r0[n] * kptr[n + 24];
1825 sum7[n] += (int)r0[n] * kptr[n + 28];
1826 }
1827 kptr += 32;
1828 r0 += 4;
1829 }
1830
1831 for (int n = 0; n < 4; n++)
1832 {
1833 output0_tm[n] = sum0[n];
1834 output1_tm[n] = sum1[n];
1835 output2_tm[n] = sum2[n];
1836 output3_tm[n] = sum3[n];
1837 output4_tm[n] = sum4[n];
1838 output5_tm[n] = sum5[n];
1839 output6_tm[n] = sum6[n];
1840 output7_tm[n] = sum7[n];
1841 }
1842 #endif // __ARM_NEON
1843 output0_tm += 36;
1844 output1_tm += 36;
1845 output2_tm += 36;
1846 output3_tm += 36;
1847 output4_tm += 36;
1848 output5_tm += 36;
1849 output6_tm += 36;
1850 output7_tm += 36;
1851 }
1852 }
1853
1854 nn_outch = (outch - remain_outch_start) >> 2;
1855
1856 for (int pp = 0; pp < nn_outch; pp++)
1857 {
1858 int p = remain_outch_start + pp * 4;
1859
1860 int* output0_tm = top_blob_tm.channel(p);
1861 int* output1_tm = top_blob_tm.channel(p + 1);
1862 int* output2_tm = top_blob_tm.channel(p + 2);
1863 int* output3_tm = top_blob_tm.channel(p + 3);
1864
1865 output0_tm = output0_tm + r * 4;
1866 output1_tm = output1_tm + r * 4;
1867 output2_tm = output2_tm + r * 4;
1868 output3_tm = output3_tm + r * 4;
1869
1870 for (int i = 0; i < tiles; i++)
1871 {
1872 const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
1873 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
1874 #if __ARM_NEON
1875 #if __aarch64__
1876 asm volatile(
1877 // inch loop
1878 "eor v0.16b, v0.16b, v0.16b \n"
1879 "eor v1.16b, v1.16b, v1.16b \n"
1880 "eor v2.16b, v2.16b, v2.16b \n"
1881 "eor v3.16b, v3.16b, v3.16b \n"
1882 "mov w4, %w12 \n"
1883
1884 "0: \n" // for (int q=0; q<inch; q++)
1885 "prfm pldl1keep, [%5, #128] \n" // _r0 = vld1_s16(r0); // input inch0
1886 "ld1 {v8.4h}, [%4] \n"
1887 "ld1 {v9.4h, v10.4h}, [%5] \n" // _k01 = vld1q_s16(kptr);
1888 "add %5, %5, #16 \n"
1889 "ld1 {v11.4h, v12.4h}, [%5] \n" // _k23 = vld1q_s16(kptr+8);
1890 "add %4, %4, #8 \n"
1891 "add %5, %5, #16 \n"
1892
1893 "subs w4, w4, #1 \n"
1894
1895 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
1896 "smlal v1.4s, v8.4h, v10.4h \n" // sum1 += (a00-a03) * (k10-k13)
1897 "smlal v2.4s, v8.4h, v11.4h \n" // sum2 += (a00-a03) * (k20-k23)
1898 "smlal v3.4s, v8.4h, v12.4h \n" // sum3 += (a00-a03) * (k30-k33)
1899
1900 "bne 0b \n" // end for
1901
1902 "st1 {v0.4s}, [%0] \n" // store the result to memory
1903 "st1 {v1.4s}, [%1] \n" //
1904 "st1 {v2.4s}, [%2] \n" //
1905 "st1 {v3.4s}, [%3] \n" //
1906
1907 : "=r"(output0_tm), // %0
1908 "=r"(output1_tm), // %1
1909 "=r"(output2_tm), // %2
1910 "=r"(output3_tm), // %3
1911 "=r"(r0), // %4
1912 "=r"(kptr) // %5
1913 : "0"(output0_tm),
1914 "1"(output1_tm),
1915 "2"(output2_tm),
1916 "3"(output3_tm),
1917 "4"(r0),
1918 "5"(kptr),
1919 "r"(inch) // %12
1920 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
1921 #else
1922 asm volatile(
1923 // inch loop
1924 "vmov.s32 q0, #0 \n"
1925 "vmov.s32 q1, #0 \n"
1926 "vmov.s32 q2, #0 \n"
1927 "vmov.s32 q3, #0 \n"
1928 "mov r4, %12 \n"
1929
1930 "0: \n" // for (int q=0; q<inch; q++)
1931 "vld1.s16 {d16}, [%4]! \n" // _r0 = vld1_s16(r0); // input inch0
1932 "vld1.s16 {d18-d19}, [%5] \n" // _k01 = vld1q_s16(kptr);
1933 "add %5, #16 \n"
1934 "vld1.s16 {d20-d21}, [%5] \n" // _k23 = vld1q_s16(kptr+8);
1935 "add %5, #16 \n"
1936
1937 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
1938 "vmlal.s16 q1, d16, d19 \n" // sum1 += (a00-a03) * (k10-k13)
1939 "vmlal.s16 q2, d16, d20 \n" // sum2 += (a00-a03) * (k20-k23)
1940 "vmlal.s16 q3, d16, d21 \n" // sum3 += (a00-a03) * (k30-k33)
1941
1942 "subs r4, r4, #1 \n"
1943 "bne 0b \n" // end for
1944
1945 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
1946 "vst1.s32 {d2-d3}, [%1] \n"
1947 "vst1.s32 {d4-d5}, [%2] \n"
1948 "vst1.s32 {d6-d7}, [%3] \n"
1949
1950 : "=r"(output0_tm), // %0
1951 "=r"(output1_tm), // %1
1952 "=r"(output2_tm), // %2
1953 "=r"(output3_tm), // %3
1954 "=r"(r0), // %4
1955 "=r"(kptr) // %5
1956 : "0"(output0_tm),
1957 "1"(output1_tm),
1958 "2"(output2_tm),
1959 "3"(output3_tm),
1960 "4"(r0),
1961 "5"(kptr),
1962 "r"(inch) // %12
1963 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10");
1964 #endif // __aarch64__
1965 #else
1966 int sum0[4] = {0};
1967 int sum1[4] = {0};
1968 int sum2[4] = {0};
1969 int sum3[4] = {0};
1970
1971 for (int q = 0; q < inch; q++)
1972 {
1973 for (int n = 0; n < 4; n++)
1974 {
1975 sum0[n] += (int)r0[n] * kptr[n];
1976 sum1[n] += (int)r0[n] * kptr[n + 4];
1977 sum2[n] += (int)r0[n] * kptr[n + 8];
1978 sum3[n] += (int)r0[n] * kptr[n + 12];
1979 }
1980 kptr += 16;
1981 r0 += 4;
1982 }
1983
1984 for (int n = 0; n < 4; n++)
1985 {
1986 output0_tm[n] = sum0[n];
1987 output1_tm[n] = sum1[n];
1988 output2_tm[n] = sum2[n];
1989 output3_tm[n] = sum3[n];
1990 }
1991 #endif // __ARM_NEON
1992 output0_tm += 36;
1993 output1_tm += 36;
1994 output2_tm += 36;
1995 output3_tm += 36;
1996 }
1997 }
1998
1999 remain_outch_start += nn_outch << 2;
2000
2001 for (int p = remain_outch_start; p < outch; p++)
2002 {
2003 int* output0_tm = top_blob_tm.channel(p);
2004
2005 output0_tm = output0_tm + r * 4;
2006
2007 for (int i = 0; i < tiles; i++)
2008 {
2009 const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
2010 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
2011 #if __ARM_NEON
2012 #if __aarch64__
2013 asm volatile(
2014 // inch loop
2015 "eor v0.16b, v0.16b, v0.16b \n"
2016 "mov w4, %w6 \n"
2017
2018 "0: \n" // for (int q=0; q<inch; q++)
2019 "ld1 {v8.4h}, [%1] \n" // _r0 = vld1_s16(r0); // input inch0
2020 "ld1 {v9.4h}, [%2] \n" // _k0 = vld1q_s16(kptr);
2021 "add %1, %1, #8 \n"
2022 "add %2, %2, #8 \n"
2023
2024 "subs w4, w4, #1 \n"
2025
2026 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
2027
2028 "bne 0b \n" // end for
2029
2030 "st1 {v0.4s}, [%0] \n" // store the result to memory
2031
2032 : "=r"(output0_tm), // %0
2033 "=r"(r0), // %1
2034 "=r"(kptr) // %2
2035 : "0"(output0_tm),
2036 "1"(r0),
2037 "2"(kptr),
2038 "r"(inch) // %6
2039 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9");
2040 #else
2041 asm volatile(
2042 // inch loop
2043 "vmov.s32 q0, #0 \n"
2044 "mov r4, %6 \n"
2045
2046 "0: \n" // for (int q=0; q<inch; q++)
2047 "vld1.s16 {d16}, [%1] \n" // _r0 = vld1_s16(r0); // input inch0
2048 "add %1, #8 \n"
2049 "vld1.s16 {d18}, [%2] \n" // _k0 = vld1q_s16(kptr);
2050 "add %2, #8 \n"
2051 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
2052
2053 "subs r4, r4, #1 \n"
2054 "bne 0b \n" // end for
2055
2056 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
2057
2058 : "=r"(output0_tm), // %0
2059 "=r"(r0), // %1
2060 "=r"(kptr) // %2
2061 : "0"(output0_tm),
2062 "1"(r0),
2063 "2"(kptr),
2064 "r"(inch) // %6
2065 : "cc", "memory", "r4", "q0", "q8", "q9");
2066 #endif // __aarch64__
2067 #else // __ARM_NEON
2068 int sum0[4] = {0};
2069
2070 for (int q = 0; q < inch; q++)
2071 {
2072 for (int n = 0; n < 4; n++)
2073 {
2074 sum0[n] += (int)r0[n] * kptr[n];
2075 }
2076 kptr += 4;
2077 r0 += 4;
2078 }
2079
2080 for (int n = 0; n < 4; n++)
2081 {
2082 output0_tm[n] = sum0[n];
2083 }
2084 #endif // __ARM_NEON
2085 output0_tm += 36;
2086 }
2087 }
2088
2089 // for (int p=0; p<outch; p++)
2090 // {
2091 // Mat out0_tm = top_blob_tm.channel(p);
2092 // const Mat kernel0_tm = kernel_tm.channel(p);
2093
2094 // for (int i=0; i<tiles; i++)
2095 // {
2096 // int* output0_tm = out0_tm.row<int>(i);
2097
2098 // int sum0[36] = {0};
2099
2100 // for (int q=0; q<inch; q++)
2101 // {
2102 // const short* r0 = bottom_blob_tm.channel(q).row<short>(i);
2103 // const short* k0 = kernel0_tm.row<short>(q);
2104
2105 // for (int n=0; n<36; n++)
2106 // {
2107 // sum0[n] += (int)r0[n] * k0[n];
2108 // }
2109 // }
2110
2111 // for (int n=0; n<36; n++)
2112 // {
2113 // output0_tm[n] = sum0[n];
2114 // }
2115 // }
2116 // }
2117 }
2118 }
2119 bottom_blob_tm = Mat();
2120 // END dot
2121
2122 // BEGIN transform output
2123 Mat top_blob_bordered;
2124 top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
2125 {
2126 // AT
2127 // const float itm[4][6] = {
2128 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
2129 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
2130 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
2131 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
2132 // };
2133
2134 // 0 = r00 + r01 + r02 + r03 + r04
2135 // 1 = r01 - r02 + 2 * (r03 - r04)
2136 // 2 = r01 + r02 + 4 * (r03 + r04)
2137 // 3 = r01 - r02 + 8 * (r03 - r04) + r05
2138
2139 int w_tm = outw / 4 * 6;
2140 int h_tm = outh / 4 * 6;
2141
2142 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
2143 int nRowBlocks = w_tm / 6;
2144
2145 #pragma omp parallel for num_threads(opt.num_threads)
2146 for (int p = 0; p < outch; p++)
2147 {
2148 int* out_tile = top_blob_tm.channel(p);
2149 int* outRow0 = top_blob_bordered.channel(p);
2150 int* outRow1 = outRow0 + outw;
2151 int* outRow2 = outRow0 + outw * 2;
2152 int* outRow3 = outRow0 + outw * 3;
2153
2154 for (int j = 0; j < nColBlocks; j++)
2155 {
2156 for (int i = 0; i < nRowBlocks; i++)
2157 {
2158 #if __ARM_NEON
2159 int32x4_t _s0, _s1, _s2, _s3, _s4, _s5;
2160 int32x2_t _s0n, _s1n, _s2n, _s3n, _s4n, _s5n;
2161 int32x4_t _w0, _w3;
2162 int32x2_t _w0n, _w3n;
2163 int32x4_t _d0, _d1, _d2, _d3, _d4, _d5;
2164 int32x4_t _o0, _o1, _o2, _o3;
2165 // load
2166 _s0 = vld1q_s32(out_tile);
2167 _s0n = vld1_s32(out_tile + 4);
2168 _s1 = vld1q_s32(out_tile + 6);
2169 _s1n = vld1_s32(out_tile + 10);
2170 _s2 = vld1q_s32(out_tile + 12);
2171 _s2n = vld1_s32(out_tile + 16);
2172 _s3 = vld1q_s32(out_tile + 18);
2173 _s3n = vld1_s32(out_tile + 22);
2174 _s4 = vld1q_s32(out_tile + 24);
2175 _s4n = vld1_s32(out_tile + 28);
2176 _s5 = vld1q_s32(out_tile + 30);
2177 _s5n = vld1_s32(out_tile + 34);
2178 // w = A_T * W
2179 int32x2_t _tp0 = {1, 4};
2180 int32x2_t _tp1 = {2, 8};
2181
2182 // 4*s5[n]
2183 int32x4_t _s5x4 = vshlq_n_s32(_s5, 2);
2184 int32x2_t _s5x4n = vshl_n_s32(_s5n, 2);
2185
2186 int32x4_t _t1p2 = vaddq_s32(_s1, _s2);
2187 int32x2_t _t1p2n = vadd_s32(_s1n, _s2n);
2188 int32x4_t _t3p4 = vaddq_s32(_s3, _s4);
2189 int32x2_t _t3p4n = vadd_s32(_s3n, _s4n);
2190 int32x4_t _t1s2 = vsubq_s32(_s1, _s2);
2191 int32x2_t _t1s2n = vsub_s32(_s1n, _s2n);
2192 int32x4_t _t3s4 = vsubq_s32(_s3, _s4);
2193 int32x2_t _t3s4n = vsub_s32(_s3n, _s4n);
2194
2195 _w0 = vaddq_s32(_s0, _t1p2);
2196 _w0n = vadd_s32(_s0n, _t1p2n);
2197 _w0 = vaddq_s32(_w0, _t3p4);
2198 _w0n = vadd_s32(_w0n, _t3p4n);
2199 _w0n = vmul_s32(_w0n, _tp0);
2200
2201 // _w2,_w2n
2202 _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
2203 _t1p2n = vmla_lane_s32(_t1p2n, _t3p4n, _tp0, 1);
2204 _t1p2n = vmul_s32(_t1p2n, _tp0);
2205
2206 _w3 = vaddq_s32(_s5x4, _t1s2);
2207 _w3n = vadd_s32(_s5x4n, _t1s2n);
2208 _w3 = vmlaq_lane_s32(_w3, _t3s4, _tp1, 1);
2209 _w3n = vmla_lane_s32(_w3n, _t3s4n, _tp1, 1);
2210 _w3n = vmul_s32(_w3n, _tp0);
2211
2212 // _w1, _w1n
2213 _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
2214 _t1s2n = vmla_lane_s32(_t1s2n, _t3s4n, _tp1, 0);
2215 _t1s2n = vmul_s32(_t1s2n, _tp0);
2216
2217 int32x4_t _w02n = vcombine_s32(_w0n, _t1p2n);
2218 int32x4_t _w13n = vcombine_s32(_t1s2n, _w3n);
2219
2220 // transpose w to w_t
2221 #if __aarch64__
2222 int32x4_t _wt0 = vtrn1q_s32(_w0, _t1s2);
2223 int32x4_t _wt1 = vtrn2q_s32(_w0, _t1s2);
2224 int32x4_t _wt2 = vtrn1q_s32(_t1p2, _w3);
2225 int32x4_t _wt3 = vtrn2q_s32(_t1p2, _w3);
2226 int64x2_t _dt0 = vtrn1q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
2227 int64x2_t _dt2 = vtrn2q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
2228 int64x2_t _dt1 = vtrn1q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
2229 int64x2_t _dt3 = vtrn2q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
2230 _d0 = vreinterpretq_s32_s64(_dt0);
2231 _d1 = vreinterpretq_s32_s64(_dt1);
2232 _d2 = vreinterpretq_s32_s64(_dt2);
2233 _d3 = vreinterpretq_s32_s64(_dt3);
2234 _d4 = vtrn1q_s32(_w02n, _w13n);
2235 _d5 = vtrn2q_s32(_w02n, _w13n);
2236 #else
2237 asm volatile(
2238 "vtrn.32 %q[_w0], %q[_w1] \n"
2239 "vtrn.32 %q[_w2], %q[_w3] \n"
2240 "vswp %f[_w0], %e[_w2] \n"
2241 "vswp %f[_w1], %e[_w3] \n"
2242 "vtrn.32 %q[_w02n], %q[_w13n] \n"
2243 : [_w0] "+w"(_w0),
2244 [_w1] "+w"(_t1s2),
2245 [_w2] "+w"(_t1p2),
2246 [_w3] "+w"(_w3),
2247 [_w02n] "+w"(_w02n),
2248 [_w13n] "+w"(_w13n)
2249 :
2250 : "cc", "memory");
2251 _d0 = _w0;
2252 _d1 = _t1s2;
2253 _d2 = _t1p2;
2254 _d3 = _w3;
2255 _d4 = _w02n;
2256 _d5 = _w13n;
2257 #endif
2258 // Y = A_T * w_t
2259 _t1p2 = vaddq_s32(_d1, _d2);
2260 _t3p4 = vaddq_s32(_d3, _d4);
2261 _t1s2 = vsubq_s32(_d1, _d2);
2262 _t3s4 = vsubq_s32(_d3, _d4);
2263
2264 _o0 = vaddq_s32(_d0, _t1p2);
2265 _o0 = vaddq_s32(_o0, _t3p4);
2266
2267 // _o2
2268 _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
2269
2270 _o3 = vaddq_s32(_d5, _t1s2);
2271 _o3 = vmlaq_lane_s32(_o3, _t3s4, _tp1, 1);
2272
2273 // _o1
2274 _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
2275
2276 // save to top blob tm
2277 float32x4_t _ot0 = vcvtq_f32_s32(_o0);
2278 float32x4_t _ot1 = vcvtq_f32_s32(_t1s2);
2279 float32x4_t _ot2 = vcvtq_f32_s32(_t1p2);
2280 float32x4_t _ot3 = vcvtq_f32_s32(_o3);
2281
2282 _ot0 = vmulq_n_f32(_ot0, 0.0017361112);
2283 _ot1 = vmulq_n_f32(_ot1, 0.0017361112);
2284 _ot2 = vmulq_n_f32(_ot2, 0.0017361112);
2285 _ot3 = vmulq_n_f32(_ot3, 0.0017361112);
2286
2287 _o0 = vcvtq_s32_f32(_ot0);
2288 _o1 = vcvtq_s32_f32(_ot1);
2289 _o2 = vcvtq_s32_f32(_ot2);
2290 _o3 = vcvtq_s32_f32(_ot3);
2291
2292 vst1q_s32(outRow0, _o0);
2293 vst1q_s32(outRow1, _o1);
2294 vst1q_s32(outRow2, _o2);
2295 vst1q_s32(outRow3, _o3);
2296 #else
2297 int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6];
2298 int w0[6], w1[6], w2[6], w3[6];
2299 int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4];
2300 int o0[4], o1[4], o2[4], o3[4];
2301
2302 // load
2303 for (int n = 0; n < 6; n++)
2304 {
2305 s0[n] = out_tile[n];
2306 s1[n] = out_tile[n + 6];
2307 s2[n] = out_tile[n + 12];
2308 s3[n] = out_tile[n + 18];
2309 s4[n] = out_tile[n + 24];
2310 s5[n] = out_tile[n + 30];
2311 }
2312 // w = A_T * W
2313 for (int n = 0; n < 5; n++)
2314 {
2315 w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n];
2316 w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n];
2317 w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n];
2318 w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n];
2319 }
2320 for (int n = 5; n < 6; n++)
2321 {
2322 w0[n] = 4 * (s0[n] + s1[n] + s2[n] + s3[n] + s4[n]);
2323 w1[n] = 4 * (s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]);
2324 w2[n] = 4 * (s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]);
2325 w3[n] = 4 * (s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n]);
2326 }
2327 // transpose w to w_t
2328 {
2329 d0[0] = w0[0];
2330 d0[1] = w1[0];
2331 d0[2] = w2[0];
2332 d0[3] = w3[0];
2333 d1[0] = w0[1];
2334 d1[1] = w1[1];
2335 d1[2] = w2[1];
2336 d1[3] = w3[1];
2337 d2[0] = w0[2];
2338 d2[1] = w1[2];
2339 d2[2] = w2[2];
2340 d2[3] = w3[2];
2341 d3[0] = w0[3];
2342 d3[1] = w1[3];
2343 d3[2] = w2[3];
2344 d3[3] = w3[3];
2345 d4[0] = w0[4];
2346 d4[1] = w1[4];
2347 d4[2] = w2[4];
2348 d4[3] = w3[4];
2349 d5[0] = w0[5];
2350 d5[1] = w1[5];
2351 d5[2] = w2[5];
2352 d5[3] = w3[5];
2353 }
2354 // Y = A_T * w_t
2355 for (int n = 0; n < 4; n++)
2356 {
2357 o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n];
2358 o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n];
2359 o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n];
2360 o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n];
2361 }
2362 // save to top blob tm
2363 for (int n = 0; n < 4; n++)
2364 {
2365 outRow0[n] = o0[n] / 576;
2366 outRow1[n] = o1[n] / 576;
2367 outRow2[n] = o2[n] / 576;
2368 outRow3[n] = o3[n] / 576;
2369 }
2370 #endif // __ARM_NEON
2371 out_tile += 36;
2372
2373 outRow0 += 4;
2374 outRow1 += 4;
2375 outRow2 += 4;
2376 outRow3 += 4;
2377 }
2378
2379 outRow0 += outw * 3;
2380 outRow1 += outw * 3;
2381 outRow2 += outw * 3;
2382 outRow3 += outw * 3;
2383 }
2384 }
2385 }
2386 // END transform output
2387
2388 // cut result pad
2389 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
2390 }
2391
conv3x3s1_winograd43_dequant_int8_neon(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Mat & _bias,std::vector<float> scales_dequant,const Option & opt)2392 static void conv3x3s1_winograd43_dequant_int8_neon(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Mat& _bias, std::vector<float> scales_dequant, const Option& opt)
2393 {
2394 int w = bottom_blob.w;
2395 int h = bottom_blob.h;
2396 int inch = bottom_blob.c;
2397
2398 int outw = top_blob.w;
2399 int outh = top_blob.h;
2400 int outch = top_blob.c;
2401
2402 const float* bias = _bias;
2403
2404 // pad to 4n+2, winograd F(4,3)
2405 Mat bottom_blob_bordered = bottom_blob;
2406
2407 outw = (outw + 3) / 4 * 4;
2408 outh = (outh + 3) / 4 * 4;
2409
2410 w = outw + 2;
2411 h = outh + 2;
2412 Option opt_b = opt;
2413 opt_b.blob_allocator = opt.workspace_allocator;
2414 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
2415
2416 // BEGIN transform input
2417 Mat bottom_blob_tm;
2418 {
2419 int w_tm = outw / 4 * 6;
2420 int h_tm = outh / 4 * 6;
2421
2422 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
2423 int nRowBlocks = w_tm / 6;
2424
2425 const int tiles = nColBlocks * nRowBlocks;
2426
2427 bottom_blob_tm.create(4, inch, tiles * 9, 2u, opt.workspace_allocator);
2428
2429 // BT
2430 // const float itm[4][4] = {
2431 // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
2432 // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
2433 // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
2434 // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
2435 // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
2436 // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f}
2437 // };
2438
2439 // 0 = 4 * r00 - 5 * r02 + r04
2440 // 1 = -4 * (r01 + r02) + r03 + r04
2441 // 2 = 4 * (r01 - r02) - r03 + r04
2442 // 3 = -2 * r01 - r02 + 2 * r03 + r04
2443 // 4 = 2 * r01 - r02 - 2 * r03 + r04
2444 // 5 = 4 * r01 - 5 * r03 + r05
2445
2446 #pragma omp parallel for num_threads(opt.num_threads)
2447 for (int q = 0; q < inch; q++)
2448 {
2449 const signed char* img = bottom_blob_bordered.channel(q);
2450
2451 for (int j = 0; j < nColBlocks; j++)
2452 {
2453 const signed char* r0 = img + w * j * 4;
2454 const signed char* r1 = r0 + w;
2455 const signed char* r2 = r1 + w;
2456 const signed char* r3 = r2 + w;
2457 const signed char* r4 = r3 + w;
2458 const signed char* r5 = r4 + w;
2459
2460 for (int i = 0; i < nRowBlocks; i++)
2461 {
2462 short* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row<short>(q);
2463 short* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row<short>(q);
2464 short* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row<short>(q);
2465 short* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row<short>(q);
2466 short* out_tm4 = bottom_blob_tm.channel(tiles * 4 + j * nRowBlocks + i).row<short>(q);
2467 short* out_tm5 = bottom_blob_tm.channel(tiles * 5 + j * nRowBlocks + i).row<short>(q);
2468 short* out_tm6 = bottom_blob_tm.channel(tiles * 6 + j * nRowBlocks + i).row<short>(q);
2469 short* out_tm7 = bottom_blob_tm.channel(tiles * 7 + j * nRowBlocks + i).row<short>(q);
2470 short* out_tm8 = bottom_blob_tm.channel(tiles * 8 + j * nRowBlocks + i).row<short>(q);
2471 #if __ARM_NEON
2472 int8x8_t _d0, _d1, _d2, _d3, _d4, _d5;
2473 int16x8_t _w0, _w1, _w2, _w3, _w4, _w5;
2474 int16x8_t _t0, _t1, _t2, _t3, _t4, _t5;
2475 int16x8_t _n0, _n1, _n2, _n3, _n4, _n5;
2476 // load
2477 _d0 = vld1_s8(r0);
2478 _d1 = vld1_s8(r1);
2479 _d2 = vld1_s8(r2);
2480 _d3 = vld1_s8(r3);
2481 _d4 = vld1_s8(r4);
2482 _d5 = vld1_s8(r5);
2483
2484 int8x8_t _1_n = vdup_n_s8(-1);
2485 int8x8_t _2_p = vdup_n_s8(2);
2486 int8x8_t _2_n = vdup_n_s8(-2);
2487 int8x8_t _4_p = vdup_n_s8(4);
2488 int8x8_t _4_n = vdup_n_s8(-4);
2489 int8x8_t _5_n = vdup_n_s8(-5);
2490
2491 int16x8_t _1_n_s16 = vdupq_n_s16(-1);
2492 int16x8_t _2_p_s16 = vdupq_n_s16(2);
2493 int16x8_t _2_n_s16 = vdupq_n_s16(-2);
2494 int16x8_t _4_p_s16 = vdupq_n_s16(4);
2495 int16x8_t _4_n_s16 = vdupq_n_s16(-4);
2496 int16x8_t _5_n_s16 = vdupq_n_s16(-5);
2497 // w = B_t * d
2498 _w0 = vmull_s8(_d0, _4_p);
2499 _w0 = vmlal_s8(_w0, _d2, _5_n);
2500 _w0 = vaddw_s8(_w0, _d4);
2501
2502 _w1 = vmull_s8(_d1, _4_n);
2503 _w1 = vmlal_s8(_w1, _d2, _4_n);
2504 _w1 = vaddw_s8(_w1, _d3);
2505 _w1 = vaddw_s8(_w1, _d4);
2506
2507 _w2 = vmull_s8(_d1, _4_p);
2508 _w2 = vmlal_s8(_w2, _d2, _4_n);
2509 _w2 = vmlal_s8(_w2, _d3, _1_n);
2510 _w2 = vaddw_s8(_w2, _d4);
2511
2512 _w3 = vmull_s8(_d1, _2_n);
2513 _w3 = vmlal_s8(_w3, _d2, _1_n);
2514 _w3 = vmlal_s8(_w3, _d3, _2_p);
2515 _w3 = vaddw_s8(_w3, _d4);
2516
2517 _w4 = vmull_s8(_d1, _2_p);
2518 _w4 = vmlal_s8(_w4, _d2, _1_n);
2519 _w4 = vmlal_s8(_w4, _d3, _2_n);
2520 _w4 = vaddw_s8(_w4, _d4);
2521
2522 _w5 = vmull_s8(_d1, _4_p);
2523 _w5 = vmlal_s8(_w5, _d3, _5_n);
2524 _w5 = vaddw_s8(_w5, _d5);
2525 // transpose d to d_t
2526 {
2527 _t0[0] = _w0[0];
2528 _t1[0] = _w0[1];
2529 _t2[0] = _w0[2];
2530 _t3[0] = _w0[3];
2531 _t4[0] = _w0[4];
2532 _t5[0] = _w0[5];
2533 _t0[1] = _w1[0];
2534 _t1[1] = _w1[1];
2535 _t2[1] = _w1[2];
2536 _t3[1] = _w1[3];
2537 _t4[1] = _w1[4];
2538 _t5[1] = _w1[5];
2539 _t0[2] = _w2[0];
2540 _t1[2] = _w2[1];
2541 _t2[2] = _w2[2];
2542 _t3[2] = _w2[3];
2543 _t4[2] = _w2[4];
2544 _t5[2] = _w2[5];
2545 _t0[3] = _w3[0];
2546 _t1[3] = _w3[1];
2547 _t2[3] = _w3[2];
2548 _t3[3] = _w3[3];
2549 _t4[3] = _w3[4];
2550 _t5[3] = _w3[5];
2551 _t0[4] = _w4[0];
2552 _t1[4] = _w4[1];
2553 _t2[4] = _w4[2];
2554 _t3[4] = _w4[3];
2555 _t4[4] = _w4[4];
2556 _t5[4] = _w4[5];
2557 _t0[5] = _w5[0];
2558 _t1[5] = _w5[1];
2559 _t2[5] = _w5[2];
2560 _t3[5] = _w5[3];
2561 _t4[5] = _w5[4];
2562 _t5[5] = _w5[5];
2563 }
2564 // d = B_t * d_t
2565 _n0 = vmulq_s16(_t0, _4_p_s16);
2566 _n0 = vmlaq_s16(_n0, _t2, _5_n_s16);
2567 _n0 = vaddq_s16(_n0, _t4);
2568
2569 _n1 = vmulq_s16(_t1, _4_n_s16);
2570 _n1 = vmlaq_s16(_n1, _t2, _4_n_s16);
2571 _n1 = vaddq_s16(_n1, _t3);
2572 _n1 = vaddq_s16(_n1, _t4);
2573
2574 _n2 = vmulq_s16(_t1, _4_p_s16);
2575 _n2 = vmlaq_s16(_n2, _t2, _4_n_s16);
2576 _n2 = vmlaq_s16(_n2, _t3, _1_n_s16);
2577 _n2 = vaddq_s16(_n2, _t4);
2578
2579 _n3 = vmulq_s16(_t1, _2_n_s16);
2580 _n3 = vmlaq_s16(_n3, _t2, _1_n_s16);
2581 _n3 = vmlaq_s16(_n3, _t3, _2_p_s16);
2582 _n3 = vaddq_s16(_n3, _t4);
2583
2584 _n4 = vmulq_s16(_t1, _2_p_s16);
2585 _n4 = vmlaq_s16(_n4, _t2, _1_n_s16);
2586 _n4 = vmlaq_s16(_n4, _t3, _2_n_s16);
2587 _n4 = vaddq_s16(_n4, _t4);
2588
2589 _n5 = vmulq_s16(_t1, _4_p_s16);
2590 _n5 = vmlaq_s16(_n5, _t3, _5_n_s16);
2591 _n5 = vaddq_s16(_n5, _t5);
2592 // save to out_tm
2593 out_tm0[0] = _n0[0];
2594 out_tm0[1] = _n0[1];
2595 out_tm0[2] = _n0[2];
2596 out_tm0[3] = _n0[3];
2597 out_tm1[0] = _n0[4];
2598 out_tm1[1] = _n0[5];
2599 out_tm1[2] = _n1[0];
2600 out_tm1[3] = _n1[1];
2601 out_tm2[0] = _n1[2];
2602 out_tm2[1] = _n1[3];
2603 out_tm2[2] = _n1[4];
2604 out_tm2[3] = _n1[5];
2605
2606 out_tm3[0] = _n2[0];
2607 out_tm3[1] = _n2[1];
2608 out_tm3[2] = _n2[2];
2609 out_tm3[3] = _n2[3];
2610 out_tm4[0] = _n2[4];
2611 out_tm4[1] = _n2[5];
2612 out_tm4[2] = _n3[0];
2613 out_tm4[3] = _n3[1];
2614 out_tm5[0] = _n3[2];
2615 out_tm5[1] = _n3[3];
2616 out_tm5[2] = _n3[4];
2617 out_tm5[3] = _n3[5];
2618
2619 out_tm6[0] = _n4[0];
2620 out_tm6[1] = _n4[1];
2621 out_tm6[2] = _n4[2];
2622 out_tm6[3] = _n4[3];
2623 out_tm7[0] = _n4[4];
2624 out_tm7[1] = _n4[5];
2625 out_tm7[2] = _n5[0];
2626 out_tm7[3] = _n5[1];
2627 out_tm8[0] = _n5[2];
2628 out_tm8[1] = _n5[3];
2629 out_tm8[2] = _n5[4];
2630 out_tm8[3] = _n5[5];
2631 #else
2632 short d0[6], d1[6], d2[6], d3[6], d4[6], d5[6];
2633 short w0[6], w1[6], w2[6], w3[6], w4[6], w5[6];
2634 short t0[6], t1[6], t2[6], t3[6], t4[6], t5[6];
2635
2636 // load
2637 for (int n = 0; n < 6; n++)
2638 {
2639 d0[n] = r0[n];
2640 d1[n] = r1[n];
2641 d2[n] = r2[n];
2642 d3[n] = r3[n];
2643 d4[n] = r4[n];
2644 d5[n] = r5[n];
2645 }
2646 // w = B_t * d
2647 for (int n = 0; n < 6; n++)
2648 {
2649 w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n];
2650 w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n];
2651 w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n];
2652 w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n];
2653 w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n];
2654 w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n];
2655 }
2656 // transpose d to d_t
2657 {
2658 t0[0] = w0[0];
2659 t1[0] = w0[1];
2660 t2[0] = w0[2];
2661 t3[0] = w0[3];
2662 t4[0] = w0[4];
2663 t5[0] = w0[5];
2664 t0[1] = w1[0];
2665 t1[1] = w1[1];
2666 t2[1] = w1[2];
2667 t3[1] = w1[3];
2668 t4[1] = w1[4];
2669 t5[1] = w1[5];
2670 t0[2] = w2[0];
2671 t1[2] = w2[1];
2672 t2[2] = w2[2];
2673 t3[2] = w2[3];
2674 t4[2] = w2[4];
2675 t5[2] = w2[5];
2676 t0[3] = w3[0];
2677 t1[3] = w3[1];
2678 t2[3] = w3[2];
2679 t3[3] = w3[3];
2680 t4[3] = w3[4];
2681 t5[3] = w3[5];
2682 t0[4] = w4[0];
2683 t1[4] = w4[1];
2684 t2[4] = w4[2];
2685 t3[4] = w4[3];
2686 t4[4] = w4[4];
2687 t5[4] = w4[5];
2688 t0[5] = w5[0];
2689 t1[5] = w5[1];
2690 t2[5] = w5[2];
2691 t3[5] = w5[3];
2692 t4[5] = w5[4];
2693 t5[5] = w5[5];
2694 }
2695 // d = B_t * d_t
2696 for (int n = 0; n < 6; n++)
2697 {
2698 d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n];
2699 d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n];
2700 d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n];
2701 d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n];
2702 d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n];
2703 d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n];
2704 }
2705 // save to out_tm
2706 {
2707 out_tm0[0] = d0[0];
2708 out_tm0[1] = d0[1];
2709 out_tm0[2] = d0[2];
2710 out_tm0[3] = d0[3];
2711 out_tm1[0] = d0[4];
2712 out_tm1[1] = d0[5];
2713 out_tm1[2] = d1[0];
2714 out_tm1[3] = d1[1];
2715 out_tm2[0] = d1[2];
2716 out_tm2[1] = d1[3];
2717 out_tm2[2] = d1[4];
2718 out_tm2[3] = d1[5];
2719
2720 out_tm3[0] = d2[0];
2721 out_tm3[1] = d2[1];
2722 out_tm3[2] = d2[2];
2723 out_tm3[3] = d2[3];
2724 out_tm4[0] = d2[4];
2725 out_tm4[1] = d2[5];
2726 out_tm4[2] = d3[0];
2727 out_tm4[3] = d3[1];
2728 out_tm5[0] = d3[2];
2729 out_tm5[1] = d3[3];
2730 out_tm5[2] = d3[4];
2731 out_tm5[3] = d3[5];
2732
2733 out_tm6[0] = d4[0];
2734 out_tm6[1] = d4[1];
2735 out_tm6[2] = d4[2];
2736 out_tm6[3] = d4[3];
2737 out_tm7[0] = d4[4];
2738 out_tm7[1] = d4[5];
2739 out_tm7[2] = d5[0];
2740 out_tm7[3] = d5[1];
2741 out_tm8[0] = d5[2];
2742 out_tm8[1] = d5[3];
2743 out_tm8[2] = d5[4];
2744 out_tm8[3] = d5[5];
2745 }
2746 #endif // __ARM_NEON
2747 r0 += 4;
2748 r1 += 4;
2749 r2 += 4;
2750 r3 += 4;
2751 r4 += 4;
2752 r5 += 4;
2753 }
2754 }
2755 }
2756 }
2757 bottom_blob_bordered = Mat();
2758
2759 // BEGIN dot
2760 Mat top_blob_tm;
2761 {
2762 int w_tm = outw / 4 * 6;
2763 int h_tm = outh / 4 * 6;
2764
2765 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
2766 int nRowBlocks = w_tm / 6;
2767
2768 const int tiles = nColBlocks * nRowBlocks;
2769
2770 top_blob_tm.create(36, tiles, outch, 4u, opt.workspace_allocator);
2771
2772 #pragma omp parallel for num_threads(opt.num_threads)
2773 for (int r = 0; r < 9; r++)
2774 {
2775 int nn_outch = 0;
2776 int remain_outch_start = 0;
2777
2778 nn_outch = outch >> 3;
2779 remain_outch_start = nn_outch << 3;
2780
2781 for (int pp = 0; pp < nn_outch; pp++)
2782 {
2783 int p = pp * 8;
2784
2785 int* output0_tm = top_blob_tm.channel(p);
2786 int* output1_tm = top_blob_tm.channel(p + 1);
2787 int* output2_tm = top_blob_tm.channel(p + 2);
2788 int* output3_tm = top_blob_tm.channel(p + 3);
2789 int* output4_tm = top_blob_tm.channel(p + 4);
2790 int* output5_tm = top_blob_tm.channel(p + 5);
2791 int* output6_tm = top_blob_tm.channel(p + 6);
2792 int* output7_tm = top_blob_tm.channel(p + 7);
2793
2794 output0_tm = output0_tm + r * 4;
2795 output1_tm = output1_tm + r * 4;
2796 output2_tm = output2_tm + r * 4;
2797 output3_tm = output3_tm + r * 4;
2798 output4_tm = output4_tm + r * 4;
2799 output5_tm = output5_tm + r * 4;
2800 output6_tm = output6_tm + r * 4;
2801 output7_tm = output7_tm + r * 4;
2802
2803 for (int i = 0; i < tiles; i++)
2804 {
2805 const short* kptr = kernel_tm_test[r].channel(p / 8);
2806 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
2807 #if __ARM_NEON
2808 #if __aarch64__
2809 asm volatile(
2810 // inch loop
2811 "eor v0.16b, v0.16b, v0.16b \n"
2812 "eor v1.16b, v1.16b, v1.16b \n"
2813 "eor v2.16b, v2.16b, v2.16b \n"
2814 "eor v3.16b, v3.16b, v3.16b \n"
2815 "eor v4.16b, v4.16b, v4.16b \n"
2816 "eor v5.16b, v5.16b, v5.16b \n"
2817 "eor v6.16b, v6.16b, v6.16b \n"
2818 "eor v7.16b, v7.16b, v7.16b \n"
2819 "mov w4, %w20 \n"
2820
2821 "0: \n" // for (int q=0; q<inch; q++)
2822 "prfm pldl1keep, [%9, #128] \n" // _r0 = vld1_s16(r0);
2823 "ld1 {v8.4h}, [%8] \n"
2824 "ld1 {v9.4h, v10.4h}, [%9] \n" // _k01 = vld1q_s16(kptr);
2825 "add %9, %9, #16 \n"
2826 "ld1 {v11.4h, v12.4h}, [%9] \n" // _k23 = vld1q_s16(kptr+8);
2827 "add %9, %9, #16 \n"
2828 "ld1 {v13.4h, v14.4h}, [%9] \n" // _k45 = vld1q_s16(kptr+16);
2829 "add %9, %9, #16 \n"
2830 "ld1 {v15.4h, v16.4h}, [%9] \n" // _k67 = vld1q_s16(kptr+24);
2831 "add %8, %8, #8 \n"
2832 "add %9, %9, #16 \n"
2833
2834 "subs w4, w4, #1 \n"
2835
2836 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
2837 "smlal v1.4s, v8.4h, v10.4h \n" // sum1 += (a00-a03) * (k10-k13)
2838 "smlal v2.4s, v8.4h, v11.4h \n" // sum2 += (a00-a03) * (k20-k23)
2839 "smlal v3.4s, v8.4h, v12.4h \n" // sum3 += (a00-a03) * (k30-k33)
2840 "smlal v4.4s, v8.4h, v13.4h \n" // sum4 += (a00-a03) * (k40-k43)
2841 "smlal v5.4s, v8.4h, v14.4h \n" // sum5 += (a00-a03) * (k50-k53)
2842 "smlal v6.4s, v8.4h, v15.4h \n" // sum6 += (a00-a03) * (k60-k63)
2843 "smlal v7.4s, v8.4h, v16.4h \n" // sum7 += (a00-a03) * (k70-k73)
2844
2845 "bne 0b \n" // end for
2846
2847 "st1 {v0.4s}, [%0] \n" // store the result to memory
2848 "st1 {v1.4s}, [%1] \n" //
2849 "st1 {v2.4s}, [%2] \n" //
2850 "st1 {v3.4s}, [%3] \n" //
2851 "st1 {v4.4s}, [%4] \n" //
2852 "st1 {v5.4s}, [%5] \n" //
2853 "st1 {v6.4s}, [%6] \n" //
2854 "st1 {v7.4s}, [%7] \n" //
2855
2856 : "=r"(output0_tm), // %0
2857 "=r"(output1_tm), // %1
2858 "=r"(output2_tm), // %2
2859 "=r"(output3_tm), // %3
2860 "=r"(output4_tm), // %4
2861 "=r"(output5_tm), // %5
2862 "=r"(output6_tm), // %6
2863 "=r"(output7_tm), // %7
2864 "=r"(r0), // %8
2865 "=r"(kptr) // %9
2866 : "0"(output0_tm),
2867 "1"(output1_tm),
2868 "2"(output2_tm),
2869 "3"(output3_tm),
2870 "4"(output4_tm),
2871 "5"(output5_tm),
2872 "6"(output6_tm),
2873 "7"(output7_tm),
2874 "8"(r0),
2875 "9"(kptr),
2876 "r"(inch) // %20
2877 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16");
2878 #else
2879 asm volatile(
2880 // inch loop
2881 "vmov.s32 q0, #0 \n"
2882 "vmov.s32 q1, #0 \n"
2883 "vmov.s32 q2, #0 \n"
2884 "vmov.s32 q3, #0 \n"
2885 "vmov.s32 q4, #0 \n"
2886 "vmov.s32 q5, #0 \n"
2887 "vmov.s32 q6, #0 \n"
2888 "vmov.s32 q7, #0 \n"
2889 "mov r4, %20 \n"
2890
2891 "0: \n" // for (int q=0; q<inch; q++)
2892 "vld1.s16 {d16}, [%8]! \n" // _r0 = vld1_s16(r0); // input inch0
2893 "vld1.s16 {d18-d19}, [%9] \n" // _k01 = vld1q_s16(kptr);
2894 "add %9, #16 \n"
2895 "vld1.s16 {d20-d21}, [%9] \n" // _k23 = vld1q_s16(kptr+8);
2896 "add %9, #16 \n"
2897 "vld1.s16 {d22-d23}, [%9] \n" // _k45 = vld1q_s16(kptr+16);
2898 "add %9, #16 \n"
2899 "vld1.s16 {d24-d25}, [%9] \n" // _k67 = vld1q_s16(kptr+24);
2900 "add %9, #16 \n"
2901
2902 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
2903 "vmlal.s16 q1, d16, d19 \n" // sum1 += (a00-a03) * (k10-k13)
2904 "vmlal.s16 q2, d16, d20 \n" // sum2 += (a00-a03) * (k20-k23)
2905 "vmlal.s16 q3, d16, d21 \n" // sum3 += (a00-a03) * (k30-k33)
2906 "vmlal.s16 q4, d16, d22 \n" // sum4 += (a00-a03) * (k40-k43)
2907 "vmlal.s16 q5, d16, d23 \n" // sum5 += (a00-a03) * (k50-k53)
2908 "vmlal.s16 q6, d16, d24 \n" // sum6 += (a00-a03) * (k60-k63)
2909 "vmlal.s16 q7, d16, d25 \n" // sum7 += (a00-a03) * (k70-k73)
2910
2911 "subs r4, r4, #1 \n"
2912 "bne 0b \n" // end for
2913
2914 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
2915 "vst1.s32 {d2-d3}, [%1] \n"
2916 "vst1.s32 {d4-d5}, [%2] \n"
2917 "vst1.s32 {d6-d7}, [%3] \n"
2918 "vst1.s32 {d8-d9}, [%4] \n"
2919 "vst1.s32 {d10-d11}, [%5] \n"
2920 "vst1.s32 {d12-d13}, [%6] \n"
2921 "vst1.s32 {d14-d15}, [%7] \n"
2922
2923 : "=r"(output0_tm), // %0
2924 "=r"(output1_tm), // %1
2925 "=r"(output2_tm), // %2
2926 "=r"(output3_tm), // %3
2927 "=r"(output4_tm), // %4
2928 "=r"(output5_tm), // %5
2929 "=r"(output6_tm), // %6
2930 "=r"(output7_tm), // %7
2931 "=r"(r0), // %8
2932 "=r"(kptr) // %9
2933 : "0"(output0_tm),
2934 "1"(output1_tm),
2935 "2"(output2_tm),
2936 "3"(output3_tm),
2937 "4"(output4_tm),
2938 "5"(output5_tm),
2939 "6"(output6_tm),
2940 "7"(output7_tm),
2941 "8"(r0),
2942 "9"(kptr),
2943 "r"(inch) // %20
2944 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
2945 #endif // __aarch64__
2946 #else
2947 int sum0[4] = {0};
2948 int sum1[4] = {0};
2949 int sum2[4] = {0};
2950 int sum3[4] = {0};
2951 int sum4[4] = {0};
2952 int sum5[4] = {0};
2953 int sum6[4] = {0};
2954 int sum7[4] = {0};
2955
2956 for (int q = 0; q < inch; q++)
2957 {
2958 for (int n = 0; n < 4; n++)
2959 {
2960 sum0[n] += (int)r0[n] * kptr[n];
2961 sum1[n] += (int)r0[n] * kptr[n + 4];
2962 sum2[n] += (int)r0[n] * kptr[n + 8];
2963 sum3[n] += (int)r0[n] * kptr[n + 12];
2964 sum4[n] += (int)r0[n] * kptr[n + 16];
2965 sum5[n] += (int)r0[n] * kptr[n + 20];
2966 sum6[n] += (int)r0[n] * kptr[n + 24];
2967 sum7[n] += (int)r0[n] * kptr[n + 28];
2968 }
2969 kptr += 32;
2970 r0 += 4;
2971 }
2972
2973 for (int n = 0; n < 4; n++)
2974 {
2975 output0_tm[n] = sum0[n];
2976 output1_tm[n] = sum1[n];
2977 output2_tm[n] = sum2[n];
2978 output3_tm[n] = sum3[n];
2979 output4_tm[n] = sum4[n];
2980 output5_tm[n] = sum5[n];
2981 output6_tm[n] = sum6[n];
2982 output7_tm[n] = sum7[n];
2983 }
2984 #endif // __ARM_NEON
2985 output0_tm += 36;
2986 output1_tm += 36;
2987 output2_tm += 36;
2988 output3_tm += 36;
2989 output4_tm += 36;
2990 output5_tm += 36;
2991 output6_tm += 36;
2992 output7_tm += 36;
2993 }
2994 }
2995
2996 nn_outch = (outch - remain_outch_start) >> 2;
2997
2998 for (int pp = 0; pp < nn_outch; pp++)
2999 {
3000 int p = remain_outch_start + pp * 4;
3001
3002 int* output0_tm = top_blob_tm.channel(p);
3003 int* output1_tm = top_blob_tm.channel(p + 1);
3004 int* output2_tm = top_blob_tm.channel(p + 2);
3005 int* output3_tm = top_blob_tm.channel(p + 3);
3006
3007 output0_tm = output0_tm + r * 4;
3008 output1_tm = output1_tm + r * 4;
3009 output2_tm = output2_tm + r * 4;
3010 output3_tm = output3_tm + r * 4;
3011
3012 for (int i = 0; i < tiles; i++)
3013 {
3014 const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
3015 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
3016 #if __ARM_NEON
3017 #if __aarch64__
3018 asm volatile(
3019 // inch loop
3020 "eor v0.16b, v0.16b, v0.16b \n"
3021 "eor v1.16b, v1.16b, v1.16b \n"
3022 "eor v2.16b, v2.16b, v2.16b \n"
3023 "eor v3.16b, v3.16b, v3.16b \n"
3024 "mov w4, %w12 \n"
3025
3026 "0: \n" // for (int q=0; q<inch; q++)
3027 "prfm pldl1keep, [%5, #128] \n" // _r0 = vld1_s16(r0); // input inch0
3028 "ld1 {v8.4h}, [%4] \n"
3029 "ld1 {v9.4h, v10.4h}, [%5] \n" // _k01 = vld1q_s16(kptr);
3030 "add %5, %5, #16 \n"
3031 "ld1 {v11.4h, v12.4h}, [%5] \n" // _k23 = vld1q_s16(kptr+8);
3032 "add %4, %4, #8 \n"
3033 "add %5, %5, #16 \n"
3034
3035 "subs w4, w4, #1 \n"
3036
3037 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
3038 "smlal v1.4s, v8.4h, v10.4h \n" // sum1 += (a00-a03) * (k10-k13)
3039 "smlal v2.4s, v8.4h, v11.4h \n" // sum2 += (a00-a03) * (k20-k23)
3040 "smlal v3.4s, v8.4h, v12.4h \n" // sum3 += (a00-a03) * (k30-k33)
3041
3042 "bne 0b \n" // end for
3043
3044 "st1 {v0.4s}, [%0] \n" // store the result to memory
3045 "st1 {v1.4s}, [%1] \n" //
3046 "st1 {v2.4s}, [%2] \n" //
3047 "st1 {v3.4s}, [%3] \n" //
3048
3049 : "=r"(output0_tm), // %0
3050 "=r"(output1_tm), // %1
3051 "=r"(output2_tm), // %2
3052 "=r"(output3_tm), // %3
3053 "=r"(r0), // %4
3054 "=r"(kptr) // %5
3055 : "0"(output0_tm),
3056 "1"(output1_tm),
3057 "2"(output2_tm),
3058 "3"(output3_tm),
3059 "4"(r0),
3060 "5"(kptr),
3061 "r"(inch) // %12
3062 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
3063 #else
3064 asm volatile(
3065 // inch loop
3066 "vmov.s32 q0, #0 \n"
3067 "vmov.s32 q1, #0 \n"
3068 "vmov.s32 q2, #0 \n"
3069 "vmov.s32 q3, #0 \n"
3070 "mov r4, %12 \n"
3071
3072 "0: \n" // for (int q=0; q<inch; q++)
3073 "vld1.s16 {d16}, [%4]! \n" // _r0 = vld1_s16(r0); // input inch0
3074 "vld1.s16 {d18-d19}, [%5] \n" // _k01 = vld1q_s16(kptr);
3075 "add %5, #16 \n"
3076 "vld1.s16 {d20-d21}, [%5] \n" // _k23 = vld1q_s16(kptr+8);
3077 "add %5, #16 \n"
3078
3079 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
3080 "vmlal.s16 q1, d16, d19 \n" // sum1 += (a00-a03) * (k10-k13)
3081 "vmlal.s16 q2, d16, d20 \n" // sum2 += (a00-a03) * (k20-k23)
3082 "vmlal.s16 q3, d16, d21 \n" // sum3 += (a00-a03) * (k30-k33)
3083
3084 "subs r4, r4, #1 \n"
3085 "bne 0b \n" // end for
3086
3087 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
3088 "vst1.s32 {d2-d3}, [%1] \n"
3089 "vst1.s32 {d4-d5}, [%2] \n"
3090 "vst1.s32 {d6-d7}, [%3] \n"
3091
3092 : "=r"(output0_tm), // %0
3093 "=r"(output1_tm), // %1
3094 "=r"(output2_tm), // %2
3095 "=r"(output3_tm), // %3
3096 "=r"(r0), // %4
3097 "=r"(kptr) // %5
3098 : "0"(output0_tm),
3099 "1"(output1_tm),
3100 "2"(output2_tm),
3101 "3"(output3_tm),
3102 "4"(r0),
3103 "5"(kptr),
3104 "r"(inch) // %12
3105 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q8", "q9", "q10");
3106 #endif // __aarch64__
3107 #else
3108 int sum0[4] = {0};
3109 int sum1[4] = {0};
3110 int sum2[4] = {0};
3111 int sum3[4] = {0};
3112
3113 for (int q = 0; q < inch; q++)
3114 {
3115 for (int n = 0; n < 4; n++)
3116 {
3117 sum0[n] += (int)r0[n] * kptr[n];
3118 sum1[n] += (int)r0[n] * kptr[n + 4];
3119 sum2[n] += (int)r0[n] * kptr[n + 8];
3120 sum3[n] += (int)r0[n] * kptr[n + 12];
3121 }
3122 kptr += 16;
3123 r0 += 4;
3124 }
3125
3126 for (int n = 0; n < 4; n++)
3127 {
3128 output0_tm[n] = sum0[n];
3129 output1_tm[n] = sum1[n];
3130 output2_tm[n] = sum2[n];
3131 output3_tm[n] = sum3[n];
3132 }
3133 #endif // __ARM_NEON
3134 output0_tm += 36;
3135 output1_tm += 36;
3136 output2_tm += 36;
3137 output3_tm += 36;
3138 }
3139 }
3140
3141 remain_outch_start += nn_outch << 2;
3142
3143 for (int p = remain_outch_start; p < outch; p++)
3144 {
3145 int* output0_tm = top_blob_tm.channel(p);
3146
3147 output0_tm = output0_tm + r * 4;
3148
3149 for (int i = 0; i < tiles; i++)
3150 {
3151 const short* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
3152 const short* r0 = bottom_blob_tm.channel(tiles * r + i);
3153 #if __ARM_NEON
3154 #if __aarch64__
3155 asm volatile(
3156 // inch loop
3157 "eor v0.16b, v0.16b, v0.16b \n"
3158 "mov w4, %w6 \n"
3159
3160 "0: \n" // for (int q=0; q<inch; q++)
3161 "ld1 {v8.4h}, [%1] \n" // _r0 = vld1_s16(r0); // input inch0
3162 "ld1 {v9.4h}, [%2] \n" // _k0 = vld1q_s16(kptr);
3163 "add %1, %1, #8 \n"
3164 "add %2, %2, #8 \n"
3165
3166 "subs w4, w4, #1 \n"
3167
3168 "smlal v0.4s, v8.4h, v9.4h \n" // sum0 += (a00-a03) * (k00-k03)
3169
3170 "bne 0b \n" // end for
3171
3172 "st1 {v0.4s}, [%0] \n" // store the result to memory
3173
3174 : "=r"(output0_tm), // %0
3175 "=r"(r0), // %1
3176 "=r"(kptr) // %2
3177 : "0"(output0_tm),
3178 "1"(r0),
3179 "2"(kptr),
3180 "r"(inch) // %6
3181 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9");
3182 #else
3183 asm volatile(
3184 // inch loop
3185 "vmov.s32 q0, #0 \n"
3186 "mov r4, %6 \n"
3187
3188 "0: \n" // for (int q=0; q<inch; q++)
3189 "vld1.s16 {d16}, [%1] \n" // _r0 = vld1_s16(r0); // input inch0
3190 "add %1, #8 \n"
3191 "vld1.s16 {d18}, [%2] \n" // _k0 = vld1q_s16(kptr);
3192 "add %2, #8 \n"
3193 "vmlal.s16 q0, d16, d18 \n" // sum0 += (a00-a03) * (k00-k03)
3194
3195 "subs r4, r4, #1 \n"
3196 "bne 0b \n" // end for
3197
3198 "vst1.s32 {d0-d1}, [%0] \n" // store the result to memory
3199
3200 : "=r"(output0_tm), // %0
3201 "=r"(r0), // %1
3202 "=r"(kptr) // %2
3203 : "0"(output0_tm),
3204 "1"(r0),
3205 "2"(kptr),
3206 "r"(inch) // %6
3207 : "cc", "memory", "r4", "q0", "q8", "q9");
3208 #endif // __aarch64__
3209 #else // __ARM_NEON
3210 int sum0[4] = {0};
3211
3212 for (int q = 0; q < inch; q++)
3213 {
3214 for (int n = 0; n < 4; n++)
3215 {
3216 sum0[n] += (int)r0[n] * kptr[n];
3217 }
3218 kptr += 4;
3219 r0 += 4;
3220 }
3221
3222 for (int n = 0; n < 4; n++)
3223 {
3224 output0_tm[n] = sum0[n];
3225 }
3226 #endif // __ARM_NEON
3227 output0_tm += 36;
3228 }
3229 }
3230
3231 // for (int p=0; p<outch; p++)
3232 // {
3233 // Mat out0_tm = top_blob_tm.channel(p);
3234 // const Mat kernel0_tm = kernel_tm.channel(p);
3235
3236 // for (int i=0; i<tiles; i++)
3237 // {
3238 // int* output0_tm = out0_tm.row<int>(i);
3239
3240 // int sum0[36] = {0};
3241
3242 // for (int q=0; q<inch; q++)
3243 // {
3244 // const short* r0 = bottom_blob_tm.channel(q).row<short>(i);
3245 // const short* k0 = kernel0_tm.row<short>(q);
3246
3247 // for (int n=0; n<36; n++)
3248 // {
3249 // sum0[n] += (int)r0[n] * k0[n];
3250 // }
3251 // }
3252
3253 // for (int n=0; n<36; n++)
3254 // {
3255 // output0_tm[n] = sum0[n];
3256 // }
3257 // }
3258 // }
3259 }
3260 }
3261 bottom_blob_tm = Mat();
3262 // END dot
3263
3264 // BEGIN transform output
3265 Mat top_blob_bordered;
3266 top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
3267 {
3268 // AT
3269 // const float itm[4][6] = {
3270 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
3271 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
3272 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
3273 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
3274 // };
3275
3276 // 0 = r00 + r01 + r02 + r03 + r04
3277 // 1 = r01 - r02 + 2 * (r03 - r04)
3278 // 2 = r01 + r02 + 4 * (r03 + r04)
3279 // 3 = r01 - r02 + 8 * (r03 - r04) + r05
3280
3281 int w_tm = outw / 4 * 6;
3282 int h_tm = outh / 4 * 6;
3283
3284 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
3285 int nRowBlocks = w_tm / 6;
3286
3287 #pragma omp parallel for num_threads(opt.num_threads)
3288 for (int p = 0; p < outch; p++)
3289 {
3290 int* out_tile = top_blob_tm.channel(p);
3291 float* outRow0 = top_blob_bordered.channel(p);
3292 float* outRow1 = outRow0 + outw;
3293 float* outRow2 = outRow0 + outw * 2;
3294 float* outRow3 = outRow0 + outw * 3;
3295
3296 const float bias0 = bias ? bias[p] : 0.f;
3297
3298 const float scale_dequant0 = scales_dequant[p];
3299
3300 const float scale0 = scale_dequant0 / 576.0;
3301
3302 for (int j = 0; j < nColBlocks; j++)
3303 {
3304 for (int i = 0; i < nRowBlocks; i++)
3305 {
3306 #if __ARM_NEON
3307 int32x4_t _s0, _s1, _s2, _s3, _s4, _s5;
3308 int32x2_t _s0n, _s1n, _s2n, _s3n, _s4n, _s5n;
3309 int32x4_t _w0, _w3;
3310 int32x2_t _w0n, _w3n;
3311 int32x4_t _d0, _d1, _d2, _d3, _d4, _d5;
3312 int32x4_t _o0, _o3;
3313 // load
3314 _s0 = vld1q_s32(out_tile);
3315 _s0n = vld1_s32(out_tile + 4);
3316 _s1 = vld1q_s32(out_tile + 6);
3317 _s1n = vld1_s32(out_tile + 10);
3318 _s2 = vld1q_s32(out_tile + 12);
3319 _s2n = vld1_s32(out_tile + 16);
3320 _s3 = vld1q_s32(out_tile + 18);
3321 _s3n = vld1_s32(out_tile + 22);
3322 _s4 = vld1q_s32(out_tile + 24);
3323 _s4n = vld1_s32(out_tile + 28);
3324 _s5 = vld1q_s32(out_tile + 30);
3325 _s5n = vld1_s32(out_tile + 34);
3326 // w = A_T * W
3327 int32x2_t _tp0 = {1, 4};
3328 int32x2_t _tp1 = {2, 8};
3329
3330 // 4*s5[n]
3331 int32x4_t _s5x4 = vshlq_n_s32(_s5, 2);
3332 int32x2_t _s5x4n = vshl_n_s32(_s5n, 2);
3333
3334 int32x4_t _t1p2 = vaddq_s32(_s1, _s2);
3335 int32x2_t _t1p2n = vadd_s32(_s1n, _s2n);
3336 int32x4_t _t3p4 = vaddq_s32(_s3, _s4);
3337 int32x2_t _t3p4n = vadd_s32(_s3n, _s4n);
3338 int32x4_t _t1s2 = vsubq_s32(_s1, _s2);
3339 int32x2_t _t1s2n = vsub_s32(_s1n, _s2n);
3340 int32x4_t _t3s4 = vsubq_s32(_s3, _s4);
3341 int32x2_t _t3s4n = vsub_s32(_s3n, _s4n);
3342
3343 _w0 = vaddq_s32(_s0, _t1p2);
3344 _w0n = vadd_s32(_s0n, _t1p2n);
3345 _w0 = vaddq_s32(_w0, _t3p4);
3346 _w0n = vadd_s32(_w0n, _t3p4n);
3347 _w0n = vmul_s32(_w0n, _tp0);
3348
3349 // _w2,_w2n
3350 _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
3351 _t1p2n = vmla_lane_s32(_t1p2n, _t3p4n, _tp0, 1);
3352 _t1p2n = vmul_s32(_t1p2n, _tp0);
3353
3354 _w3 = vaddq_s32(_s5x4, _t1s2);
3355 _w3n = vadd_s32(_s5x4n, _t1s2n);
3356 _w3 = vmlaq_lane_s32(_w3, _t3s4, _tp1, 1);
3357 _w3n = vmla_lane_s32(_w3n, _t3s4n, _tp1, 1);
3358 _w3n = vmul_s32(_w3n, _tp0);
3359
3360 // _w1, _w1n
3361 _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
3362 _t1s2n = vmla_lane_s32(_t1s2n, _t3s4n, _tp1, 0);
3363 _t1s2n = vmul_s32(_t1s2n, _tp0);
3364
3365 int32x4_t _w02n = vcombine_s32(_w0n, _t1p2n);
3366 int32x4_t _w13n = vcombine_s32(_t1s2n, _w3n);
3367
3368 // transpose w to w_t
3369 #if __aarch64__
3370 int32x4_t _wt0 = vtrn1q_s32(_w0, _t1s2);
3371 int32x4_t _wt1 = vtrn2q_s32(_w0, _t1s2);
3372 int32x4_t _wt2 = vtrn1q_s32(_t1p2, _w3);
3373 int32x4_t _wt3 = vtrn2q_s32(_t1p2, _w3);
3374 int64x2_t _dt0 = vtrn1q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
3375 int64x2_t _dt2 = vtrn2q_s64(vreinterpretq_s64_s32(_wt0), vreinterpretq_s64_s32(_wt2));
3376 int64x2_t _dt1 = vtrn1q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
3377 int64x2_t _dt3 = vtrn2q_s64(vreinterpretq_s64_s32(_wt1), vreinterpretq_s64_s32(_wt3));
3378 _d0 = vreinterpretq_s32_s64(_dt0);
3379 _d1 = vreinterpretq_s32_s64(_dt1);
3380 _d2 = vreinterpretq_s32_s64(_dt2);
3381 _d3 = vreinterpretq_s32_s64(_dt3);
3382 _d4 = vtrn1q_s32(_w02n, _w13n);
3383 _d5 = vtrn2q_s32(_w02n, _w13n);
3384 #else
3385 asm volatile(
3386 "vtrn.32 %q[_w0], %q[_w1] \n"
3387 "vtrn.32 %q[_w2], %q[_w3] \n"
3388 "vswp %f[_w0], %e[_w2] \n"
3389 "vswp %f[_w1], %e[_w3] \n"
3390 "vtrn.32 %q[_w02n], %q[_w13n] \n"
3391 : [_w0] "+w"(_w0),
3392 [_w1] "+w"(_t1s2),
3393 [_w2] "+w"(_t1p2),
3394 [_w3] "+w"(_w3),
3395 [_w02n] "+w"(_w02n),
3396 [_w13n] "+w"(_w13n)
3397 :
3398 : "cc", "memory");
3399 _d0 = _w0;
3400 _d1 = _t1s2;
3401 _d2 = _t1p2;
3402 _d3 = _w3;
3403 _d4 = _w02n;
3404 _d5 = _w13n;
3405 #endif
3406 // Y = A_T * w_t
3407 _t1p2 = vaddq_s32(_d1, _d2);
3408 _t3p4 = vaddq_s32(_d3, _d4);
3409 _t1s2 = vsubq_s32(_d1, _d2);
3410 _t3s4 = vsubq_s32(_d3, _d4);
3411
3412 _o0 = vaddq_s32(_d0, _t1p2);
3413 _o0 = vaddq_s32(_o0, _t3p4);
3414
3415 // _o2
3416 _t1p2 = vmlaq_lane_s32(_t1p2, _t3p4, _tp0, 1);
3417
3418 _o3 = vaddq_s32(_d5, _t1s2);
3419 _o3 = vmlaq_lane_s32(_o3, _t3s4, _tp1, 1);
3420
3421 // _o1
3422 _t1s2 = vmlaq_lane_s32(_t1s2, _t3s4, _tp1, 0);
3423
3424 // save to top blob tm
3425 float32x4_t _scale0 = vdupq_n_f32(scale0);
3426 float32x4_t _out0_f32 = vdupq_n_f32(bias0);
3427 float32x4_t _out1_f32 = vdupq_n_f32(bias0);
3428 float32x4_t _out2_f32 = vdupq_n_f32(bias0);
3429 float32x4_t _out3_f32 = vdupq_n_f32(bias0);
3430
3431 _out0_f32 = vmlaq_f32(_out0_f32, vcvtq_f32_s32(_o0), _scale0);
3432 _out1_f32 = vmlaq_f32(_out1_f32, vcvtq_f32_s32(_t1s2), _scale0);
3433 _out2_f32 = vmlaq_f32(_out2_f32, vcvtq_f32_s32(_t1p2), _scale0);
3434 _out3_f32 = vmlaq_f32(_out3_f32, vcvtq_f32_s32(_o3), _scale0);
3435
3436 vst1q_f32(outRow0, _out0_f32);
3437 vst1q_f32(outRow1, _out1_f32);
3438 vst1q_f32(outRow2, _out2_f32);
3439 vst1q_f32(outRow3, _out3_f32);
3440 #else
3441 int s0[6], s1[6], s2[6], s3[6], s4[6], s5[6];
3442 int w0[6], w1[6], w2[6], w3[6];
3443 int d0[4], d1[4], d2[4], d3[4], d4[4], d5[4];
3444 int o0[4], o1[4], o2[4], o3[4];
3445
3446 // load
3447 for (int n = 0; n < 6; n++)
3448 {
3449 s0[n] = out_tile[n];
3450 s1[n] = out_tile[n + 6];
3451 s2[n] = out_tile[n + 12];
3452 s3[n] = out_tile[n + 18];
3453 s4[n] = out_tile[n + 24];
3454 s5[n] = out_tile[n + 30];
3455 }
3456 // w = A_T * W
3457 for (int n = 0; n < 5; n++)
3458 {
3459 w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n];
3460 w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n];
3461 w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n];
3462 w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n];
3463 }
3464 for (int n = 5; n < 6; n++)
3465 {
3466 w0[n] = 4 * (s0[n] + s1[n] + s2[n] + s3[n] + s4[n]);
3467 w1[n] = 4 * (s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n]);
3468 w2[n] = 4 * (s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n]);
3469 w3[n] = 4 * (s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + 4 * s5[n]);
3470 }
3471 // transpose w to w_t
3472 {
3473 d0[0] = w0[0];
3474 d0[1] = w1[0];
3475 d0[2] = w2[0];
3476 d0[3] = w3[0];
3477 d1[0] = w0[1];
3478 d1[1] = w1[1];
3479 d1[2] = w2[1];
3480 d1[3] = w3[1];
3481 d2[0] = w0[2];
3482 d2[1] = w1[2];
3483 d2[2] = w2[2];
3484 d2[3] = w3[2];
3485 d3[0] = w0[3];
3486 d3[1] = w1[3];
3487 d3[2] = w2[3];
3488 d3[3] = w3[3];
3489 d4[0] = w0[4];
3490 d4[1] = w1[4];
3491 d4[2] = w2[4];
3492 d4[3] = w3[4];
3493 d5[0] = w0[5];
3494 d5[1] = w1[5];
3495 d5[2] = w2[5];
3496 d5[3] = w3[5];
3497 }
3498 // Y = A_T * w_t
3499 for (int n = 0; n < 4; n++)
3500 {
3501 o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n];
3502 o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n];
3503 o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n];
3504 o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n];
3505 }
3506 // save to top blob tm
3507 for (int n = 0; n < 4; n++)
3508 {
3509 outRow0[n] = (float)o0[n] * scale0 + bias0;
3510 outRow1[n] = (float)o1[n] * scale0 + bias0;
3511 outRow2[n] = (float)o2[n] * scale0 + bias0;
3512 outRow3[n] = (float)o3[n] * scale0 + bias0;
3513 }
3514 #endif // __ARM_NEON
3515 out_tile += 36;
3516
3517 outRow0 += 4;
3518 outRow1 += 4;
3519 outRow2 += 4;
3520 outRow3 += 4;
3521 }
3522
3523 outRow0 += outw * 3;
3524 outRow1 += outw * 3;
3525 outRow2 += outw * 3;
3526 outRow3 += outw * 3;
3527 }
3528 }
3529 }
3530 // END transform output
3531
3532 // cut result pad
3533 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
3534 }
3535
conv3x3s2_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch)3536 static void conv3x3s2_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch)
3537 {
3538 kernel_tm.create(8 * 9, inch, outch / 8 + outch % 8, (size_t)1u);
3539
3540 const signed char* kernel = _kernel;
3541
3542 int p = 0;
3543 for (; p + 7 < outch; p += 8)
3544 {
3545 const signed char* k0 = kernel + (p + 0) * inch * 9;
3546 const signed char* k1 = kernel + (p + 1) * inch * 9;
3547 const signed char* k2 = kernel + (p + 2) * inch * 9;
3548 const signed char* k3 = kernel + (p + 3) * inch * 9;
3549 const signed char* k4 = kernel + (p + 4) * inch * 9;
3550 const signed char* k5 = kernel + (p + 5) * inch * 9;
3551 const signed char* k6 = kernel + (p + 6) * inch * 9;
3552 const signed char* k7 = kernel + (p + 7) * inch * 9;
3553
3554 signed char* ktmp = kernel_tm.channel(p / 8);
3555
3556 for (int q = 0; q < inch; q++)
3557 {
3558 for (int k = 0; k < 9; k++)
3559 {
3560 ktmp[0] = k0[k];
3561 ktmp[1] = k1[k];
3562 ktmp[2] = k2[k];
3563 ktmp[3] = k3[k];
3564 ktmp[4] = k4[k];
3565 ktmp[5] = k5[k];
3566 ktmp[6] = k6[k];
3567 ktmp[7] = k7[k];
3568 ktmp += 8;
3569 }
3570
3571 k0 += 9;
3572 k1 += 9;
3573 k2 += 9;
3574 k3 += 9;
3575 k4 += 9;
3576 k5 += 9;
3577 k6 += 9;
3578 k7 += 9;
3579 }
3580 }
3581 for (; p < outch; p++)
3582 {
3583 const signed char* k0 = kernel + (p + 0) * inch * 9;
3584
3585 signed char* ktmp = kernel_tm.channel(p / 8 + p % 8);
3586
3587 for (int q = 0; q < inch; q++)
3588 {
3589 for (int k = 0; k < 9; k++)
3590 {
3591 ktmp[k] = k0[k];
3592 }
3593 ktmp += 9;
3594
3595 k0 += 9;
3596 }
3597 }
3598 }
3599
conv3x3s2_packed_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Option & opt)3600 static void conv3x3s2_packed_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Option& opt)
3601 {
3602 int w = bottom_blob.w;
3603 int inch = bottom_blob.c;
3604
3605 int outw = top_blob.w;
3606 int outh = top_blob.h;
3607 int outch = top_blob.c;
3608
3609 const int tailstep = w - 2 * outw + w;
3610
3611 int nn_outch = outch >> 3;
3612 int remain_outch_start = nn_outch << 3;
3613
3614 #pragma omp parallel for num_threads(opt.num_threads)
3615 for (int pp = 0; pp < nn_outch; pp++)
3616 {
3617 int p = pp * 8;
3618
3619 Mat out0 = top_blob.channel(p + 0);
3620 Mat out1 = top_blob.channel(p + 1);
3621 Mat out2 = top_blob.channel(p + 2);
3622 Mat out3 = top_blob.channel(p + 3);
3623 Mat out4 = top_blob.channel(p + 4);
3624 Mat out5 = top_blob.channel(p + 5);
3625 Mat out6 = top_blob.channel(p + 6);
3626 Mat out7 = top_blob.channel(p + 7);
3627
3628 out0.fill(0);
3629 out1.fill(0);
3630 out2.fill(0);
3631 out3.fill(0);
3632 out4.fill(0);
3633 out5.fill(0);
3634 out6.fill(0);
3635 out7.fill(0);
3636
3637 const signed char* ktmp = _kernel.channel(p / 8);
3638
3639 for (int q = 0; q < inch; q++)
3640 {
3641 int* outptr0 = out0;
3642 int* outptr1 = out1;
3643 int* outptr2 = out2;
3644 int* outptr3 = out3;
3645 int* outptr4 = out4;
3646 int* outptr5 = out5;
3647 int* outptr6 = out6;
3648 int* outptr7 = out7;
3649
3650 const signed char* img0 = bottom_blob.channel(q);
3651
3652 const signed char* r0 = img0;
3653 const signed char* r1 = img0 + w;
3654 const signed char* r2 = img0 + w * 2;
3655
3656 int i = 0;
3657
3658 for (; i < outh; i++)
3659 {
3660 #if __ARM_NEON
3661 #if __aarch64__
3662 int nn = outw >> 3;
3663 int remain = outw & 7;
3664 #else
3665 int nn = outw >> 2;
3666 int remain = outw & 3;
3667 #endif // __aarch64__
3668 #else
3669 int remain = outw;
3670 #endif // __ARM_NEON
3671
3672 #if __ARM_NEON
3673 #if __aarch64__
3674 if (nn > 0)
3675 {
3676 asm volatile(
3677 "0: \n"
3678
3679 "ld1 {v0.8b, v1.8b, v2.8b}, [%12], #24 \n" //ktmp
3680 "ld2 {v3.8b, v4.8b}, [%9], #16 \n" //r0-r2
3681 "ld2 {v5.8b, v6.8b}, [%9] \n"
3682
3683 "ld1 {v8.4s, v9.4s}, [%1] \n" //out0
3684 "ld1 {v10.4s, v11.4s}, [%2] \n" //out1
3685 "ld1 {v12.4s, v13.4s}, [%3] \n" //out2
3686 "ld1 {v14.4s, v15.4s}, [%4] \n" //out3
3687 "ld1 {v16.4s, v17.4s}, [%5] \n" //out4
3688 "ld1 {v18.4s, v19.4s}, [%6] \n" //out5
3689 "ld1 {v20.4s, v21.4s}, [%7] \n" //out6
3690 "ld1 {v22.4s, v23.4s}, [%8] \n" //out7
3691
3692 "ext v7.8b, v3.8b, v5.8b, #1 \n"
3693
3694 "sshll v0.8h, v0.8b, #0 \n" //(k00-k70)
3695 "sshll v1.8h, v1.8b, #0 \n" //(k01-k71)
3696 "sshll v2.8h, v2.8b, #0 \n" //(k02-k72)
3697 "sshll v3.8h, v3.8b, #0 \n" // r0
3698 "sshll v4.8h, v4.8b, #0 \n" // r1
3699 "sshll v7.8h, v7.8b, #0 \n" // r2
3700
3701 // r0
3702 "smlal v8.4s, v3.4h, v0.h[0] \n" // out0 += (r00-r07)*k00
3703 "smlal2 v9.4s, v3.8h, v0.h[0] \n"
3704 "smlal v10.4s, v3.4h, v0.h[1] \n" // out1 += (r00-r07)*k10
3705 "smlal2 v11.4s, v3.8h, v0.h[1] \n"
3706 "smlal v12.4s, v3.4h, v0.h[2] \n" // out2 += (r00-r07)*k20
3707 "smlal2 v13.4s, v3.8h, v0.h[2] \n"
3708 "smlal v14.4s, v3.4h, v0.h[3] \n" // out3 += (r00-r07)*k30
3709 "smlal2 v15.4s, v3.8h, v0.h[3] \n"
3710 "smlal v16.4s, v3.4h, v0.h[4] \n" // out4 += (r00-r07)*k40
3711 "smlal2 v17.4s, v3.8h, v0.h[4] \n"
3712 "smlal v18.4s, v3.4h, v0.h[5] \n" // out5 += (r00-r07)*k50
3713 "smlal2 v19.4s, v3.8h, v0.h[5] \n"
3714 "smlal v20.4s, v3.4h, v0.h[6] \n" // out6 += (r00-r07)*k60
3715 "smlal2 v21.4s, v3.8h, v0.h[6] \n"
3716 "smlal v22.4s, v3.4h, v0.h[7] \n" // out7 += (r00-r07)*k70
3717 "smlal2 v23.4s, v3.8h, v0.h[7] \n"
3718 // r1
3719 "smlal v8.4s, v4.4h, v1.h[0] \n" // out0 += (r10-r17)*k01
3720 "smlal2 v9.4s, v4.8h, v1.h[0] \n"
3721 "smlal v10.4s, v4.4h, v1.h[1] \n" // out1 += (r10-r17)*k11
3722 "smlal2 v11.4s, v4.8h, v1.h[1] \n"
3723 "smlal v12.4s, v4.4h, v1.h[2] \n" // out2 += (r10-r17)*k21
3724 "smlal2 v13.4s, v4.8h, v1.h[2] \n"
3725 "smlal v14.4s, v4.4h, v1.h[3] \n" // out3 += (r10-r17)*k31
3726 "smlal2 v15.4s, v4.8h, v1.h[3] \n"
3727 "smlal v16.4s, v4.4h, v1.h[4] \n" // out4 += (r10-r17)*k41
3728 "smlal2 v17.4s, v4.8h, v1.h[4] \n"
3729 "smlal v18.4s, v4.4h, v1.h[5] \n" // out5 += (r10-r17)*k51
3730 "smlal2 v19.4s, v4.8h, v1.h[5] \n"
3731 "smlal v20.4s, v4.4h, v1.h[6] \n" // out6 += (r10-r17)*k61
3732 "smlal2 v21.4s, v4.8h, v1.h[6] \n"
3733 "smlal v22.4s, v4.4h, v1.h[7] \n" // out7 += (r10-r17)*k71
3734 "smlal2 v23.4s, v4.8h, v1.h[7] \n"
3735 // r2
3736 "smlal v8.4s, v7.4h, v2.h[0] \n" // out0 += (r20-r27)*k02
3737 "smlal2 v9.4s, v7.8h, v2.h[0] \n"
3738 "smlal v10.4s, v7.4h, v2.h[1] \n" // out1 += (r20-r27)*k12
3739 "smlal2 v11.4s, v7.8h, v2.h[1] \n"
3740 "smlal v12.4s, v7.4h, v2.h[2] \n" // out2 += (r20-r27)*k22
3741 "smlal2 v13.4s, v7.8h, v2.h[2] \n"
3742 "smlal v14.4s, v7.4h, v2.h[3] \n" // out3 += (r20-r27)*k32
3743 "smlal2 v15.4s, v7.8h, v2.h[3] \n"
3744 "smlal v16.4s, v7.4h, v2.h[4] \n" // out4 += (r20-r27)*k42
3745 "smlal2 v17.4s, v7.8h, v2.h[4] \n"
3746 "smlal v18.4s, v7.4h, v2.h[5] \n" // out5 += (r20-r27)*k52
3747 "smlal2 v19.4s, v7.8h, v2.h[5] \n"
3748 "smlal v20.4s, v7.4h, v2.h[6] \n" // out6 += (r20-r27)*k62
3749 "smlal2 v21.4s, v7.8h, v2.h[6] \n"
3750 "smlal v22.4s, v7.4h, v2.h[7] \n" // out7 += (r20-r27)*k72
3751 "smlal2 v23.4s, v7.8h, v2.h[7] \n"
3752
3753 "ld1 {v0.8b, v1.8b, v2.8b}, [%12], #24 \n" //ktmp
3754 "ld2 {v3.8b, v4.8b}, [%10], #16 \n" //r3-r5
3755 "ld2 {v5.8b, v6.8b}, [%10] \n"
3756
3757 "ext v7.8b, v3.8b, v5.8b, #1 \n"
3758
3759 "sshll v0.8h, v0.8b, #0 \n" //(k03-k73)
3760 "sshll v1.8h, v1.8b, #0 \n" //(k04-k74)
3761 "sshll v2.8h, v2.8b, #0 \n" //(k05-k75)
3762 "sshll v3.8h, v3.8b, #0 \n" // r3
3763 "sshll v4.8h, v4.8b, #0 \n" // r4
3764 "sshll v7.8h, v7.8b, #0 \n" // r5
3765
3766 // r3
3767 "smlal v8.4s, v3.4h, v0.h[0] \n" // out0 += (r30-r37)*k03
3768 "smlal2 v9.4s, v3.8h, v0.h[0] \n"
3769 "smlal v10.4s, v3.4h, v0.h[1] \n" // out1 += (r30-r37)*k13
3770 "smlal2 v11.4s, v3.8h, v0.h[1] \n"
3771 "smlal v12.4s, v3.4h, v0.h[2] \n" // out2 += (r30-r37)*k23
3772 "smlal2 v13.4s, v3.8h, v0.h[2] \n"
3773 "smlal v14.4s, v3.4h, v0.h[3] \n" // out3 += (r30-r37)*k33
3774 "smlal2 v15.4s, v3.8h, v0.h[3] \n"
3775 "smlal v16.4s, v3.4h, v0.h[4] \n" // out4 += (r30-r37)*k43
3776 "smlal2 v17.4s, v3.8h, v0.h[4] \n"
3777 "smlal v18.4s, v3.4h, v0.h[5] \n" // out5 += (r30-r37)*k53
3778 "smlal2 v19.4s, v3.8h, v0.h[5] \n"
3779 "smlal v20.4s, v3.4h, v0.h[6] \n" // out6 += (r30-r37)*k63
3780 "smlal2 v21.4s, v3.8h, v0.h[6] \n"
3781 "smlal v22.4s, v3.4h, v0.h[7] \n" // out7 += (r30-r37)*k73
3782 "smlal2 v23.4s, v3.8h, v0.h[7] \n"
3783 // r4
3784 "smlal v8.4s, v4.4h, v1.h[0] \n" // out0 += (r40-r47)*k04
3785 "smlal2 v9.4s, v4.8h, v1.h[0] \n"
3786 "smlal v10.4s, v4.4h, v1.h[1] \n" // out1 += (r40-r47)*k14
3787 "smlal2 v11.4s, v4.8h, v1.h[1] \n"
3788 "smlal v12.4s, v4.4h, v1.h[2] \n" // out2 += (r40-r47)*k24
3789 "smlal2 v13.4s, v4.8h, v1.h[2] \n"
3790 "smlal v14.4s, v4.4h, v1.h[3] \n" // out3 += (r40-r47)*k34
3791 "smlal2 v15.4s, v4.8h, v1.h[3] \n"
3792 "smlal v16.4s, v4.4h, v1.h[4] \n" // out4 += (r40-r47)*k44
3793 "smlal2 v17.4s, v4.8h, v1.h[4] \n"
3794 "smlal v18.4s, v4.4h, v1.h[5] \n" // out5 += (r40-r47)*k54
3795 "smlal2 v19.4s, v4.8h, v1.h[5] \n"
3796 "smlal v20.4s, v4.4h, v1.h[6] \n" // out6 += (r40-r47)*k64
3797 "smlal2 v21.4s, v4.8h, v1.h[6] \n"
3798 "smlal v22.4s, v4.4h, v1.h[7] \n" // out7 += (r40-r47)*k74
3799 "smlal2 v23.4s, v4.8h, v1.h[7] \n"
3800 // r5
3801 "smlal v8.4s, v7.4h, v2.h[0] \n" // out0 += (r50-r57)*k05
3802 "smlal2 v9.4s, v7.8h, v2.h[0] \n"
3803 "smlal v10.4s, v7.4h, v2.h[1] \n" // out1 += (r50-r57)*k15
3804 "smlal2 v11.4s, v7.8h, v2.h[1] \n"
3805 "smlal v12.4s, v7.4h, v2.h[2] \n" // out2 += (r50-r57)*k25
3806 "smlal2 v13.4s, v7.8h, v2.h[2] \n"
3807 "smlal v14.4s, v7.4h, v2.h[3] \n" // out3 += (r50-r57)*k35
3808 "smlal2 v15.4s, v7.8h, v2.h[3] \n"
3809 "smlal v16.4s, v7.4h, v2.h[4] \n" // out4 += (r50-r57)*k45
3810 "smlal2 v17.4s, v7.8h, v2.h[4] \n"
3811 "smlal v18.4s, v7.4h, v2.h[5] \n" // out5 += (r50-r57)*k55
3812 "smlal2 v19.4s, v7.8h, v2.h[5] \n"
3813 "smlal v20.4s, v7.4h, v2.h[6] \n" // out6 += (r50-r57)*k65
3814 "smlal2 v21.4s, v7.8h, v2.h[6] \n"
3815 "smlal v22.4s, v7.4h, v2.h[7] \n" // out7 += (r50-r57)*k75
3816 "smlal2 v23.4s, v7.8h, v2.h[7] \n"
3817
3818 "ld1 {v0.8b, v1.8b, v2.8b}, [%12], #24 \n" //ktmp
3819 "ld2 {v3.8b, v4.8b}, [%11], #16 \n" //r6-r8
3820 "ld2 {v5.8b, v6.8b}, [%11] \n"
3821
3822 "ext v7.8b, v3.8b, v5.8b, #1 \n"
3823
3824 "sshll v0.8h, v0.8b, #0 \n" //(k06-k76)
3825 "sshll v1.8h, v1.8b, #0 \n" //(k07-k77)
3826 "sshll v2.8h, v2.8b, #0 \n" //(k08-k78)
3827 "sshll v3.8h, v3.8b, #0 \n" // r6
3828 "sshll v4.8h, v4.8b, #0 \n" // r7
3829 "sshll v7.8h, v7.8b, #0 \n" // r8
3830
3831 // r6
3832 "smlal v8.4s, v3.4h, v0.h[0] \n" // out0 += (r60-r67)*k06
3833 "smlal2 v9.4s, v3.8h, v0.h[0] \n"
3834 "smlal v10.4s, v3.4h, v0.h[1] \n" // out1 += (r60-r67)*k16
3835 "smlal2 v11.4s, v3.8h, v0.h[1] \n"
3836 "smlal v12.4s, v3.4h, v0.h[2] \n" // out2 += (r60-r67)*k26
3837 "smlal2 v13.4s, v3.8h, v0.h[2] \n"
3838 "smlal v14.4s, v3.4h, v0.h[3] \n" // out3 += (r60-r67)*k36
3839 "smlal2 v15.4s, v3.8h, v0.h[3] \n"
3840 "smlal v16.4s, v3.4h, v0.h[4] \n" // out4 += (r60-r67)*k46
3841 "smlal2 v17.4s, v3.8h, v0.h[4] \n"
3842 "smlal v18.4s, v3.4h, v0.h[5] \n" // out5 += (r60-r67)*k56
3843 "smlal2 v19.4s, v3.8h, v0.h[5] \n"
3844 "smlal v20.4s, v3.4h, v0.h[6] \n" // out6 += (r60-r67)*k66
3845 "smlal2 v21.4s, v3.8h, v0.h[6] \n"
3846 "smlal v22.4s, v3.4h, v0.h[7] \n" // out7 += (r60-r67)*k76
3847 "smlal2 v23.4s, v3.8h, v0.h[7] \n"
3848 // r7
3849 "smlal v8.4s, v4.4h, v1.h[0] \n" // out0 += (r70-r77)*k07
3850 "smlal2 v9.4s, v4.8h, v1.h[0] \n"
3851 "smlal v10.4s, v4.4h, v1.h[1] \n" // out1 += (r70-r77)*k17
3852 "smlal2 v11.4s, v4.8h, v1.h[1] \n"
3853 "smlal v12.4s, v4.4h, v1.h[2] \n" // out2 += (r70-r77)*k27
3854 "smlal2 v13.4s, v4.8h, v1.h[2] \n"
3855 "smlal v14.4s, v4.4h, v1.h[3] \n" // out3 += (r70-r77)*k37
3856 "smlal2 v15.4s, v4.8h, v1.h[3] \n"
3857 "smlal v16.4s, v4.4h, v1.h[4] \n" // out4 += (r70-r77)*k47
3858 "smlal2 v17.4s, v4.8h, v1.h[4] \n"
3859 "smlal v18.4s, v4.4h, v1.h[5] \n" // out5 += (r70-r77)*k57
3860 "smlal2 v19.4s, v4.8h, v1.h[5] \n"
3861 "smlal v20.4s, v4.4h, v1.h[6] \n" // out6 += (r70-r77)*k67
3862 "smlal2 v21.4s, v4.8h, v1.h[6] \n"
3863 "smlal v22.4s, v4.4h, v1.h[7] \n" // out7 += (r70-r77)*k77
3864 "smlal2 v23.4s, v4.8h, v1.h[7] \n"
3865 // r8
3866 "smlal v8.4s, v7.4h, v2.h[0] \n" // out0 += (r80-r87)*k08
3867 "smlal2 v9.4s, v7.8h, v2.h[0] \n"
3868 "smlal v10.4s, v7.4h, v2.h[1] \n" // out1 += (r80-r87)*k18
3869 "smlal2 v11.4s, v7.8h, v2.h[1] \n"
3870 "smlal v12.4s, v7.4h, v2.h[2] \n" // out2 += (r80-r87)*k28
3871 "smlal2 v13.4s, v7.8h, v2.h[2] \n"
3872 "smlal v14.4s, v7.4h, v2.h[3] \n" // out3 += (r80-r87)*k38
3873 "smlal2 v15.4s, v7.8h, v2.h[3] \n"
3874 "smlal v16.4s, v7.4h, v2.h[4] \n" // out4 += (r80-r87)*k48
3875 "smlal2 v17.4s, v7.8h, v2.h[4] \n"
3876 "smlal v18.4s, v7.4h, v2.h[5] \n" // out5 += (r80-r87)*k58
3877 "smlal2 v19.4s, v7.8h, v2.h[5] \n"
3878 "smlal v20.4s, v7.4h, v2.h[6] \n" // out6 += (r80-r87)*k68
3879 "smlal2 v21.4s, v7.8h, v2.h[6] \n"
3880 "smlal v22.4s, v7.4h, v2.h[7] \n" // out7 += (r80-r87)*k78
3881 "smlal2 v23.4s, v7.8h, v2.h[7] \n"
3882
3883 "st1 {v8.4s, v9.4s}, [%1], #32 \n"
3884 "st1 {v10.4s, v11.4s}, [%2], #32 \n"
3885 "st1 {v12.4s, v13.4s}, [%3], #32 \n"
3886 "st1 {v14.4s, v15.4s}, [%4], #32 \n"
3887 "st1 {v16.4s, v17.4s}, [%5], #32 \n"
3888 "st1 {v18.4s, v19.4s}, [%6], #32 \n"
3889 "st1 {v20.4s, v21.4s}, [%7], #32 \n"
3890 "st1 {v22.4s, v23.4s}, [%8], #32 \n"
3891
3892 "subs %w0, %w0, #1 \n"
3893 "sub %12, %12, #72 \n" // reset ktmp
3894
3895 "bne 0b \n"
3896
3897 : "=r"(nn), // %0
3898 "=r"(outptr0), // %1
3899 "=r"(outptr1), // %2
3900 "=r"(outptr2), // %3
3901 "=r"(outptr3), // %4
3902 "=r"(outptr4), // %5
3903 "=r"(outptr5), // %6
3904 "=r"(outptr6), // %7
3905 "=r"(outptr7), // %8
3906 "=r"(r0), // %9
3907 "=r"(r1), // %10
3908 "=r"(r2), // %11
3909 "=r"(ktmp) // %12
3910 : "0"(nn),
3911 "1"(outptr0),
3912 "2"(outptr1),
3913 "3"(outptr2),
3914 "4"(outptr3),
3915 "5"(outptr4),
3916 "6"(outptr5),
3917 "7"(outptr6),
3918 "8"(outptr7),
3919 "9"(r0),
3920 "10"(r1),
3921 "11"(r2),
3922 "12"(ktmp)
3923 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
3924 }
3925 #else // __aarch64__
3926 if (nn > 0)
3927 {
3928 asm volatile(
3929 "0: \n"
3930 "pld [%1, #128] \n"
3931 "vld1.s32 {d16-d17}, [%1] \n" // out0
3932 "pld [%2, #128] \n"
3933 "vld1.s32 {d18-d19}, [%2] \n" // out1
3934 "pld [%3, #128] \n"
3935 "vld1.s32 {d20-d21}, [%3] \n" // out2
3936 "pld [%4, #128] \n"
3937 "vld1.s32 {d22-d23}, [%4] \n" // out3
3938
3939 // r0
3940 "pld [%9, #64] \n"
3941 "vld2.s8 {d8-d9}, [%9] \n" // d8(a00 a02 a04 a06 a08 a010 a012 a014), d9(a01 a03 a05 a07 a09 a011 a013 a015)
3942 "add %9, #8 \n"
3943 "pld [%12, #64] \n"
3944 "vld1.s8 {d0-d2}, [%12]! \n" // d0(k00-k70) d1(k01-k71) d2(k02-k72)
3945
3946 "pld [%5, #128] \n"
3947 "vld1.s32 {d24-d25}, [%5] \n" // out4
3948 "pld [%6, #128] \n"
3949 "vld1.s32 {d26-d27}, [%6] \n" // out5
3950
3951 "vmovl.s8 q2, d2 \n" // q2(k02-k72)
3952 "vmovl.s8 q1, d1 \n" // q1(k01-k71)
3953 "vmovl.s8 q0, d0 \n" // q0(k00-k70)
3954 "vext.s8 d12, d8, d8, #1 \n" // d12(a02 a04 a06 a08 x x x x)
3955
3956 "pld [%7, #128] \n"
3957 "vld1.s32 {d28-d29}, [%7] \n" // out6
3958
3959 "vmovl.s8 q5, d9 \n" // q5(a01 a03 a05 a07 a09 a011 a013 a015) d11
3960 "vmovl.s8 q4, d8 \n" // q4(a00 a02 a04 a06 a08 a010 a012 a014) d9
3961 "vmovl.s8 q6, d12 \n" // q6(a02 a04 a06 a08 a010 a012 a014 a016) d13
3962
3963 "pld [%8, #128] \n"
3964 "vld1.s32 {d30-d31}, [%8] \n" // out7
3965
3966 "vmlal.s16 q8, d8, d0[0] \n" // sum0 += (a00 a02 a04 a06) * k00
3967 "vmlal.s16 q9, d8, d0[1] \n" // sum1 += (a00 a02 a04 a06) * k10
3968 "vmlal.s16 q10, d8, d0[2] \n" // sum2 += (a00 a02 a04 a06) * k20
3969 "vmlal.s16 q11, d8, d0[3] \n" // sum3 += (a00 a02 a04 a06) * k30
3970 "vmlal.s16 q12, d8, d1[0] \n" // sum4 += (a00 a02 a04 a06) * k40
3971 "vmlal.s16 q13, d8, d1[1] \n" // sum5 += (a00 a02 a04 a06) * k50
3972 "vmlal.s16 q14, d8, d1[2] \n" // sum6 += (a00 a02 a04 a06) * k60
3973 "vmlal.s16 q15, d8, d1[3] \n" // sum7 += (a00 a02 a04 a06) * k70
3974
3975 "vmlal.s16 q8, d10, d2[0] \n" // sum0 += (a01-a07) * k01
3976 "vmlal.s16 q9, d10, d2[1] \n" // sum1 += (a01-a07) * k11
3977 "vmlal.s16 q10, d10, d2[2] \n" // sum2 += (a01-a07) * k21
3978 "vmlal.s16 q11, d10, d2[3] \n" // sum3 += (a01-a07) * k31
3979 "vmlal.s16 q12, d10, d3[0] \n" // sum4 += (a01-a07) * k41
3980 "vmlal.s16 q13, d10, d3[1] \n" // sum5 += (a01-a07) * k51
3981 "vmlal.s16 q14, d10, d3[2] \n" // sum6 += (a01-a07) * k61
3982 "vmlal.s16 q15, d10, d3[3] \n" // sum7 += (a01-a07) * k71
3983
3984 "pld [%10, #64] \n"
3985 "vld2.s8 {d8-d9}, [%10] \n" // d8(a10 a12 a14 a16 a18 a110 a112 a114), d9(a11 a13 a15 a17 a19 a111 a113 a115)
3986 "add %10, #8 \n"
3987
3988 "vmlal.s16 q8, d12, d4[0] \n" // sum0 += (a02-a08) * k02
3989 "vmlal.s16 q9, d12, d4[1] \n" // sum1 += (a02-a08) * k12
3990 "vmlal.s16 q10, d12, d4[2] \n" // sum2 += (a02-a08) * k22
3991 "vmlal.s16 q11, d12, d4[3] \n" // sum3 += (a02-a08) * k32
3992
3993 "pld [%12, #64] \n"
3994 "vld1.s8 {d0-d2}, [%12]! \n" // d0(k03-k73) d1(k04-k74) d2(k05-k75)
3995
3996 "vmlal.s16 q12, d12, d5[0] \n" // sum4 += (a02-a08) * k42
3997 "vmlal.s16 q13, d12, d5[1] \n" // sum5 += (a02-a08) * k52
3998 "vmlal.s16 q14, d12, d5[2] \n" // sum6 += (a02-a08) * k62
3999 "vmlal.s16 q15, d12, d5[3] \n" // sum7 += (a02-a08) * k72
4000
4001 // r1
4002 "vext.s8 d12, d8, d8, #1 \n" // d12(a12 a14 a16 a18 x x x x)
4003
4004 "vmovl.s8 q2, d2 \n" // q2(k05-k75)
4005 "vmovl.s8 q1, d1 \n" // q1(k04-k74)
4006 "vmovl.s8 q0, d0 \n" // q0(k03-k73)
4007 "vmovl.s8 q5, d9 \n" // q5(a11-a115)
4008 "vmovl.s8 q4, d8 \n" // q4(a10-a114)
4009 "vmovl.s8 q6, d12 \n" // q6(a12-a116)
4010
4011 "vmlal.s16 q8, d8, d0[0] \n" // sum0 += (a10-a16) * k03
4012 "vmlal.s16 q9, d8, d0[1] \n" // sum1 += (a10-a16) * k13
4013 "vmlal.s16 q10, d8, d0[2] \n" // sum2 += (a10-a16) * k23
4014 "vmlal.s16 q11, d8, d0[3] \n" // sum3 += (a10-a16) * k33
4015 "vmlal.s16 q12, d8, d1[0] \n" // sum4 += (a10-a16) * k43
4016 "vmlal.s16 q13, d8, d1[1] \n" // sum5 += (a10-a16) * k53
4017 "vmlal.s16 q14, d8, d1[2] \n" // sum6 += (a10-a16) * k63
4018 "vmlal.s16 q15, d8, d1[3] \n" // sum7 += (a10-a16) * k73
4019
4020 "vmlal.s16 q8, d10, d2[0] \n" // sum0 += (a11-a17) * k04
4021 "vmlal.s16 q9, d10, d2[1] \n" // sum1 += (a11-a17) * k14
4022 "vmlal.s16 q10, d10, d2[2] \n" // sum2 += (a11-a17) * k24
4023 "vmlal.s16 q11, d10, d2[3] \n" // sum3 += (a11-a17) * k34
4024 "vmlal.s16 q12, d10, d3[0] \n" // sum4 += (a11-a17) * k44
4025 "vmlal.s16 q13, d10, d3[1] \n" // sum5 += (a11-a17) * k54
4026 "vmlal.s16 q14, d10, d3[2] \n" // sum6 += (a11-a17) * k64
4027 "vmlal.s16 q15, d10, d3[3] \n" // sum7 += (a11-a17) * k74
4028
4029 "pld [%11, #64] \n"
4030 "vld2.s8 {d8-d9}, [%11] \n" // d8(a20 a22 a24 a26 a28 a210 a212 a214), d9(a21 a23 a25 a27 a29 a211 a213 a215)
4031 "add %11, #8 \n"
4032
4033 "vmlal.s16 q8, d12, d4[0] \n" // sum0 += (a12-a18) * k05
4034 "vmlal.s16 q9, d12, d4[1] \n" // sum1 += (a12-a18) * k15
4035 "vmlal.s16 q10, d12, d4[2] \n" // sum2 += (a12-a18) * k25
4036 "vmlal.s16 q11, d12, d4[3] \n" // sum3 += (a12-a18) * k35
4037
4038 "pld [%12, #64] \n"
4039 "vld1.s8 {d0-d2}, [%12]! \n" // d0(k06-k76) d1(k07-k77) d2(k08-k78)
4040
4041 "vmlal.s16 q12, d12, d5[0] \n" // sum4 += (a12-a18) * k45
4042 "vmlal.s16 q13, d12, d5[1] \n" // sum5 += (a12-a18) * k55
4043 "vmlal.s16 q14, d12, d5[2] \n" // sum6 += (a12-a18) * k65
4044 "vmlal.s16 q15, d12, d5[3] \n" // sum7 += (a12-a18) * k75
4045
4046 // r2
4047 "vext.s8 d12, d8, d8, #1 \n" // d12(a22 a24 a26 a28 x x x x)
4048
4049 "vmovl.s8 q2, d2 \n" // q2(k08-k78)
4050 "vmovl.s8 q1, d1 \n" // q1(k07-k77)
4051 "vmovl.s8 q0, d0 \n" // q0(k06-k76)
4052 "vmovl.s8 q5, d9 \n" // q5(a21-a215)
4053 "vmovl.s8 q4, d8 \n" // q4(a20-a214)
4054 "vmovl.s8 q6, d12 \n" // q6(a22-a216)
4055
4056 "vmlal.s16 q8, d8, d0[0] \n" // sum0 += (a20-a26) * k06
4057 "vmlal.s16 q9, d8, d0[1] \n" // sum1 += (a20-a26) * k16
4058 "vmlal.s16 q10, d8, d0[2] \n" // sum2 += (a20-a26) * k26
4059 "vmlal.s16 q11, d8, d0[3] \n" // sum3 += (a20-a26) * k36
4060 "vmlal.s16 q12, d8, d1[0] \n" // sum4 += (a20-a26) * k46
4061 "vmlal.s16 q13, d8, d1[1] \n" // sum5 += (a20-a26) * k56
4062 "vmlal.s16 q14, d8, d1[2] \n" // sum6 += (a20-a26) * k66
4063 "vmlal.s16 q15, d8, d1[3] \n" // sum7 += (a20-a26) * k76
4064
4065 "vmlal.s16 q8, d10, d2[0] \n" // sum0 += (a21-a27) * k07
4066 "vmlal.s16 q9, d10, d2[1] \n" // sum1 += (a21-a27) * k17
4067 "vmlal.s16 q10, d10, d2[2] \n" // sum2 += (a21-a27) * k27
4068 "vmlal.s16 q11, d10, d2[3] \n" // sum3 += (a21-a27) * k37
4069 "vmlal.s16 q12, d10, d3[0] \n" // sum4 += (a21-a27) * k47
4070 "vmlal.s16 q13, d10, d3[1] \n" // sum5 += (a21-a27) * k57
4071 "vmlal.s16 q14, d10, d3[2] \n" // sum6 += (a21-a27) * k67
4072 "vmlal.s16 q15, d10, d3[3] \n" // sum7 += (a21-a27) * k77
4073
4074 "vmlal.s16 q8, d12, d4[0] \n" // sum0 += (a22-a28) * k08
4075 "vmlal.s16 q9, d12, d4[1] \n" // sum1 += (a22-a28) * k18
4076 "vmlal.s16 q10, d12, d4[2] \n" // sum2 += (a22-a28) * k28
4077 "vmlal.s16 q11, d12, d4[3] \n" // sum3 += (a22-a28) * k38
4078 "vmlal.s16 q12, d12, d5[0] \n" // sum4 += (a22-a28) * k48
4079 "vmlal.s16 q13, d12, d5[1] \n" // sum5 += (a22-a28) * k58
4080 "vmlal.s16 q14, d12, d5[2] \n" // sum6 += (a22-a28) * k68
4081 "vmlal.s16 q15, d12, d5[3] \n" // sum7 += (a22-a28) * k78
4082
4083 // save s32 to memory
4084 "sub %12, %12, #72 \n"
4085 "vst1.s32 {d16-d17}, [%1]! \n" // out0
4086 "vst1.s32 {d18-d19}, [%2]! \n" // out1
4087 "vst1.s32 {d20-d21}, [%3]! \n" // out2
4088 "vst1.s32 {d22-d23}, [%4]! \n" // out3
4089 "subs %0, #1 \n"
4090 "vst1.s32 {d24-d25}, [%5]! \n" // out4
4091 "vst1.s32 {d26-d27}, [%6]! \n" // out5
4092 "vst1.s32 {d28-d29}, [%7]! \n" // out6
4093 "vst1.s32 {d30-d31}, [%8]! \n" // out7
4094
4095 "bne 0b \n"
4096 : "=r"(nn), // %0
4097 "=r"(outptr0), // %1
4098 "=r"(outptr1), // %2
4099 "=r"(outptr2), // %3
4100 "=r"(outptr3), // %4
4101 "=r"(outptr4), // %5
4102 "=r"(outptr5), // %6
4103 "=r"(outptr6), // %7
4104 "=r"(outptr7), // %8
4105 "=r"(r0), // %9
4106 "=r"(r1), // %10
4107 "=r"(r2), // %11
4108 "=r"(ktmp) // %12
4109 : "0"(nn),
4110 "1"(outptr0),
4111 "2"(outptr1),
4112 "3"(outptr2),
4113 "4"(outptr3),
4114 "5"(outptr4),
4115 "6"(outptr5),
4116 "7"(outptr6),
4117 "8"(outptr7),
4118 "9"(r0),
4119 "10"(r1),
4120 "11"(r2),
4121 "12"(ktmp)
4122 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
4123 }
4124 #endif // __aarch64__
4125 #endif // __ARM_NEON
4126 for (; remain > 0; remain--)
4127 {
4128 #if __ARM_NEON
4129 #if __aarch64__
4130 int8x8_t _r0_s8 = vld1_s8(r0); // (a00 a01 a02 ....)
4131 int8x8_t _r1_s8 = vld1_s8(r1); // (a10 a11 a12 ....)
4132 int8x8_t _r2_s8 = vld1_s8(r2); // (a20 a21 a22 ....)
4133
4134 int16x8_t _r0 = vmovl_s8(_r0_s8);
4135 int16x8_t _r1 = vmovl_s8(_r1_s8);
4136 int16x8_t _r2 = vmovl_s8(_r2_s8);
4137
4138 int32x4_t _sum03 = {};
4139 int32x4_t _sum47 = {};
4140
4141 _sum03 = vld1q_lane_s32(outptr0, _sum03, 0); // out0
4142 _sum03 = vld1q_lane_s32(outptr1, _sum03, 1); // out1
4143 _sum03 = vld1q_lane_s32(outptr2, _sum03, 2); // out2
4144 _sum03 = vld1q_lane_s32(outptr3, _sum03, 3); // out3
4145 _sum47 = vld1q_lane_s32(outptr4, _sum47, 0); // out4
4146 _sum47 = vld1q_lane_s32(outptr5, _sum47, 1); // out5
4147 _sum47 = vld1q_lane_s32(outptr6, _sum47, 2); // out6
4148 _sum47 = vld1q_lane_s32(outptr7, _sum47, 3); // out7
4149
4150 // k0 - k2
4151 int8x8_t _k0_8 = vld1_s8(ktmp); //(k00-k70)
4152 int8x8_t _k1_8 = vld1_s8(ktmp + 8); //(k01-k71)
4153 int8x8_t _k2_8 = vld1_s8(ktmp + 16); //(k02-k72)
4154
4155 int16x8_t _k0 = vmovl_s8(_k0_8);
4156 int16x8_t _k1 = vmovl_s8(_k1_8);
4157 int16x8_t _k2 = vmovl_s8(_k2_8);
4158
4159 int32x4_t _sum0 = vmull_laneq_s16(vget_low_s16(_k0), _r0, 0);
4160 int32x4_t _sum0n = vmull_laneq_s16(vget_high_s16(_k0), _r0, 0);
4161 int32x4_t _sum1 = vmull_laneq_s16(vget_low_s16(_k1), _r0, 1);
4162 int32x4_t _sum1n = vmull_laneq_s16(vget_high_s16(_k1), _r0, 1);
4163 _sum03 = vmlal_laneq_s16(_sum03, vget_low_s16(_k2), _r0, 2);
4164 _sum47 = vmlal_laneq_s16(_sum47, vget_high_s16(_k2), _r0, 2);
4165
4166 // k3 - k5
4167 _k0_8 = vld1_s8(ktmp + 24); //(k03-k73)
4168 _k1_8 = vld1_s8(ktmp + 32); //(k04-k74)
4169 _k2_8 = vld1_s8(ktmp + 40); //(k05-k75)
4170
4171 _k0 = vmovl_s8(_k0_8);
4172 _k1 = vmovl_s8(_k1_8);
4173 _k2 = vmovl_s8(_k2_8);
4174
4175 _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_k0), _r1, 0);
4176 _sum0n = vmlal_laneq_s16(_sum0n, vget_high_s16(_k0), _r1, 0);
4177 _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_k1), _r1, 1);
4178 _sum1n = vmlal_laneq_s16(_sum1n, vget_high_s16(_k1), _r1, 1);
4179 _sum03 = vmlal_laneq_s16(_sum03, vget_low_s16(_k2), _r1, 2);
4180 _sum47 = vmlal_laneq_s16(_sum47, vget_high_s16(_k2), _r1, 2);
4181
4182 // k6 - k8
4183 _k0_8 = vld1_s8(ktmp + 48); //(k06-k76)
4184 _k1_8 = vld1_s8(ktmp + 56); //(k07-k77)
4185 _k2_8 = vld1_s8(ktmp + 64); //(k08-k78)
4186
4187 _k0 = vmovl_s8(_k0_8);
4188 _k1 = vmovl_s8(_k1_8);
4189 _k2 = vmovl_s8(_k2_8);
4190
4191 _sum0 = vmlal_laneq_s16(_sum0, vget_low_s16(_k0), _r2, 0);
4192 _sum0n = vmlal_laneq_s16(_sum0n, vget_high_s16(_k0), _r2, 0);
4193 _sum1 = vmlal_laneq_s16(_sum1, vget_low_s16(_k1), _r2, 1);
4194 _sum1n = vmlal_laneq_s16(_sum1n, vget_high_s16(_k1), _r2, 1);
4195 _sum03 = vmlal_laneq_s16(_sum03, vget_low_s16(_k2), _r2, 2);
4196 _sum47 = vmlal_laneq_s16(_sum47, vget_high_s16(_k2), _r2, 2);
4197
4198 _sum0 = vaddq_s32(_sum0, _sum1);
4199 _sum0n = vaddq_s32(_sum0n, _sum1n);
4200 _sum03 = vaddq_s32(_sum03, _sum0);
4201 _sum47 = vaddq_s32(_sum47, _sum0n);
4202
4203 vst1q_lane_s32(outptr0, _sum03, 0);
4204 vst1q_lane_s32(outptr1, _sum03, 1);
4205 vst1q_lane_s32(outptr2, _sum03, 2);
4206 vst1q_lane_s32(outptr3, _sum03, 3);
4207 vst1q_lane_s32(outptr4, _sum47, 0);
4208 vst1q_lane_s32(outptr5, _sum47, 1);
4209 vst1q_lane_s32(outptr6, _sum47, 2);
4210 vst1q_lane_s32(outptr7, _sum47, 3);
4211
4212 outptr0++;
4213 outptr1++;
4214 outptr2++;
4215 outptr3++;
4216 outptr4++;
4217 outptr5++;
4218 outptr6++;
4219 outptr7++;
4220 #else // __aarch64__
4221 asm volatile(
4222 "pld [%8, #64] \n"
4223 "vld1.s8 {d0}, [%8] \n" // d0(a00 a01 a02 ....)
4224 "pld [%9, #64] \n"
4225 "vld1.s8 {d2}, [%9] \n" // d2(a10 a11 a12 ....)
4226 "pld [%10, #64] \n"
4227 "vld1.s8 {d4}, [%10] \n" // d4(a20 a21 a22 ....)
4228
4229 "pld [%11, #64] \n"
4230 "vld1.s8 {d6-d8}, [%11]! \n" // d6(k00-k70) d7(k01-k71) d8(k02-k72)
4231
4232 "vmovl.s8 q0, d0 \n" // d0(a00 a01 a02 x)
4233 "vmovl.s8 q1, d2 \n" // d2(a10 a11 a12 x)
4234 "vmovl.s8 q2, d4 \n" // d4(a20 a21 a22 x)
4235
4236 "vmovl.s8 q5, d8 \n" // d10(k02-k32) d11(k42-k72)
4237 "vmovl.s8 q4, d7 \n" // d8(k01-k31) d9(k41-k71)
4238 "vmovl.s8 q3, d6 \n" // d6(k00-k30) d7(k40-k70)
4239
4240 "vld1.s32 {d20[0]}, [%0] \n" // out0 q10
4241 "vld1.s32 {d20[1]}, [%1] \n" // out1
4242 "vld1.s32 {d21[0]}, [%2] \n" // out2
4243 "vld1.s32 {d21[1]}, [%3] \n" // out3
4244
4245 "pld [%11, #64] \n"
4246 "vld1.s8 {d24-d26}, [%11]! \n"
4247 "vmovl.s8 q14, d26 \n" // d28(k05-k35) d29(k45-k75)
4248 "vmovl.s8 q13, d25 \n" // d26(k04-k34) d27(k44-k74)
4249 "vmovl.s8 q12, d24 \n" // d24(k03-k33) d25(k43-k73)
4250
4251 "vld1.s32 {d22[0]}, [%4] \n" // out4 q11
4252 "vld1.s32 {d22[1]}, [%5] \n" // out5
4253 "vld1.s32 {d23[0]}, [%6] \n" // out6
4254 "vld1.s32 {d23[1]}, [%7] \n" // out7
4255
4256 "vmull.s16 q6, d6, d0[0] \n" // a00 x (k00-k30)
4257 "vmull.s16 q7, d7, d0[0] \n" // a00 x (k40-k70)
4258 "vmull.s16 q8, d8, d0[1] \n" // a01 x (k01-k31)
4259 "vmull.s16 q9, d9, d0[1] \n" // a01 x (k41-k71)
4260 "vmlal.s16 q10, d10, d0[2] \n" // a02 x (k02-k32)
4261 "vmlal.s16 q11, d11, d0[2] \n" // a02 x (k42-k72)
4262
4263 "pld [%11, #64] \n"
4264 "vld1.s8 {d6-d8}, [%11]! \n"
4265 "vmovl.s8 q5, d8 \n" // d10(k08-k38) d11(k48-k78)
4266 "vmovl.s8 q4, d7 \n" // d8(k07-k37) d9(k47-k77)
4267 "vmovl.s8 q3, d6 \n" // d6(k06-k36) d7(k46-k76)
4268
4269 "vmlal.s16 q6, d24, d2[0] \n" // a10 x (k03-k33)
4270 "vmlal.s16 q7, d25, d2[0] \n" // a10 x (k43-k73)
4271 "vmlal.s16 q8, d26, d2[1] \n" // a11 x (k04-k34)
4272 "vmlal.s16 q9, d27, d2[1] \n" // a11 x (k44-k74)
4273 "vmlal.s16 q10, d28, d2[2] \n" // a12 x (k05-k35)
4274 "vmlal.s16 q11, d29, d2[2] \n" // a12 x (k45-k75)
4275
4276 "vmlal.s16 q6, d6, d4[0] \n" // a20 x (k06-k36)
4277 "vmlal.s16 q7, d7, d4[0] \n" // a20 x (k46-k76)
4278 "vmlal.s16 q8, d8, d4[1] \n" // a21 x (k07-k37)
4279 "vmlal.s16 q9, d9, d4[1] \n" // a21 x (k47-k77)
4280 "vmlal.s16 q10, d10, d4[2] \n" // a22 x (k08-k38)
4281 "vmlal.s16 q11, d11, d4[2] \n" // a22 x (k48-k78)
4282
4283 "vadd.s32 q8, q8, q6 \n"
4284 "vadd.s32 q9, q9, q7 \n"
4285
4286 "sub %11, %11, #72 \n"
4287
4288 "vadd.s32 q10, q10, q8 \n"
4289 "vadd.s32 q11, q11, q9 \n"
4290
4291 "vst1.s32 {d20[0]}, [%0]! \n" // out0
4292 "vst1.s32 {d20[1]}, [%1]! \n" // out1
4293 "vst1.s32 {d21[0]}, [%2]! \n" // out2
4294 "vst1.s32 {d21[1]}, [%3]! \n" // out3
4295 "vst1.s32 {d22[0]}, [%4]! \n" // out4
4296 "vst1.s32 {d22[1]}, [%5]! \n" // out5
4297 "vst1.s32 {d23[0]}, [%6]! \n" // out6
4298 "vst1.s32 {d23[1]}, [%7]! \n" // out7
4299
4300 : "=r"(outptr0), // %0
4301 "=r"(outptr1), // %1
4302 "=r"(outptr2), // %2
4303 "=r"(outptr3), // %3
4304 "=r"(outptr4), // %4
4305 "=r"(outptr5), // %5
4306 "=r"(outptr6), // %6
4307 "=r"(outptr7), // %7
4308 "=r"(r0), // %8
4309 "=r"(r1), // %9
4310 "=r"(r2), // %10
4311 "=r"(ktmp) // %11
4312 : "0"(outptr0),
4313 "1"(outptr1),
4314 "2"(outptr2),
4315 "3"(outptr3),
4316 "4"(outptr4),
4317 "5"(outptr5),
4318 "6"(outptr6),
4319 "7"(outptr7),
4320 "8"(r0),
4321 "9"(r1),
4322 "10"(r2),
4323 "11"(ktmp)
4324 : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
4325 #endif // __aarch64__
4326 #else // __ARM_NEON
4327 int sum0 = 0;
4328 int sum1 = 0;
4329 int sum2 = 0;
4330 int sum3 = 0;
4331 int sum4 = 0;
4332 int sum5 = 0;
4333 int sum6 = 0;
4334 int sum7 = 0;
4335
4336 sum0 += (int)r0[0] * ktmp[0];
4337 sum1 += (int)r0[0] * ktmp[1];
4338 sum2 += (int)r0[0] * ktmp[2];
4339 sum3 += (int)r0[0] * ktmp[3];
4340 sum4 += (int)r0[0] * ktmp[4];
4341 sum5 += (int)r0[0] * ktmp[5];
4342 sum6 += (int)r0[0] * ktmp[6];
4343 sum7 += (int)r0[0] * ktmp[7];
4344 ktmp += 8;
4345
4346 sum0 += (int)r0[1] * ktmp[0];
4347 sum1 += (int)r0[1] * ktmp[1];
4348 sum2 += (int)r0[1] * ktmp[2];
4349 sum3 += (int)r0[1] * ktmp[3];
4350 sum4 += (int)r0[1] * ktmp[4];
4351 sum5 += (int)r0[1] * ktmp[5];
4352 sum6 += (int)r0[1] * ktmp[6];
4353 sum7 += (int)r0[1] * ktmp[7];
4354 ktmp += 8;
4355
4356 sum0 += (int)r0[2] * ktmp[0];
4357 sum1 += (int)r0[2] * ktmp[1];
4358 sum2 += (int)r0[2] * ktmp[2];
4359 sum3 += (int)r0[2] * ktmp[3];
4360 sum4 += (int)r0[2] * ktmp[4];
4361 sum5 += (int)r0[2] * ktmp[5];
4362 sum6 += (int)r0[2] * ktmp[6];
4363 sum7 += (int)r0[2] * ktmp[7];
4364 ktmp += 8;
4365
4366 sum0 += (int)r1[0] * ktmp[0];
4367 sum1 += (int)r1[0] * ktmp[1];
4368 sum2 += (int)r1[0] * ktmp[2];
4369 sum3 += (int)r1[0] * ktmp[3];
4370 sum4 += (int)r1[0] * ktmp[4];
4371 sum5 += (int)r1[0] * ktmp[5];
4372 sum6 += (int)r1[0] * ktmp[6];
4373 sum7 += (int)r1[0] * ktmp[7];
4374 ktmp += 8;
4375
4376 sum0 += (int)r1[1] * ktmp[0];
4377 sum1 += (int)r1[1] * ktmp[1];
4378 sum2 += (int)r1[1] * ktmp[2];
4379 sum3 += (int)r1[1] * ktmp[3];
4380 sum4 += (int)r1[1] * ktmp[4];
4381 sum5 += (int)r1[1] * ktmp[5];
4382 sum6 += (int)r1[1] * ktmp[6];
4383 sum7 += (int)r1[1] * ktmp[7];
4384 ktmp += 8;
4385
4386 sum0 += (int)r1[2] * ktmp[0];
4387 sum1 += (int)r1[2] * ktmp[1];
4388 sum2 += (int)r1[2] * ktmp[2];
4389 sum3 += (int)r1[2] * ktmp[3];
4390 sum4 += (int)r1[2] * ktmp[4];
4391 sum5 += (int)r1[2] * ktmp[5];
4392 sum6 += (int)r1[2] * ktmp[6];
4393 sum7 += (int)r1[2] * ktmp[7];
4394 ktmp += 8;
4395
4396 sum0 += (int)r2[0] * ktmp[0];
4397 sum1 += (int)r2[0] * ktmp[1];
4398 sum2 += (int)r2[0] * ktmp[2];
4399 sum3 += (int)r2[0] * ktmp[3];
4400 sum4 += (int)r2[0] * ktmp[4];
4401 sum5 += (int)r2[0] * ktmp[5];
4402 sum6 += (int)r2[0] * ktmp[6];
4403 sum7 += (int)r2[0] * ktmp[7];
4404 ktmp += 8;
4405
4406 sum0 += (int)r2[1] * ktmp[0];
4407 sum1 += (int)r2[1] * ktmp[1];
4408 sum2 += (int)r2[1] * ktmp[2];
4409 sum3 += (int)r2[1] * ktmp[3];
4410 sum4 += (int)r2[1] * ktmp[4];
4411 sum5 += (int)r2[1] * ktmp[5];
4412 sum6 += (int)r2[1] * ktmp[6];
4413 sum7 += (int)r2[1] * ktmp[7];
4414 ktmp += 8;
4415
4416 sum0 += (int)r2[2] * ktmp[0];
4417 sum1 += (int)r2[2] * ktmp[1];
4418 sum2 += (int)r2[2] * ktmp[2];
4419 sum3 += (int)r2[2] * ktmp[3];
4420 sum4 += (int)r2[2] * ktmp[4];
4421 sum5 += (int)r2[2] * ktmp[5];
4422 sum6 += (int)r2[2] * ktmp[6];
4423 sum7 += (int)r2[2] * ktmp[7];
4424 ktmp += 8;
4425
4426 *outptr0 += sum0;
4427 *outptr1 += sum1;
4428 *outptr2 += sum2;
4429 *outptr3 += sum3;
4430 *outptr4 += sum4;
4431 *outptr5 += sum5;
4432 *outptr6 += sum6;
4433 *outptr7 += sum7;
4434
4435 ktmp -= 8 * 9;
4436
4437 outptr0++;
4438 outptr1++;
4439 outptr2++;
4440 outptr3++;
4441 outptr4++;
4442 outptr5++;
4443 outptr6++;
4444 outptr7++;
4445 #endif // __ARM_NEON
4446 r0 += 2;
4447 r1 += 2;
4448 r2 += 2;
4449 }
4450
4451 r0 += tailstep;
4452 r1 += tailstep;
4453 r2 += tailstep;
4454 }
4455
4456 ktmp += 8 * 9;
4457 }
4458 }
4459
4460 #pragma omp parallel for num_threads(opt.num_threads)
4461 for (int p = remain_outch_start; p < outch; p++)
4462 {
4463 Mat out = top_blob.channel(p);
4464
4465 out.fill(0);
4466
4467 const signed char* ktmp = _kernel.channel(p / 8 + p % 8);
4468
4469 for (int q = 0; q < inch; q++)
4470 {
4471 int* outptr = out;
4472
4473 const signed char* img0 = bottom_blob.channel(q);
4474
4475 const signed char* r0 = img0;
4476 const signed char* r1 = img0 + w;
4477 const signed char* r2 = img0 + w * 2;
4478
4479 int i = 0;
4480
4481 for (; i < outh; i++)
4482 {
4483 #if __ARM_NEON
4484 int nn = outw >> 3;
4485 int remain = outw & 7;
4486 #else
4487 int remain = outw;
4488 #endif // __ARM_NEON
4489
4490 #if __ARM_NEON
4491 #if __aarch64__
4492 if (nn > 0)
4493 {
4494 asm volatile(
4495 "0: \n"
4496
4497 "ld1 {v0.8b, v1.8b}, [%5] \n" //ktmp
4498 "ld2 {v2.8b, v3.8b}, [%2], #16 \n" //r0-r2
4499 "ld2 {v4.8b, v5.8b}, [%2] \n"
4500
4501 "ld2 {v6.8b, v7.8b}, [%3], #16 \n" //r3-r5
4502 "ld2 {v8.8b, v9.8b}, [%3] \n"
4503
4504 "ld2 {v10.8b, v11.8b}, [%4], #16 \n" //r6-r8
4505 "ld2 {v12.8b, v13.8b}, [%4] \n"
4506
4507 "ld1 {v14.4s, v15.4s}, [%1] \n" //out0
4508
4509 "ext v4.8b, v2.8b, v4.8b, #1 \n"
4510 "ext v8.8b, v6.8b, v8.8b, #1 \n"
4511 "ext v12.8b, v10.8b, v12.8b, #1 \n"
4512
4513 "sshll v0.8h, v0.8b, #0 \n" //(k0-k7)
4514 "sshll v1.8h, v1.8b, #0 \n" //(k8)
4515 "sshll v2.8h, v2.8b, #0 \n" // r0
4516 "sshll v3.8h, v3.8b, #0 \n" // r1
4517 "sshll v4.8h, v4.8b, #0 \n" // r2
4518 "sshll v6.8h, v6.8b, #0 \n" // r3
4519 "sshll v7.8h, v7.8b, #0 \n" // r4
4520 "sshll v8.8h, v8.8b, #0 \n" // r5
4521 "sshll v10.8h, v10.8b, #0 \n" // r6
4522 "sshll v11.8h, v11.8b, #0 \n" // r7
4523 "sshll v12.8h, v12.8b, #0 \n" // r8
4524
4525 // r0
4526 "smull v16.4s, v2.4h, v0.h[0] \n" // out = r0*k0
4527 "smull2 v17.4s, v2.8h, v0.h[0] \n"
4528 "smull v18.4s, v3.4h, v0.h[1] \n" // outn = r1*k1
4529 "smull2 v19.4s, v3.8h, v0.h[1] \n"
4530 "smlal v16.4s, v4.4h, v0.h[2] \n" // out = r2*k2
4531 "smlal2 v17.4s, v4.8h, v0.h[2] \n"
4532 "smlal v18.4s, v6.4h, v0.h[3] \n" // outn = r3*k3
4533 "smlal2 v19.4s, v6.8h, v0.h[3] \n"
4534 "smlal v16.4s, v7.4h, v0.h[4] \n" // out = r4*k4
4535 "smlal2 v17.4s, v7.8h, v0.h[4] \n"
4536 "smlal v18.4s, v8.4h, v0.h[5] \n" // outn = r5*k5
4537 "smlal2 v19.4s, v8.8h, v0.h[5] \n"
4538 "smlal v16.4s, v10.4h, v0.h[6] \n" // out = r6*k6
4539 "smlal2 v17.4s, v10.8h, v0.h[6] \n"
4540 "smlal v18.4s, v11.4h, v0.h[7] \n" // outn = r7*k7
4541 "smlal2 v19.4s, v11.8h, v0.h[7] \n"
4542 "smlal v16.4s, v12.4h, v1.h[0] \n" // out = r8*k8
4543 "smlal2 v17.4s, v12.8h, v1.h[0] \n"
4544
4545 "add v8.4s, v16.4s, v18.4s \n"
4546 "add v9.4s, v17.4s, v19.4s \n"
4547
4548 "st1 {v8.4s, v9.4s}, [%1], #32 \n"
4549
4550 "subs %w0, %w0, #1 \n"
4551
4552 "bne 0b \n"
4553
4554 : "=r"(nn), // %0
4555 "=r"(outptr), // %1
4556 "=r"(r0), // %2
4557 "=r"(r1), // %3
4558 "=r"(r2), // %4
4559 "=r"(ktmp) // %5
4560 : "0"(nn),
4561 "1"(outptr),
4562 "2"(r0),
4563 "3"(r1),
4564 "4"(r2),
4565 "5"(ktmp)
4566 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
4567 }
4568 #else
4569 if (nn > 0)
4570 {
4571 asm volatile(
4572 "vld1.s8 {d0-d1}, [%5] \n" // d0(k0 - k7) d1(k8 ...)
4573 "vmovl.s8 q1, d1 \n" // d2(k8 ...)
4574 "vmovl.s8 q0, d0 \n" // d0(k0 - k3) d1(k4 - k7)
4575 "0: \n"
4576 "pld [%2, #192] \n"
4577 "vld2.s8 {d4-d5}, [%2]! \n" // r0 d4(a00 a02 ... a014) d5(a01 a03 ... a015)
4578 "vld2.s8 {d8-d9}, [%2] \n" // d8(a016 ....)
4579 "vld2.s8 {d10-d11}, [%3]! \n" // r1 d10(a10 a12 ... a114) d11(a11 a13 ... a115)
4580 "vld2.s8 {d14-d15}, [%3] \n" // d14(a116 ....)
4581 "vld2.s8 {d16-d17}, [%4]! \n" // r2 d16(a20 a22 ... a214) d17(a21 a23 ... a215)
4582 "vld2.s8 {d20-d21}, [%4] \n" // d20(a216 ....)
4583 "vld1.s32 {d22-d25}, [%1] \n" // q11(out0 - out3) q12(out4 - out7)
4584
4585 "vext.s8 d8, d4, d8, #1 \n" // d8(a02 a04 ... a016)
4586 "vext.s8 d14, d10, d14, #1 \n" // d14(a12 a14 ... a116)
4587 "vext.s8 d20, d16, d20, #1 \n" // d20(a22 a24 ... a216)
4588
4589 "vmovl.s8 q3, d5 \n" // q3(a01 a03 ... a015)
4590 "vmovl.s8 q2, d4 \n" // q2(a00 a02 ... a014)
4591 "vmovl.s8 q4, d8 \n" // q4(a02 a04 ... a016)
4592
4593 "vmovl.s8 q6, d11 \n" // q6(a11 a13 ... a115)
4594 "vmovl.s8 q5, d10 \n" // q5(a10 a12 ... a114)
4595 "vmovl.s8 q7, d14 \n" // q7(a12 a14 ... a116)
4596
4597 "vmovl.s8 q9, d17 \n" // q9(a21 a23 ... a215)
4598 "vmovl.s8 q8, d16 \n" // q8(a20 a22 ... a214)
4599 "vmovl.s8 q10, d20 \n" // q10(a22 a24 ... a216)
4600
4601 "vmlal.s16 q11, d4, d0[0] \n" // k0
4602 "vmlal.s16 q12, d5, d0[0] \n"
4603 "vmull.s16 q13, d6, d0[1] \n" // k1
4604 "vmull.s16 q14, d7, d0[1] \n"
4605 "vmlal.s16 q11, d8, d0[2] \n" // k2
4606 "vmlal.s16 q12, d9, d0[2] \n"
4607
4608 "vmlal.s16 q13, d12, d1[0] \n" // k4
4609 "vmlal.s16 q14, d13, d1[0] \n"
4610 "vmlal.s16 q11, d10, d0[3] \n" // k3
4611 "vmlal.s16 q12, d11, d0[3] \n"
4612 "vmlal.s16 q13, d14, d1[1] \n" // k5
4613 "vmlal.s16 q14, d15, d1[1] \n"
4614
4615 "vmlal.s16 q11, d16, d1[2] \n" // k6
4616 "vmlal.s16 q12, d17, d1[2] \n"
4617 "vmlal.s16 q13, d18, d1[3] \n" // k7
4618 "vmlal.s16 q14, d19, d1[3] \n"
4619 "vmlal.s16 q11, d20, d2[0] \n" // k8
4620 "vmlal.s16 q12, d21, d2[0] \n"
4621
4622 "vadd.s32 q11, q11, q13 \n"
4623 "vadd.s32 q12, q12, q14 \n"
4624
4625 "vst1.32 {d22-d25}, [%1]! \n"
4626
4627 "subs %0, #1 \n"
4628 "bne 0b \n"
4629 : "=r"(nn), // %0
4630 "=r"(outptr), // %1
4631 "=r"(r0), // %2
4632 "=r"(r1), // %3
4633 "=r"(r2), // %4
4634 "=r"(ktmp) // %5
4635 : "0"(nn),
4636 "1"(outptr),
4637 "2"(r0),
4638 "3"(r1),
4639 "4"(r2),
4640 "5"(ktmp)
4641 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
4642 }
4643 #endif // __aarch64__
4644 #endif // __ARM_NEON
4645 if (remain > 0)
4646 {
4647 #if __ARM_NEON
4648 int8x8_t _k01234567s8 = vld1_s8(ktmp);
4649 int8x8_t _k8xxxxxxxs8 = vld1_s8(ktmp + 8);
4650 int8x8_t _k34567xxxs8 = vext_s8(_k01234567s8, _k01234567s8, 3);
4651 int8x8_t _k678xxxxxs8 = vext_s8(_k01234567s8, _k8xxxxxxxs8, 6);
4652 int16x8_t _k0123_s16 = vmovl_s8(_k01234567s8);
4653 int16x8_t _k3456_s16 = vmovl_s8(_k34567xxxs8);
4654 int16x8_t _k678x_s16 = vmovl_s8(_k678xxxxxs8);
4655 #endif
4656 for (; remain > 0; remain--)
4657 {
4658 #if __ARM_NEON
4659 int8x8_t _r00s8 = vld1_s8(r0);
4660 int8x8_t _r10s8 = vld1_s8(r1);
4661 int8x8_t _r20s8 = vld1_s8(r2);
4662
4663 int16x8_t _r00s16 = vmovl_s8(_r00s8);
4664 int16x8_t _r10s16 = vmovl_s8(_r10s8);
4665 int16x8_t _r20s16 = vmovl_s8(_r20s8);
4666
4667 int32x4_t _sum = vmull_s16(vget_low_s16(_r00s16), vget_low_s16(_k0123_s16));
4668 _sum = vmlal_s16(_sum, vget_low_s16(_r10s16), vget_low_s16(_k3456_s16));
4669 _sum = vmlal_s16(_sum, vget_low_s16(_r20s16), vget_low_s16(_k678x_s16));
4670
4671 _sum = vsetq_lane_s32(*outptr, _sum, 3);
4672
4673 #if __aarch64__
4674 *outptr = vaddvq_s32(_sum);
4675 #else
4676 int32x2_t _ss = vadd_s32(vget_low_s32(_sum), vget_high_s32(_sum));
4677 _ss = vpadd_s32(_ss, _ss);
4678
4679 *outptr = vget_lane_s32(_ss, 0);
4680 #endif // __aarch64__
4681 #else
4682 int sum = 0;
4683
4684 sum += (int)r0[0] * ktmp[0];
4685 sum += (int)r0[1] * ktmp[1];
4686 sum += (int)r0[2] * ktmp[2];
4687 sum += (int)r1[0] * ktmp[3];
4688 sum += (int)r1[1] * ktmp[4];
4689 sum += (int)r1[2] * ktmp[5];
4690 sum += (int)r2[0] * ktmp[6];
4691 sum += (int)r2[1] * ktmp[7];
4692 sum += (int)r2[2] * ktmp[8];
4693
4694 *outptr += sum;
4695 #endif // __ARM_NEON
4696 r0 += 2;
4697 r1 += 2;
4698 r2 += 2;
4699 outptr++;
4700 }
4701 }
4702
4703 r0 += tailstep;
4704 r1 += tailstep;
4705 r2 += tailstep;
4706 }
4707
4708 ktmp += 9;
4709 }
4710 }
4711 }
4712