1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
im2col_sgemm_pack1to4_int8_neon(const Mat & bottom_im2col,Mat & top_blob,const Mat & kernel,const Option & opt)15 static void im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Option& opt)
16 {
17 // Mat bottom_im2col(size, maxk, inch, 8u, 8, opt.workspace_allocator);
18
19 const int size = bottom_im2col.w;
20 const int maxk = bottom_im2col.h;
21 const int inch = bottom_im2col.c;
22
23 const int outch = top_blob.c;
24
25 // permute
26 Mat tmp;
27 #if __aarch64__
28 #if __ARM_FEATURE_DOTPROD
29 if (inch >= 8)
30 {
31 if (size >= 16)
32 tmp.create(16 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator);
33 else if (size >= 8)
34 tmp.create(8 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator);
35 else if (size >= 4)
36 tmp.create(4 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator);
37 else if (size >= 2)
38 tmp.create(2 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 2 + size % 2, 8u, 8, opt.workspace_allocator);
39 else
40 tmp.create(maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size, 8u, 8, opt.workspace_allocator);
41 }
42 else if (inch >= 4)
43 {
44 if (size >= 16)
45 tmp.create(16 * maxk, inch / 4 + inch % 4, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator);
46 else if (size >= 8)
47 tmp.create(8 * maxk, inch / 4 + inch % 4, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator);
48 else if (size >= 4)
49 tmp.create(4 * maxk, inch / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator);
50 else if (size >= 2)
51 tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator);
52 else
53 tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator);
54 }
55 else
56 {
57 if (size >= 16)
58 tmp.create(16 * maxk, inch, size / 16 + (size % 16) / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator);
59 else if (size >= 8)
60 tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator);
61 else if (size >= 4)
62 tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator);
63 else if (size >= 2)
64 tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator);
65 else
66 tmp.create(maxk, inch, size, 8u, 1, opt.workspace_allocator);
67 }
68 #else // __ARM_FEATURE_DOTPROD
69 if (inch >= 8)
70 {
71 if (size >= 4)
72 tmp.create(4 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 8u, 8, opt.workspace_allocator);
73 else if (size >= 2)
74 tmp.create(2 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 2 + size % 2, 8u, 8, opt.workspace_allocator);
75 else
76 tmp.create(maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size, 8u, 8, opt.workspace_allocator);
77 }
78 else if (inch >= 4)
79 {
80 if (size >= 4)
81 tmp.create(4 * maxk, inch / 4 + inch % 4, size / 4 + (size % 4) / 2 + size % 2, 4u, 4, opt.workspace_allocator);
82 else if (size >= 2)
83 tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator);
84 else
85 tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator);
86 }
87 else
88 {
89 if (size >= 4)
90 tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 1u, 1, opt.workspace_allocator);
91 else if (size >= 2)
92 tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator);
93 else
94 tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator);
95 }
96 #endif // __ARM_FEATURE_DOTPROD
97 #else // __aarch64__
98 if (inch >= 8)
99 {
100 if (size >= 2)
101 tmp.create(2 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size / 2 + size % 2, 8u, 8, opt.workspace_allocator);
102 else
103 tmp.create(maxk, inch / 8 + (inch % 8) / 4 + inch % 4, size, 8u, 8, opt.workspace_allocator);
104 }
105 else if (inch >= 4)
106 {
107 if (size >= 2)
108 tmp.create(2 * maxk, inch / 4 + inch % 4, size / 2 + size % 2, 4u, 4, opt.workspace_allocator);
109 else
110 tmp.create(maxk, inch / 4 + inch % 4, size, 4u, 4, opt.workspace_allocator);
111 }
112 else
113 {
114 if (size >= 2)
115 tmp.create(2 * maxk, inch, size / 2 + size % 2, 1u, 1, opt.workspace_allocator);
116 else
117 tmp.create(maxk, inch, size, 1u, 1, opt.workspace_allocator);
118 }
119 #endif // __aarch64__
120 {
121 #if __aarch64__
122 #if __ARM_FEATURE_DOTPROD
123 int nn_size = size >> 4;
124 int remain_size_start = 0;
125
126 #pragma omp parallel for num_threads(opt.num_threads)
127 for (int ii = 0; ii < nn_size; ii++)
128 {
129 int i = remain_size_start + ii * 16;
130
131 signed char* tmpptr = tmp.channel(i / 16);
132
133 int q = 0;
134 for (; q + 7 < inch; q += 8)
135 {
136 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
137 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
138 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
139 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
140 const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i;
141 const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i;
142 const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i;
143 const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i;
144
145 for (int k = 0; k < maxk; k++)
146 {
147 asm volatile(
148 "ld1 {v0.16b}, [%0] \n"
149 "ld1 {v1.16b}, [%1] \n"
150 "ld1 {v2.16b}, [%2] \n"
151 "ld1 {v3.16b}, [%3] \n"
152 "ld1 {v4.16b}, [%4] \n"
153 "ld1 {v5.16b}, [%5] \n"
154 "ld1 {v6.16b}, [%6] \n"
155 "ld1 {v7.16b}, [%7] \n"
156 "st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%8], #64 \n"
157 "st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [%8], #64 \n"
158 : "=r"(img0), // %0
159 "=r"(img1),
160 "=r"(img2),
161 "=r"(img3),
162 "=r"(img4),
163 "=r"(img5),
164 "=r"(img6),
165 "=r"(img7),
166 "=r"(tmpptr) // %8
167 : "0"(img0),
168 "1"(img1),
169 "2"(img2),
170 "3"(img3),
171 "4"(img4),
172 "5"(img5),
173 "6"(img6),
174 "7"(img7),
175 "8"(tmpptr)
176 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
177 img0 += size;
178 img1 += size;
179 img2 += size;
180 img3 += size;
181 img4 += size;
182 img5 += size;
183 img6 += size;
184 img7 += size;
185 }
186 }
187 for (; q + 3 < inch; q += 4)
188 {
189 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
190 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
191 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
192 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
193
194 for (int k = 0; k < maxk; k++)
195 {
196 asm volatile(
197 "ld1 {v0.16b}, [%0] \n"
198 "ld1 {v1.16b}, [%1] \n"
199 "ld1 {v2.16b}, [%2] \n"
200 "ld1 {v3.16b}, [%3] \n"
201 "st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n"
202 : "=r"(img0), // %0
203 "=r"(img1),
204 "=r"(img2),
205 "=r"(img3),
206 "=r"(tmpptr) // %4
207 : "0"(img0),
208 "1"(img1),
209 "2"(img2),
210 "3"(img3),
211 "4"(tmpptr)
212 : "memory", "v0", "v1", "v2", "v3");
213 img0 += size;
214 img1 += size;
215 img2 += size;
216 img3 += size;
217 }
218 }
219 for (; q < inch; q++)
220 {
221 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
222
223 for (int k = 0; k < maxk; k++)
224 {
225 asm volatile(
226 "prfm pldl1keep, [%0, #128] \n"
227 "ld1 {v0.16b}, [%0] \n"
228 "st1 {v0.16b}, [%1], #16 \n"
229 : "=r"(img0), // %0
230 "=r"(tmpptr) // %1
231 : "0"(img0),
232 "1"(tmpptr)
233 : "memory", "v0");
234 img0 += size;
235 }
236 }
237 }
238
239 remain_size_start += nn_size << 4;
240 nn_size = (size - remain_size_start) >> 3;
241
242 #pragma omp parallel for num_threads(opt.num_threads)
243 for (int ii = 0; ii < nn_size; ii++)
244 {
245 int i = remain_size_start + ii * 8;
246
247 signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8);
248
249 int q = 0;
250 for (; q + 7 < inch; q += 8)
251 {
252 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
253 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
254 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
255 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
256 const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i;
257 const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i;
258 const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i;
259 const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i;
260
261 for (int k = 0; k < maxk; k++)
262 {
263 asm volatile(
264 "ld1 {v0.8b}, [%0] \n"
265 "ld1 {v1.8b}, [%1] \n"
266 "ld1 {v2.8b}, [%2] \n"
267 "ld1 {v3.8b}, [%3] \n"
268 "ld1 {v4.8b}, [%4] \n"
269 "ld1 {v5.8b}, [%5] \n"
270 "ld1 {v6.8b}, [%6] \n"
271 "ld1 {v7.8b}, [%7] \n"
272 "st4 {v0.8b, v1.8b, v2.8b, v3.8b}, [%8], #32 \n"
273 "st4 {v4.8b, v5.8b, v6.8b, v7.8b}, [%8], #32 \n"
274 : "=r"(img0), // %0
275 "=r"(img1),
276 "=r"(img2),
277 "=r"(img3),
278 "=r"(img4),
279 "=r"(img5),
280 "=r"(img6),
281 "=r"(img7),
282 "=r"(tmpptr) // %8
283 : "0"(img0),
284 "1"(img1),
285 "2"(img2),
286 "3"(img3),
287 "4"(img4),
288 "5"(img5),
289 "6"(img6),
290 "7"(img7),
291 "8"(tmpptr)
292 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
293 img0 += size;
294 img1 += size;
295 img2 += size;
296 img3 += size;
297 img4 += size;
298 img5 += size;
299 img6 += size;
300 img7 += size;
301 }
302 }
303 for (; q + 3 < inch; q += 4)
304 {
305 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
306 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
307 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
308 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
309
310 for (int k = 0; k < maxk; k++)
311 {
312 asm volatile(
313 "ld1 {v0.8b}, [%0] \n"
314 "ld1 {v1.8b}, [%1] \n"
315 "ld1 {v2.8b}, [%2] \n"
316 "ld1 {v3.8b}, [%3] \n"
317 "st4 {v0.8b, v1.8b, v2.8b, v3.8b}, [%4], #32 \n"
318 : "=r"(img0), // %0
319 "=r"(img1),
320 "=r"(img2),
321 "=r"(img3),
322 "=r"(tmpptr) // %4
323 : "0"(img0),
324 "1"(img1),
325 "2"(img2),
326 "3"(img3),
327 "4"(tmpptr)
328 : "memory", "v0", "v1", "v2", "v3");
329 img0 += size;
330 img1 += size;
331 img2 += size;
332 img3 += size;
333 }
334 }
335 for (; q < inch; q++)
336 {
337 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
338
339 for (int k = 0; k < maxk; k++)
340 {
341 asm volatile(
342 "prfm pldl1keep, [%0, #64] \n"
343 "ld1 {v0.8b}, [%0] \n"
344 "st1 {v0.8b}, [%1], #8 \n"
345 : "=r"(img0), // %0
346 "=r"(tmpptr) // %1
347 : "0"(img0),
348 "1"(tmpptr)
349 : "memory", "v0");
350 img0 += size;
351 }
352 }
353 }
354
355 remain_size_start += nn_size << 3;
356 nn_size = (size - remain_size_start) >> 2;
357 #else // __ARM_FEATURE_DOTPROD
358 int remain_size_start = 0;
359 int nn_size = (size - remain_size_start) >> 2;
360 #endif // __ARM_FEATURE_DOTPROD
361
362 #pragma omp parallel for num_threads(opt.num_threads)
363 for (int ii = 0; ii < nn_size; ii++)
364 {
365 int i = remain_size_start + ii * 4;
366
367 #if __ARM_FEATURE_DOTPROD
368 signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4);
369 #else
370 signed char* tmpptr = tmp.channel(i / 4);
371 #endif
372
373 int q = 0;
374 for (; q + 7 < inch; q += 8)
375 {
376 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
377 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
378 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
379 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
380 const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i;
381 const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i;
382 const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i;
383 const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i;
384
385 for (int k = 0; k < maxk; k++)
386 {
387 #if __ARM_FEATURE_DOTPROD
388 tmpptr[0] = img0[0];
389 tmpptr[1] = img1[0];
390 tmpptr[2] = img2[0];
391 tmpptr[3] = img3[0];
392 tmpptr[4] = img0[1];
393 tmpptr[5] = img1[1];
394 tmpptr[6] = img2[1];
395 tmpptr[7] = img3[1];
396 tmpptr += 8;
397
398 tmpptr[0] = img0[2];
399 tmpptr[1] = img1[2];
400 tmpptr[2] = img2[2];
401 tmpptr[3] = img3[2];
402 tmpptr[4] = img0[3];
403 tmpptr[5] = img1[3];
404 tmpptr[6] = img2[3];
405 tmpptr[7] = img3[3];
406 tmpptr += 8;
407
408 tmpptr[0] = img4[0];
409 tmpptr[1] = img5[0];
410 tmpptr[2] = img6[0];
411 tmpptr[3] = img7[0];
412 tmpptr[4] = img4[1];
413 tmpptr[5] = img5[1];
414 tmpptr[6] = img6[1];
415 tmpptr[7] = img7[1];
416 tmpptr += 8;
417
418 tmpptr[0] = img4[2];
419 tmpptr[1] = img5[2];
420 tmpptr[2] = img6[2];
421 tmpptr[3] = img7[2];
422 tmpptr[4] = img4[3];
423 tmpptr[5] = img5[3];
424 tmpptr[6] = img6[3];
425 tmpptr[7] = img7[3];
426 tmpptr += 8;
427 #else
428 tmpptr[0] = img0[0];
429 tmpptr[1] = img1[0];
430 tmpptr[2] = img2[0];
431 tmpptr[3] = img3[0];
432 tmpptr[4] = img4[0];
433 tmpptr[5] = img5[0];
434 tmpptr[6] = img6[0];
435 tmpptr[7] = img7[0];
436 tmpptr += 8;
437
438 tmpptr[0] = img0[1];
439 tmpptr[1] = img1[1];
440 tmpptr[2] = img2[1];
441 tmpptr[3] = img3[1];
442 tmpptr[4] = img4[1];
443 tmpptr[5] = img5[1];
444 tmpptr[6] = img6[1];
445 tmpptr[7] = img7[1];
446 tmpptr += 8;
447
448 tmpptr[0] = img0[2];
449 tmpptr[1] = img1[2];
450 tmpptr[2] = img2[2];
451 tmpptr[3] = img3[2];
452 tmpptr[4] = img4[2];
453 tmpptr[5] = img5[2];
454 tmpptr[6] = img6[2];
455 tmpptr[7] = img7[2];
456 tmpptr += 8;
457
458 tmpptr[0] = img0[3];
459 tmpptr[1] = img1[3];
460 tmpptr[2] = img2[3];
461 tmpptr[3] = img3[3];
462 tmpptr[4] = img4[3];
463 tmpptr[5] = img5[3];
464 tmpptr[6] = img6[3];
465 tmpptr[7] = img7[3];
466 tmpptr += 8;
467 #endif // __ARM_FEATURE_DOTPROD
468
469 img0 += size;
470 img1 += size;
471 img2 += size;
472 img3 += size;
473 img4 += size;
474 img5 += size;
475 img6 += size;
476 img7 += size;
477 }
478 }
479 for (; q + 3 < inch; q += 4)
480 {
481 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
482 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
483 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
484 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
485
486 for (int k = 0; k < maxk; k++)
487 {
488 tmpptr[0] = img0[0];
489 tmpptr[1] = img1[0];
490 tmpptr[2] = img2[0];
491 tmpptr[3] = img3[0];
492 tmpptr[4] = img0[1];
493 tmpptr[5] = img1[1];
494 tmpptr[6] = img2[1];
495 tmpptr[7] = img3[1];
496 tmpptr += 8;
497
498 tmpptr[0] = img0[2];
499 tmpptr[1] = img1[2];
500 tmpptr[2] = img2[2];
501 tmpptr[3] = img3[2];
502 tmpptr[4] = img0[3];
503 tmpptr[5] = img1[3];
504 tmpptr[6] = img2[3];
505 tmpptr[7] = img3[3];
506 tmpptr += 8;
507
508 img0 += size;
509 img1 += size;
510 img2 += size;
511 img3 += size;
512 }
513 }
514 for (; q < inch; q++)
515 {
516 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
517
518 for (int k = 0; k < maxk; k++)
519 {
520 tmpptr[0] = img0[0];
521 tmpptr[1] = img0[1];
522 tmpptr[2] = img0[2];
523 tmpptr[3] = img0[3];
524
525 tmpptr += 4;
526
527 img0 += size;
528 }
529 }
530 }
531
532 remain_size_start += nn_size << 2;
533 nn_size = (size - remain_size_start) >> 1;
534 #else
535 int remain_size_start = 0;
536 int nn_size = (size - remain_size_start) >> 1;
537 #endif
538
539 #pragma omp parallel for num_threads(opt.num_threads)
540 for (int ii = 0; ii < nn_size; ii++)
541 {
542 int i = remain_size_start + ii * 2;
543
544 #if __aarch64__
545 #if __ARM_FEATURE_DOTPROD
546 signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2);
547 #else
548 signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2);
549 #endif
550 #else
551 signed char* tmpptr = tmp.channel(i / 2);
552 #endif
553
554 int q = 0;
555 for (; q + 7 < inch; q += 8)
556 {
557 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
558 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
559 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
560 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
561 const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i;
562 const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i;
563 const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i;
564 const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i;
565
566 for (int k = 0; k < maxk; k++)
567 {
568 #if __ARM_FEATURE_DOTPROD
569 tmpptr[0] = img0[0];
570 tmpptr[1] = img1[0];
571 tmpptr[2] = img2[0];
572 tmpptr[3] = img3[0];
573 tmpptr[4] = img0[1];
574 tmpptr[5] = img1[1];
575 tmpptr[6] = img2[1];
576 tmpptr[7] = img3[1];
577 tmpptr += 8;
578
579 tmpptr[0] = img4[0];
580 tmpptr[1] = img5[0];
581 tmpptr[2] = img6[0];
582 tmpptr[3] = img7[0];
583 tmpptr[4] = img4[1];
584 tmpptr[5] = img5[1];
585 tmpptr[6] = img6[1];
586 tmpptr[7] = img7[1];
587 tmpptr += 8;
588 #else
589 tmpptr[0] = img0[0];
590 tmpptr[1] = img1[0];
591 tmpptr[2] = img2[0];
592 tmpptr[3] = img3[0];
593 tmpptr[4] = img4[0];
594 tmpptr[5] = img5[0];
595 tmpptr[6] = img6[0];
596 tmpptr[7] = img7[0];
597 tmpptr += 8;
598
599 tmpptr[0] = img0[1];
600 tmpptr[1] = img1[1];
601 tmpptr[2] = img2[1];
602 tmpptr[3] = img3[1];
603 tmpptr[4] = img4[1];
604 tmpptr[5] = img5[1];
605 tmpptr[6] = img6[1];
606 tmpptr[7] = img7[1];
607 tmpptr += 8;
608 #endif // __ARM_FEATURE_DOTPROD
609
610 img0 += size;
611 img1 += size;
612 img2 += size;
613 img3 += size;
614 img4 += size;
615 img5 += size;
616 img6 += size;
617 img7 += size;
618 }
619 }
620 for (; q + 3 < inch; q += 4)
621 {
622 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
623 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
624 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
625 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
626
627 for (int k = 0; k < maxk; k++)
628 {
629 tmpptr[0] = img0[0];
630 tmpptr[1] = img1[0];
631 tmpptr[2] = img2[0];
632 tmpptr[3] = img3[0];
633 tmpptr[4] = img0[1];
634 tmpptr[5] = img1[1];
635 tmpptr[6] = img2[1];
636 tmpptr[7] = img3[1];
637 tmpptr += 8;
638
639 img0 += size;
640 img1 += size;
641 img2 += size;
642 img3 += size;
643 }
644 }
645 for (; q < inch; q++)
646 {
647 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
648
649 for (int k = 0; k < maxk; k++)
650 {
651 tmpptr[0] = img0[0];
652 tmpptr[1] = img0[1];
653
654 tmpptr += 2;
655
656 img0 += size;
657 }
658 }
659 }
660
661 remain_size_start += nn_size << 1;
662
663 #pragma omp parallel for num_threads(opt.num_threads)
664 for (int i = remain_size_start; i < size; i++)
665 {
666 #if __aarch64__
667 #if __ARM_FEATURE_DOTPROD
668 signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2);
669 #else
670 signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2);
671 #endif
672 #else
673 signed char* tmpptr = tmp.channel(i / 2 + i % 2);
674 #endif
675
676 int q = 0;
677 for (; q + 7 < inch; q += 8)
678 {
679 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
680 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
681 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
682 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
683 const signed char* img4 = (const signed char*)bottom_im2col.channel(q + 4) + i;
684 const signed char* img5 = (const signed char*)bottom_im2col.channel(q + 5) + i;
685 const signed char* img6 = (const signed char*)bottom_im2col.channel(q + 6) + i;
686 const signed char* img7 = (const signed char*)bottom_im2col.channel(q + 7) + i;
687
688 for (int k = 0; k < maxk; k++)
689 {
690 tmpptr[0] = img0[0];
691 tmpptr[1] = img1[0];
692 tmpptr[2] = img2[0];
693 tmpptr[3] = img3[0];
694 tmpptr[4] = img4[0];
695 tmpptr[5] = img5[0];
696 tmpptr[6] = img6[0];
697 tmpptr[7] = img7[0];
698 tmpptr += 8;
699
700 img0 += size;
701 img1 += size;
702 img2 += size;
703 img3 += size;
704 img4 += size;
705 img5 += size;
706 img6 += size;
707 img7 += size;
708 }
709 }
710 for (; q + 3 < inch; q += 4)
711 {
712 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
713 const signed char* img1 = (const signed char*)bottom_im2col.channel(q + 1) + i;
714 const signed char* img2 = (const signed char*)bottom_im2col.channel(q + 2) + i;
715 const signed char* img3 = (const signed char*)bottom_im2col.channel(q + 3) + i;
716
717 for (int k = 0; k < maxk; k++)
718 {
719 tmpptr[0] = img0[0];
720 tmpptr[1] = img1[0];
721 tmpptr[2] = img2[0];
722 tmpptr[3] = img3[0];
723 tmpptr += 4;
724
725 img0 += size;
726 img1 += size;
727 img2 += size;
728 img3 += size;
729 }
730 }
731 for (; q < inch; q++)
732 {
733 const signed char* img0 = (const signed char*)bottom_im2col.channel(q) + i;
734
735 for (int k = 0; k < maxk; k++)
736 {
737 tmpptr[0] = img0[0];
738
739 tmpptr += 1;
740
741 img0 += size;
742 }
743 }
744 }
745 }
746
747 #pragma omp parallel for num_threads(opt.num_threads)
748 for (int p = 0; p < outch; p++)
749 {
750 int* outptr0 = top_blob.channel(p);
751
752 int i = 0;
753 #if __aarch64__
754 #if __ARM_FEATURE_DOTPROD
755 for (; i + 15 < size; i += 16)
756 {
757 const signed char* tmpptr = tmp.channel(i / 16);
758 const signed char* kptr0 = kernel.channel(p);
759
760 int nn = (inch / 8) * maxk;
761 int nn4 = ((inch % 8) / 4) * maxk;
762 int nn1 = (inch % 4) * maxk;
763
764 asm volatile(
765 "eor v16.16b, v16.16b, v16.16b \n"
766 "eor v17.16b, v17.16b, v17.16b \n"
767 "eor v18.16b, v18.16b, v18.16b \n"
768 "eor v19.16b, v19.16b, v19.16b \n"
769 "eor v20.16b, v20.16b, v20.16b \n"
770 "eor v21.16b, v21.16b, v21.16b \n"
771 "eor v22.16b, v22.16b, v22.16b \n"
772 "eor v23.16b, v23.16b, v23.16b \n"
773 "eor v24.16b, v24.16b, v24.16b \n"
774 "eor v25.16b, v25.16b, v25.16b \n"
775 "eor v26.16b, v26.16b, v26.16b \n"
776 "eor v27.16b, v27.16b, v27.16b \n"
777 "eor v28.16b, v28.16b, v28.16b \n"
778 "eor v29.16b, v29.16b, v29.16b \n"
779 "eor v30.16b, v30.16b, v30.16b \n"
780 "eor v31.16b, v31.16b, v31.16b \n"
781
782 "cmp %w1, #0 \n"
783 "beq 1f \n"
784
785 "ld1 {v8.16b}, [%5], #16 \n" // _w0123_l
786
787 "ld1 {v0.16b}, [%4], #16 \n" // _val0123_l
788
789 "0: \n"
790
791 "ld1 {v1.16b}, [%4], #16 \n" // _val4567_l
792
793 "sdot v16.4s, v8.16b, v0.4b[0] \n"
794 "sdot v17.4s, v8.16b, v0.4b[1] \n"
795 "sdot v18.4s, v8.16b, v0.4b[2] \n"
796 "sdot v19.4s, v8.16b, v0.4b[3] \n"
797
798 "ld1 {v2.16b}, [%4], #16 \n" // _val891011_l
799
800 "sdot v20.4s, v8.16b, v1.4b[0] \n"
801 "sdot v21.4s, v8.16b, v1.4b[1] \n"
802 "sdot v22.4s, v8.16b, v1.4b[2] \n"
803 "sdot v23.4s, v8.16b, v1.4b[3] \n"
804
805 "ld1 {v3.16b}, [%4], #16 \n" // _val12131415_l
806
807 "sdot v24.4s, v8.16b, v2.4b[0] \n"
808 "sdot v25.4s, v8.16b, v2.4b[1] \n"
809
810 "ld1 {v9.16b}, [%5], #16 \n" // _w0123_h
811
812 "sdot v26.4s, v8.16b, v2.4b[2] \n"
813 "sdot v27.4s, v8.16b, v2.4b[3] \n"
814
815 "ld1 {v4.16b}, [%4], #16 \n" // _val0123_h
816
817 "sdot v28.4s, v8.16b, v3.4b[0] \n"
818 "sdot v29.4s, v8.16b, v3.4b[1] \n"
819 "sdot v30.4s, v8.16b, v3.4b[2] \n"
820 "sdot v31.4s, v8.16b, v3.4b[3] \n"
821
822 "ld1 {v5.16b}, [%4], #16 \n" // _val4567_h
823
824 "sdot v16.4s, v9.16b, v4.4b[0] \n"
825 "sdot v17.4s, v9.16b, v4.4b[1] \n"
826 "sdot v18.4s, v9.16b, v4.4b[2] \n"
827 "sdot v19.4s, v9.16b, v4.4b[3] \n"
828
829 "ld1 {v6.16b}, [%4], #16 \n" // _val891011_h
830
831 "sdot v20.4s, v9.16b, v5.4b[0] \n"
832 "sdot v21.4s, v9.16b, v5.4b[1] \n"
833 "sdot v22.4s, v9.16b, v5.4b[2] \n"
834 "sdot v23.4s, v9.16b, v5.4b[3] \n"
835
836 "ld1 {v7.16b}, [%4], #16 \n" // _val12131415_h
837
838 "sdot v24.4s, v9.16b, v6.4b[0] \n"
839 "sdot v25.4s, v9.16b, v6.4b[1] \n"
840
841 "ld1 {v8.16b}, [%5], #16 \n" // _w0123_l
842
843 "sdot v26.4s, v9.16b, v6.4b[2] \n"
844 "sdot v27.4s, v9.16b, v6.4b[3] \n"
845
846 "ld1 {v0.16b}, [%4], #16 \n" // _val0123_l
847
848 "sdot v28.4s, v9.16b, v7.4b[0] \n"
849 "sdot v29.4s, v9.16b, v7.4b[1] \n"
850
851 "subs %w1, %w1, #1 \n"
852
853 "sdot v30.4s, v9.16b, v7.4b[2] \n"
854 "sdot v31.4s, v9.16b, v7.4b[3] \n"
855
856 "bne 0b \n"
857
858 "sub %4, %4, #16 \n"
859 "sub %5, %5, #16 \n"
860
861 "1: \n"
862
863 "cmp %w2, #0 \n"
864 "beq 3f \n"
865
866 "2: \n"
867
868 "ld1 {v8.16b}, [%5], #16 \n"
869
870 "ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n"
871
872 "sdot v16.4s, v8.16b, v0.4b[0] \n"
873 "sdot v17.4s, v8.16b, v0.4b[1] \n"
874 "sdot v18.4s, v8.16b, v0.4b[2] \n"
875 "sdot v19.4s, v8.16b, v0.4b[3] \n"
876 "sdot v20.4s, v8.16b, v1.4b[0] \n"
877 "sdot v21.4s, v8.16b, v1.4b[1] \n"
878 "sdot v22.4s, v8.16b, v1.4b[2] \n"
879 "sdot v23.4s, v8.16b, v1.4b[3] \n"
880 "sdot v24.4s, v8.16b, v2.4b[0] \n"
881 "sdot v25.4s, v8.16b, v2.4b[1] \n"
882 "sdot v26.4s, v8.16b, v2.4b[2] \n"
883 "sdot v27.4s, v8.16b, v2.4b[3] \n"
884 "sdot v28.4s, v8.16b, v3.4b[0] \n"
885 "sdot v29.4s, v8.16b, v3.4b[1] \n"
886
887 "subs %w2, %w2, #1 \n"
888
889 "sdot v30.4s, v8.16b, v3.4b[2] \n"
890 "sdot v31.4s, v8.16b, v3.4b[3] \n"
891
892 "bne 2b \n"
893
894 "3: \n"
895
896 "lsr w4, %w3, #2 \n" // w4 = nn1 >> 2
897 "cmp w4, #0 \n"
898 "beq 5f \n"
899
900 "4: \n"
901
902 "ld1 {v8.8b, v9.8b}, [%5], #16 \n"
903
904 "ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [%4], #64 \n"
905
906 "uzp1 v10.8b, v8.8b, v9.8b \n"
907 "uzp2 v11.8b, v8.8b, v9.8b \n"
908
909 "uzp1 v4.16b, v0.16b, v1.16b \n"
910 "uzp2 v5.16b, v0.16b, v1.16b \n"
911 "uzp1 v6.16b, v2.16b, v3.16b \n"
912 "uzp2 v7.16b, v2.16b, v3.16b \n"
913
914 "uzp1 v8.8b, v10.8b, v11.8b \n"
915 "uzp2 v9.8b, v10.8b, v11.8b \n"
916
917 "uzp1 v0.16b, v4.16b, v5.16b \n" // 0 1 4 5
918 "uzp2 v1.16b, v4.16b, v5.16b \n" // 8 9 c d
919
920 "mov v8.d[1], v9.d[0] \n" // _w
921
922 "uzp1 v2.16b, v6.16b, v7.16b \n" // 2 3 6 7
923 "uzp2 v3.16b, v6.16b, v7.16b \n" // a b e f
924
925 "sdot v16.4s, v8.16b, v0.4b[0] \n"
926 "sdot v17.4s, v8.16b, v0.4b[1] \n"
927 "sdot v18.4s, v8.16b, v2.4b[0] \n"
928 "sdot v19.4s, v8.16b, v2.4b[1] \n"
929 "sdot v20.4s, v8.16b, v0.4b[2] \n"
930 "sdot v21.4s, v8.16b, v0.4b[3] \n"
931 "sdot v22.4s, v8.16b, v2.4b[2] \n"
932 "sdot v23.4s, v8.16b, v2.4b[3] \n"
933 "sdot v24.4s, v8.16b, v1.4b[0] \n"
934 "sdot v25.4s, v8.16b, v1.4b[1] \n"
935 "sdot v26.4s, v8.16b, v3.4b[0] \n"
936 "sdot v27.4s, v8.16b, v3.4b[1] \n"
937 "sdot v28.4s, v8.16b, v1.4b[2] \n"
938 "sdot v29.4s, v8.16b, v1.4b[3] \n"
939 "sdot v30.4s, v8.16b, v3.4b[2] \n"
940 "sdot v31.4s, v8.16b, v3.4b[3] \n"
941
942 "subs w4, w4, #1 \n"
943 "bne 4b \n"
944
945 "5: \n"
946
947 "and w4, %w3, #3 \n" // w4 = remain = nn1 & 3
948 "cmp w4, #0 \n" // w4 > 0
949 "beq 7f \n"
950
951 "6: \n"
952
953 "ld1 {v1.8b}, [%5] \n"
954 "ld1 {v0.16b}, [%4] \n"
955
956 "sshll v1.8h, v1.8b, #0 \n"
957 "sshll v2.8h, v0.8b, #0 \n"
958 "sshll2 v3.8h, v0.16b, #0 \n"
959
960 "smlal v16.4s, v1.4h, v2.h[0] \n"
961 "smlal v17.4s, v1.4h, v2.h[1] \n"
962 "smlal v18.4s, v1.4h, v2.h[2] \n"
963 "smlal v19.4s, v1.4h, v2.h[3] \n"
964 "smlal v20.4s, v1.4h, v2.h[4] \n"
965 "smlal v21.4s, v1.4h, v2.h[5] \n"
966 "smlal v22.4s, v1.4h, v2.h[6] \n"
967 "smlal v23.4s, v1.4h, v2.h[7] \n"
968 "smlal v24.4s, v1.4h, v3.h[0] \n"
969 "smlal v25.4s, v1.4h, v3.h[1] \n"
970 "smlal v26.4s, v1.4h, v3.h[2] \n"
971 "smlal v27.4s, v1.4h, v3.h[3] \n"
972 "smlal v28.4s, v1.4h, v3.h[4] \n"
973 "smlal v29.4s, v1.4h, v3.h[5] \n"
974 "smlal v30.4s, v1.4h, v3.h[6] \n"
975 "smlal v31.4s, v1.4h, v3.h[7] \n"
976
977 "add %4, %4, #16 \n"
978 "add %5, %5, #4 \n"
979
980 "subs w4, w4, #1 \n"
981 "bne 6b \n"
982
983 "7: \n"
984
985 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0], #64 \n"
986 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%0], #64 \n"
987 "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%0], #64 \n"
988 "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%0], #64 \n"
989 : "=r"(outptr0),
990 "=r"(nn),
991 "=r"(nn4),
992 "=r"(nn1),
993 "=r"(tmpptr),
994 "=r"(kptr0)
995 : "0"(outptr0),
996 "1"(nn),
997 "2"(nn4),
998 "3"(nn1),
999 "4"(tmpptr),
1000 "5"(kptr0)
1001 : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1002 }
1003 for (; i + 7 < size; i += 8)
1004 {
1005 const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8);
1006 const signed char* kptr0 = kernel.channel(p);
1007
1008 int nn = (inch / 8) * maxk;
1009 int nn4 = ((inch % 8) / 4) * maxk;
1010 int nn1 = (inch % 4) * maxk;
1011
1012 int32x4_t _sum0 = vdupq_n_s32(0);
1013 int32x4_t _sum1 = vdupq_n_s32(0);
1014 int32x4_t _sum2 = vdupq_n_s32(0);
1015 int32x4_t _sum3 = vdupq_n_s32(0);
1016 int32x4_t _sum4 = vdupq_n_s32(0);
1017 int32x4_t _sum5 = vdupq_n_s32(0);
1018 int32x4_t _sum6 = vdupq_n_s32(0);
1019 int32x4_t _sum7 = vdupq_n_s32(0);
1020
1021 for (int j = 0; j < nn; j++)
1022 {
1023 int8x16_t _val0123_l = vld1q_s8(tmpptr);
1024 int8x16_t _val4567_l = vld1q_s8(tmpptr + 16);
1025
1026 int8x16_t _w0123_l = vld1q_s8(kptr0);
1027
1028 _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0);
1029 _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1);
1030 _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2);
1031 _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3);
1032 _sum4 = vdotq_laneq_s32(_sum4, _w0123_l, _val4567_l, 0);
1033 _sum5 = vdotq_laneq_s32(_sum5, _w0123_l, _val4567_l, 1);
1034 _sum6 = vdotq_laneq_s32(_sum6, _w0123_l, _val4567_l, 2);
1035 _sum7 = vdotq_laneq_s32(_sum7, _w0123_l, _val4567_l, 3);
1036
1037 int8x16_t _val0123_h = vld1q_s8(tmpptr + 32);
1038 int8x16_t _val4567_h = vld1q_s8(tmpptr + 48);
1039
1040 int8x16_t _w0123_h = vld1q_s8(kptr0 + 16);
1041
1042 _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0);
1043 _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1);
1044 _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2);
1045 _sum3 = vdotq_laneq_s32(_sum3, _w0123_h, _val0123_h, 3);
1046 _sum4 = vdotq_laneq_s32(_sum4, _w0123_h, _val4567_h, 0);
1047 _sum5 = vdotq_laneq_s32(_sum5, _w0123_h, _val4567_h, 1);
1048 _sum6 = vdotq_laneq_s32(_sum6, _w0123_h, _val4567_h, 2);
1049 _sum7 = vdotq_laneq_s32(_sum7, _w0123_h, _val4567_h, 3);
1050
1051 tmpptr += 64;
1052 kptr0 += 32;
1053 }
1054
1055 for (int j = 0; j < nn4; j++)
1056 {
1057 int8x16_t _val0123 = vld1q_s8(tmpptr);
1058 int8x16_t _val4567 = vld1q_s8(tmpptr + 16);
1059 int8x16_t _w0 = vld1q_s8(kptr0);
1060
1061 _sum0 = vdotq_laneq_s32(_sum0, _w0, _val0123, 0);
1062 _sum1 = vdotq_laneq_s32(_sum1, _w0, _val0123, 1);
1063 _sum2 = vdotq_laneq_s32(_sum2, _w0, _val0123, 2);
1064 _sum3 = vdotq_laneq_s32(_sum3, _w0, _val0123, 3);
1065 _sum4 = vdotq_laneq_s32(_sum4, _w0, _val4567, 0);
1066 _sum5 = vdotq_laneq_s32(_sum5, _w0, _val4567, 1);
1067 _sum6 = vdotq_laneq_s32(_sum6, _w0, _val4567, 2);
1068 _sum7 = vdotq_laneq_s32(_sum7, _w0, _val4567, 3);
1069
1070 tmpptr += 32;
1071 kptr0 += 16;
1072 }
1073
1074 int j = 0;
1075 for (; j + 3 < nn1; j += 4)
1076 {
1077 int8x8x4_t _val4 = vld4_s8(tmpptr);
1078
1079 int8x8x2_t _val0145 = vuzp_s8(_val4.val[0], _val4.val[1]);
1080 int8x8x2_t _val2367 = vuzp_s8(_val4.val[2], _val4.val[3]);
1081
1082 int8x16_t _val0123 = vcombine_s8(_val0145.val[0], _val2367.val[0]);
1083 int8x16_t _val4567 = vcombine_s8(_val0145.val[1], _val2367.val[1]);
1084
1085 int8x16_t _w = vld1q_s8(kptr0);
1086
1087 int8x8x2_t _w01 = vuzp_s8(vget_low_s8(_w), vget_high_s8(_w));
1088 int8x8x2_t _w0123 = vuzp_s8(_w01.val[0], _w01.val[1]);
1089 int8x16_t _w0123f = vcombine_s8(_w0123.val[0], _w0123.val[1]);
1090
1091 _sum0 = vdotq_laneq_s32(_sum0, _w0123f, _val0123, 0);
1092 _sum1 = vdotq_laneq_s32(_sum1, _w0123f, _val0123, 1);
1093 _sum2 = vdotq_laneq_s32(_sum2, _w0123f, _val0123, 2);
1094 _sum3 = vdotq_laneq_s32(_sum3, _w0123f, _val0123, 3);
1095 _sum4 = vdotq_laneq_s32(_sum4, _w0123f, _val4567, 0);
1096 _sum5 = vdotq_laneq_s32(_sum5, _w0123f, _val4567, 1);
1097 _sum6 = vdotq_laneq_s32(_sum6, _w0123f, _val4567, 2);
1098 _sum7 = vdotq_laneq_s32(_sum7, _w0123f, _val4567, 3);
1099
1100 tmpptr += 32;
1101 kptr0 += 16;
1102 }
1103 for (; j < nn1; j++)
1104 {
1105 int16x4_t _val0 = vdup_n_s16(tmpptr[0]);
1106 int16x4_t _val1 = vdup_n_s16(tmpptr[1]);
1107 int16x4_t _val2 = vdup_n_s16(tmpptr[2]);
1108 int16x4_t _val3 = vdup_n_s16(tmpptr[3]);
1109 int16x4_t _val4 = vdup_n_s16(tmpptr[4]);
1110 int16x4_t _val5 = vdup_n_s16(tmpptr[5]);
1111 int16x4_t _val6 = vdup_n_s16(tmpptr[6]);
1112 int16x4_t _val7 = vdup_n_s16(tmpptr[7]);
1113
1114 int16x4_t _w0123;
1115 _w0123 = vset_lane_s16(kptr0[0], _w0123, 0);
1116 _w0123 = vset_lane_s16(kptr0[1], _w0123, 1);
1117 _w0123 = vset_lane_s16(kptr0[2], _w0123, 2);
1118 _w0123 = vset_lane_s16(kptr0[3], _w0123, 3);
1119
1120 _sum0 = vmlal_s16(_sum0, _val0, _w0123);
1121 _sum1 = vmlal_s16(_sum1, _val1, _w0123);
1122 _sum2 = vmlal_s16(_sum2, _val2, _w0123);
1123 _sum3 = vmlal_s16(_sum3, _val3, _w0123);
1124 _sum4 = vmlal_s16(_sum4, _val4, _w0123);
1125 _sum5 = vmlal_s16(_sum5, _val5, _w0123);
1126 _sum6 = vmlal_s16(_sum6, _val6, _w0123);
1127 _sum7 = vmlal_s16(_sum7, _val7, _w0123);
1128
1129 tmpptr += 8;
1130 kptr0 += 4;
1131 }
1132
1133 vst1q_s32(outptr0, _sum0);
1134 vst1q_s32(outptr0 + 4, _sum1);
1135 vst1q_s32(outptr0 + 8, _sum2);
1136 vst1q_s32(outptr0 + 12, _sum3);
1137 vst1q_s32(outptr0 + 16, _sum4);
1138 vst1q_s32(outptr0 + 20, _sum5);
1139 vst1q_s32(outptr0 + 24, _sum6);
1140 vst1q_s32(outptr0 + 28, _sum7);
1141 outptr0 += 32;
1142 }
1143 #endif
1144 for (; i + 3 < size; i += 4)
1145 {
1146 #if __ARM_FEATURE_DOTPROD
1147 const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4);
1148 #else
1149 const signed char* tmpptr = tmp.channel(i / 4);
1150 #endif
1151 const signed char* kptr0 = kernel.channel(p);
1152
1153 int nn = (inch / 8) * maxk;
1154 int nn4 = ((inch % 8) / 4) * maxk;
1155 int nn1 = (inch % 4) * maxk;
1156 #if __ARM_FEATURE_DOTPROD
1157 int32x4_t _sum0 = vdupq_n_s32(0);
1158 int32x4_t _sum1 = vdupq_n_s32(0);
1159 int32x4_t _sum2 = vdupq_n_s32(0);
1160 int32x4_t _sum3 = vdupq_n_s32(0);
1161
1162 for (int j = 0; j < nn; j++)
1163 {
1164 int8x16_t _val0123_l = vld1q_s8(tmpptr);
1165 int8x16_t _w0123_l = vld1q_s8(kptr0);
1166
1167 _sum0 = vdotq_laneq_s32(_sum0, _w0123_l, _val0123_l, 0);
1168 _sum1 = vdotq_laneq_s32(_sum1, _w0123_l, _val0123_l, 1);
1169 _sum2 = vdotq_laneq_s32(_sum2, _w0123_l, _val0123_l, 2);
1170 _sum3 = vdotq_laneq_s32(_sum3, _w0123_l, _val0123_l, 3);
1171
1172 int8x16_t _val0123_h = vld1q_s8(tmpptr + 16);
1173 int8x16_t _w0123_h = vld1q_s8(kptr0 + 16);
1174
1175 _sum0 = vdotq_laneq_s32(_sum0, _w0123_h, _val0123_h, 0);
1176 _sum1 = vdotq_laneq_s32(_sum1, _w0123_h, _val0123_h, 1);
1177 _sum2 = vdotq_laneq_s32(_sum2, _w0123_h, _val0123_h, 2);
1178 _sum3 = vdotq_laneq_s32(_sum3, _w0123_h, _val0123_h, 3);
1179
1180 tmpptr += 32;
1181 kptr0 += 32;
1182 }
1183
1184 for (int j = 0; j < nn4; j++)
1185 {
1186 int8x16_t _val0123 = vld1q_s8(tmpptr);
1187 int8x16_t _w0 = vld1q_s8(kptr0);
1188
1189 _sum0 = vdotq_laneq_s32(_sum0, _w0, _val0123, 0);
1190 _sum1 = vdotq_laneq_s32(_sum1, _w0, _val0123, 1);
1191 _sum2 = vdotq_laneq_s32(_sum2, _w0, _val0123, 2);
1192 _sum3 = vdotq_laneq_s32(_sum3, _w0, _val0123, 3);
1193
1194 tmpptr += 16;
1195 kptr0 += 16;
1196 }
1197
1198 int j = 0;
1199 for (; j + 3 < nn1; j += 4)
1200 {
1201 int8x16_t _val = vld1q_s8(tmpptr);
1202
1203 int8x8x2_t _val01 = vuzp_s8(vget_low_s8(_val), vget_high_s8(_val));
1204 int8x8x2_t _val0123 = vuzp_s8(_val01.val[0], _val01.val[1]);
1205 int8x16_t _val0123f = vcombine_s8(_val0123.val[0], _val0123.val[1]);
1206
1207 int8x16_t _w = vld1q_s8(kptr0);
1208
1209 int8x8x2_t _w01 = vuzp_s8(vget_low_s8(_w), vget_high_s8(_w));
1210 int8x8x2_t _w0123 = vuzp_s8(_w01.val[0], _w01.val[1]);
1211 int8x16_t _w0123f = vcombine_s8(_w0123.val[0], _w0123.val[1]);
1212
1213 _sum0 = vdotq_laneq_s32(_sum0, _w0123f, _val0123f, 0);
1214 _sum1 = vdotq_laneq_s32(_sum1, _w0123f, _val0123f, 1);
1215 _sum2 = vdotq_laneq_s32(_sum2, _w0123f, _val0123f, 2);
1216 _sum3 = vdotq_laneq_s32(_sum3, _w0123f, _val0123f, 3);
1217
1218 tmpptr += 16;
1219 kptr0 += 16;
1220 }
1221 for (; j < nn1; j++)
1222 {
1223 int16x4_t _val0 = vdup_n_s16(tmpptr[0]);
1224 int16x4_t _val1 = vdup_n_s16(tmpptr[1]);
1225 int16x4_t _val2 = vdup_n_s16(tmpptr[2]);
1226 int16x4_t _val3 = vdup_n_s16(tmpptr[3]);
1227
1228 int16x4_t _w0123;
1229 _w0123 = vset_lane_s16(kptr0[0], _w0123, 0);
1230 _w0123 = vset_lane_s16(kptr0[1], _w0123, 1);
1231 _w0123 = vset_lane_s16(kptr0[2], _w0123, 2);
1232 _w0123 = vset_lane_s16(kptr0[3], _w0123, 3);
1233
1234 _sum0 = vmlal_s16(_sum0, _val0, _w0123);
1235 _sum1 = vmlal_s16(_sum1, _val1, _w0123);
1236 _sum2 = vmlal_s16(_sum2, _val2, _w0123);
1237 _sum3 = vmlal_s16(_sum3, _val3, _w0123);
1238
1239 tmpptr += 4;
1240 kptr0 += 4;
1241 }
1242
1243 vst1q_s32(outptr0, _sum0);
1244 vst1q_s32(outptr0 + 4, _sum1);
1245 vst1q_s32(outptr0 + 8, _sum2);
1246 vst1q_s32(outptr0 + 12, _sum3);
1247 outptr0 += 16;
1248 #else // __ARM_FEATURE_DOTPROD
1249 asm volatile(
1250 "eor v0.16b, v0.16b, v0.16b \n"
1251 "eor v1.16b, v1.16b, v1.16b \n"
1252 "eor v2.16b, v2.16b, v2.16b \n"
1253 "eor v3.16b, v3.16b, v3.16b \n"
1254
1255 "cmp %w1, #0 \n"
1256 "beq 3f \n"
1257
1258 "eor v4.16b, v4.16b, v4.16b \n"
1259 "eor v5.16b, v5.16b, v5.16b \n"
1260 "eor v6.16b, v6.16b, v6.16b \n"
1261 "eor v7.16b, v7.16b, v7.16b \n"
1262 "eor v8.16b, v8.16b, v8.16b \n"
1263 "eor v9.16b, v9.16b, v9.16b \n"
1264 "eor v10.16b, v10.16b, v10.16b \n"
1265 "eor v11.16b, v11.16b, v11.16b \n"
1266 "eor v12.16b, v12.16b, v12.16b \n"
1267 "eor v13.16b, v13.16b, v13.16b \n"
1268 "eor v14.16b, v14.16b, v14.16b \n"
1269 "eor v15.16b, v15.16b, v15.16b \n"
1270
1271 "prfm pldl1keep, [%4, #128] \n"
1272
1273 "prfm pldl1keep, [%5, #256] \n"
1274
1275 "lsr w4, %w1, #1 \n" // w4 = nn >> 1
1276 "cmp w4, #0 \n"
1277 "beq 1f \n"
1278
1279 "prfm pldl1keep, [%5, #512] \n"
1280
1281 "add x5, %4, #16 \n"
1282
1283 "prfm pldl1keep, [x5, #128] \n"
1284
1285 "ld1 {v16.16b}, [%4] \n" // val L H
1286 "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%5], #64 \n"
1287 "add %4, %4, #32 \n"
1288 "ext v17.16b, v16.16b, v16.16b, #8 \n" // val H L
1289
1290 "ld1 {v18.16b}, [%4] \n"
1291 "add %4, %4, #32 \n"
1292
1293 "0: \n"
1294
1295 "smull v24.8h, v16.8b, v20.8b \n"
1296 "prfm pldl1keep, [%5, #256] \n"
1297 "smull2 v25.8h, v17.16b, v20.16b \n"
1298 "prfm pldl1keep, [%5, #512] \n"
1299 "smull v26.8h, v16.8b, v21.8b \n"
1300 "subs w4, w4, #1 \n"
1301 "smull2 v27.8h, v17.16b, v21.16b \n"
1302 "ext v19.16b, v18.16b, v18.16b, #8 \n" // val H L
1303
1304 "smlal v24.8h, v18.8b, v22.8b \n"
1305 "smlal2 v25.8h, v19.16b, v22.16b \n"
1306 "smlal v26.8h, v18.8b, v23.8b \n"
1307 "smlal2 v27.8h, v19.16b, v23.16b \n"
1308
1309 "smull2 v29.8h, v16.16b, v20.16b \n"
1310 "sadalp v0.4s, v24.8h \n"
1311 "smull v28.8h, v17.8b, v20.8b \n"
1312 "sadalp v1.4s, v25.8h \n"
1313 "smull2 v31.8h, v16.16b, v21.16b \n"
1314 "ld1 {v16.16b}, [x5] \n" // val L H
1315 "smull v30.8h, v17.8b, v21.8b \n"
1316 "add x5, x5, #32 \n"
1317 "smlal2 v29.8h, v18.16b, v22.16b \n"
1318 "sadalp v2.4s, v26.8h \n"
1319 "smlal v28.8h, v19.8b, v22.8b \n"
1320 "sadalp v3.4s, v27.8h \n"
1321 "smlal2 v31.8h, v18.16b, v23.16b \n"
1322 "ld1 {v18.16b}, [x5] \n"
1323 "smlal v30.8h, v19.8b, v23.8b \n"
1324 "ext v17.16b, v16.16b, v16.16b, #8 \n" // val H L
1325
1326 "smull v24.8h, v16.8b, v20.8b \n"
1327 "add x5, x5, #32 \n"
1328 "smull2 v25.8h, v17.16b, v20.16b \n"
1329 "prfm pldl1keep, [x5, #128] \n"
1330 "smull v26.8h, v16.8b, v21.8b \n"
1331 "prfm pldl1keep, [x5, #384] \n"
1332 "smull2 v27.8h, v17.16b, v21.16b \n"
1333 "ext v19.16b, v18.16b, v18.16b, #8 \n" // val H L
1334
1335 "smlal v24.8h, v18.8b, v22.8b \n"
1336 "sadalp v5.4s, v29.8h \n"
1337 "smlal2 v25.8h, v19.16b, v22.16b \n"
1338 "sadalp v4.4s, v28.8h \n"
1339 "smlal v26.8h, v18.8b, v23.8b \n"
1340 "sadalp v7.4s, v31.8h \n"
1341 "smlal2 v27.8h, v19.16b, v23.16b \n"
1342 "sadalp v6.4s, v30.8h \n"
1343
1344 "smull2 v29.8h, v16.16b, v20.16b \n"
1345 "sadalp v8.4s, v24.8h \n"
1346 "smull v28.8h, v17.8b, v20.8b \n"
1347 "sadalp v9.4s, v25.8h \n"
1348 "smull2 v31.8h, v16.16b, v21.16b \n"
1349 "ld1 {v16.16b}, [%4] \n" // val L H
1350 "smull v30.8h, v17.8b, v21.8b \n"
1351 "add %4, %4, #32 \n"
1352 "smlal2 v29.8h, v18.16b, v22.16b \n"
1353 "sadalp v10.4s, v26.8h \n"
1354 "smlal v28.8h, v19.8b, v22.8b \n"
1355 "sadalp v11.4s, v27.8h \n"
1356 "smlal2 v31.8h, v18.16b, v23.16b \n"
1357 "ld1 {v18.16b}, [%4] \n"
1358 "smlal v30.8h, v19.8b, v23.8b \n"
1359 "add %4, %4, #32 \n"
1360 "ld1 {v20.16b, v21.16b, v22.16b, v23.16b}, [%5], #64 \n"
1361
1362 "sadalp v13.4s, v29.8h \n"
1363 "prfm pldl1keep, [%4, #128] \n"
1364 "sadalp v12.4s, v28.8h \n"
1365 "prfm pldl1keep, [%4, #384] \n"
1366 "sadalp v15.4s, v31.8h \n"
1367 "ext v17.16b, v16.16b, v16.16b, #8 \n" // val H L
1368
1369 "sadalp v14.4s, v30.8h \n"
1370
1371 "bne 0b \n"
1372
1373 "sub %4, %4, #64 \n"
1374 "sub %5, %5, #64 \n"
1375
1376 "1: \n"
1377 "and w4, %w1, #1 \n" // w4 = remain = nn & 1
1378 "cmp w4, #0 \n" // w4 > 0
1379 "beq 2f \n"
1380
1381 "ld1 {v16.8b, v17.8b}, [%4], #16 \n"
1382 "ld1 {v20.8b, v21.8b, v22.8b, v23.8b}, [%5], #32 \n"
1383
1384 "smull v24.8h, v16.8b, v20.8b \n"
1385 "smull v25.8h, v16.8b, v21.8b \n"
1386 "smull v26.8h, v16.8b, v22.8b \n"
1387 "ld1 {v18.8b, v19.8b}, [%4], #16 \n"
1388 "smull v27.8h, v16.8b, v23.8b \n"
1389 "sadalp v0.4s, v24.8h \n"
1390 "smull v28.8h, v17.8b, v20.8b \n"
1391 "sadalp v1.4s, v25.8h \n"
1392 "smull v29.8h, v17.8b, v21.8b \n"
1393 "sadalp v2.4s, v26.8h \n"
1394 "smull v30.8h, v17.8b, v22.8b \n"
1395 "sadalp v3.4s, v27.8h \n"
1396 "smull v31.8h, v17.8b, v23.8b \n"
1397 "sadalp v4.4s, v28.8h \n"
1398 "smull v24.8h, v18.8b, v20.8b \n"
1399 "sadalp v5.4s, v29.8h \n"
1400 "smull v25.8h, v18.8b, v21.8b \n"
1401 "sadalp v6.4s, v30.8h \n"
1402 "smull v26.8h, v18.8b, v22.8b \n"
1403 "sadalp v7.4s, v31.8h \n"
1404 "smull v27.8h, v18.8b, v23.8b \n"
1405 "sadalp v8.4s, v24.8h \n"
1406 "smull v28.8h, v19.8b, v20.8b \n"
1407 "sadalp v9.4s, v25.8h \n"
1408 "smull v29.8h, v19.8b, v21.8b \n"
1409 "sadalp v10.4s, v26.8h \n"
1410 "smull v30.8h, v19.8b, v22.8b \n"
1411 "sadalp v11.4s, v27.8h \n"
1412 "smull v31.8h, v19.8b, v23.8b \n"
1413
1414 "sadalp v12.4s, v28.8h \n"
1415 "sadalp v13.4s, v29.8h \n"
1416 "sadalp v14.4s, v30.8h \n"
1417 "sadalp v15.4s, v31.8h \n"
1418
1419 "2: \n"
1420
1421 "addp v0.4s, v0.4s, v1.4s \n"
1422 "addp v2.4s, v2.4s, v3.4s \n"
1423 "addp v4.4s, v4.4s, v5.4s \n"
1424 "addp v6.4s, v6.4s, v7.4s \n"
1425 "addp v8.4s, v8.4s, v9.4s \n"
1426 "addp v10.4s, v10.4s, v11.4s \n"
1427 "addp v12.4s, v12.4s, v13.4s \n"
1428 "addp v14.4s, v14.4s, v15.4s \n"
1429
1430 "addp v0.4s, v0.4s, v2.4s \n"
1431 "addp v1.4s, v4.4s, v6.4s \n"
1432 "addp v2.4s, v8.4s, v10.4s \n"
1433 "addp v3.4s, v12.4s, v14.4s \n"
1434
1435 "3: \n"
1436
1437 "cmp %w2, #0 \n"
1438 "beq 7f \n"
1439
1440 "eor v8.16b, v8.16b, v8.16b \n"
1441 "eor v9.16b, v9.16b, v9.16b \n"
1442 "eor v10.16b, v10.16b, v10.16b \n"
1443 "eor v11.16b, v11.16b, v11.16b \n"
1444 "eor v12.16b, v12.16b, v12.16b \n"
1445 "eor v13.16b, v13.16b, v13.16b \n"
1446 "eor v14.16b, v14.16b, v14.16b \n"
1447 "eor v15.16b, v15.16b, v15.16b \n"
1448
1449 "lsr w4, %w2, #1 \n" // w4 = nn4 >> 1
1450 "cmp w4, #0 \n"
1451 "beq 5f \n"
1452
1453 "4: \n"
1454
1455 "ld1 {v16.8b, v17.8b}, [%4], #16 \n"
1456 "ld1 {v22.8b, v23.8b}, [%5], #16 \n"
1457
1458 "zip1 v18.2s, v16.2s, v16.2s \n" // _val00
1459 "zip2 v19.2s, v16.2s, v16.2s \n" // _val11
1460
1461 "smull v24.8h, v18.8b, v22.8b \n"
1462 "smull v25.8h, v18.8b, v23.8b \n"
1463
1464 "zip1 v20.2s, v17.2s, v17.2s \n" // _val22
1465
1466 "smull v26.8h, v19.8b, v22.8b \n"
1467 "smull v27.8h, v19.8b, v23.8b \n"
1468
1469 "zip2 v21.2s, v17.2s, v17.2s \n" // _val33
1470
1471 "smull v28.8h, v20.8b, v22.8b \n"
1472 "smull v29.8h, v20.8b, v23.8b \n"
1473
1474 "ld1 {v16.8b, v17.8b}, [%4], #16 \n"
1475
1476 "smull v30.8h, v21.8b, v22.8b \n"
1477 "smull v31.8h, v21.8b, v23.8b \n"
1478
1479 "ld1 {v22.8b, v23.8b}, [%5], #16 \n"
1480
1481 "zip1 v18.2s, v16.2s, v16.2s \n" // _val44
1482 "zip2 v19.2s, v16.2s, v16.2s \n" // _val55
1483
1484 "smlal v24.8h, v18.8b, v22.8b \n"
1485 "smlal v25.8h, v18.8b, v23.8b \n"
1486
1487 "zip1 v20.2s, v17.2s, v17.2s \n" // _val66
1488
1489 "smlal v26.8h, v19.8b, v22.8b \n"
1490 "smlal v27.8h, v19.8b, v23.8b \n"
1491
1492 "zip2 v21.2s, v17.2s, v17.2s \n" // _val77
1493
1494 "sadalp v8.4s, v24.8h \n"
1495 "smlal v28.8h, v20.8b, v22.8b \n"
1496 "sadalp v9.4s, v25.8h \n"
1497 "smlal v29.8h, v20.8b, v23.8b \n"
1498 "sadalp v10.4s, v26.8h \n"
1499 "smlal v30.8h, v21.8b, v22.8b \n"
1500 "sadalp v11.4s, v27.8h \n"
1501 "smlal v31.8h, v21.8b, v23.8b \n"
1502 "sadalp v12.4s, v28.8h \n"
1503 "sadalp v13.4s, v29.8h \n"
1504
1505 "subs w4, w4, #1 \n"
1506
1507 "sadalp v14.4s, v30.8h \n"
1508 "sadalp v15.4s, v31.8h \n"
1509
1510 "bne 4b \n"
1511
1512 "5: \n"
1513
1514 "and w4, %w2, #1 \n" // w4 = remain = nn4 & 1
1515 "cmp w4, #0 \n" // w4 > 0
1516 "beq 6f \n"
1517
1518 "ld1 {v16.8b, v17.8b}, [%4], #16 \n"
1519 "ld1 {v22.8b, v23.8b}, [%5], #16 \n"
1520
1521 "zip1 v18.2s, v16.2s, v16.2s \n" // _val00
1522 "zip2 v19.2s, v16.2s, v16.2s \n" // _val11
1523
1524 "smull v24.8h, v18.8b, v22.8b \n"
1525 "smull v25.8h, v18.8b, v23.8b \n"
1526
1527 "zip1 v20.2s, v17.2s, v17.2s \n" // _val22
1528
1529 "smull v26.8h, v19.8b, v22.8b \n"
1530 "smull v27.8h, v19.8b, v23.8b \n"
1531
1532 "zip2 v21.2s, v17.2s, v17.2s \n" // _val33
1533
1534 "sadalp v8.4s, v24.8h \n"
1535 "smull v28.8h, v20.8b, v22.8b \n"
1536 "sadalp v9.4s, v25.8h \n"
1537 "smull v29.8h, v20.8b, v23.8b \n"
1538 "sadalp v10.4s, v26.8h \n"
1539 "smull v30.8h, v21.8b, v22.8b \n"
1540 "sadalp v11.4s, v27.8h \n"
1541 "smull v31.8h, v21.8b, v23.8b \n"
1542 "sadalp v12.4s, v28.8h \n"
1543 "sadalp v13.4s, v29.8h \n"
1544 "sadalp v14.4s, v30.8h \n"
1545 "sadalp v15.4s, v31.8h \n"
1546
1547 "6: \n"
1548
1549 "addp v8.4s, v8.4s, v9.4s \n"
1550 "addp v10.4s, v10.4s, v11.4s \n"
1551 "addp v12.4s, v12.4s, v13.4s \n"
1552 "addp v14.4s, v14.4s, v15.4s \n"
1553
1554 "add v0.4s, v0.4s, v8.4s \n"
1555 "add v1.4s, v1.4s, v10.4s \n"
1556 "add v2.4s, v2.4s, v12.4s \n"
1557 "add v3.4s, v3.4s, v14.4s \n"
1558
1559 "7: \n"
1560
1561 "lsr w4, %w3, #2 \n" // w4 = nn1 >> 2
1562 "cmp w4, #0 \n"
1563 "beq 9f \n"
1564
1565 "8: \n"
1566
1567 "ld1 {v8.16b}, [%4], #16 \n"
1568 "ld1 {v9.16b}, [%5], #16 \n"
1569
1570 "sshll v4.8h, v8.8b, #0 \n"
1571 "sshll2 v5.8h, v8.16b, #0 \n"
1572 "sshll v6.8h, v9.8b, #0 \n"
1573 "sshll2 v7.8h, v9.16b, #0 \n"
1574
1575 "smlal v0.4s, v6.4h, v4.h[0] \n"
1576 "smlal v1.4s, v6.4h, v4.h[1] \n"
1577 "smlal v2.4s, v6.4h, v4.h[2] \n"
1578 "smlal v3.4s, v6.4h, v4.h[3] \n"
1579 "smlal2 v0.4s, v6.8h, v4.h[4] \n"
1580 "smlal2 v1.4s, v6.8h, v4.h[5] \n"
1581 "smlal2 v2.4s, v6.8h, v4.h[6] \n"
1582 "smlal2 v3.4s, v6.8h, v4.h[7] \n"
1583 "smlal v0.4s, v7.4h, v5.h[0] \n"
1584 "smlal v1.4s, v7.4h, v5.h[1] \n"
1585 "smlal v2.4s, v7.4h, v5.h[2] \n"
1586 "smlal v3.4s, v7.4h, v5.h[3] \n"
1587 "smlal2 v0.4s, v7.8h, v5.h[4] \n"
1588 "smlal2 v1.4s, v7.8h, v5.h[5] \n"
1589 "smlal2 v2.4s, v7.8h, v5.h[6] \n"
1590 "smlal2 v3.4s, v7.8h, v5.h[7] \n"
1591
1592 "subs w4, w4, #1 \n"
1593 "bne 8b \n"
1594
1595 "9: \n"
1596
1597 "and w4, %w3, #3 \n" // w4 = nn1 & 3
1598 "cmp w4, #0 \n" // w4 > 0
1599 "beq 11f \n"
1600
1601 "10: \n"
1602
1603 "ld1 {v4.8b}, [%4] \n"
1604 "ld1 {v6.8b}, [%5] \n"
1605
1606 "sshll v4.8h, v4.8b, #0 \n"
1607 "sshll v6.8h, v6.8b, #0 \n"
1608
1609 "smlal v0.4s, v6.4h, v4.h[0] \n"
1610 "smlal v1.4s, v6.4h, v4.h[1] \n"
1611 "smlal v2.4s, v6.4h, v4.h[2] \n"
1612 "smlal v3.4s, v6.4h, v4.h[3] \n"
1613
1614 "add %4, %4, #4 \n"
1615 "add %5, %5, #4 \n"
1616
1617 "subs w4, w4, #1 \n"
1618 "bne 10b \n"
1619
1620 "11: \n"
1621
1622 "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
1623
1624 : "=r"(outptr0),
1625 "=r"(nn),
1626 "=r"(nn4),
1627 "=r"(nn1),
1628 "=r"(tmpptr),
1629 "=r"(kptr0)
1630 : "0"(outptr0),
1631 "1"(nn),
1632 "2"(nn4),
1633 "3"(nn1),
1634 "4"(tmpptr),
1635 "5"(kptr0)
1636 : "memory", "x4", "x5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1637 #endif // __ARM_FEATURE_DOTPROD
1638 }
1639 #endif // __aarch64__
1640 for (; i + 1 < size; i += 2)
1641 {
1642 #if __aarch64__
1643 #if __ARM_FEATURE_DOTPROD
1644 const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2);
1645 #else
1646 const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2);
1647 #endif
1648 #else
1649 const signed char* tmpptr = tmp.channel(i / 2);
1650 #endif
1651 const signed char* kptr0 = kernel.channel(p);
1652
1653 int nn = (inch / 8) * maxk;
1654 int nn4 = ((inch % 8) / 4) * maxk;
1655 int nn1 = (inch % 4) * maxk;
1656 #if __aarch64__
1657 int32x4_t _sum00 = vdupq_n_s32(0);
1658 int32x4_t _sum10 = vdupq_n_s32(0);
1659 #if __ARM_FEATURE_DOTPROD
1660 for (int j = 0; j < nn; j++)
1661 {
1662 int8x16_t _val01_l_h = vld1q_s8(tmpptr);
1663 int8x16_t _w0123_l = vld1q_s8(kptr0);
1664
1665 _sum00 = vdotq_laneq_s32(_sum00, _w0123_l, _val01_l_h, 0);
1666 _sum10 = vdotq_laneq_s32(_sum10, _w0123_l, _val01_l_h, 1);
1667
1668 int8x16_t _w0123_h = vld1q_s8(kptr0 + 16);
1669
1670 _sum00 = vdotq_laneq_s32(_sum00, _w0123_h, _val01_l_h, 2);
1671 _sum10 = vdotq_laneq_s32(_sum10, _w0123_h, _val01_l_h, 3);
1672
1673 tmpptr += 16;
1674 kptr0 += 32;
1675 }
1676
1677 if (nn4 > 0)
1678 {
1679 int j = 0;
1680 for (; j + 1 < nn4; j += 2)
1681 {
1682 int8x16_t _val0123 = vld1q_s8(tmpptr);
1683 int8x16_t _w0 = vld1q_s8(kptr0);
1684
1685 _sum00 = vdotq_laneq_s32(_sum00, _w0, _val0123, 0);
1686 _sum10 = vdotq_laneq_s32(_sum10, _w0, _val0123, 1);
1687
1688 int8x16_t _w1 = vld1q_s8(kptr0 + 16);
1689
1690 _sum00 = vdotq_laneq_s32(_sum00, _w1, _val0123, 2);
1691 _sum10 = vdotq_laneq_s32(_sum10, _w1, _val0123, 3);
1692
1693 tmpptr += 16;
1694 kptr0 += 32;
1695 }
1696 for (; j < nn4; j++)
1697 {
1698 int8x8_t _val01 = vld1_s8(tmpptr);
1699 int8x16_t _w0 = vld1q_s8(kptr0);
1700
1701 _sum00 = vdotq_lane_s32(_sum00, _w0, _val01, 0);
1702 _sum10 = vdotq_lane_s32(_sum10, _w0, _val01, 1);
1703
1704 tmpptr += 8;
1705 kptr0 += 16;
1706 }
1707 }
1708 #else // __ARM_FEATURE_DOTPROD
1709 if (nn > 0)
1710 {
1711 int32x4_t _sum01 = vdupq_n_s32(0);
1712 int32x4_t _sum02 = vdupq_n_s32(0);
1713 int32x4_t _sum03 = vdupq_n_s32(0);
1714 int32x4_t _sum11 = vdupq_n_s32(0);
1715 int32x4_t _sum12 = vdupq_n_s32(0);
1716 int32x4_t _sum13 = vdupq_n_s32(0);
1717
1718 int j = 0;
1719 for (; j + 1 < nn; j += 2)
1720 {
1721 int8x16_t _val0 = vld1q_s8(tmpptr);
1722 int8x16_t _val1 = vld1q_s8(tmpptr + 16);
1723
1724 int8x16_t _w01 = vld1q_s8(kptr0);
1725 int8x16_t _w23 = vld1q_s8(kptr0 + 16);
1726
1727 int16x8_t _wv00 = vmull_s8(vget_low_s8(_val0), vget_low_s8(_w01));
1728 int16x8_t _wv01 = vmull_s8(vget_low_s8(_val0), vget_high_s8(_w01));
1729 int16x8_t _wv02 = vmull_s8(vget_low_s8(_val0), vget_low_s8(_w23));
1730 int16x8_t _wv03 = vmull_s8(vget_low_s8(_val0), vget_high_s8(_w23));
1731
1732 int16x8_t _wv10 = vmull_s8(vget_high_s8(_val0), vget_low_s8(_w01));
1733 int16x8_t _wv11 = vmull_s8(vget_high_s8(_val0), vget_high_s8(_w01));
1734 int16x8_t _wv12 = vmull_s8(vget_high_s8(_val0), vget_low_s8(_w23));
1735 int16x8_t _wv13 = vmull_s8(vget_high_s8(_val0), vget_high_s8(_w23));
1736
1737 int8x16_t _w45 = vld1q_s8(kptr0 + 32);
1738 int8x16_t _w67 = vld1q_s8(kptr0 + 48);
1739
1740 _wv00 = vmlal_s8(_wv00, vget_low_s8(_val1), vget_low_s8(_w45));
1741 _wv01 = vmlal_s8(_wv01, vget_low_s8(_val1), vget_high_s8(_w45));
1742 _wv02 = vmlal_s8(_wv02, vget_low_s8(_val1), vget_low_s8(_w67));
1743 _wv03 = vmlal_s8(_wv03, vget_low_s8(_val1), vget_high_s8(_w67));
1744
1745 _wv10 = vmlal_s8(_wv10, vget_high_s8(_val1), vget_low_s8(_w45));
1746 _wv11 = vmlal_s8(_wv11, vget_high_s8(_val1), vget_high_s8(_w45));
1747 _wv12 = vmlal_s8(_wv12, vget_high_s8(_val1), vget_low_s8(_w67));
1748 _wv13 = vmlal_s8(_wv13, vget_high_s8(_val1), vget_high_s8(_w67));
1749
1750 _sum00 = vpadalq_s16(_sum00, _wv00);
1751 _sum01 = vpadalq_s16(_sum01, _wv01);
1752 _sum02 = vpadalq_s16(_sum02, _wv02);
1753 _sum03 = vpadalq_s16(_sum03, _wv03);
1754 _sum10 = vpadalq_s16(_sum10, _wv10);
1755 _sum11 = vpadalq_s16(_sum11, _wv11);
1756 _sum12 = vpadalq_s16(_sum12, _wv12);
1757 _sum13 = vpadalq_s16(_sum13, _wv13);
1758
1759 tmpptr += 32;
1760 kptr0 += 64;
1761 }
1762 for (; j < nn; j++)
1763 {
1764 int8x16_t _val = vld1q_s8(tmpptr);
1765
1766 int8x16_t _w01 = vld1q_s8(kptr0);
1767 int8x16_t _w23 = vld1q_s8(kptr0 + 16);
1768
1769 int16x8_t _wv00 = vmull_s8(vget_low_s8(_val), vget_low_s8(_w01));
1770 int16x8_t _wv01 = vmull_s8(vget_low_s8(_val), vget_high_s8(_w01));
1771 int16x8_t _wv02 = vmull_s8(vget_low_s8(_val), vget_low_s8(_w23));
1772 int16x8_t _wv03 = vmull_s8(vget_low_s8(_val), vget_high_s8(_w23));
1773 int16x8_t _wv10 = vmull_s8(vget_high_s8(_val), vget_low_s8(_w01));
1774 int16x8_t _wv11 = vmull_s8(vget_high_s8(_val), vget_high_s8(_w01));
1775 int16x8_t _wv12 = vmull_s8(vget_high_s8(_val), vget_low_s8(_w23));
1776 int16x8_t _wv13 = vmull_s8(vget_high_s8(_val), vget_high_s8(_w23));
1777
1778 _sum00 = vpadalq_s16(_sum00, _wv00);
1779 _sum01 = vpadalq_s16(_sum01, _wv01);
1780 _sum02 = vpadalq_s16(_sum02, _wv02);
1781 _sum03 = vpadalq_s16(_sum03, _wv03);
1782 _sum10 = vpadalq_s16(_sum10, _wv10);
1783 _sum11 = vpadalq_s16(_sum11, _wv11);
1784 _sum12 = vpadalq_s16(_sum12, _wv12);
1785 _sum13 = vpadalq_s16(_sum13, _wv13);
1786
1787 tmpptr += 16;
1788 kptr0 += 32;
1789 }
1790
1791 int32x4_t _s001 = vpaddq_s32(_sum00, _sum01);
1792 int32x4_t _s023 = vpaddq_s32(_sum02, _sum03);
1793 int32x4_t _s101 = vpaddq_s32(_sum10, _sum11);
1794 int32x4_t _s123 = vpaddq_s32(_sum12, _sum13);
1795
1796 _sum00 = vpaddq_s32(_s001, _s023);
1797 _sum10 = vpaddq_s32(_s101, _s123);
1798 }
1799
1800 if (nn4 > 0)
1801 {
1802 int32x4_t _sum100 = vdupq_n_s32(0);
1803 int32x4_t _sum101 = vdupq_n_s32(0);
1804 int32x4_t _sum110 = vdupq_n_s32(0);
1805 int32x4_t _sum111 = vdupq_n_s32(0);
1806
1807 int j = 0;
1808 for (; j + 1 < nn4; j += 2)
1809 {
1810 int8x16_t _val0123 = vld1q_s8(tmpptr);
1811
1812 int32x4x2_t _val00221133 = vzipq_s32(vreinterpretq_s32_s8(_val0123), vreinterpretq_s32_s8(_val0123));
1813 int8x8_t _val00 = vreinterpret_s8_s32(vget_low_s32(_val00221133.val[0]));
1814 int8x8_t _val11 = vreinterpret_s8_s32(vget_high_s32(_val00221133.val[0]));
1815 int8x8_t _val22 = vreinterpret_s8_s32(vget_low_s32(_val00221133.val[1]));
1816 int8x8_t _val33 = vreinterpret_s8_s32(vget_high_s32(_val00221133.val[1]));
1817
1818 int8x16_t _w01 = vld1q_s8(kptr0);
1819 int8x16_t _w23 = vld1q_s8(kptr0 + 16);
1820
1821 int16x8_t _wv00 = vmull_s8(_val00, vget_low_s8(_w01));
1822 int16x8_t _wv01 = vmull_s8(_val00, vget_high_s8(_w01));
1823 int16x8_t _wv10 = vmull_s8(_val11, vget_low_s8(_w01));
1824 int16x8_t _wv11 = vmull_s8(_val11, vget_high_s8(_w01));
1825
1826 _wv00 = vmlal_s8(_wv00, _val22, vget_low_s8(_w23));
1827 _wv01 = vmlal_s8(_wv01, _val22, vget_high_s8(_w23));
1828 _wv10 = vmlal_s8(_wv10, _val33, vget_low_s8(_w23));
1829 _wv11 = vmlal_s8(_wv11, _val33, vget_high_s8(_w23));
1830
1831 _sum100 = vpadalq_s16(_sum100, _wv00);
1832 _sum101 = vpadalq_s16(_sum101, _wv01);
1833 _sum110 = vpadalq_s16(_sum110, _wv10);
1834 _sum111 = vpadalq_s16(_sum111, _wv11);
1835
1836 tmpptr += 16;
1837 kptr0 += 32;
1838 }
1839 for (; j < nn4; j++)
1840 {
1841 int8x8_t _val01 = vld1_s8(tmpptr);
1842 int32x2x2_t _val0011 = vzip_s32(vreinterpret_s32_s8(_val01), vreinterpret_s32_s8(_val01));
1843 int8x8_t _val00 = vreinterpret_s8_s32(_val0011.val[0]);
1844 int8x8_t _val11 = vreinterpret_s8_s32(_val0011.val[1]);
1845
1846 int8x16_t _w01 = vld1q_s8(kptr0);
1847
1848 int16x8_t _wv00 = vmull_s8(_val00, vget_low_s8(_w01));
1849 int16x8_t _wv01 = vmull_s8(_val00, vget_high_s8(_w01));
1850 int16x8_t _wv10 = vmull_s8(_val11, vget_low_s8(_w01));
1851 int16x8_t _wv11 = vmull_s8(_val11, vget_high_s8(_w01));
1852
1853 _sum100 = vpadalq_s16(_sum100, _wv00);
1854 _sum101 = vpadalq_s16(_sum101, _wv01);
1855 _sum110 = vpadalq_s16(_sum110, _wv10);
1856 _sum111 = vpadalq_s16(_sum111, _wv11);
1857
1858 tmpptr += 8;
1859 kptr0 += 16;
1860 }
1861
1862 int32x4_t _s001 = vpaddq_s32(_sum100, _sum101);
1863 int32x4_t _s101 = vpaddq_s32(_sum110, _sum111);
1864
1865 _sum00 = vaddq_s32(_sum00, _s001);
1866 _sum10 = vaddq_s32(_sum10, _s101);
1867 }
1868 #endif // __ARM_FEATURE_DOTPROD
1869
1870 int j = 0;
1871 for (; j + 3 < nn1; j += 4)
1872 {
1873 int16x8_t _val01234567 = vmovl_s8(vld1_s8(tmpptr));
1874
1875 int8x16_t _w = vld1q_s8(kptr0);
1876 int16x8_t _w01234567 = vmovl_s8(vget_low_s8(_w));
1877 int16x8_t _w89abcdef = vmovl_s8(vget_high_s8(_w));
1878 int16x4_t _w0123 = vget_low_s16(_w01234567);
1879 int16x4_t _w4567 = vget_high_s16(_w01234567);
1880 int16x4_t _w89ab = vget_low_s16(_w89abcdef);
1881 int16x4_t _wcdef = vget_high_s16(_w89abcdef);
1882
1883 _sum00 = vmlal_laneq_s16(_sum00, _w0123, _val01234567, 0);
1884 _sum10 = vmlal_laneq_s16(_sum10, _w0123, _val01234567, 1);
1885 _sum00 = vmlal_laneq_s16(_sum00, _w4567, _val01234567, 2);
1886 _sum10 = vmlal_laneq_s16(_sum10, _w4567, _val01234567, 3);
1887 _sum00 = vmlal_laneq_s16(_sum00, _w89ab, _val01234567, 4);
1888 _sum10 = vmlal_laneq_s16(_sum10, _w89ab, _val01234567, 5);
1889 _sum00 = vmlal_laneq_s16(_sum00, _wcdef, _val01234567, 6);
1890 _sum10 = vmlal_laneq_s16(_sum10, _wcdef, _val01234567, 7);
1891
1892 tmpptr += 8;
1893 kptr0 += 16;
1894 }
1895 for (; j < nn1; j++)
1896 {
1897 int16x4_t _val0 = vdup_n_s16(tmpptr[0]);
1898 int16x4_t _val1 = vdup_n_s16(tmpptr[1]);
1899
1900 int16x4_t _w0123;
1901 _w0123 = vset_lane_s16(kptr0[0], _w0123, 0);
1902 _w0123 = vset_lane_s16(kptr0[1], _w0123, 1);
1903 _w0123 = vset_lane_s16(kptr0[2], _w0123, 2);
1904 _w0123 = vset_lane_s16(kptr0[3], _w0123, 3);
1905
1906 _sum00 = vmlal_s16(_sum00, _val0, _w0123);
1907 _sum10 = vmlal_s16(_sum10, _val1, _w0123);
1908
1909 tmpptr += 2;
1910 kptr0 += 4;
1911 }
1912
1913 vst1q_s32(outptr0, _sum00);
1914 vst1q_s32(outptr0 + 4, _sum10);
1915 outptr0 += 8;
1916 #else // __aarch64__
1917 asm volatile(
1918 "veor q0, q0 \n"
1919 "veor q1, q1 \n"
1920 "veor q2, q2 \n"
1921 "veor q3, q3 \n"
1922 "veor q4, q4 \n"
1923 "veor q5, q5 \n"
1924 "veor q6, q6 \n"
1925 "veor q7, q7 \n"
1926
1927 "cmp %1, #0 \n"
1928 "beq 3f \n"
1929
1930 "pld [%4, #256] \n"
1931
1932 "lsr r4, %1, #1 \n" // r4 = nn = size >> 1
1933 "cmp r4, #0 \n"
1934 "beq 1f \n"
1935
1936 "add r5, %5, #16 \n"
1937 "pld [%5, #128] \n"
1938 "mov r6, #32 \n"
1939 "pld [%5, #384] \n"
1940
1941 "vld1.s8 {d20-d21}, [%5 :128], r6 \n" // _w01
1942
1943 "vld1.s8 {d16-d19}, [%4 :128]! \n" // _val0 _val1
1944
1945 "vld1.s8 {d22-d23}, [%5 :128], r6 \n" // _w45
1946
1947 "0: \n"
1948
1949 "vmull.s8 q12, d16, d20 \n"
1950 "pld [%4, #256] \n"
1951 "vmull.s8 q13, d16, d21 \n"
1952 "pld [%5, #384] \n"
1953 "vmull.s8 q14, d17, d20 \n"
1954 "vmull.s8 q15, d17, d21 \n"
1955 "vld1.s8 {d20-d21}, [r5 :128], r6 \n" // _w23
1956
1957 "vmlal.s8 q12, d18, d22 \n"
1958 "vmlal.s8 q13, d18, d23 \n"
1959 "subs r4, r4, #1 \n"
1960 "vmlal.s8 q14, d19, d22 \n"
1961 "vmlal.s8 q15, d19, d23 \n"
1962 "vld1.s8 {d22-d23}, [r5 :128], r6 \n" // _w67
1963
1964 "vpadal.s16 q0, q12 \n"
1965 "vmull.s8 q12, d16, d20 \n"
1966 "vpadal.s16 q1, q13 \n"
1967 "vmull.s8 q13, d16, d21 \n"
1968 "vpadal.s16 q4, q14 \n"
1969 "vmull.s8 q14, d17, d20 \n"
1970 "vpadal.s16 q5, q15 \n"
1971 "vmull.s8 q15, d17, d21 \n"
1972 "vld1.s8 {d16-d17}, [%4 :128]! \n" // _val0
1973
1974 "vmlal.s8 q12, d18, d22 \n"
1975 "vld1.s8 {d20-d21}, [%5 :128], r6 \n" // _w01
1976 "vmlal.s8 q13, d18, d23 \n"
1977 "pld [r5, #128] \n"
1978 "vmlal.s8 q14, d19, d22 \n"
1979 "pld [r5, #384] \n"
1980 "vmlal.s8 q15, d19, d23 \n"
1981 "vld1.s8 {d18-d19}, [%4 :128]! \n" // _val1
1982
1983 "vpadal.s16 q2, q12 \n"
1984 "vld1.s8 {d22-d23}, [%5 :128], r6 \n" // _w45
1985 "vpadal.s16 q3, q13 \n"
1986 "pld [%4, #128] \n"
1987 "vpadal.s16 q6, q14 \n"
1988 "pld [%5, #128] \n"
1989 "vpadal.s16 q7, q15 \n"
1990
1991 "bne 0b \n"
1992
1993 "sub %4, %4, #32 \n"
1994 "sub %5, %5, #64 \n"
1995
1996 "1: \n"
1997 "and r4, %1, #1 \n" // r4 = remain = size & 1
1998 "cmp r4, #0 \n" // r4 > 0
1999 "beq 2f \n"
2000
2001 "vld1.s8 {d16-d17}, [%4 :128]! \n" // _val
2002 "vld1.s8 {d20-d21}, [%5 :128]! \n" // _w01
2003
2004 "vmull.s8 q12, d16, d20 \n"
2005
2006 "vld1.s8 {d22-d23}, [%5 :128]! \n" // _w23
2007 "vmull.s8 q13, d16, d21 \n"
2008 "vmull.s8 q14, d17, d20 \n"
2009 "vmull.s8 q15, d17, d21 \n"
2010
2011 "vpadal.s16 q0, q12 \n"
2012 "vmull.s8 q12, d16, d22 \n"
2013 "vpadal.s16 q1, q13 \n"
2014 "vmull.s8 q13, d16, d23 \n"
2015 "vpadal.s16 q4, q14 \n"
2016 "vmull.s8 q14, d17, d22 \n"
2017 "vpadal.s16 q5, q15 \n"
2018 "vmull.s8 q15, d17, d23 \n"
2019
2020 "vpadal.s16 q2, q12 \n"
2021 "vpadal.s16 q3, q13 \n"
2022 "vpadal.s16 q6, q14 \n"
2023 "vpadal.s16 q7, q15 \n"
2024
2025 "2: \n"
2026
2027 "vpadd.s32 d16, d0, d1 \n"
2028 "vpadd.s32 d17, d2, d3 \n"
2029 "vpadd.s32 d18, d4, d5 \n"
2030 "vpadd.s32 d19, d6, d7 \n"
2031 "vpadd.s32 d20, d8, d9 \n"
2032 "vpadd.s32 d21, d10, d11 \n"
2033 "vpadd.s32 d22, d12, d13 \n"
2034 "vpadd.s32 d23, d14, d15 \n"
2035
2036 "vpadd.s32 d0, d16, d17 \n"
2037 "vpadd.s32 d1, d18, d19 \n"
2038 "vpadd.s32 d2, d20, d21 \n"
2039 "vpadd.s32 d3, d22, d23 \n"
2040
2041 "3: \n"
2042
2043 "cmp %2, #0 \n"
2044 "beq 7f \n"
2045
2046 "veor q2, q2 \n"
2047 "veor q3, q3 \n"
2048 "veor q4, q4 \n"
2049 "veor q5, q5 \n"
2050
2051 "lsr r4, %2, #1 \n" // r4 = nn4 >> 1
2052 "cmp r4, #0 \n"
2053 "beq 5f \n"
2054
2055 "4: \n"
2056
2057 "vld1.s8 {d16-d17}, [%4]! \n" // _val0123
2058 "vld1.s8 {d20-d23}, [%5]! \n" // _w01 _w23
2059
2060 "vmov.s8 q9, q8 \n"
2061 "vtrn.s32 q8, q9 \n" // _val00 _val22 _val11 _val33
2062
2063 "vmull.s8 q12, d16, d20 \n"
2064 "vmull.s8 q13, d16, d21 \n"
2065 "vmull.s8 q14, d18, d20 \n"
2066 "vmull.s8 q15, d18, d21 \n"
2067
2068 "vmlal.s8 q12, d17, d22 \n"
2069 "vmlal.s8 q13, d17, d23 \n"
2070 "vmlal.s8 q14, d19, d22 \n"
2071 "vmlal.s8 q15, d19, d23 \n"
2072
2073 "vpadal.s16 q2, q12 \n"
2074 "vpadal.s16 q3, q13 \n"
2075 "vpadal.s16 q4, q14 \n"
2076 "vpadal.s16 q5, q15 \n"
2077
2078 "subs r4, r4, #1 \n"
2079 "bne 4b \n"
2080
2081 "5: \n"
2082
2083 "and r4, %2, #1 \n" // r4 = nn4 & 1
2084 "cmp r4, #0 \n" // r4 > 0
2085 "beq 6f \n"
2086
2087 "vld1.s8 {d16}, [%4]! \n" // _val01
2088 "vld1.s8 {d18-d19}, [%5]! \n" // _w01
2089
2090 "vmov.s8 d17, d16 \n"
2091 "vtrn.s32 d16, d17 \n" // _val00 _val11
2092
2093 "vmull.s8 q12, d16, d18 \n"
2094 "vmull.s8 q13, d16, d19 \n"
2095 "vmull.s8 q14, d17, d18 \n"
2096 "vmull.s8 q15, d17, d19 \n"
2097
2098 "vpadal.s16 q2, q12 \n"
2099 "vpadal.s16 q3, q13 \n"
2100 "vpadal.s16 q4, q14 \n"
2101 "vpadal.s16 q5, q15 \n"
2102
2103 "6: \n"
2104
2105 "vpadd.s32 d16, d4, d5 \n"
2106 "vpadd.s32 d17, d6, d7 \n"
2107 "vpadd.s32 d18, d8, d9 \n"
2108 "vpadd.s32 d19, d10, d11 \n"
2109
2110 "vadd.s32 q0, q0, q8 \n"
2111 "vadd.s32 q1, q1, q9 \n"
2112
2113 "7: \n"
2114
2115 "lsr r4, %3, #2 \n" // r4 = nn1 >> 2
2116 "cmp r4, #0 \n"
2117 "beq 9f \n"
2118
2119 "8: \n"
2120
2121 "vld1.s8 {d4}, [%4]! \n"
2122 "vmovl.s8 q2, d4 \n"
2123
2124 "vld1.s8 {d10-d11}, [%5]! \n"
2125 "vmovl.s8 q3, d10 \n"
2126 "vmovl.s8 q4, d11 \n"
2127
2128 "vmlal.s16 q0, d6, d4[0] \n"
2129 "vmlal.s16 q1, d6, d4[1] \n"
2130 "vmlal.s16 q0, d7, d4[2] \n"
2131 "vmlal.s16 q1, d7, d4[3] \n"
2132 "vmlal.s16 q0, d8, d5[0] \n"
2133 "vmlal.s16 q1, d8, d5[1] \n"
2134 "vmlal.s16 q0, d9, d5[2] \n"
2135 "vmlal.s16 q1, d9, d5[3] \n"
2136
2137 "subs r4, r4, #1 \n"
2138 "bne 8b \n"
2139
2140 "9: \n"
2141
2142 "and r4, %3, #3 \n" // r4 = nn1 & 3
2143 "cmp r4, #0 \n" // w4 > 0
2144 "beq 11f \n"
2145
2146 "10: \n"
2147
2148 "vld1.s8 {d4[]}, [%4]! \n"
2149 "vld1.s8 {d6[]}, [%4]! \n"
2150 "vmovl.s8 q2, d4 \n"
2151 "vmovl.s8 q3, d6 \n"
2152
2153 "vld1.s8 {d8}, [%5] \n"
2154 "vmovl.s8 q4, d8 \n"
2155
2156 "vmlal.s16 q0, d4, d8 \n"
2157 "vmlal.s16 q1, d6, d8 \n"
2158
2159 "add %5, %5, #4 \n"
2160
2161 "subs r4, r4, #1 \n"
2162 "bne 10b \n"
2163
2164 "11: \n"
2165
2166 "vst1.s32 {d0-d3}, [%0 :128]! \n"
2167
2168 : "=r"(outptr0),
2169 "=r"(nn),
2170 "=r"(nn4),
2171 "=r"(nn1),
2172 "=r"(tmpptr),
2173 "=r"(kptr0)
2174 : "0"(outptr0),
2175 "1"(nn),
2176 "2"(nn4),
2177 "3"(nn1),
2178 "4"(tmpptr),
2179 "5"(kptr0)
2180 : "memory", "r4", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
2181 #endif // __aarch64__
2182 }
2183 for (; i < size; i++)
2184 {
2185 #if __aarch64__
2186 #if __ARM_FEATURE_DOTPROD
2187 const signed char* tmpptr = tmp.channel(i / 16 + (i % 16) / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2);
2188 #else
2189 const signed char* tmpptr = tmp.channel(i / 4 + (i % 4) / 2 + i % 2);
2190 #endif
2191 #else
2192 const signed char* tmpptr = tmp.channel(i / 2 + i % 2);
2193 #endif
2194 const signed char* kptr0 = kernel.channel(p);
2195
2196 int nn = (inch / 8) * maxk;
2197 int nn4 = ((inch % 8) / 4) * maxk;
2198 int nn1 = (inch % 4) * maxk;
2199
2200 int32x4_t _sum0 = vdupq_n_s32(0);
2201 #if __ARM_FEATURE_DOTPROD
2202 for (int j = 0; j < nn; j++)
2203 {
2204 int8x8_t _val0_l_h = vld1_s8(tmpptr);
2205
2206 int8x16_t _w0123_l = vld1q_s8(kptr0);
2207
2208 _sum0 = vdotq_lane_s32(_sum0, _w0123_l, _val0_l_h, 0);
2209
2210 int8x16_t _w0123_h = vld1q_s8(kptr0 + 16);
2211
2212 _sum0 = vdotq_lane_s32(_sum0, _w0123_h, _val0_l_h, 1);
2213
2214 tmpptr += 8;
2215 kptr0 += 32;
2216 }
2217
2218 if (nn4 > 0)
2219 {
2220 int j = 0;
2221 for (; j + 1 < nn4; j += 2)
2222 {
2223 int8x8_t _val01 = vld1_s8(tmpptr);
2224
2225 int8x16_t _w0 = vld1q_s8(kptr0);
2226
2227 _sum0 = vdotq_lane_s32(_sum0, _w0, _val01, 0);
2228
2229 int8x16_t _w1 = vld1q_s8(kptr0 + 16);
2230
2231 _sum0 = vdotq_lane_s32(_sum0, _w1, _val01, 1);
2232
2233 tmpptr += 8;
2234 kptr0 += 32;
2235 }
2236 for (; j < nn4; j++)
2237 {
2238 int8x8_t _val_xxx = vld1_s8(tmpptr);
2239
2240 int8x16_t _w0 = vld1q_s8(kptr0);
2241
2242 _sum0 = vdotq_lane_s32(_sum0, _w0, _val_xxx, 0);
2243
2244 tmpptr += 4;
2245 kptr0 += 16;
2246 }
2247 }
2248 #else // __ARM_FEATURE_DOTPROD
2249 if (nn > 0)
2250 {
2251 int32x4_t _sum1 = vdupq_n_s32(0);
2252 int32x4_t _sum2 = vdupq_n_s32(0);
2253 int32x4_t _sum3 = vdupq_n_s32(0);
2254
2255 int j = 0;
2256 for (; j + 1 < nn; j += 2)
2257 {
2258 int8x16_t _val = vld1q_s8(tmpptr);
2259
2260 int8x16_t _w01 = vld1q_s8(kptr0);
2261 int8x16_t _w23 = vld1q_s8(kptr0 + 16);
2262
2263 int16x8_t _wv0 = vmull_s8(vget_low_s8(_val), vget_low_s8(_w01));
2264 int16x8_t _wv1 = vmull_s8(vget_low_s8(_val), vget_high_s8(_w01));
2265 int16x8_t _wv2 = vmull_s8(vget_low_s8(_val), vget_low_s8(_w23));
2266 int16x8_t _wv3 = vmull_s8(vget_low_s8(_val), vget_high_s8(_w23));
2267
2268 int8x16_t _w45 = vld1q_s8(kptr0 + 32);
2269 int8x16_t _w67 = vld1q_s8(kptr0 + 48);
2270
2271 _wv0 = vmlal_s8(_wv0, vget_high_s8(_val), vget_low_s8(_w45));
2272 _wv1 = vmlal_s8(_wv1, vget_high_s8(_val), vget_high_s8(_w45));
2273 _wv2 = vmlal_s8(_wv2, vget_high_s8(_val), vget_low_s8(_w67));
2274 _wv3 = vmlal_s8(_wv3, vget_high_s8(_val), vget_high_s8(_w67));
2275
2276 _sum0 = vpadalq_s16(_sum0, _wv0);
2277 _sum1 = vpadalq_s16(_sum1, _wv1);
2278 _sum2 = vpadalq_s16(_sum2, _wv2);
2279 _sum3 = vpadalq_s16(_sum3, _wv3);
2280
2281 tmpptr += 16;
2282 kptr0 += 64;
2283 }
2284 for (; j < nn; j++)
2285 {
2286 int8x8_t _val = vld1_s8(tmpptr);
2287
2288 int8x16_t _w01 = vld1q_s8(kptr0);
2289 int8x16_t _w23 = vld1q_s8(kptr0 + 16);
2290
2291 int16x8_t _wv0 = vmull_s8(_val, vget_low_s8(_w01));
2292 int16x8_t _wv1 = vmull_s8(_val, vget_high_s8(_w01));
2293 int16x8_t _wv2 = vmull_s8(_val, vget_low_s8(_w23));
2294 int16x8_t _wv3 = vmull_s8(_val, vget_high_s8(_w23));
2295
2296 _sum0 = vpadalq_s16(_sum0, _wv0);
2297 _sum1 = vpadalq_s16(_sum1, _wv1);
2298 _sum2 = vpadalq_s16(_sum2, _wv2);
2299 _sum3 = vpadalq_s16(_sum3, _wv3);
2300
2301 tmpptr += 8;
2302 kptr0 += 32;
2303 }
2304
2305 #if __aarch64__
2306 int32x4_t _s01 = vpaddq_s32(_sum0, _sum1);
2307 int32x4_t _s23 = vpaddq_s32(_sum2, _sum3);
2308
2309 _sum0 = vpaddq_s32(_s01, _s23);
2310 #else
2311 int32x2_t _s01_low = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0));
2312 int32x2_t _s01_high = vpadd_s32(vget_low_s32(_sum1), vget_high_s32(_sum1));
2313 int32x2_t _s23_low = vpadd_s32(vget_low_s32(_sum2), vget_high_s32(_sum2));
2314 int32x2_t _s23_high = vpadd_s32(vget_low_s32(_sum3), vget_high_s32(_sum3));
2315
2316 _sum0 = vcombine_s32(vpadd_s32(_s01_low, _s01_high), vpadd_s32(_s23_low, _s23_high));
2317 #endif
2318 }
2319
2320 if (nn4 > 0)
2321 {
2322 int32x4_t _sum10 = vdupq_n_s32(0);
2323 int32x4_t _sum11 = vdupq_n_s32(0);
2324
2325 int j = 0;
2326 for (; j + 1 < nn4; j += 2)
2327 {
2328 int8x8_t _val01 = vld1_s8(tmpptr);
2329 int32x2x2_t _val0011 = vzip_s32(vreinterpret_s32_s8(_val01), vreinterpret_s32_s8(_val01));
2330 int8x8_t _val00 = vreinterpret_s8_s32(_val0011.val[0]);
2331 int8x8_t _val11 = vreinterpret_s8_s32(_val0011.val[1]);
2332
2333 int8x16_t _w0 = vld1q_s8(kptr0);
2334 int8x16_t _w1 = vld1q_s8(kptr0 + 16);
2335
2336 int16x8_t _wv0 = vmull_s8(_val00, vget_low_s8(_w0));
2337 int16x8_t _wv1 = vmull_s8(_val00, vget_high_s8(_w0));
2338
2339 _wv0 = vmlal_s8(_wv0, _val11, vget_low_s8(_w1));
2340 _wv1 = vmlal_s8(_wv1, _val11, vget_high_s8(_w1));
2341
2342 _sum10 = vpadalq_s16(_sum10, _wv0);
2343 _sum11 = vpadalq_s16(_sum11, _wv1);
2344
2345 tmpptr += 8;
2346 kptr0 += 32;
2347 }
2348 for (; j < nn4; j++)
2349 {
2350 int8x8_t _val_xxx = vld1_s8(tmpptr);
2351 int8x8_t _val_val = vreinterpret_s8_s32(vzip_s32(vreinterpret_s32_s8(_val_xxx), vreinterpret_s32_s8(_val_xxx)).val[0]);
2352
2353 int8x16_t _w0 = vld1q_s8(kptr0);
2354
2355 int16x8_t _wv0 = vmull_s8(_val_val, vget_low_s8(_w0));
2356 int16x8_t _wv1 = vmull_s8(_val_val, vget_high_s8(_w0));
2357
2358 _sum10 = vpadalq_s16(_sum10, _wv0);
2359 _sum11 = vpadalq_s16(_sum11, _wv1);
2360
2361 tmpptr += 4;
2362 kptr0 += 16;
2363 }
2364
2365 #if __aarch64__
2366 int32x4_t _s01 = vpaddq_s32(_sum10, _sum11);
2367 #else
2368 int32x2_t _s01_low = vpadd_s32(vget_low_s32(_sum10), vget_high_s32(_sum10));
2369 int32x2_t _s01_high = vpadd_s32(vget_low_s32(_sum11), vget_high_s32(_sum11));
2370
2371 int32x4_t _s01 = vcombine_s32(_s01_low, _s01_high);
2372 #endif
2373
2374 _sum0 = vaddq_s32(_sum0, _s01);
2375 }
2376 #endif // __ARM_FEATURE_DOTPROD
2377
2378 int32x4_t _sum1 = vdupq_n_s32(0);
2379
2380 int j = 0;
2381 for (; j + 3 < nn1; j += 4)
2382 {
2383 int16x4_t _val0123 = vget_low_s16(vmovl_s8(vld1_s8(tmpptr)));
2384
2385 int8x16_t _w = vld1q_s8(kptr0);
2386 int16x8_t _w01234567 = vmovl_s8(vget_low_s8(_w));
2387 int16x8_t _w89abcdef = vmovl_s8(vget_high_s8(_w));
2388 int16x4_t _w0123 = vget_low_s16(_w01234567);
2389 int16x4_t _w4567 = vget_high_s16(_w01234567);
2390 int16x4_t _w89ab = vget_low_s16(_w89abcdef);
2391 int16x4_t _wcdef = vget_high_s16(_w89abcdef);
2392
2393 _sum0 = vmlal_lane_s16(_sum0, _w0123, _val0123, 0);
2394 _sum1 = vmlal_lane_s16(_sum1, _w4567, _val0123, 1);
2395 _sum0 = vmlal_lane_s16(_sum0, _w89ab, _val0123, 2);
2396 _sum1 = vmlal_lane_s16(_sum1, _wcdef, _val0123, 3);
2397
2398 tmpptr += 4;
2399 kptr0 += 16;
2400 }
2401 for (; j < nn1; j++)
2402 {
2403 int16x4_t _val = vdup_n_s16(tmpptr[0]);
2404
2405 int16x4_t _w0123;
2406 _w0123 = vset_lane_s16(kptr0[0], _w0123, 0);
2407 _w0123 = vset_lane_s16(kptr0[1], _w0123, 1);
2408 _w0123 = vset_lane_s16(kptr0[2], _w0123, 2);
2409 _w0123 = vset_lane_s16(kptr0[3], _w0123, 3);
2410
2411 _sum0 = vmlal_s16(_sum0, _val, _w0123);
2412
2413 tmpptr += 1;
2414 kptr0 += 4;
2415 }
2416
2417 _sum0 = vaddq_s32(_sum0, _sum1);
2418
2419 vst1q_s32(outptr0, _sum0);
2420 outptr0 += 4;
2421 }
2422 }
2423 }
2424
convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_w,int kernel_h)2425 static void convolution_im2col_sgemm_transform_kernel_pack1to4_int8_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h)
2426 {
2427 const int maxk = kernel_w * kernel_h;
2428
2429 // interleave
2430 // src = maxk-inch-outch
2431 // dst = 8a-4b-maxk-inch/8a-outch/4b
2432 // dst = 4a-4b-2-maxk-inch/8a-outch/4b (arm82)
2433 Mat kernel = _kernel.reshape(maxk, inch, outch);
2434 if (inch >= 8)
2435 kernel_tm.create(32 * maxk, inch / 8 + (inch % 8) / 4 + inch % 4, outch / 4, 1u);
2436 if (inch >= 4)
2437 kernel_tm.create(16 * maxk, inch / 4 + inch % 4, outch / 4, 1u);
2438 else
2439 kernel_tm.create(4 * maxk, inch, outch / 4, 1u);
2440
2441 for (int q = 0; q + 3 < outch; q += 4)
2442 {
2443 signed char* g00 = kernel_tm.channel(q / 4);
2444
2445 int p = 0;
2446 for (; p + 7 < inch; p += 8)
2447 {
2448 for (int k = 0; k < maxk; k++)
2449 {
2450 #if __ARM_FEATURE_DOTPROD
2451 for (int i = 0; i < 4; i++)
2452 {
2453 for (int j = 0; j < 4; j++)
2454 {
2455 const signed char* k00 = kernel.channel(q + i).row<const signed char>(p + j);
2456
2457 g00[0] = k00[k];
2458
2459 g00++;
2460 }
2461 }
2462 for (int i = 0; i < 4; i++)
2463 {
2464 for (int j = 4; j < 8; j++)
2465 {
2466 const signed char* k00 = kernel.channel(q + i).row<const signed char>(p + j);
2467
2468 g00[0] = k00[k];
2469
2470 g00++;
2471 }
2472 }
2473 #else
2474 for (int i = 0; i < 4; i++)
2475 {
2476 for (int j = 0; j < 8; j++)
2477 {
2478 const signed char* k00 = kernel.channel(q + i).row<const signed char>(p + j);
2479
2480 g00[0] = k00[k];
2481
2482 g00++;
2483 }
2484 }
2485 #endif
2486 }
2487 }
2488 for (; p + 3 < inch; p += 4)
2489 {
2490 for (int k = 0; k < maxk; k++)
2491 {
2492 for (int i = 0; i < 4; i++)
2493 {
2494 for (int j = 0; j < 4; j++)
2495 {
2496 const signed char* k00 = kernel.channel(q + i).row<const signed char>(p + j);
2497
2498 g00[0] = k00[k];
2499
2500 g00++;
2501 }
2502 }
2503 }
2504 }
2505 for (; p < inch; p++)
2506 {
2507 for (int k = 0; k < maxk; k++)
2508 {
2509 for (int i = 0; i < 4; i++)
2510 {
2511 const signed char* k00 = kernel.channel(q + i).row<const signed char>(p);
2512
2513 g00[0] = k00[k];
2514
2515 g00++;
2516 }
2517 }
2518 }
2519 }
2520 }
2521
convolution_im2col_sgemm_pack1to4_int8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)2522 static void convolution_im2col_sgemm_pack1to4_int8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
2523 {
2524 int w = bottom_blob.w;
2525 int inch = bottom_blob.c;
2526
2527 int outw = top_blob.w;
2528 int outh = top_blob.h;
2529 const int size = outw * outh;
2530
2531 const int maxk = kernel_w * kernel_h;
2532
2533 // im2col
2534 Mat bottom_im2col(size, maxk, inch, 1u, 1, opt.workspace_allocator);
2535 {
2536 const int gap = w * stride_h - outw * stride_w;
2537
2538 #pragma omp parallel for num_threads(opt.num_threads)
2539 for (int p = 0; p < inch; p++)
2540 {
2541 const Mat img = bottom_blob.channel(p);
2542 signed char* ptr = bottom_im2col.channel(p);
2543
2544 for (int u = 0; u < kernel_h; u++)
2545 {
2546 for (int v = 0; v < kernel_w; v++)
2547 {
2548 const signed char* sptr = img.row<const signed char>(dilation_h * u) + dilation_w * v;
2549
2550 for (int i = 0; i < outh; i++)
2551 {
2552 int j = 0;
2553 for (; j + 3 < outw; j += 4)
2554 {
2555 ptr[0] = sptr[0];
2556 ptr[1] = sptr[stride_w];
2557 ptr[2] = sptr[stride_w * 2];
2558 ptr[3] = sptr[stride_w * 3];
2559
2560 sptr += stride_w * 4;
2561 ptr += 4;
2562 }
2563 for (; j + 1 < outw; j += 2)
2564 {
2565 ptr[0] = sptr[0];
2566 ptr[1] = sptr[stride_w];
2567
2568 sptr += stride_w * 2;
2569 ptr += 2;
2570 }
2571 for (; j < outw; j++)
2572 {
2573 ptr[0] = sptr[0];
2574
2575 sptr += stride_w;
2576 ptr += 1;
2577 }
2578
2579 sptr += gap;
2580 }
2581 }
2582 }
2583 }
2584 }
2585
2586 im2col_sgemm_pack1to4_int8_neon(bottom_im2col, top_blob, kernel, opt);
2587 }
2588