1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #if __aarch64__
16
17 #if 1
18 #include "gemm_symm_int8.h"
conv_im2col_sgemm_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)19 static void conv_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
20 {
21 const int m = outch;
22 const int k = inch * kernel_size;
23 kernel_tm.create(m * k, (size_t)1u);
24 const int8_t* a = _kernel;
25 int8_t* sa = kernel_tm;
26 reorder_a((int8_t*)a, sa, m, k, k);
27 }
28
conv_im2col_sgemm_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)29 static void conv_im2col_sgemm_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm,
30 const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
31 {
32 int w = bottom_blob.w;
33 int inch = bottom_blob.c;
34
35 int outw = top_blob.w;
36 int outh = top_blob.h;
37 int outch = top_blob.c;
38
39 // im2col
40 Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, 1UL, opt.workspace_allocator);
41 {
42 const int stride = kernel_h * kernel_w * outw * outh;
43 signed char* ret = (signed char*)bottom_im2col;
44
45 #pragma omp parallel for num_threads(opt.num_threads)
46 for (int p = 0; p < inch; p++)
47 {
48 const signed char* input = bottom_blob.channel(p);
49 int retID = stride * p;
50 for (int u = 0; u < kernel_h; u++)
51 {
52 for (int v = 0; v < kernel_w; v++)
53 {
54 for (int i = 0; i < outh; i++)
55 {
56 for (int j = 0; j < outw; j++)
57 {
58 int row = u + i * stride_h;
59 int col = v + j * stride_w;
60 int index = row * w + col;
61 ret[retID] = input[index];
62 retID++;
63 }
64 }
65 }
66 }
67 }
68 }
69
70 const int m = outch;
71 const int n = outw * outh;
72 const int k = inch * kernel_w * kernel_h;
73
74 ncnn::Mat bottom_tm(k * n, (size_t)1u, opt.workspace_allocator);
75 {
76 const int8_t* pData = bottom_im2col;
77 int8_t* pReorder = bottom_tm;
78 reorder_b(pData, pReorder, k, n, n);
79 }
80 // GEMM
81 int32_t* pc = top_blob;
82 const int8_t* pa = kernel_tm;
83 int8_t* pb = bottom_tm;
84 const size_t ldc = top_blob.cstep;
85
86 int8kernel((void*)pc, pa, pb, m, k, n, ldc, 0, 0, opt);
87 }
88 #else
conv_im2col_sgemm_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)89 static void conv_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
90 {
91 const signed char* kernel = _kernel;
92
93 // kernel memory packed 4 x 4
94 kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u);
95
96 int nn_outch = 0;
97 int remain_outch_start = 0;
98
99 nn_outch = outch >> 2;
100 remain_outch_start = nn_outch << 2;
101
102 for (int pp = 0; pp < nn_outch; pp++)
103 {
104 int p = pp * 4;
105
106 const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
107 const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
108 const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
109 const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
110
111 signed char* ktmp = kernel_tm.channel(p / 4);
112
113 int q = 0;
114 for (; q + 1 < inch * kernel_size; q += 2)
115 {
116 ktmp[0] = k0[0];
117 ktmp[1] = k0[1];
118 ktmp[2] = k1[0];
119 ktmp[3] = k1[1];
120 ktmp[4] = k2[0];
121 ktmp[5] = k2[1];
122 ktmp[6] = k3[0];
123 ktmp[7] = k3[1];
124
125 ktmp += 8;
126
127 k0 += 2;
128 k1 += 2;
129 k2 += 2;
130 k3 += 2;
131 }
132
133 for (; q < inch * kernel_size; q++)
134 {
135 ktmp[0] = k0[0];
136 ktmp[1] = k1[0];
137 ktmp[2] = k2[0];
138 ktmp[3] = k3[0];
139 ktmp += 4;
140
141 k0 += 1;
142 k1 += 1;
143 k2 += 1;
144 k3 += 1;
145 }
146 }
147
148 for (int p = remain_outch_start; p < outch; p++)
149 {
150 const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
151
152 signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
153
154 int q = 0;
155 for (; q + 1 < inch * kernel_size; q = q + 2)
156 {
157 ktmp[0] = k0[0];
158 ktmp[1] = k0[1];
159 ktmp += 2;
160 k0 += 2;
161 }
162
163 for (; q < inch * kernel_size; q++)
164 {
165 ktmp[0] = k0[0];
166 ktmp++;
167 k0++;
168 }
169 }
170 }
171
conv_im2col_sgemm_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)172 static void conv_im2col_sgemm_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm,
173 const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
174 {
175 int w = bottom_blob.w;
176 int inch = bottom_blob.c;
177
178 int outw = top_blob.w;
179 int outh = top_blob.h;
180 int outch = top_blob.c;
181
182 // im2row
183 Mat bottom_im2row(kernel_h * kernel_w * inch, outw * outh, 1UL, opt.workspace_allocator);
184 {
185 int out_stride = kernel_h * kernel_w * inch * outw;
186 signed char* ret = (signed char*)bottom_im2row;
187
188 // #pragma omp parallel for num_threads(opt.num_threads)
189 for (int i = 0; i < outh; i++)
190 {
191 int retID = out_stride * i;
192 for (int j = 0; j < outw; j++)
193 {
194 for (int p = 0; p < inch; p++)
195 {
196 const signed char* input = bottom_blob.channel(p);
197
198 for (int u = 0; u < kernel_h; u++)
199 {
200 for (int v = 0; v < kernel_w; v++)
201 {
202 int row = u + i * stride_h;
203 int col = v + j * stride_w;
204 int index = row * w + col;
205 ret[retID] = input[index];
206 retID++;
207 }
208 }
209 }
210 }
211 }
212 }
213
214 int kernel_size = kernel_w * kernel_h;
215 int out_size = outw * outh;
216
217 // int M = outch; // outch
218 int N = outw * outh; // outsize or out stride
219 int K = kernel_w * kernel_h * inch; // ksize * inch
220
221 // bottom_im2row memory packed 4 x 4
222 Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, (size_t)1u, opt.workspace_allocator);
223 {
224 int nn_size = out_size >> 2;
225 int remain_size_start = nn_size << 2;
226
227 #pragma omp parallel for num_threads(opt.num_threads)
228 for (int ii = 0; ii < nn_size; ii++)
229 {
230 int i = ii * 4;
231
232 const signed char* img0 = bottom_im2row.row<signed char>(i);
233 const signed char* img1 = bottom_im2row.row<signed char>(i + 1);
234 const signed char* img2 = bottom_im2row.row<signed char>(i + 2);
235 const signed char* img3 = bottom_im2row.row<signed char>(i + 3);
236
237 signed char* tmpptr = bottom_tm.channel(i / 4);
238
239 int q = 0;
240 for (; q + 1 < inch * kernel_size; q = q + 2)
241 {
242 tmpptr[0] = img0[0];
243 tmpptr[1] = img0[1];
244 tmpptr[2] = img1[0];
245 tmpptr[3] = img1[1];
246 tmpptr[4] = img2[0];
247 tmpptr[5] = img2[1];
248 tmpptr[6] = img3[0];
249 tmpptr[7] = img3[1];
250
251 tmpptr += 8;
252 img0 += 2;
253 img1 += 2;
254 img2 += 2;
255 img3 += 2;
256 }
257
258 for (; q < inch * kernel_size; q++)
259 {
260 tmpptr[0] = img0[0];
261 tmpptr[1] = img1[0];
262 tmpptr[2] = img2[0];
263 tmpptr[3] = img3[0];
264
265 tmpptr += 4;
266 img0 += 1;
267 img1 += 1;
268 img2 += 1;
269 img3 += 1;
270 }
271 }
272
273 #pragma omp parallel for num_threads(opt.num_threads)
274 for (int i = remain_size_start; i < out_size; i++)
275 {
276 const signed char* img0 = bottom_im2row.row<signed char>(i);
277
278 signed char* tmpptr = bottom_tm.channel(i / 4 + i % 4);
279
280 int q = 0;
281 for (; q + 1 < inch * kernel_size; q = q + 2)
282 {
283 tmpptr[0] = img0[0];
284 tmpptr[1] = img0[1];
285
286 tmpptr += 2;
287 img0 += 2;
288 }
289
290 for (; q < inch * kernel_size; q++)
291 {
292 tmpptr[0] = img0[0];
293
294 tmpptr += 1;
295 img0 += 1;
296 }
297 }
298 }
299
300 // 4x4
301 // sgemm(int M, int N, int K, float* A, float* B, float* C)
302 {
303 // int M = outch; // outch
304 // int N = outw * outh; // outsize or out stride
305 // int L = kernel_w * kernel_h * inch; // ksize * inch
306
307 int nn_outch = 0;
308 int remain_outch_start = 0;
309
310 nn_outch = outch >> 2;
311 remain_outch_start = nn_outch << 2;
312
313 #pragma omp parallel for num_threads(opt.num_threads)
314 for (int pp = 0; pp < nn_outch; pp++)
315 {
316 int i = pp * 4;
317
318 int* output0 = top_blob.channel(i);
319 int* output1 = top_blob.channel(i + 1);
320 int* output2 = top_blob.channel(i + 2);
321 int* output3 = top_blob.channel(i + 3);
322
323 int j = 0;
324 for (; j + 3 < N; j = j + 4)
325 {
326 const signed char* vb = bottom_tm.channel(j / 4);
327 const signed char* va = kernel_tm.channel(i / 4);
328
329 #if __ARM_NEON
330 asm volatile(
331 "prfm pldl1keep, [%4, #128] \n"
332 "prfm pldl1keep, [%5, #128] \n"
333 "eor v16.16b, v16.16b, v16.16b \n" // sum0
334 "eor v17.16b, v17.16b, v17.16b \n" // sum1
335 "eor v18.16b, v18.16b, v18.16b \n" // sum2
336 "eor v19.16b, v19.16b, v19.16b \n" // sum3
337
338 "lsr w4, %w12, #2 \n" // r4 = nn = L >> 2
339 "cmp w4, #0 \n"
340 "beq 1f \n"
341
342 "0: \n" // for (; k+3<L; k=k+4)
343 "ld1 {v0.16b}, [%4] \n" // i0, i1, i2, i3
344 "ld1 {v4.16b}, [%5] \n" // k0, k1, k2, k3
345 "add %4, %4, #16 \n"
346 "add %5, %5, #16 \n"
347
348 "rev32 v1.8h, v0.8h \n" // i1, i0, i3, i2
349 "rev64 v2.4s, v0.4s \n" // i2, i3, i0, i1
350 "rev64 v3.8h, v0.8h \n" // i3, i2, i1, i0
351
352 "smull v8.8h, v4.8b, v0.8b \n"
353 "smull v9.8h, v4.8b, v1.8b \n"
354 "smull v10.8h, v4.8b, v2.8b \n"
355 "smull v11.8h, v4.8b, v3.8b \n"
356
357 "prfm pldl1keep, [%4, #128] \n"
358 "prfm pldl1keep, [%5, #128] \n"
359
360 "smlal2 v8.8h, v4.16b, v0.16b \n"
361 "smlal2 v9.8h, v4.16b, v1.16b \n"
362 "smlal2 v10.8h, v4.16b, v2.16b \n"
363 "smlal2 v11.8h, v4.16b, v3.16b \n"
364
365 "sadalp v16.4s, v8.8h \n" // i0k0, i1k1, i2k2, i3k3
366 "sadalp v17.4s, v9.8h \n" // i1k0, i0k1, i3k2, i2k3
367 "sadalp v18.4s, v10.8h \n" // i2k0, i3k1, i0k2, i1k3
368 "sadalp v19.4s, v11.8h \n" // i3k0, i2k1, i1k2, i0k3
369
370 "subs w4, w4, #1 \n"
371 "bne 0b \n"
372
373 "1: \n" // for (; k+1<L; k=k+2)
374
375 // remain loop
376 "and w4, %w12, #3 \n" // w4 = remain = K & 3;
377 "cmp w4, #0 \n"
378 "beq 3f \n"
379
380 "lsr w4, w4, #1 \n" // r4 = nn = L >> 1
381 "cmp w4, #0 \n"
382 "beq 3f \n"
383
384 "2: \n" // for (; k+1<L; k=k+2)
385
386 "ld1 {v0.8b}, [%4] \n" // i0, i1, i2, i3
387 "ld1 {v4.8b}, [%5] \n" // k0, k1, k2, k3
388 "add %4, %4, #8 \n"
389 "add %5, %5, #8 \n"
390
391 "rev32 v1.4h, v0.4h \n" // i2, i3, i0, i1
392 "rev64 v2.2s, v0.2s \n" // i1, i0, i3, i2
393 "rev64 v3.4h, v0.4h \n" // i0, i1, i2, i3
394
395 "smull v8.8h, v4.8b, v0.8b \n"
396 "smull v9.8h, v4.8b, v1.8b \n"
397 "smull v10.8h, v4.8b, v2.8b \n"
398 "smull v11.8h, v4.8b, v3.8b \n"
399 "sadalp v16.4s, v8.8h \n"
400 "sadalp v17.4s, v9.8h \n"
401 "sadalp v18.4s,v10.8h \n"
402 "sadalp v19.4s,v11.8h \n"
403
404 "subs w4, w4, #1 \n"
405 "bne 2b \n"
406
407 "3: \n" // realloc
408
409 "mov v20.s[0], v16.s[0] \n"
410 "mov v20.s[1], v17.s[0] \n"
411 "mov v20.s[2], v18.s[0] \n"
412 "mov v20.s[3], v19.s[0] \n"
413
414 "mov v21.s[0], v17.s[1] \n"
415 "mov v21.s[1], v16.s[1] \n"
416 "mov v21.s[2], v19.s[1] \n"
417 "mov v21.s[3], v18.s[1] \n"
418
419 "mov v22.s[0], v18.s[2] \n"
420 "mov v22.s[1], v19.s[2] \n"
421 "mov v22.s[2], v16.s[2] \n"
422 "mov v22.s[3], v17.s[2] \n"
423
424 "mov v23.s[0], v19.s[3] \n"
425 "mov v23.s[1], v18.s[3] \n"
426 "mov v23.s[2], v17.s[3] \n"
427 "mov v23.s[3], v16.s[3] \n"
428
429 "and w4, %w12, #1 \n" // w4 = remain = K & 1;
430 "cmp w4, #0 \n"
431 "beq 5f \n"
432
433 "4: \n"
434 "ld1 {v0.8b}, [%4] \n"
435 "ld1 {v1.8b}, [%5] \n"
436 "add %4, %4, #4 \n"
437 "add %5, %5, #4 \n"
438
439 "sshll v0.8h, v0.8b, #0 \n" // i0[0], i1[0], i2[0], i3[0]
440 "sshll v1.8h, v1.8b, #0 \n" // k0[0], k1[0], k2[0], k3[0]
441
442 "smlal v20.4s, v0.4h, v1.h[0] \n" // i0k0, i1k0, i2k0, i3k0
443 "smlal v21.4s, v0.4h, v1.h[1] \n" // i0k1, i1k1, i2k1, i3k1
444 "smlal v22.4s, v0.4h, v1.h[2] \n" // i0k2, i1k2, i2k2, i3k2
445 "smlal v23.4s, v0.4h, v1.h[3] \n" // i0k3, i1k3, i2k3, i3k3
446
447 "subs w4, w4, #1 \n"
448
449 "bne 2b \n"
450
451 "5: \n"
452
453 "st1 {v20.4s}, [%0] \n"
454 "st1 {v21.4s}, [%1] \n"
455 "st1 {v22.4s}, [%2] \n"
456 "st1 {v23.4s}, [%3] \n"
457
458 : "=r"(output0), // %0
459 "=r"(output1), // %1
460 "=r"(output2), // %2
461 "=r"(output3), // %3
462 "=r"(vb), // %4
463 "=r"(va) // %5
464 : "0"(output0),
465 "1"(output1),
466 "2"(output2),
467 "3"(output3),
468 "4"(vb),
469 "5"(va),
470 "r"(K) // %12
471 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
472 #else
473 int sum0[4] = {0};
474 int sum1[4] = {0};
475 int sum2[4] = {0};
476 int sum3[4] = {0};
477
478 int k = 0;
479
480 for (; k + 1 < K; k = k + 2)
481 {
482 for (int n = 0; n < 4; n++)
483 {
484 sum0[n] += (int)va[0] * vb[2 * n]; // k0
485 sum0[n] += (int)va[1] * vb[2 * n + 1];
486
487 sum1[n] += (int)va[2] * vb[2 * n]; // k1
488 sum1[n] += (int)va[3] * vb[2 * n + 1];
489
490 sum2[n] += (int)va[4] * vb[2 * n]; // k2
491 sum2[n] += (int)va[5] * vb[2 * n + 1];
492
493 sum3[n] += (int)va[6] * vb[2 * n]; // k3
494 sum3[n] += (int)va[7] * vb[2 * n + 1];
495 }
496
497 va += 8;
498 vb += 8;
499 }
500
501 for (; k < K; k++)
502 {
503 for (int n = 0; n < 4; n++)
504 {
505 sum0[n] += (int)va[0] * vb[n];
506 sum1[n] += (int)va[1] * vb[n];
507 sum2[n] += (int)va[2] * vb[n];
508 sum3[n] += (int)va[3] * vb[n];
509 }
510
511 va += 4;
512 vb += 4;
513 }
514
515 for (int n = 0; n < 4; n++)
516 {
517 output0[n] = sum0[n];
518 output1[n] = sum1[n];
519 output2[n] = sum2[n];
520 output3[n] = sum3[n];
521 }
522 #endif
523 output0 += 4;
524 output1 += 4;
525 output2 += 4;
526 output3 += 4;
527 }
528
529 for (; j < N; j++)
530 {
531 const signed char* vb = bottom_tm.channel(j / 4 + j % 4);
532 const signed char* va = kernel_tm.channel(i / 4);
533
534 #if 0 //__ARM_NEON
535 int32x4_t _sum = vdupq_n_s32(0);
536
537 int k=0;
538
539 for (; k+3<K; k=k+4)
540 {
541 int8x8_t _r0 = vld1_s8(vb); // i0[0-3]
542 int8x8x2_t _k = vld2_s8(va); // k0[0-1], k1[0-1], k2[0-1], k3[0-1];k0[2-3], k1[2-3], k2[2-3], k3[2-3]
543
544 int16x8_t _r0_s16 = vmovl_s8(_r0); // i0[0],i0[1],i0[2],i0[3]
545 int16x8_t _k02_s16 = vmovl_s8(_k.val[0]); // k0[0],k1[0],k2[0],k3[0],k0[2],k1[2],k2[2],k3[2]
546 int16x8_t _k13_s16 = vmovl_s8(_k.val[1]); // k0[1],k1[1],k2[1],k3[1],k0[3],k1[3],k2[3],k3[3]
547
548 _sum = vmlal_lane_s16(_sum, vget_low_s16(_k02_s16), vget_low_s16(_r0_s16), 0); // i0[0]*k[0-3][0]
549 _sum = vmlal_lane_s16(_sum, vget_low_s16(_k13_s16), vget_low_s16(_r0_s16), 1); // i0[1]*k[0-3][1]
550 _sum = vmlal_lane_s16(_sum, vget_high_s16(_k02_s16), vget_low_s16(_r0_s16), 2); // i0[2]*k[0-3][2]
551 _sum = vmlal_lane_s16(_sum, vget_high_s16(_k13_s16), vget_low_s16(_r0_s16), 3); // i0[3]*k[0-3][3]
552
553 va += 16;
554 vb += 4;
555 }
556
557 for (; k+1<K; k=k+2)
558 {
559 int8x8_t _r0 = vld1_s8(vb); // i0[0-3]
560 int8x8_t _k = vld1_s8(va); // k0[0-1], k1[0-1], k2[0-1], k3[0-1]
561
562 _r0[2] = _r0[0];
563 _r0[3] = _r0[1];
564 _r0[4] = _r0[0];
565 _r0[5] = _r0[1];
566 _r0[6] = _r0[0];
567 _r0[7] = _r0[1];
568
569 int16x8_t _tp0 = vmull_s8(_k, _r0);
570 _sum = vpadalq_s16(_sum, _tp0);
571
572 va += 8;
573 vb += 2;
574 }
575
576 for (; k<K; k++)
577 {
578 int8x8_t _r0 = vld1_s8(vb); // i0[0-3]
579 int8x8_t _k = vld1_s8(va); // k[0-3][0]
580
581 int16x8_t _tp0 = vmull_s8(_k, _r0);
582
583 _sum = vaddw_s16(_sum, vget_low_s16(_tp0));
584
585 va += 4;
586 vb += 1;
587 }
588
589 vst1q_lane_s32(output0, _sum, 0);
590 vst1q_lane_s32(output1, _sum, 1);
591 vst1q_lane_s32(output2, _sum, 2);
592 vst1q_lane_s32(output3, _sum, 3);
593 #else
594 int sum0 = 0;
595 int sum1 = 0;
596 int sum2 = 0;
597 int sum3 = 0;
598
599 int k = 0;
600
601 for (; k + 1 < K; k = k + 2)
602 {
603 sum0 += (int)va[0] * vb[0];
604 sum0 += (int)va[1] * vb[1];
605
606 sum1 += (int)va[2] * vb[0];
607 sum1 += (int)va[3] * vb[1];
608
609 sum2 += (int)va[4] * vb[0];
610 sum2 += (int)va[5] * vb[1];
611
612 sum3 += (int)va[6] * vb[0];
613 sum3 += (int)va[7] * vb[1];
614
615 va += 8;
616 vb += 2;
617 }
618
619 for (; k < K; k++)
620 {
621 sum0 += (int)va[0] * vb[0];
622 sum1 += (int)va[1] * vb[0];
623 sum2 += (int)va[2] * vb[0];
624 sum3 += (int)va[3] * vb[0];
625
626 va += 4;
627 vb += 1;
628 }
629
630 output0[0] = sum0;
631 output1[0] = sum1;
632 output2[0] = sum2;
633 output3[0] = sum3;
634 #endif
635 output0++;
636 output1++;
637 output2++;
638 output3++;
639 }
640 }
641
642 #pragma omp parallel for num_threads(opt.num_threads)
643 for (int i = remain_outch_start; i < outch; i++)
644 {
645 int* output = top_blob.channel(i);
646
647 int j = 0;
648 for (; j + 3 < N; j = j + 4)
649 {
650 const signed char* vb = bottom_tm.channel(j / 4);
651 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
652
653 #if __ARM_NEON
654 int32x4_t _sum = vdupq_n_s32(0);
655
656 int k = 0;
657 for (; k + 1 < K; k = k + 2)
658 {
659 int8x8_t _r0 = vld1_s8(vb); // i0[0-1], i1[0-1], i2[0-1], i3[0-1]
660 int8x8_t _k = vld1_s8(va); // k0[0-1]
661
662 _k[2] = _k[0];
663 _k[3] = _k[1];
664 _k[4] = _k[0];
665 _k[5] = _k[1];
666 _k[6] = _k[0];
667 _k[7] = _k[1];
668
669 int16x8_t _tp0 = vmull_s8(_k, _r0);
670 _sum = vpadalq_s16(_sum, _tp0);
671
672 va += 2;
673 vb += 8;
674 }
675
676 for (; k < K; k++)
677 {
678 int8x8_t _r0 = vld1_s8(vb); // i0[0], i1[0], i2[0], i3[0]
679 int8x8_t _k = vld1_s8(va); // k[0][0]
680
681 int16x8_t _r0_s16 = vmovl_s8(_r0);
682 int16x8_t _k_s16 = vmovl_s8(_k);
683
684 _sum = vmlal_lane_s16(_sum, vget_low_s16(_r0_s16), vget_low_s16(_k_s16), 0); // i0k0, i1k0, i2k0, i3k0
685
686 va += 1;
687 vb += 4;
688 }
689
690 vst1q_s32(output, _sum);
691 #else
692 int sum[4] = {0};
693 int k = 0;
694 for (; k + 1 < K; k = k + 2)
695 {
696 for (int n = 0; n < 4; n++)
697 {
698 sum[n] += (int)va[0] * vb[2 * n];
699 sum[n] += (int)va[1] * vb[2 * n + 1];
700 }
701 va += 2;
702 vb += 8;
703 }
704
705 for (; k < K; k++)
706 {
707 for (int n = 0; n < 4; n++)
708 {
709 sum[n] += (int)va[0] * vb[n];
710 }
711 va += 1;
712 vb += 4;
713 }
714
715 for (int n = 0; n < 4; n++)
716 {
717 output[n] = sum[n];
718 }
719 #endif
720 output += 4;
721 }
722
723 for (; j < N; j++)
724 {
725 int sum = 0;
726
727 const signed char* vb = bottom_tm.channel(j / 4 + j % 4);
728 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
729
730 for (int k = 0; k < K; k++)
731 {
732 sum += (int)va[0] * vb[0];
733
734 va += 1;
735 vb += 1;
736 }
737 output[0] = sum;
738
739 output++;
740 }
741 }
742 }
743
744 // // sgemm(int M, int N, int K, float* A, float* B, float* C)
745 // {
746 // for (int i=0; i<M; i++)
747 // {
748 // int* output = top_blob.channel(i);
749
750 // for (int j=0; j<N; j++)
751 // {
752 // int sum = 0;
753
754 // signed char* vb = (signed char*)bottom_im2row + K * j;
755 // const signed char* va = kernel + K * i;
756
757 // for (int k=0; k<K; k++)
758 // {
759 // sum += (int)va[0] * vb[0];
760
761 // va += 1;
762 // vb += 1;
763 // }
764 // output[0] = sum;
765
766 // output++;
767 // }
768 // }
769 // }
770 }
771 #endif
772 #else
conv_im2col_sgemm_transform_kernel_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)773 static void conv_im2col_sgemm_transform_kernel_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
774 {
775 const signed char* kernel = _kernel;
776
777 #if __ARM_NEON && __aarch64__
778 // kernel memory packed 8 x 8
779 kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)1u);
780 #else
781 // kernel memory packed 4 x 8
782 kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4, (size_t)1u);
783 #endif
784
785 int nn_outch = 0;
786 int remain_outch_start = 0;
787
788 #if __ARM_NEON && __aarch64__
789 nn_outch = outch >> 3;
790 remain_outch_start = nn_outch << 3;
791
792 for (int pp = 0; pp < nn_outch; pp++)
793 {
794 int p = pp * 8;
795
796 const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
797 const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
798 const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
799 const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
800 const signed char* k4 = kernel + (p + 4) * inch * kernel_size;
801 const signed char* k5 = kernel + (p + 5) * inch * kernel_size;
802 const signed char* k6 = kernel + (p + 6) * inch * kernel_size;
803 const signed char* k7 = kernel + (p + 7) * inch * kernel_size;
804
805 signed char* ktmp = kernel_tm.channel(p / 8);
806
807 for (int q = 0; q < inch * kernel_size; q++)
808 {
809 ktmp[0] = k0[0];
810 ktmp[1] = k1[0];
811 ktmp[2] = k2[0];
812 ktmp[3] = k3[0];
813 ktmp[4] = k4[0];
814 ktmp[5] = k5[0];
815 ktmp[6] = k6[0];
816 ktmp[7] = k7[0];
817 ktmp += 8;
818
819 k0 += 1;
820 k1 += 1;
821 k2 += 1;
822 k3 += 1;
823 k4 += 1;
824 k5 += 1;
825 k6 += 1;
826 k7 += 1;
827 }
828 }
829 #endif
830
831 nn_outch = (outch - remain_outch_start) >> 2;
832
833 for (int pp = 0; pp < nn_outch; pp++)
834 {
835 int p = remain_outch_start + pp * 4;
836
837 const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
838 const signed char* k1 = kernel + (p + 1) * inch * kernel_size;
839 const signed char* k2 = kernel + (p + 2) * inch * kernel_size;
840 const signed char* k3 = kernel + (p + 3) * inch * kernel_size;
841
842 #if __ARM_NEON && __aarch64__
843 signed char* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4);
844 #else
845 signed char* ktmp = kernel_tm.channel(p / 4);
846 #endif // __ARM_NEON && __aarch64__
847
848 for (int q = 0; q < inch * kernel_size; q++)
849 {
850 ktmp[0] = k0[0];
851 ktmp[1] = k1[0];
852 ktmp[2] = k2[0];
853 ktmp[3] = k3[0];
854 ktmp += 4;
855
856 k0 += 1;
857 k1 += 1;
858 k2 += 1;
859 k3 += 1;
860 }
861 }
862
863 remain_outch_start += nn_outch << 2;
864
865 for (int p = remain_outch_start; p < outch; p++)
866 {
867 const signed char* k0 = kernel + (p + 0) * inch * kernel_size;
868
869 #if __ARM_NEON && __aarch64__
870 signed char* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4);
871 #else
872 signed char* ktmp = kernel_tm.channel(p / 4 + p % 4);
873 #endif // __ARM_NEON && __aarch64__
874
875 for (int q = 0; q < inch * kernel_size; q++)
876 {
877 ktmp[0] = k0[0];
878 ktmp++;
879 k0++;
880 }
881 }
882 }
883
conv_im2col_sgemm_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)884 static void conv_im2col_sgemm_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm,
885 const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
886 {
887 int w = bottom_blob.w;
888 int inch = bottom_blob.c;
889
890 int outw = top_blob.w;
891 int outh = top_blob.h;
892 int outch = top_blob.c;
893
894 // im2col
895 Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, 1UL, opt.workspace_allocator);
896 {
897 const int stride = kernel_h * kernel_w * outw * outh;
898 signed char* ret = (signed char*)bottom_im2col;
899
900 #pragma omp parallel for num_threads(opt.num_threads)
901 for (int p = 0; p < inch; p++)
902 {
903 const signed char* input = bottom_blob.channel(p);
904 int retID = stride * p;
905 for (int u = 0; u < kernel_h; u++)
906 {
907 for (int v = 0; v < kernel_w; v++)
908 {
909 for (int i = 0; i < outh; i++)
910 {
911 for (int j = 0; j < outw; j++)
912 {
913 int row = u + i * stride_h;
914 int col = v + j * stride_w;
915 int index = row * w + col;
916 ret[retID] = input[index];
917 retID++;
918 }
919 }
920 }
921 }
922 }
923 }
924
925 int kernel_size = kernel_w * kernel_h;
926 int out_size = outw * outh;
927
928 // bottom_im2col memory packed 8 x 8
929 Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, (size_t)1u, opt.workspace_allocator);
930 {
931 int nn_size = out_size >> 3;
932 int remain_size_start = nn_size << 3;
933
934 #pragma omp parallel for num_threads(opt.num_threads)
935 for (int ii = 0; ii < nn_size; ii++)
936 {
937 int i = ii * 8;
938
939 const signed char* img0 = bottom_im2col.channel(0);
940 img0 += i;
941
942 signed char* tmpptr = bottom_tm.channel(i / 8);
943
944 for (int q = 0; q < inch * kernel_size; q++)
945 {
946 #if __ARM_NEON
947 #if __aarch64__
948 asm volatile(
949 "prfm pldl1keep, [%0, #64] \n"
950 "ld1 {v0.8b}, [%0] \n"
951 "st1 {v0.8b}, [%1] \n"
952 : "=r"(img0), // %0
953 "=r"(tmpptr) // %1
954 : "0"(img0),
955 "1"(tmpptr)
956 : "cc", "memory", "v0");
957 #else
958 asm volatile(
959 "pld [%0, #64] \n"
960 "vld1.s8 {d0}, [%0] \n"
961 "vst1.s8 {d0}, [%1] \n"
962 : "=r"(img0), // %0
963 "=r"(tmpptr) // %1
964 : "0"(img0),
965 "1"(tmpptr)
966 : "cc", "memory", "d0");
967 #endif // __aarch64__
968 #else
969 tmpptr[0] = img0[0];
970 tmpptr[1] = img0[1];
971 tmpptr[2] = img0[2];
972 tmpptr[3] = img0[3];
973 tmpptr[4] = img0[4];
974 tmpptr[5] = img0[5];
975 tmpptr[6] = img0[6];
976 tmpptr[7] = img0[7];
977 #endif // __ARM_NEON
978 tmpptr += 8;
979 img0 += out_size;
980 }
981 }
982
983 #pragma omp parallel for num_threads(opt.num_threads)
984 for (int i = remain_size_start; i < out_size; i++)
985 {
986 const signed char* img0 = bottom_im2col.channel(0);
987 img0 += i;
988
989 signed char* tmpptr = bottom_tm.channel(i / 8 + i % 8);
990
991 for (int q = 0; q < inch * kernel_size; q++)
992 {
993 tmpptr[0] = img0[0];
994
995 tmpptr += 1;
996 img0 += out_size;
997 }
998 }
999 }
1000
1001 // sgemm(int M, int N, int L, float* A, float* B, float* C)
1002 {
1003 //int M = outch; // outch
1004 int N = outw * outh; // outsize or out stride
1005 int L = kernel_w * kernel_h * inch; // ksize * inch
1006
1007 int nn_outch = 0;
1008 int remain_outch_start = 0;
1009
1010 #if __ARM_NEON && __aarch64__
1011 nn_outch = outch >> 3;
1012 remain_outch_start = nn_outch << 3;
1013
1014 #pragma omp parallel for num_threads(opt.num_threads)
1015 for (int pp = 0; pp < nn_outch; pp++)
1016 {
1017 int i = pp * 8;
1018
1019 int* output0 = top_blob.channel(i);
1020 int* output1 = top_blob.channel(i + 1);
1021 int* output2 = top_blob.channel(i + 2);
1022 int* output3 = top_blob.channel(i + 3);
1023 int* output4 = top_blob.channel(i + 4);
1024 int* output5 = top_blob.channel(i + 5);
1025 int* output6 = top_blob.channel(i + 6);
1026 int* output7 = top_blob.channel(i + 7);
1027
1028 int j = 0;
1029 for (; j + 7 < N; j = j + 8)
1030 {
1031 signed char* vb = bottom_tm.channel(j / 8);
1032 const signed char* va = kernel_tm.channel(i / 8);
1033 #if __aarch64__
1034 asm volatile(
1035 "eor v16.16b, v16.16b, v16.16b \n" // sum0
1036 "eor v17.16b, v17.16b, v17.16b \n" // sum0n
1037 "eor v18.16b, v18.16b, v18.16b \n" // sum1
1038 "eor v19.16b, v19.16b, v19.16b \n" // sum1n
1039 "eor v20.16b, v20.16b, v20.16b \n" // sum2
1040 "eor v21.16b, v21.16b, v21.16b \n" // sum2n
1041 "eor v22.16b, v22.16b, v22.16b \n" // sum3
1042 "eor v23.16b, v23.16b, v23.16b \n" // sum3n
1043 "eor v24.16b, v24.16b, v24.16b \n" // sum4
1044 "eor v25.16b, v25.16b, v25.16b \n" // sum4n
1045 "eor v26.16b, v26.16b, v26.16b \n" // sum5
1046 "eor v27.16b, v27.16b, v27.16b \n" // sum5n
1047 "eor v28.16b, v28.16b, v28.16b \n" // sum6
1048 "eor v29.16b, v29.16b, v29.16b \n" // sum6n
1049 "eor v30.16b, v30.16b, v30.16b \n" // sum7
1050 "eor v31.16b, v31.16b, v31.16b \n" // sum7n
1051
1052 "lsr w4, %w20, #2 \n" // r4 = nn = L >> 2
1053 "cmp w4, #0 \n"
1054 "beq 1f \n"
1055
1056 "0: \n" // for (; k+3<L; k=k+4)
1057
1058 "prfm pldl1keep, [%9, #128] \n"
1059 "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%9], #32 \n"
1060
1061 "prfm pldl1keep, [%8, #128] \n"
1062 "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%8], #32 \n"
1063
1064 "sshll v0.8h, v0.8b, #0 \n" // k00 - k70
1065 "sshll v1.8h, v1.8b, #0 \n" // k01 - k71
1066 "sshll v2.8h, v2.8b, #0 \n" // k02 - k72
1067 "sshll v3.8h, v3.8b, #0 \n" // k03 - k73
1068
1069 "sshll v8.8h, v8.8b, #0 \n" // a00 - a70
1070 "sshll v9.8h, v9.8b, #0 \n" // a01 - a71
1071 "sshll v10.8h, v10.8b, #0 \n" // a02 - a72
1072 "sshll v11.8h, v11.8b, #0 \n" // a03 - a73
1073 // k0
1074 "smlal v16.4s, v8.4h, v0.h[0] \n" // sum0 += (a00-a70) * k00
1075 "smlal2 v17.4s, v8.8h, v0.h[0] \n" //
1076 "smlal v18.4s, v8.4h, v0.h[1] \n" // sum1 += (a00-a70) * k10
1077 "smlal2 v19.4s, v8.8h, v0.h[1] \n" //
1078 "smlal v20.4s, v8.4h, v0.h[2] \n" // sum2 += (a00-a70) * k20
1079 "smlal2 v21.4s, v8.8h, v0.h[2] \n" //
1080 "smlal v22.4s, v8.4h, v0.h[3] \n" // sum3 += (a00-a70) * k30
1081 "smlal2 v23.4s, v8.8h, v0.h[3] \n" //
1082 "smlal v24.4s, v8.4h, v0.h[4] \n" // sum4 += (a00-a70) * k40
1083 "smlal2 v25.4s, v8.8h, v0.h[4] \n" //
1084 "smlal v26.4s, v8.4h, v0.h[5] \n" // sum5 += (a00-a70) * k50
1085 "smlal2 v27.4s, v8.8h, v0.h[5] \n" //
1086 "smlal v28.4s, v8.4h, v0.h[6] \n" // sum6 += (a00-a70) * k60
1087 "smlal2 v29.4s, v8.8h, v0.h[6] \n" //
1088 "smlal v30.4s, v8.4h, v0.h[7] \n" // sum7 += (a00-a70) * k70
1089 "smlal2 v31.4s, v8.8h, v0.h[7] \n" //
1090 // k1
1091 "smlal v16.4s, v9.4h, v1.h[0] \n" // sum0 += (a01-a71) * k01
1092 "smlal2 v17.4s, v9.8h, v1.h[0] \n" //
1093 "smlal v18.4s, v9.4h, v1.h[1] \n" // sum1 += (a01-a71) * k11
1094 "smlal2 v19.4s, v9.8h, v1.h[1] \n" //
1095 "smlal v20.4s, v9.4h, v1.h[2] \n" // sum2 += (a01-a71) * k21
1096 "smlal2 v21.4s, v9.8h, v1.h[2] \n" //
1097 "smlal v22.4s, v9.4h, v1.h[3] \n" // sum3 += (a01-a71) * k31
1098 "smlal2 v23.4s, v9.8h, v1.h[3] \n" //
1099 "smlal v24.4s, v9.4h, v1.h[4] \n" // sum4 += (a01-a71) * k41
1100 "smlal2 v25.4s, v9.8h, v1.h[4] \n" //
1101 "smlal v26.4s, v9.4h, v1.h[5] \n" // sum5 += (a01-a71) * k51
1102 "smlal2 v27.4s, v9.8h, v1.h[5] \n" //
1103 "smlal v28.4s, v9.4h, v1.h[6] \n" // sum6 += (a01-a71) * k61
1104 "smlal2 v29.4s, v9.8h, v1.h[6] \n" //
1105 "smlal v30.4s, v9.4h, v1.h[7] \n" // sum7 += (a01-a71) * k71
1106 "smlal2 v31.4s, v9.8h, v1.h[7] \n" //
1107 // k2
1108 "smlal v16.4s, v10.4h, v2.h[0] \n" // sum0 += (a02-a72) * k02
1109 "smlal2 v17.4s, v10.8h, v2.h[0] \n" //
1110 "smlal v18.4s, v10.4h, v2.h[1] \n" // sum1 += (a02-a72) * k12
1111 "smlal2 v19.4s, v10.8h, v2.h[1] \n" //
1112 "smlal v20.4s, v10.4h, v2.h[2] \n" // sum2 += (a02-a72) * k22
1113 "smlal2 v21.4s, v10.8h, v2.h[2] \n" //
1114 "smlal v22.4s, v10.4h, v2.h[3] \n" // sum3 += (a02-a72) * k32
1115 "smlal2 v23.4s, v10.8h, v2.h[3] \n" //
1116 "smlal v24.4s, v10.4h, v2.h[4] \n" // sum4 += (a02-a72) * k42
1117 "smlal2 v25.4s, v10.8h, v2.h[4] \n" //
1118 "smlal v26.4s, v10.4h, v2.h[5] \n" // sum5 += (a02-a72) * k52
1119 "smlal2 v27.4s, v10.8h, v2.h[5] \n" //
1120 "smlal v28.4s, v10.4h, v2.h[6] \n" // sum6 += (a02-a72) * k62
1121 "smlal2 v29.4s, v10.8h, v2.h[6] \n" //
1122 "smlal v30.4s, v10.4h, v2.h[7] \n" // sum7 += (a02-a72) * k72
1123 "smlal2 v31.4s, v10.8h, v2.h[7] \n" //
1124 // k3
1125 "smlal v16.4s, v11.4h, v3.h[0] \n" // sum0 += (a03-a73) * k03
1126 "smlal2 v17.4s, v11.8h, v3.h[0] \n" //
1127 "smlal v18.4s, v11.4h, v3.h[1] \n" // sum1 += (a03-a73) * k13
1128 "smlal2 v19.4s, v11.8h, v3.h[1] \n" //
1129 "smlal v20.4s, v11.4h, v3.h[2] \n" // sum2 += (a03-a73) * k23
1130 "smlal2 v21.4s, v11.8h, v3.h[2] \n" //
1131 "smlal v22.4s, v11.4h, v3.h[3] \n" // sum3 += (a03-a73) * k33
1132 "smlal2 v23.4s, v11.8h, v3.h[3] \n" //
1133 "smlal v24.4s, v11.4h, v3.h[4] \n" // sum4 += (a03-a73) * k43
1134 "smlal2 v25.4s, v11.8h, v3.h[4] \n" //
1135 "smlal v26.4s, v11.4h, v3.h[5] \n" // sum5 += (a03-a73) * k53
1136 "smlal2 v27.4s, v11.8h, v3.h[5] \n" //
1137 "smlal v28.4s, v11.4h, v3.h[6] \n" // sum6 += (a03-a73) * k63
1138 "smlal2 v29.4s, v11.8h, v3.h[6] \n" //
1139 "smlal v30.4s, v11.4h, v3.h[7] \n" // sum7 += (a03-a73) * k73
1140 "smlal2 v31.4s, v11.8h, v3.h[7] \n" //
1141
1142 "subs w4, w4, #1 \n"
1143 "bne 0b \n"
1144
1145 "1: \n"
1146
1147 // remain loop
1148 "and w4, %w20, #3 \n" // w4 = remain = inch & 3;
1149 "cmp w4, #0 \n"
1150 "beq 3f \n"
1151
1152 "2: \n"
1153
1154 "prfm pldl1keep, [%9, #128] \n"
1155 "ld1 {v0.8b}, [%9], #8 \n"
1156
1157 "prfm pldl1keep, [%8, #128] \n"
1158 "ld1 {v8.8b}, [%8], #8 \n"
1159
1160 "sshll v0.8h, v0.8b, #0 \n" // k00 - k70
1161 "sshll v8.8h, v8.8b, #0 \n" // a00 - a70
1162
1163 // k0
1164 "smlal v16.4s, v8.4h, v0.h[0] \n" // sum0 += (a00-a70) * k00
1165 "smlal2 v17.4s, v8.8h, v0.h[0] \n" //
1166 "smlal v18.4s, v8.4h, v0.h[1] \n" // sum1 += (a00-a70) * k10
1167 "smlal2 v19.4s, v8.8h, v0.h[1] \n" //
1168 "smlal v20.4s, v8.4h, v0.h[2] \n" // sum2 += (a00-a70) * k20
1169 "smlal2 v21.4s, v8.8h, v0.h[2] \n" //
1170 "smlal v22.4s, v8.4h, v0.h[3] \n" // sum3 += (a00-a70) * k30
1171 "smlal2 v23.4s, v8.8h, v0.h[3] \n" //
1172 "smlal v24.4s, v8.4h, v0.h[4] \n" // sum4 += (a00-a70) * k40
1173 "smlal2 v25.4s, v8.8h, v0.h[4] \n" //
1174 "smlal v26.4s, v8.4h, v0.h[5] \n" // sum5 += (a00-a70) * k50
1175 "smlal2 v27.4s, v8.8h, v0.h[5] \n" //
1176 "smlal v28.4s, v8.4h, v0.h[6] \n" // sum6 += (a00-a70) * k60
1177 "smlal2 v29.4s, v8.8h, v0.h[6] \n" //
1178 "smlal v30.4s, v8.4h, v0.h[7] \n" // sum7 += (a00-a70) * k70
1179 "smlal2 v31.4s, v8.8h, v0.h[7] \n" //
1180
1181 "subs w4, w4, #1 \n"
1182
1183 "bne 2b \n"
1184
1185 "3: \n"
1186
1187 "st1 {v16.4s, v17.4s}, [%0] \n"
1188 "st1 {v18.4s, v19.4s}, [%1] \n"
1189 "st1 {v20.4s, v21.4s}, [%2] \n"
1190 "st1 {v22.4s, v23.4s}, [%3] \n"
1191 "st1 {v24.4s, v25.4s}, [%4] \n"
1192 "st1 {v26.4s, v27.4s}, [%5] \n"
1193 "st1 {v28.4s, v29.4s}, [%6] \n"
1194 "st1 {v30.4s, v31.4s}, [%7] \n"
1195
1196 : "=r"(output0), // %0
1197 "=r"(output1), // %1
1198 "=r"(output2), // %2
1199 "=r"(output3), // %3
1200 "=r"(output4), // %4
1201 "=r"(output5), // %5
1202 "=r"(output6), // %6
1203 "=r"(output7), // %7
1204 "=r"(vb), // %8
1205 "=r"(va) // %9
1206 : "0"(output0),
1207 "1"(output1),
1208 "2"(output2),
1209 "3"(output3),
1210 "4"(output4),
1211 "5"(output5),
1212 "6"(output6),
1213 "7"(output7),
1214 "8"(vb),
1215 "9"(va),
1216 "r"(L) // %20
1217 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1218 #else
1219 int sum0[8] = {0};
1220 int sum1[8] = {0};
1221 int sum2[8] = {0};
1222 int sum3[8] = {0};
1223 int sum4[8] = {0};
1224 int sum5[8] = {0};
1225 int sum6[8] = {0};
1226 int sum7[8] = {0};
1227
1228 int k = 0;
1229 for (; k + 7 < L; k = k + 8)
1230 {
1231 for (int n = 0; n < 8; n++)
1232 {
1233 sum0[n] += (int)va[0] * vb[n];
1234 sum1[n] += (int)va[1] * vb[n];
1235 sum2[n] += (int)va[2] * vb[n];
1236 sum3[n] += (int)va[3] * vb[n];
1237 sum4[n] += (int)va[4] * vb[n];
1238 sum5[n] += (int)va[5] * vb[n];
1239 sum6[n] += (int)va[6] * vb[n];
1240 sum7[n] += (int)va[7] * vb[n];
1241 va += 8;
1242
1243 sum0[n] += (int)va[0] * vb[n + 8];
1244 sum1[n] += (int)va[1] * vb[n + 8];
1245 sum2[n] += (int)va[2] * vb[n + 8];
1246 sum3[n] += (int)va[3] * vb[n + 8];
1247 sum4[n] += (int)va[4] * vb[n + 8];
1248 sum5[n] += (int)va[5] * vb[n + 8];
1249 sum6[n] += (int)va[6] * vb[n + 8];
1250 sum7[n] += (int)va[7] * vb[n + 8];
1251 va += 8;
1252
1253 sum0[n] += (int)va[0] * vb[n + 16];
1254 sum1[n] += (int)va[1] * vb[n + 16];
1255 sum2[n] += (int)va[2] * vb[n + 16];
1256 sum3[n] += (int)va[3] * vb[n + 16];
1257 sum4[n] += (int)va[4] * vb[n + 16];
1258 sum5[n] += (int)va[5] * vb[n + 16];
1259 sum6[n] += (int)va[6] * vb[n + 16];
1260 sum7[n] += (int)va[7] * vb[n + 16];
1261 va += 8;
1262
1263 sum0[n] += (int)va[0] * vb[n + 24];
1264 sum1[n] += (int)va[1] * vb[n + 24];
1265 sum2[n] += (int)va[2] * vb[n + 24];
1266 sum3[n] += (int)va[3] * vb[n + 24];
1267 sum4[n] += (int)va[4] * vb[n + 24];
1268 sum5[n] += (int)va[5] * vb[n + 24];
1269 sum6[n] += (int)va[6] * vb[n + 24];
1270 sum7[n] += (int)va[7] * vb[n + 24];
1271 va += 8;
1272
1273 sum0[n] += (int)va[0] * vb[n + 32];
1274 sum1[n] += (int)va[1] * vb[n + 32];
1275 sum2[n] += (int)va[2] * vb[n + 32];
1276 sum3[n] += (int)va[3] * vb[n + 32];
1277 sum4[n] += (int)va[4] * vb[n + 32];
1278 sum5[n] += (int)va[5] * vb[n + 32];
1279 sum6[n] += (int)va[6] * vb[n + 32];
1280 sum7[n] += (int)va[7] * vb[n + 32];
1281 va += 8;
1282
1283 sum0[n] += (int)va[0] * vb[n + 40];
1284 sum1[n] += (int)va[1] * vb[n + 40];
1285 sum2[n] += (int)va[2] * vb[n + 40];
1286 sum3[n] += (int)va[3] * vb[n + 40];
1287 sum4[n] += (int)va[4] * vb[n + 40];
1288 sum5[n] += (int)va[5] * vb[n + 40];
1289 sum6[n] += (int)va[6] * vb[n + 40];
1290 sum7[n] += (int)va[7] * vb[n + 40];
1291 va += 8;
1292
1293 sum0[n] += (int)va[0] * vb[n + 48];
1294 sum1[n] += (int)va[1] * vb[n + 48];
1295 sum2[n] += (int)va[2] * vb[n + 48];
1296 sum3[n] += (int)va[3] * vb[n + 48];
1297 sum4[n] += (int)va[4] * vb[n + 48];
1298 sum5[n] += (int)va[5] * vb[n + 48];
1299 sum6[n] += (int)va[6] * vb[n + 48];
1300 sum7[n] += (int)va[7] * vb[n + 48];
1301 va += 8;
1302
1303 sum0[n] += (int)va[0] * vb[n + 56];
1304 sum1[n] += (int)va[1] * vb[n + 56];
1305 sum2[n] += (int)va[2] * vb[n + 56];
1306 sum3[n] += (int)va[3] * vb[n + 56];
1307 sum4[n] += (int)va[4] * vb[n + 56];
1308 sum5[n] += (int)va[5] * vb[n + 56];
1309 sum6[n] += (int)va[6] * vb[n + 56];
1310 sum7[n] += (int)va[7] * vb[n + 56];
1311 va -= 56;
1312 }
1313
1314 va += 64;
1315 vb += 64;
1316 }
1317
1318 for (; k < L; k++)
1319 {
1320 for (int n = 0; n < 8; n++)
1321 {
1322 sum0[n] += (int)va[0] * vb[n];
1323 sum1[n] += (int)va[1] * vb[n];
1324 sum2[n] += (int)va[2] * vb[n];
1325 sum3[n] += (int)va[3] * vb[n];
1326 sum4[n] += (int)va[4] * vb[n];
1327 sum5[n] += (int)va[5] * vb[n];
1328 sum6[n] += (int)va[6] * vb[n];
1329 sum7[n] += (int)va[7] * vb[n];
1330 }
1331
1332 va += 8;
1333 vb += 8;
1334 }
1335
1336 for (int n = 0; n < 8; n++)
1337 {
1338 output0[n] = sum0[n];
1339 output1[n] = sum1[n];
1340 output2[n] = sum2[n];
1341 output3[n] = sum3[n];
1342 output4[n] = sum4[n];
1343 output5[n] = sum5[n];
1344 output6[n] = sum6[n];
1345 output7[n] = sum7[n];
1346 }
1347 #endif // __aarch64__
1348 output0 += 8;
1349 output1 += 8;
1350 output2 += 8;
1351 output3 += 8;
1352 output4 += 8;
1353 output5 += 8;
1354 output6 += 8;
1355 output7 += 8;
1356 }
1357
1358 for (; j < N; j++)
1359 {
1360 signed char* vb = bottom_tm.channel(j / 8 + j % 8);
1361 const signed char* va = kernel_tm.channel(i / 8);
1362
1363 #if __aarch64__
1364 asm volatile(
1365 "eor v14.16b, v14.16b, v14.16b \n" // sum0_3
1366 "eor v15.16b, v15.16b, v15.16b \n" // sum4_7
1367 "eor v16.16b, v16.16b, v16.16b \n" // sum0
1368 "eor v17.16b, v17.16b, v17.16b \n" // sum1
1369 "eor v18.16b, v18.16b, v18.16b \n" // sum2
1370 "eor v19.16b, v19.16b, v19.16b \n" // sum3
1371 "eor v20.16b, v20.16b, v20.16b \n" // sum4
1372 "eor v21.16b, v21.16b, v21.16b \n" // sum5
1373 "eor v22.16b, v22.16b, v22.16b \n" // sum6
1374 "eor v23.16b, v23.16b, v23.16b \n" // sum7
1375
1376 "lsr w4, %w20, #2 \n" // r4 = nn = L >> 2
1377 "cmp w4, #0 \n"
1378 "beq 1f \n"
1379
1380 "0: \n" // for (; k+3<L; k=k+4)
1381
1382 "prfm pldl1keep, [%9, #128] \n"
1383 "ld1 {v0.8b, v1.8b, v2.8b, v3.8b}, [%9], #32 \n" // k
1384
1385 //"prfm pldl1keep, [%8, #128] \n"
1386 "ld1 {v4.8b}, [%8] \n" // d
1387 "add %8, %8, #4 \n"
1388
1389 "sshll v0.8h, v0.8b, #0 \n" // k00 - k70
1390 "sshll v1.8h, v1.8b, #0 \n" // k01 - k71
1391 "sshll v2.8h, v2.8b, #0 \n" // k02 - k72
1392 "sshll v3.8h, v3.8b, #0 \n" // k03 - k73
1393
1394 "sshll v4.8h, v4.8b, #0 \n" // a00 - a30
1395
1396 // k0
1397 "smlal v16.4s, v0.4h, v4.h[0] \n" // sum0 += (k00-k70) * a00
1398 "smlal2 v17.4s, v0.8h, v4.h[0] \n" //
1399 "smlal v18.4s, v1.4h, v4.h[1] \n" // sum1 += (k01-k71) * a10
1400 "smlal2 v19.4s, v1.8h, v4.h[1] \n" //
1401 "smlal v20.4s, v2.4h, v4.h[2] \n" // sum2 += (k02-k72) * a20
1402 "smlal2 v21.4s, v2.8h, v4.h[2] \n" //
1403 "smlal v22.4s, v3.4h, v4.h[3] \n" // sum3 += (k03-k73) * a30
1404 "smlal2 v23.4s, v3.8h, v4.h[3] \n" //
1405
1406 "subs w4, w4, #1 \n"
1407 "bne 0b \n"
1408
1409 "add v16.4s, v16.4s, v18.4s \n"
1410 "add v17.4s, v17.4s, v19.4s \n"
1411 "add v20.4s, v20.4s, v22.4s \n"
1412 "add v21.4s, v21.4s, v23.4s \n"
1413 "add v14.4s, v16.4s, v20.4s \n"
1414 "add v15.4s, v17.4s, v21.4s \n"
1415
1416 "1: \n"
1417
1418 // remain loop
1419 "and w4, %w20, #3 \n" // w4 = remain = inch & 3;
1420 "cmp w4, #0 \n"
1421 "beq 3f \n"
1422
1423 "2: \n"
1424
1425 //"prfm pldl1keep, [%9, #128] \n"
1426 "ld1 {v0.8b}, [%9], #8 \n"
1427 //"prfm pldl1keep, [%8, #128] \n"
1428 "ld1 {v4.8b}, [%8] \n"
1429 "add %8, %8, #1 \n"
1430
1431 "sshll v0.8h, v0.8b, #0 \n" // k00 - k70
1432 "sshll v4.8h, v4.8b, #0 \n" // a00
1433
1434 // k0
1435 "smlal v14.4s, v0.4h, v4.h[0] \n" // sum0 += (k00-k70) * a00
1436 "smlal2 v15.4s, v0.8h, v4.h[0] \n" //
1437
1438 "subs w4, w4, #1 \n"
1439
1440 "bne 2b \n"
1441
1442 "3: \n"
1443
1444 "st1 {v14.s}[0], [%0] \n"
1445 "st1 {v14.s}[1], [%1] \n"
1446 "st1 {v14.s}[2], [%2] \n"
1447 "st1 {v14.s}[3], [%3] \n"
1448 "st1 {v15.s}[0], [%4] \n"
1449 "st1 {v15.s}[1], [%5] \n"
1450 "st1 {v15.s}[2], [%6] \n"
1451 "st1 {v15.s}[3], [%7] \n"
1452
1453 : "=r"(output0), // %0
1454 "=r"(output1), // %1
1455 "=r"(output2), // %2
1456 "=r"(output3), // %3
1457 "=r"(output4), // %4
1458 "=r"(output5), // %5
1459 "=r"(output6), // %6
1460 "=r"(output7), // %7
1461 "=r"(vb), // %8
1462 "=r"(va) // %9
1463 : "0"(output0),
1464 "1"(output1),
1465 "2"(output2),
1466 "3"(output3),
1467 "4"(output4),
1468 "5"(output5),
1469 "6"(output6),
1470 "7"(output7),
1471 "8"(vb),
1472 "9"(va),
1473 "r"(L) // %20
1474 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1475 #else
1476 int sum0 = 0;
1477 int sum1 = 0;
1478 int sum2 = 0;
1479 int sum3 = 0;
1480 int sum4 = 0;
1481 int sum5 = 0;
1482 int sum6 = 0;
1483 int sum7 = 0;
1484
1485 for (int k = 0; k < L; k++)
1486 {
1487 sum0 += (int)va[0] * vb[0];
1488 sum1 += (int)va[1] * vb[0];
1489 sum2 += (int)va[2] * vb[0];
1490 sum3 += (int)va[3] * vb[0];
1491 sum4 += (int)va[4] * vb[0];
1492 sum5 += (int)va[5] * vb[0];
1493 sum6 += (int)va[6] * vb[0];
1494 sum7 += (int)va[7] * vb[0];
1495
1496 va += 8;
1497 vb += 1;
1498 }
1499
1500 output0[0] = sum0;
1501 output1[0] = sum1;
1502 output2[0] = sum2;
1503 output3[0] = sum3;
1504 output4[0] = sum4;
1505 output5[0] = sum5;
1506 output6[0] = sum6;
1507 output7[0] = sum7;
1508 #endif // __aarch64__
1509 output0++;
1510 output1++;
1511 output2++;
1512 output3++;
1513 output4++;
1514 output5++;
1515 output6++;
1516 output7++;
1517 }
1518 }
1519 #endif // __ARM_NEON && __aarch64__
1520
1521 nn_outch = (outch - remain_outch_start) >> 2;
1522
1523 #pragma omp parallel for num_threads(opt.num_threads)
1524 for (int pp = 0; pp < nn_outch; pp++)
1525 {
1526 int i = remain_outch_start + pp * 4;
1527
1528 int* output0 = top_blob.channel(i);
1529 int* output1 = top_blob.channel(i + 1);
1530 int* output2 = top_blob.channel(i + 2);
1531 int* output3 = top_blob.channel(i + 3);
1532
1533 int j = 0;
1534 for (; j + 7 < N; j = j + 8)
1535 {
1536 signed char* vb = bottom_tm.channel(j / 8);
1537 #if __ARM_NEON && __aarch64__
1538 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
1539 #else
1540 const signed char* va = kernel_tm.channel(i / 4);
1541 #endif // __ARM_NEON && __aarch64__
1542
1543 #if __ARM_NEON
1544 #if __aarch64__
1545 asm volatile(
1546 "eor v16.16b, v16.16b, v16.16b \n" // sum0
1547 "eor v17.16b, v17.16b, v17.16b \n" // sum0n
1548 "eor v18.16b, v18.16b, v18.16b \n" // sum1
1549 "eor v19.16b, v19.16b, v19.16b \n" // sum1n
1550 "eor v20.16b, v20.16b, v20.16b \n" // sum2
1551 "eor v21.16b, v21.16b, v21.16b \n" // sum2n
1552 "eor v22.16b, v22.16b, v22.16b \n" // sum3
1553 "eor v23.16b, v23.16b, v23.16b \n" // sum3n
1554
1555 "lsr w4, %w12, #2 \n" // r4 = nn = L >> 2
1556 "cmp w4, #0 \n"
1557 "beq 1f \n"
1558
1559 "0: \n" // for (; k+3<L; k=k+4)
1560
1561 "prfm pldl1keep, [%5, #128] \n"
1562 "ld1 {v0.8b, v1.8b}, [%5], #16 \n"
1563
1564 "prfm pldl1keep, [%4, #128] \n"
1565 "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%4], #32 \n"
1566
1567 "sshll v0.8h, v0.8b, #0 \n" // k00 - k30,k01 - k31
1568 "sshll v1.8h, v1.8b, #0 \n" // k02 - k32,k03 - k33
1569
1570 "sshll v8.8h, v8.8b, #0 \n" // a00 - a70
1571 "sshll v9.8h, v9.8b, #0 \n" // a01 - a71
1572 "sshll v10.8h, v10.8b, #0 \n" // a02 - a72
1573 "sshll v11.8h, v11.8b, #0 \n" // a03 - a73
1574
1575 // k0
1576 "smlal v16.4s, v8.4h, v0.h[0] \n" // sum0 += (a00-a70) * k00
1577 "smlal2 v17.4s, v8.8h, v0.h[0] \n" //
1578 "smlal v18.4s, v8.4h, v0.h[1] \n" // sum1 += (a00-a70) * k10
1579 "smlal2 v19.4s, v8.8h, v0.h[1] \n" //
1580 "smlal v20.4s, v8.4h, v0.h[2] \n" // sum2 += (a00-a70) * k20
1581 "smlal2 v21.4s, v8.8h, v0.h[2] \n" //
1582 "smlal v22.4s, v8.4h, v0.h[3] \n" // sum3 += (a00-a70) * k30
1583 "smlal2 v23.4s, v8.8h, v0.h[3] \n" //
1584 // k1
1585 "smlal v16.4s, v9.4h, v0.h[4] \n" // sum0 += (a01-a71) * k01
1586 "smlal2 v17.4s, v9.8h, v0.h[4] \n" //
1587 "smlal v18.4s, v9.4h, v0.h[5] \n" // sum1 += (a01-a71) * k11
1588 "smlal2 v19.4s, v9.8h, v0.h[5] \n" //
1589 "smlal v20.4s, v9.4h, v0.h[6] \n" // sum2 += (a01-a71) * k21
1590 "smlal2 v21.4s, v9.8h, v0.h[6] \n" //
1591 "smlal v22.4s, v9.4h, v0.h[7] \n" // sum3 += (a01-a71) * k31
1592 "smlal2 v23.4s, v9.8h, v0.h[7] \n" //
1593 // k2
1594 "smlal v16.4s, v10.4h, v1.h[0] \n" // sum0 += (a02-a72) * k02
1595 "smlal2 v17.4s, v10.8h, v1.h[0] \n" //
1596 "smlal v18.4s, v10.4h, v1.h[1] \n" // sum1 += (a02-a72) * k12
1597 "smlal2 v19.4s, v10.8h, v1.h[1] \n" //
1598 "smlal v20.4s, v10.4h, v1.h[2] \n" // sum2 += (a02-a72) * k22
1599 "smlal2 v21.4s, v10.8h, v1.h[2] \n" //
1600 "smlal v22.4s, v10.4h, v1.h[3] \n" // sum3 += (a02-a72) * k32
1601 "smlal2 v23.4s, v10.8h, v1.h[3] \n" //
1602 // k3
1603 "smlal v16.4s, v11.4h, v1.h[4] \n" // sum0 += (a03-a73) * k03
1604 "smlal2 v17.4s, v11.8h, v1.h[4] \n" //
1605 "smlal v18.4s, v11.4h, v1.h[5] \n" // sum1 += (a03-a73) * k13
1606 "smlal2 v19.4s, v11.8h, v1.h[5] \n" //
1607 "smlal v20.4s, v11.4h, v1.h[6] \n" // sum2 += (a03-a73) * k23
1608 "smlal2 v21.4s, v11.8h, v1.h[6] \n" //
1609 "smlal v22.4s, v11.4h, v1.h[7] \n" // sum3 += (a03-a73) * k33
1610 "smlal2 v23.4s, v11.8h, v1.h[7] \n" //
1611
1612 "subs w4, w4, #1 \n"
1613 "bne 0b \n"
1614
1615 "1: \n"
1616
1617 // remain loop
1618 "and w4, %w12, #3 \n" // w4 = remain = inch & 3;
1619 "cmp w4, #0 \n"
1620 "beq 3f \n"
1621
1622 "2: \n"
1623
1624 //"prfm pldl1keep, [%5, #128] \n"
1625 "ld1 {v0.8b}, [%5] \n"
1626 //"prfm pldl1keep, [%4, #128] \n"
1627 "ld1 {v8.8b}, [%4], #8 \n"
1628 "add %5, %5, #4 \n"
1629
1630 "sshll v0.8h, v0.8b, #0 \n" // k00 - k30
1631 "sshll v8.8h, v8.8b, #0 \n" // a00 - a70
1632
1633 // k0
1634 "smlal v16.4s, v8.4h, v0.h[0] \n" // sum0 += (a00-a70) * k00
1635 "smlal2 v17.4s, v8.8h, v0.h[0] \n" //
1636 "smlal v18.4s, v8.4h, v0.h[1] \n" // sum1 += (a00-a70) * k10
1637 "smlal2 v19.4s, v8.8h, v0.h[1] \n" //
1638 "smlal v20.4s, v8.4h, v0.h[2] \n" // sum2 += (a00-a70) * k20
1639 "smlal2 v21.4s, v8.8h, v0.h[2] \n" //
1640 "smlal v22.4s, v8.4h, v0.h[3] \n" // sum3 += (a00-a70) * k30
1641 "smlal2 v23.4s, v8.8h, v0.h[3] \n" //
1642
1643 "subs w4, w4, #1 \n"
1644
1645 "bne 2b \n"
1646
1647 "3: \n"
1648
1649 "st1 {v16.4s, v17.4s}, [%0] \n"
1650 "st1 {v18.4s, v19.4s}, [%1] \n"
1651 "st1 {v20.4s, v21.4s}, [%2] \n"
1652 "st1 {v22.4s, v23.4s}, [%3] \n"
1653
1654 : "=r"(output0), // %0
1655 "=r"(output1), // %1
1656 "=r"(output2), // %2
1657 "=r"(output3), // %3
1658 "=r"(vb), // %4
1659 "=r"(va) // %5
1660 : "0"(output0),
1661 "1"(output1),
1662 "2"(output2),
1663 "3"(output3),
1664 "4"(vb),
1665 "5"(va),
1666 "r"(L) // %12
1667 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1668 #else
1669 asm volatile(
1670 // K loop
1671 "vmov.s32 q8, #0 \n"
1672 "vmov.s32 q9, #0 \n"
1673 "vmov.s32 q10, #0 \n"
1674 "vmov.s32 q11, #0 \n"
1675 "vmov.s32 q12, #0 \n"
1676 "vmov.s32 q13, #0 \n"
1677 "vmov.s32 q14, #0 \n"
1678 "vmov.s32 q15, #0 \n"
1679
1680 "lsr r4, %12, #3 \n" // r4 = nn = L >> 3
1681 "cmp r4, #0 \n"
1682 "beq 1f \n"
1683
1684 "0: \n" // for(; nn != 0; nn--)
1685 "pld [%4, #128] \n"
1686 "vld1.s8 {d8-d11}, [%4]! \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37 a(inch)(data)
1687 "vmovl.s8 q7, d11 \n" // a30-a37
1688 "vmovl.s8 q6, d10 \n" // a20-a27
1689 "vmovl.s8 q5, d9 \n" // a10-a17
1690 "vmovl.s8 q4, d8 \n" // a00-a07
1691
1692 "pld [%5, #128] \n"
1693 "vld1.s8 {d0-d3}, [%5]! \n" // kptr k00-k30,k01-k31, k02-k32,k03-k33, k04-k34,k05-k35, k06-k36,k07-k37 k(outch)(inch)
1694 "vmovl.s8 q3, d3 \n" // k06-k36,k07-k37
1695 "vmovl.s8 q2, d2 \n" // k04-k34,k05-k35
1696 "vmovl.s8 q1, d1 \n" // k02-k32,k03-k33
1697 "vmovl.s8 q0, d0 \n" // k00-k30,k01-k31
1698
1699 "vmlal.s16 q8, d8, d0[0] \n" // sum0 = (a00-a07) * k00
1700 "vmlal.s16 q9, d9, d0[0] \n"
1701 "vmlal.s16 q10, d8, d0[1] \n" // sum1 = (a00-a07) * k10
1702 "vmlal.s16 q11, d9, d0[1] \n"
1703 "vmlal.s16 q12, d8, d0[2] \n" // sum2 = (a00-a07) * k20
1704 "vmlal.s16 q13, d9, d0[2] \n"
1705 "vmlal.s16 q14, d8, d0[3] \n" // sum3 = (a00-a07) * k30
1706 "vmlal.s16 q15, d9, d0[3] \n"
1707
1708 "vmlal.s16 q8, d10, d1[0] \n" // sum0 += (a10-a17) * k01
1709 "vmlal.s16 q9, d11, d1[0] \n"
1710 "vmlal.s16 q10, d10, d1[1] \n" // sum1 += (a10-a17) * k11
1711 "vmlal.s16 q11, d11, d1[1] \n"
1712 "vmlal.s16 q12, d10, d1[2] \n" // sum2 += (a10-a17) * k21
1713 "vmlal.s16 q13, d11, d1[2] \n"
1714 "vmlal.s16 q14, d10, d1[3] \n" // sum3 += (a10-a17) * k31
1715 "vmlal.s16 q15, d11, d1[3] \n"
1716
1717 "pld [%4, #128] \n"
1718 "vld1.s8 {d8-d9}, [%4]! \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37 a(inch)(data)
1719 "vmovl.s8 q5, d9 \n" // a10-a17
1720 "vmovl.s8 q4, d8 \n" // a00-a07
1721
1722 "vmlal.s16 q8, d12, d2[0] \n" // sum0 += (a20-a27) * k02
1723 "vmlal.s16 q9, d13, d2[0] \n"
1724 "vmlal.s16 q10, d12, d2[1] \n" // sum1 += (a20-a27) * k12
1725 "vmlal.s16 q11, d13, d2[1] \n"
1726 "vmlal.s16 q12, d12, d2[2] \n" // sum2 += (a20-a27) * k22
1727 "vmlal.s16 q13, d13, d2[2] \n"
1728 "vmlal.s16 q14, d12, d2[3] \n" // sum3 += (a20-a27) * k32
1729 "vmlal.s16 q15, d13, d2[3] \n"
1730
1731 "vmlal.s16 q8, d14, d3[0] \n" // sum0 += (a30-a37) * k03
1732 "vmlal.s16 q9, d15, d3[0] \n"
1733 "vmlal.s16 q10, d14, d3[1] \n" // sum1 += (a30-a37) * k13
1734 "vmlal.s16 q11, d15, d3[1] \n"
1735 "vmlal.s16 q12, d14, d3[2] \n" // sum2 += (a30-a37) * k23
1736 "vmlal.s16 q13, d15, d3[2] \n"
1737 "vmlal.s16 q14, d14, d3[3] \n" // sum3 += (a30-a37) * k33
1738 "vmlal.s16 q15, d15, d3[3] \n"
1739
1740 "pld [%4, #128] \n"
1741 "vld1.s8 {d0-d1}, [%4]! \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37 a(inch)(data)
1742 "vmovl.s8 q1, d1 \n" // a10-a17
1743 "vmovl.s8 q0, d0 \n" // a00-a07
1744
1745 "vmlal.s16 q8, d8, d4[0] \n" // sum0 += (a40-a47) * k04
1746 "vmlal.s16 q9, d9, d4[0] \n"
1747 "vmlal.s16 q10, d8, d4[1] \n" // sum1 += (a40-a47) * k14
1748 "vmlal.s16 q11, d9, d4[1] \n"
1749 "vmlal.s16 q12, d8, d4[2] \n" // sum2 += (a40-a47) * k24
1750 "vmlal.s16 q13, d9, d4[2] \n"
1751 "vmlal.s16 q14, d8, d4[3] \n" // sum3 += (a40-a47) * k34
1752 "vmlal.s16 q15, d9, d4[3] \n"
1753
1754 "vmlal.s16 q8, d10, d5[0] \n" // sum0 += (a50-a57) * k05
1755 "vmlal.s16 q9, d11, d5[0] \n"
1756 "vmlal.s16 q10, d10, d5[1] \n" // sum1 += (a50-a57) * k15
1757 "vmlal.s16 q11, d11, d5[1] \n"
1758 "vmlal.s16 q12, d10, d5[2] \n" // sum2 += (a50-a57) * k25
1759 "vmlal.s16 q13, d11, d5[2] \n"
1760 "vmlal.s16 q14, d10, d5[3] \n" // sum3 += (a50-a57) * k35
1761 "vmlal.s16 q15, d11, d5[3] \n"
1762
1763 "vmlal.s16 q8, d0, d6[0] \n" // sum0 += (a60-a67) * k06
1764 "vmlal.s16 q9, d1, d6[0] \n"
1765 "vmlal.s16 q10, d0, d6[1] \n" // sum1 += (a60-a67) * k16
1766 "vmlal.s16 q11, d1, d6[1] \n"
1767 "vmlal.s16 q12, d0, d6[2] \n" // sum2 += (a60-a67) * k26
1768 "vmlal.s16 q13, d1, d6[2] \n"
1769 "vmlal.s16 q14, d0, d6[3] \n" // sum3 += (a60-a67) * k36
1770 "vmlal.s16 q15, d1, d6[3] \n"
1771
1772 "vmlal.s16 q8, d2, d7[0] \n" // sum0 += (a70-a77) * k07
1773 "vmlal.s16 q9, d3, d7[0] \n"
1774 "vmlal.s16 q10, d2, d7[1] \n" // sum1 += (a70-a77) * k17
1775 "vmlal.s16 q11, d3, d7[1] \n"
1776 "vmlal.s16 q12, d2, d7[2] \n" // sum2 += (a70-a77) * k27
1777 "vmlal.s16 q13, d3, d7[2] \n"
1778 "vmlal.s16 q14, d2, d7[3] \n" // sum3 += (a70-a77) * k37
1779 "vmlal.s16 q15, d3, d7[3] \n"
1780
1781 "subs r4, r4, #1 \n"
1782 "bne 0b \n" // end for
1783
1784 "1: \n"
1785 // remain loop
1786 "and r4, %12, #7 \n" // r4 = remain = inch & 7
1787 "cmp r4, #0 \n"
1788 "beq 3f \n"
1789
1790 "2: \n" // for(; remain != 0; remain--)
1791 "vld1.s8 {d2}, [%4]! \n" // tmpr a00-a70 a(inch)(data)
1792 "vld1.s8 {d0}, [%5] \n" // kptr k00-k30 k(outch)(inch)
1793 "vmovl.s8 q1, d2 \n"
1794 "vmovl.s8 q0, d0 \n"
1795 "add %5, #4 \n"
1796
1797 "vmlal.s16 q8, d2, d0[0] \n" // sum0 += (a00-a70) * k00
1798 "vmlal.s16 q9, d3, d0[0] \n"
1799 "vmlal.s16 q10, d2, d0[1] \n" // sum1 += (a00-a70) * k10
1800 "vmlal.s16 q11, d3, d0[1] \n"
1801 "vmlal.s16 q12, d2, d0[2] \n" // sum2 += (a00-a70) * k20
1802 "vmlal.s16 q13, d3, d0[2] \n"
1803 "vmlal.s16 q14, d2, d0[3] \n" // sum3 += (a00-a70) * k30
1804 "vmlal.s16 q15, d3, d0[3] \n"
1805
1806 "subs r4, r4, #1 \n"
1807 "bne 2b \n"
1808
1809 "3: \n" // store the result to memory
1810 "vst1.s32 {d16-d19}, [%0] \n"
1811 "vst1.s32 {d20-d23}, [%1] \n"
1812 "vst1.s32 {d24-d27}, [%2] \n"
1813 "vst1.s32 {d28-d31}, [%3] \n"
1814
1815 : "=r"(output0), // %0
1816 "=r"(output1), // %1
1817 "=r"(output2), // %2
1818 "=r"(output3), // %3
1819 "=r"(vb), // %4
1820 "=r"(va) // %5
1821 : "0"(output0),
1822 "1"(output1),
1823 "2"(output2),
1824 "3"(output3),
1825 "4"(vb),
1826 "5"(va),
1827 "r"(L) // %12
1828 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1829 #endif // __aarch64__
1830 #else
1831 int sum0[8] = {0};
1832 int sum1[8] = {0};
1833 int sum2[8] = {0};
1834 int sum3[8] = {0};
1835
1836 int k = 0;
1837 for (; k + 7 < L; k = k + 8)
1838 {
1839 for (int n = 0; n < 8; n++)
1840 {
1841 sum0[n] += (int)va[0] * vb[n];
1842 sum1[n] += (int)va[1] * vb[n];
1843 sum2[n] += (int)va[2] * vb[n];
1844 sum3[n] += (int)va[3] * vb[n];
1845 va += 4;
1846
1847 sum0[n] += (int)va[0] * vb[n + 8];
1848 sum1[n] += (int)va[1] * vb[n + 8];
1849 sum2[n] += (int)va[2] * vb[n + 8];
1850 sum3[n] += (int)va[3] * vb[n + 8];
1851 va += 4;
1852
1853 sum0[n] += (int)va[0] * vb[n + 16];
1854 sum1[n] += (int)va[1] * vb[n + 16];
1855 sum2[n] += (int)va[2] * vb[n + 16];
1856 sum3[n] += (int)va[3] * vb[n + 16];
1857 va += 4;
1858
1859 sum0[n] += (int)va[0] * vb[n + 24];
1860 sum1[n] += (int)va[1] * vb[n + 24];
1861 sum2[n] += (int)va[2] * vb[n + 24];
1862 sum3[n] += (int)va[3] * vb[n + 24];
1863 va += 4;
1864
1865 sum0[n] += (int)va[0] * vb[n + 32];
1866 sum1[n] += (int)va[1] * vb[n + 32];
1867 sum2[n] += (int)va[2] * vb[n + 32];
1868 sum3[n] += (int)va[3] * vb[n + 32];
1869 va += 4;
1870
1871 sum0[n] += (int)va[0] * vb[n + 40];
1872 sum1[n] += (int)va[1] * vb[n + 40];
1873 sum2[n] += (int)va[2] * vb[n + 40];
1874 sum3[n] += (int)va[3] * vb[n + 40];
1875 va += 4;
1876
1877 sum0[n] += (int)va[0] * vb[n + 48];
1878 sum1[n] += (int)va[1] * vb[n + 48];
1879 sum2[n] += (int)va[2] * vb[n + 48];
1880 sum3[n] += (int)va[3] * vb[n + 48];
1881 va += 4;
1882
1883 sum0[n] += (int)va[0] * vb[n + 56];
1884 sum1[n] += (int)va[1] * vb[n + 56];
1885 sum2[n] += (int)va[2] * vb[n + 56];
1886 sum3[n] += (int)va[3] * vb[n + 56];
1887 va -= 28;
1888 }
1889
1890 va += 32;
1891 vb += 64;
1892 }
1893
1894 for (; k < L; k++)
1895 {
1896 for (int n = 0; n < 8; n++)
1897 {
1898 sum0[n] += (int)va[0] * vb[n];
1899 sum1[n] += (int)va[1] * vb[n];
1900 sum2[n] += (int)va[2] * vb[n];
1901 sum3[n] += (int)va[3] * vb[n];
1902 }
1903
1904 va += 4;
1905 vb += 8;
1906 }
1907
1908 for (int n = 0; n < 8; n++)
1909 {
1910 output0[n] = sum0[n];
1911 output1[n] = sum1[n];
1912 output2[n] = sum2[n];
1913 output3[n] = sum3[n];
1914 }
1915 #endif // __ARM_NEON
1916 output0 += 8;
1917 output1 += 8;
1918 output2 += 8;
1919 output3 += 8;
1920 }
1921
1922 for (; j < N; j++)
1923 {
1924 signed char* vb = bottom_tm.channel(j / 8 + j % 8);
1925 #if __ARM_NEON && __aarch64__
1926 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
1927 #else
1928 const signed char* va = kernel_tm.channel(i / 4);
1929 #endif // __ARM_NEON && __aarch64__
1930
1931 #if __ARM_NEON
1932 #if __aarch64__
1933 asm volatile(
1934 "eor v14.16b, v14.16b, v14.16b \n" // sum0_3
1935 "eor v16.16b, v16.16b, v16.16b \n" // sum0
1936 "eor v17.16b, v17.16b, v17.16b \n" // sum1
1937 "eor v18.16b, v18.16b, v18.16b \n" // sum2
1938 "eor v19.16b, v19.16b, v19.16b \n" // sum3
1939
1940 "lsr w4, %w12, #2 \n" // r4 = nn = L >> 2
1941 "cmp w4, #0 \n"
1942 "beq 1f \n"
1943
1944 "0: \n" // for (; k+3<L; k=k+4)
1945
1946 "prfm pldl1keep, [%5, #128] \n"
1947 "ld1 {v0.8b, v1.8b}, [%5], #16 \n" // k
1948
1949 //"prfm pldl1keep, [%4, #128] \n"
1950 "ld1 {v4.8b}, [%4] \n" // d
1951 "add %4, %4, #4 \n"
1952
1953 "sshll v0.8h, v0.8b, #0 \n" // k00 - k30,k01 - k31
1954 "sshll v1.8h, v1.8b, #0 \n" // k02 - k32,k03 - k33
1955 "sshll v4.8h, v4.8b, #0 \n" // a00 - a30
1956
1957 "subs w4, w4, #1 \n"
1958 // k0
1959 "smlal v16.4s, v0.4h, v4.h[0] \n" // sum0 += (k00-k30) * a00
1960 "smlal2 v17.4s, v0.8h, v4.h[0] \n" // sum1 += (k01-k31) * a10
1961 "smlal v18.4s, v1.4h, v4.h[1] \n" // sum2 += (k02-k32) * a20
1962 "smlal2 v19.4s, v1.8h, v4.h[1] \n" // sum3 += (k03-k33) * a30
1963
1964 "bne 0b \n"
1965
1966 "add v16.4s, v16.4s, v18.4s \n"
1967 "add v17.4s, v17.4s, v19.4s \n"
1968 "add v14.4s, v16.4s, v17.4s \n"
1969
1970 "1: \n"
1971
1972 // remain loop
1973 "and w4, %w12, #3 \n" // w4 = remain = inch & 3;
1974 "cmp w4, #0 \n"
1975 "beq 3f \n"
1976
1977 "2: \n"
1978
1979 //"prfm pldl1keep, [%5, #128] \n"
1980 "ld1 {v0.8b}, [%5] \n"
1981 //"prfm pldl1keep, [4, #128] \n"
1982 "ld1 {v4.8b}, [%4] \n"
1983 "add %4, %4, #1 \n"
1984 "add %5, %5, #4 \n"
1985
1986 "subs w4, w4, #1 \n"
1987
1988 "sshll v0.8h, v0.8b, #0 \n" // k00 - k30
1989 "sshll v4.8h, v4.8b, #0 \n" // a00
1990 // k0
1991 "smlal v14.4s, v0.4h, v4.h[0] \n" // sum0 += (k00-k30) * a00
1992
1993 "bne 2b \n"
1994
1995 "3: \n"
1996
1997 "st1 {v14.s}[0], [%0] \n"
1998 "st1 {v14.s}[1], [%1] \n"
1999 "st1 {v14.s}[2], [%2] \n"
2000 "st1 {v14.s}[3], [%3] \n"
2001
2002 : "=r"(output0), // %0
2003 "=r"(output1), // %1
2004 "=r"(output2), // %2
2005 "=r"(output3), // %3
2006 "=r"(vb), // %4
2007 "=r"(va) // %5
2008 : "0"(output0),
2009 "1"(output1),
2010 "2"(output2),
2011 "3"(output3),
2012 "4"(vb),
2013 "5"(va),
2014 "r"(L) // %12
2015 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
2016 #else
2017 asm volatile(
2018 // inch loop
2019 "veor q6, q6, q6 \n"
2020 "veor q7, q7, q7 \n"
2021 "veor q8, q8, q8 \n"
2022 "veor q9, q9, q9 \n"
2023 "veor q10, q10, q10 \n"
2024 "veor q11, q11, q11 \n"
2025 "veor q12, q12, q12 \n"
2026 "veor q13, q13, q13 \n"
2027 "vmov.s32 q14, #0 \n"
2028
2029 "lsr r4, %12, #3 \n" // r4 = nn = L >> 2
2030 "cmp r4, #0 \n"
2031 "beq 1f \n"
2032
2033 "0: \n" // for(; nn != 0; nn--)
2034 "pld [%4, #128] \n"
2035 "vld1.s8 {d0}, [%4]! \n" // tmpr a00,a10,a20,a30 a(inch)(data)
2036 "vmovl.s8 q0, d0 \n" // a00-a07
2037
2038 "pld [%5, #128] \n"
2039 "vld1.s8 {d2-d5}, [%5]! \n" // kptr k00-k30,k01-k31, k02-k32,k03-k33, k04-k34,k05-k35, k06-k36,k07-k37 k(outch)(inch)
2040 "vmovl.s8 q4, d5 \n" // k06-k36,k07-k37
2041 "vmovl.s8 q3, d4 \n" // k04-k34,k05-k35
2042 "vmovl.s8 q2, d3 \n" // k02-k32,k03-k33
2043 "vmovl.s8 q1, d2 \n" // k00-k30,k01-k31
2044
2045 "vmlal.s16 q6, d2, d0[0] \n" // (k00-k30) * a00
2046 "vmlal.s16 q7, d3, d0[1] \n" // (k01-k31) * a01
2047 "vmlal.s16 q8, d4, d0[2] \n" // (k02-k32) * a02
2048 "vmlal.s16 q9, d5, d0[3] \n" // (k03-k33) * a03
2049 "vmlal.s16 q10, d6, d1[0] \n" // (k04-k34) * a04
2050 "vmlal.s16 q11, d7, d1[1] \n" // (k05-k35) * a05
2051 "vmlal.s16 q12, d8, d1[2] \n" // (k06-k36) * a06
2052 "vmlal.s16 q13, d9, d1[3] \n" // (k07-k37) * a07
2053
2054 "subs r4, r4, #1 \n"
2055 "bne 0b \n" // end for
2056
2057 "vadd.s32 q6, q6, q7 \n"
2058 "vadd.s32 q9, q9, q8 \n"
2059 "vadd.s32 q11, q11, q10 \n"
2060 "vadd.s32 q13, q13, q12 \n"
2061
2062 "vadd.s32 q9, q9, q6 \n"
2063 "vadd.s32 q13, q13, q11 \n"
2064 "vadd.s32 q14, q13, q9 \n"
2065
2066 "1: \n"
2067 // remain loop
2068 "and r4, %12, #7 \n" // r4 = remain = inch & 3
2069 "cmp r4, #0 \n"
2070 "beq 3f \n"
2071
2072 "2: \n" // for(; remain != 0; remain--)
2073 "vld1.s8 {d2}, [%4] \n" // tmpr a00 a(inch)(data)
2074 "vld1.s8 {d0}, [%5] \n" // kptr k00-k30 k(outch)(inch)
2075 "vmovl.s8 q1, d2 \n"
2076 "vmovl.s8 q0, d0 \n"
2077 "add %4, #1 \n"
2078 "add %5, #4 \n"
2079
2080 "vmlal.s16 q14, d0, d2[0] \n"
2081
2082 "subs r4, r4, #1 \n"
2083 "bne 2b \n"
2084
2085 "3: \n" // store the result to memory
2086 "vst1.s32 {d28[0]}, [%0] \n"
2087 "vst1.s32 {d28[1]}, [%1] \n"
2088 "vst1.s32 {d29[0]}, [%2] \n"
2089 "vst1.s32 {d29[1]}, [%3] \n"
2090
2091 : "=r"(output0), // %0
2092 "=r"(output1), // %1
2093 "=r"(output2), // %2
2094 "=r"(output3), // %3
2095 "=r"(vb), // %4
2096 "=r"(va) // %5
2097 : "0"(output0),
2098 "1"(output1),
2099 "2"(output2),
2100 "3"(output3),
2101 "4"(vb),
2102 "5"(va),
2103 "r"(L) // %12
2104 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14");
2105 #endif // __aarch64__
2106 #else
2107 int sum0 = 0;
2108 int sum1 = 0;
2109 int sum2 = 0;
2110 int sum3 = 0;
2111
2112 for (int k = 0; k < L; k++)
2113 {
2114 sum0 += (int)va[0] * vb[0];
2115 sum1 += (int)va[1] * vb[0];
2116 sum2 += (int)va[2] * vb[0];
2117 sum3 += (int)va[3] * vb[0];
2118
2119 va += 4;
2120 vb += 1;
2121 }
2122
2123 output0[0] = sum0;
2124 output1[0] = sum1;
2125 output2[0] = sum2;
2126 output3[0] = sum3;
2127 #endif // __ARM_NEON
2128 output0++;
2129 output1++;
2130 output2++;
2131 output3++;
2132 }
2133 }
2134
2135 remain_outch_start += nn_outch << 2;
2136
2137 #pragma omp parallel for num_threads(opt.num_threads)
2138 for (int i = remain_outch_start; i < outch; i++)
2139 {
2140 int* output = top_blob.channel(i);
2141
2142 int j = 0;
2143 for (; j + 7 < N; j = j + 8)
2144 {
2145 signed char* vb = bottom_tm.channel(j / 8);
2146 #if __ARM_NEON && __aarch64__
2147 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
2148 #else
2149 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
2150 #endif // __ARM_NEON && __aarch64__
2151
2152 #if __ARM_NEON
2153 #if __aarch64__
2154 asm volatile(
2155 "eor v16.16b, v16.16b, v16.16b \n" // sum0
2156 "eor v17.16b, v17.16b, v17.16b \n" // sum0n
2157
2158 "lsr w4, %w6, #2 \n" // r4 = nn = L >> 2
2159 "cmp w4, #0 \n"
2160 "beq 1f \n"
2161
2162 "0: \n" // for (; k+3<L; k=k+4)
2163
2164 "prfm pldl1keep, [%2, #128] \n"
2165 "ld1 {v0.8b}, [%2] \n"
2166
2167 "prfm pldl1keep, [%1, #128] \n"
2168 "ld1 {v8.8b, v9.8b, v10.8b, v11.8b}, [%1], #32 \n"
2169 "add %2, %2, #4 \n"
2170
2171 "sshll v0.8h, v0.8b, #0 \n" // k00 - k03
2172
2173 "sshll v8.8h, v8.8b, #0 \n" // a00 - a70
2174 "sshll v9.8h, v9.8b, #0 \n" // a01 - a71
2175 "sshll v10.8h, v10.8b, #0 \n" // a02 - a72
2176 "sshll v11.8h, v11.8b, #0 \n" // a03 - a73
2177
2178 // k0
2179 "smlal v16.4s, v8.4h, v0.h[0] \n" // sum0 += (a00-a70) * k00
2180 "smlal2 v17.4s, v8.8h, v0.h[0] \n" //
2181 // k1
2182 "smlal v16.4s, v9.4h, v0.h[1] \n" // sum0 += (a01-a71) * k01
2183 "smlal2 v17.4s, v9.8h, v0.h[1] \n" //
2184 // k2
2185 "smlal v16.4s, v10.4h, v0.h[2] \n" // sum0 += (a02-a72) * k02
2186 "smlal2 v17.4s, v10.8h, v0.h[2] \n" //
2187 // k3
2188 "smlal v16.4s, v11.4h, v0.h[3] \n" // sum0 += (a03-a73) * k03
2189 "smlal2 v17.4s, v11.8h, v0.h[3] \n" //
2190
2191 "subs w4, w4, #1 \n"
2192 "bne 0b \n"
2193
2194 "1: \n"
2195
2196 // remain loop
2197 "and w4, %w6, #3 \n" // w4 = remain = inch & 3;
2198 "cmp w4, #0 \n"
2199 "beq 3f \n"
2200
2201 "2: \n"
2202
2203 //"prfm pldl1keep, [%2, #128] \n"
2204 "ld1 {v0.8b}, [%2] \n"
2205 //"prfm pldl1keep, [%1, #128] \n"
2206 "ld1 {v8.8b}, [%1], #8 \n"
2207 "add %2, %2, #1 \n"
2208
2209 "sshll v0.8h, v0.8b, #0 \n" // k00 - k30
2210 "sshll v8.8h, v8.8b, #0 \n" // a00 - a70
2211
2212 // k0
2213 "smlal v16.4s, v8.4h, v0.h[0] \n" // sum0 += (a00-a70) * k00
2214 "smlal2 v17.4s, v8.8h, v0.h[0] \n" //
2215
2216 "subs w4, w4, #1 \n"
2217
2218 "bne 2b \n"
2219
2220 "3: \n"
2221
2222 "st1 {v16.4s, v17.4s}, [%0] \n"
2223
2224 : "=r"(output), // %0
2225 "=r"(vb), // %1
2226 "=r"(va) // %2
2227 : "0"(output),
2228 "1"(vb),
2229 "2"(va),
2230 "r"(L) // %6
2231 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17");
2232 #else
2233 asm volatile(
2234 // inch loop
2235 "vmov.s32 q6, #0 \n"
2236 "vmov.s32 q7, #0 \n"
2237
2238 "lsr r4, %6, #3 \n" // r4 = nn = inch >> 3
2239 "cmp r4, #0 \n"
2240 "beq 1f \n"
2241
2242 "0: \n" // for(; nn != 0; nn--)
2243 "pld [%1, #128] \n"
2244 "vld1.s8 {d4-d7}, [%1]! \n" // tmpr a00-a07,a10-a17,a20-a27,a30-a37 a(inch)(data)
2245 "vmovl.s8 q5, d7 \n" // a30-a37
2246 "vmovl.s8 q4, d6 \n" // a20-a27
2247 "vmovl.s8 q3, d5 \n" // a10-a17
2248 "vmovl.s8 q2, d4 \n" // a00-a07
2249
2250 "pld [%2, #128] \n"
2251 "vld1.s8 {d0}, [%2]! \n" // kptr k00-k07 k(outch)(inch)
2252 "vmovl.s8 q1, d1 \n" // k04,k05,k06,k07
2253 "vmovl.s8 q0, d0 \n" // k00,k01,k02,k03
2254
2255 "vmlal.s16 q6, d4, d0[0] \n" // (a00-a07) * k00
2256 "vmlal.s16 q7, d5, d0[0] \n"
2257 "vmlal.s16 q6, d6, d0[1] \n" // (a10-a17) * k01
2258 "vmlal.s16 q7, d7, d0[1] \n"
2259 "vmlal.s16 q6, d8, d0[2] \n" // (a20-a27) * k02
2260 "vmlal.s16 q7, d9, d0[2] \n"
2261 "vmlal.s16 q6, d10, d0[3] \n" // (a30-a37) * k03
2262 "vmlal.s16 q7, d11, d0[3] \n"
2263
2264 "pld [%1, #128] \n"
2265 "vld1.s8 {d4-d7}, [%1]! \n" // tmpr a40-a47,a50-a57,a60-a67,a70-a77 a(inch)(data)
2266 "vmovl.s8 q5, d7 \n" // a70-a77
2267 "vmovl.s8 q4, d6 \n" // a60-a67
2268 "vmovl.s8 q3, d5 \n" // a50-a57
2269 "vmovl.s8 q2, d4 \n" // a40-a47
2270
2271 "vmlal.s16 q6, d4, d1[0] \n" // (a00-a07) * k00
2272 "vmlal.s16 q7, d5, d1[0] \n"
2273 "vmlal.s16 q6, d6, d1[1] \n" // (a10-a17) * k01
2274 "vmlal.s16 q7, d7, d1[1] \n"
2275 "vmlal.s16 q6, d8, d1[2] \n" // (a20-a27) * k02
2276 "vmlal.s16 q7, d9, d1[2] \n"
2277 "vmlal.s16 q6, d10, d1[3] \n" // (a30-a37) * k03
2278 "vmlal.s16 q7, d11, d1[3] \n"
2279
2280 "subs r4, r4, #1 \n"
2281 "bne 0b \n" // end for
2282
2283 "1: \n"
2284 // remain loop
2285 "and r4, %6, #7 \n" // r4 = remain = inch & 7
2286 "cmp r4, #0 \n"
2287 "beq 3f \n"
2288
2289 "2: \n" // for(; remain != 0; remain--)
2290 "vld1.s8 {d2}, [%1]! \n" // tmpr a00-a07 a(inch)(data)
2291 "vld1.s8 {d0}, [%2] \n" // kptr k00 k(outch)(inch)
2292 "vmovl.s8 q1, d2 \n"
2293 "vmovl.s8 q0, d0 \n"
2294 "add %2, #1 \n"
2295
2296 "vmlal.s16 q6, d2, d0[0] \n" // (a00-a07) * k00
2297 "vmlal.s16 q7, d3, d0[0] \n"
2298
2299 "subs r4, r4, #1 \n"
2300 "bne 2b \n"
2301
2302 "3: \n" // store the result to memory
2303 "vst1.s32 {d12-d15}, [%0] \n"
2304
2305 : "=r"(output), // %0
2306 "=r"(vb), // %1
2307 "=r"(va) // %2
2308 : "0"(output),
2309 "1"(vb),
2310 "2"(va),
2311 "r"(L) // %6
2312 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
2313 #endif // __aarch64__
2314 #else
2315 int sum[8] = {0};
2316
2317 int k = 0;
2318 for (; k + 7 < L; k = k + 8)
2319 {
2320 for (int n = 0; n < 8; n++)
2321 {
2322 sum[n] += (int)va[0] * vb[n];
2323 sum[n] += (int)va[1] * vb[n + 8];
2324 sum[n] += (int)va[2] * vb[n + 16];
2325 sum[n] += (int)va[3] * vb[n + 24];
2326 sum[n] += (int)va[4] * vb[n + 32];
2327 sum[n] += (int)va[5] * vb[n + 40];
2328 sum[n] += (int)va[6] * vb[n + 48];
2329 sum[n] += (int)va[7] * vb[n + 56];
2330 }
2331
2332 va += 8;
2333 vb += 64;
2334 }
2335
2336 for (; k < L; k++)
2337 {
2338 for (int n = 0; n < 8; n++)
2339 {
2340 sum[n] += (int)va[0] * vb[n];
2341 }
2342
2343 va += 1;
2344 vb += 8;
2345 }
2346
2347 for (int n = 0; n < 8; n++)
2348 {
2349 output[n] = sum[n];
2350 }
2351 #endif // __ARM_NEON
2352 output += 8;
2353 }
2354
2355 for (; j < N; j++)
2356 {
2357 int sum = 0;
2358
2359 signed char* vb = bottom_tm.channel(j / 8 + j % 8);
2360 #if __ARM_NEON && __aarch64__
2361 const signed char* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
2362 #else
2363 const signed char* va = kernel_tm.channel(i / 4 + i % 4);
2364 #endif // __ARM_NEON && __aarch64__
2365
2366 for (int k = 0; k < L; k++)
2367 {
2368 sum += (int)va[0] * vb[0];
2369
2370 va += 1;
2371 vb += 1;
2372 }
2373 output[0] = sum;
2374
2375 output++;
2376 }
2377 }
2378 }
2379 }
2380 #endif
2381