1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv1x1s1_sgemm_transform_kernel_pack4to1_neon(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch)15 static void conv1x1s1_sgemm_transform_kernel_pack4to1_neon(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch)
16 {
17 // interleave
18 // src = inch-outch
19 // dst = 4a-inch/4a-outch
20 #if __aarch64__
21 kernel_tm_pack4.create(8, inch / 4, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)4u * 4, 4);
22 #else
23 kernel_tm_pack4.create(4, inch / 4, outch / 4 + outch % 4, (size_t)4u * 4, 4);
24 #endif
25
26 int p = 0;
27 #if __aarch64__
28 for (; p + 7 < outch; p += 8)
29 {
30 const float* k0 = (const float*)kernel + (p + 0) * inch;
31 const float* k1 = (const float*)kernel + (p + 1) * inch;
32 const float* k2 = (const float*)kernel + (p + 2) * inch;
33 const float* k3 = (const float*)kernel + (p + 3) * inch;
34 const float* k4 = (const float*)kernel + (p + 4) * inch;
35 const float* k5 = (const float*)kernel + (p + 5) * inch;
36 const float* k6 = (const float*)kernel + (p + 6) * inch;
37 const float* k7 = (const float*)kernel + (p + 7) * inch;
38
39 float* ktmp = kernel_tm_pack4.channel(p / 8);
40
41 for (int q = 0; q + 3 < inch; q += 4)
42 {
43 ktmp[0] = k0[0];
44 ktmp[1] = k1[0];
45 ktmp[2] = k2[0];
46 ktmp[3] = k3[0];
47 ktmp[4] = k4[0];
48 ktmp[5] = k5[0];
49 ktmp[6] = k6[0];
50 ktmp[7] = k7[0];
51
52 ktmp[8] = k0[1];
53 ktmp[9] = k1[1];
54 ktmp[10] = k2[1];
55 ktmp[11] = k3[1];
56 ktmp[12] = k4[1];
57 ktmp[13] = k5[1];
58 ktmp[14] = k6[1];
59 ktmp[15] = k7[1];
60
61 ktmp[16] = k0[2];
62 ktmp[17] = k1[2];
63 ktmp[18] = k2[2];
64 ktmp[19] = k3[2];
65 ktmp[20] = k4[2];
66 ktmp[21] = k5[2];
67 ktmp[22] = k6[2];
68 ktmp[23] = k7[2];
69
70 ktmp[24] = k0[3];
71 ktmp[25] = k1[3];
72 ktmp[26] = k2[3];
73 ktmp[27] = k3[3];
74 ktmp[28] = k4[3];
75 ktmp[29] = k5[3];
76 ktmp[30] = k6[3];
77 ktmp[31] = k7[3];
78
79 k0 += 4;
80 k1 += 4;
81 k2 += 4;
82 k3 += 4;
83 k4 += 4;
84 k5 += 4;
85 k6 += 4;
86 k7 += 4;
87 ktmp += 32;
88 }
89 }
90 #endif
91 for (; p + 3 < outch; p += 4)
92 {
93 const float* k0 = (const float*)kernel + (p + 0) * inch;
94 const float* k1 = (const float*)kernel + (p + 1) * inch;
95 const float* k2 = (const float*)kernel + (p + 2) * inch;
96 const float* k3 = (const float*)kernel + (p + 3) * inch;
97
98 #if __aarch64__
99 float* ktmp = kernel_tm_pack4.channel(p / 8 + (p % 8) / 4);
100 #else
101 float* ktmp = kernel_tm_pack4.channel(p / 4);
102 #endif
103
104 for (int q = 0; q + 3 < inch; q += 4)
105 {
106 ktmp[0] = k0[0];
107 ktmp[1] = k1[0];
108 ktmp[2] = k2[0];
109 ktmp[3] = k3[0];
110
111 ktmp[4] = k0[1];
112 ktmp[5] = k1[1];
113 ktmp[6] = k2[1];
114 ktmp[7] = k3[1];
115
116 ktmp[8] = k0[2];
117 ktmp[9] = k1[2];
118 ktmp[10] = k2[2];
119 ktmp[11] = k3[2];
120
121 ktmp[12] = k0[3];
122 ktmp[13] = k1[3];
123 ktmp[14] = k2[3];
124 ktmp[15] = k3[3];
125
126 k0 += 4;
127 k1 += 4;
128 k2 += 4;
129 k3 += 4;
130 ktmp += 16;
131 }
132 }
133 for (; p < outch; p++)
134 {
135 const float* k0 = (const float*)kernel + p * inch;
136
137 #if __aarch64__
138 float* ktmp = kernel_tm_pack4.channel(p / 8 + (p % 8) / 4 + p % 4);
139 #else
140 float* ktmp = kernel_tm_pack4.channel(p / 4 + p % 4);
141 #endif
142
143 for (int q = 0; q + 3 < inch; q += 4)
144 {
145 ktmp[0] = k0[0];
146 ktmp[1] = k0[1];
147 ktmp[2] = k0[2];
148 ktmp[3] = k0[3];
149
150 k0 += 4;
151 ktmp += 4;
152 }
153 }
154 }
155
conv1x1s1_sgemm_pack4to1_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)156 static void conv1x1s1_sgemm_pack4to1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
157 {
158 int w = bottom_blob.w;
159 int h = bottom_blob.h;
160 int inch = bottom_blob.c;
161 int outch = top_blob.c;
162
163 size_t elemsize = bottom_blob.elemsize;
164 int elempack = bottom_blob.elempack;
165
166 const int size = w * h;
167
168 const float* bias = _bias;
169
170 // interleave
171 #if __aarch64__
172 Mat tmp(12, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + size % 12 % 4, elemsize, elempack, opt.workspace_allocator);
173 #else
174 Mat tmp(8, inch, size / 8 + (size % 8) / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
175 #endif
176 {
177 int nn_size = 0;
178 int remain_size_start = 0;
179
180 #if __aarch64__
181 nn_size = size / 12;
182
183 #pragma omp parallel for num_threads(opt.num_threads)
184 for (int ii = 0; ii < nn_size; ii++)
185 {
186 int i = ii * 12;
187
188 const float* img0 = bottom_blob.channel(0);
189 img0 += i * 4;
190
191 float* tmpptr = tmp.channel(i / 12);
192
193 for (int q = 0; q < inch; q++)
194 {
195 // transpose 4x12
196 asm volatile(
197 "prfm pldl1keep, [%0, #512] \n"
198 "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
199 "prfm pldl1keep, [%0, #512] \n"
200 "ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0], #64 \n"
201 "prfm pldl1keep, [%0, #512] \n"
202 "ld4 {v16.4s, v17.4s, v18.4s, v19.4s}, [%0] \n"
203 "sub %0, %0, #128 \n"
204 "st1 {v0.4s}, [%1], #16 \n"
205 "st1 {v4.4s}, [%1], #16 \n"
206 "st1 {v16.4s}, [%1], #16 \n"
207 "st1 {v1.4s}, [%1], #16 \n"
208 "st1 {v5.4s}, [%1], #16 \n"
209 "st1 {v17.4s}, [%1], #16 \n"
210 "st1 {v2.4s}, [%1], #16 \n"
211 "st1 {v6.4s}, [%1], #16 \n"
212 "st1 {v18.4s}, [%1], #16 \n"
213 "st1 {v3.4s}, [%1], #16 \n"
214 "st1 {v7.4s}, [%1], #16 \n"
215 "st1 {v19.4s}, [%1], #16 \n"
216 : "=r"(img0), // %0
217 "=r"(tmpptr) // %1
218 : "0"(img0),
219 "1"(tmpptr)
220 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19");
221
222 img0 += bottom_blob.cstep * 4;
223 }
224 }
225
226 remain_size_start += nn_size * 12;
227 nn_size = (size - remain_size_start) >> 3;
228 #else // __aarch64__
229 nn_size = size >> 3;
230 #endif // __aarch64__
231
232 #pragma omp parallel for num_threads(opt.num_threads)
233 for (int ii = 0; ii < nn_size; ii++)
234 {
235 int i = remain_size_start + ii * 8;
236
237 const float* img0 = bottom_blob.channel(0);
238 img0 += i * 4;
239
240 #if __aarch64__
241 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
242 #else
243 float* tmpptr = tmp.channel(i / 8);
244 #endif
245
246 for (int q = 0; q < inch; q++)
247 {
248 // transpose 4x8
249 #if __aarch64__
250 asm volatile(
251 "prfm pldl1keep, [%0, #512] \n"
252 "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
253 "prfm pldl1keep, [%0, #512] \n"
254 "ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0] \n"
255 "sub %0, %0, #64 \n"
256 "st1 {v0.4s}, [%1], #16 \n"
257 "st1 {v4.4s}, [%1], #16 \n"
258 "st1 {v1.4s}, [%1], #16 \n"
259 "st1 {v5.4s}, [%1], #16 \n"
260 "st1 {v2.4s}, [%1], #16 \n"
261 "st1 {v6.4s}, [%1], #16 \n"
262 "st1 {v3.4s}, [%1], #16 \n"
263 "st1 {v7.4s}, [%1], #16 \n"
264 : "=r"(img0), // %0
265 "=r"(tmpptr) // %1
266 : "0"(img0),
267 "1"(tmpptr)
268 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
269 #else
270 asm volatile(
271 "pld [%0, #256] \n"
272 "vld4.f32 {d0-d3}, [%0 :128]! \n"
273 "pld [%0, #256] \n"
274 "vld4.f32 {d4-d7}, [%0 :128]! \n"
275 "pld [%0, #256] \n"
276 "vld4.f32 {d16-d19}, [%0 :128]! \n"
277 "pld [%0, #256] \n"
278 "vld4.f32 {d20-d23}, [%0 :128] \n"
279 "sub %0, %0, #96 \n"
280 "vswp d1, d4 \n"
281 "vswp d3, d6 \n"
282 "vswp d17, d20 \n"
283 "vswp d19, d22 \n"
284 "vst1.f32 {d0-d1}, [%1 :128]! \n"
285 "vst1.f32 {d16-d17}, [%1 :128]! \n"
286 "vst1.f32 {d4-d5}, [%1 :128]! \n"
287 "vst1.f32 {d20-d21}, [%1 :128]! \n"
288 "vst1.f32 {d2-d3}, [%1 :128]! \n"
289 "vst1.f32 {d18-d19}, [%1 :128]! \n"
290 "vst1.f32 {d6-d7}, [%1 :128]! \n"
291 "vst1.f32 {d22-d23}, [%1 :128]! \n"
292 : "=r"(img0), // %0
293 "=r"(tmpptr) // %1
294 : "0"(img0),
295 "1"(tmpptr)
296 : "memory", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
297 #endif
298
299 img0 += bottom_blob.cstep * 4;
300 }
301 }
302
303 remain_size_start += nn_size << 3;
304 nn_size = (size - remain_size_start) >> 2;
305
306 #pragma omp parallel for num_threads(opt.num_threads)
307 for (int ii = 0; ii < nn_size; ii++)
308 {
309 int i = remain_size_start + ii * 4;
310
311 const float* img0 = bottom_blob.channel(0);
312 img0 += i * 4;
313
314 #if __aarch64__
315 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
316 #else
317 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
318 #endif
319
320 for (int q = 0; q < inch; q++)
321 {
322 // transpose 4x4
323 #if __aarch64__
324 asm volatile(
325 "prfm pldl1keep, [%0, #512] \n"
326 "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0] \n"
327 "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n"
328 : "=r"(img0), // %0
329 "=r"(tmpptr) // %1
330 : "0"(img0),
331 "1"(tmpptr)
332 : "memory", "v0", "v1", "v2", "v3");
333 #else
334 asm volatile(
335 "pld [%0, #256] \n"
336 "vld4.f32 {d0-d3}, [%0 :128]! \n"
337 "pld [%0, #256] \n"
338 "vld4.f32 {d4-d7}, [%0 :128] \n"
339 "sub %0, %0, #32 \n"
340 "vswp d1, d4 \n"
341 "vswp d3, d6 \n"
342 "vst1.f32 {d0-d1}, [%1 :128]! \n"
343 "vst1.f32 {d4-d5}, [%1 :128]! \n"
344 "vst1.f32 {d2-d3}, [%1 :128]! \n"
345 "vst1.f32 {d6-d7}, [%1 :128]! \n"
346 : "=r"(img0), // %0
347 "=r"(tmpptr) // %1
348 : "0"(img0),
349 "1"(tmpptr)
350 : "memory", "q0", "q1", "q2", "q3");
351 #endif
352
353 img0 += bottom_blob.cstep * 4;
354 }
355 }
356
357 remain_size_start += nn_size << 2;
358
359 #pragma omp parallel for num_threads(opt.num_threads)
360 for (int i = remain_size_start; i < size; i++)
361 {
362 const float* img0 = bottom_blob.channel(0);
363 img0 += i * 4;
364
365 #if __aarch64__
366 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
367 #else
368 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
369 #endif
370
371 for (int q = 0; q < inch; q++)
372 {
373 #if __aarch64__
374 asm volatile(
375 "prfm pldl1keep, [%0, #128] \n"
376 "ld1 {v0.4s}, [%0] \n"
377 "st1 {v0.4s}, [%1], #16 \n"
378 : "=r"(img0), // %0
379 "=r"(tmpptr) // %1
380 : "0"(img0),
381 "1"(tmpptr)
382 : "memory", "v0");
383 #else
384 asm volatile(
385 "pld [%0, #128] \n"
386 "vld1.f32 {d0-d1}, [%0 :128] \n"
387 "vst1.f32 {d0-d1}, [%1 :128]! \n"
388 : "=r"(img0), // %0
389 "=r"(tmpptr) // %1
390 : "0"(img0),
391 "1"(tmpptr)
392 : "memory", "q0");
393 #endif
394
395 img0 += bottom_blob.cstep * 4;
396 }
397 }
398 }
399
400 int nn_outch = 0;
401 int remain_outch_start = 0;
402
403 #if __aarch64__
404 nn_outch = outch >> 3;
405
406 #pragma omp parallel for num_threads(opt.num_threads)
407 for (int pp = 0; pp < nn_outch; pp++)
408 {
409 int p = pp * 8;
410
411 float* outptr0 = top_blob.channel(p);
412 float* outptr1 = top_blob.channel(p + 1);
413 float* outptr2 = top_blob.channel(p + 2);
414 float* outptr3 = top_blob.channel(p + 3);
415 float* outptr4 = top_blob.channel(p + 4);
416 float* outptr5 = top_blob.channel(p + 5);
417 float* outptr6 = top_blob.channel(p + 6);
418 float* outptr7 = top_blob.channel(p + 7);
419
420 const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
421 const float* biasptr = bias ? bias + p : zeros;
422
423 int i = 0;
424 for (; i + 11 < size; i += 12)
425 {
426 float* tmpptr = tmp.channel(i / 12);
427 const float* kptr = (const float*)kernel.channel(p / 8);
428
429 int nn = inch; // inch always > 0
430
431 asm volatile(
432 "ld1 {v30.4s, v31.4s}, [%22] \n"
433 "dup v8.4s, v30.s[0] \n"
434 "dup v9.4s, v30.s[0] \n"
435 "dup v10.4s, v30.s[0] \n"
436 "dup v11.4s, v30.s[1] \n"
437 "dup v12.4s, v30.s[1] \n"
438 "dup v13.4s, v30.s[1] \n"
439 "dup v14.4s, v30.s[2] \n"
440 "dup v15.4s, v30.s[2] \n"
441 "dup v16.4s, v30.s[2] \n"
442 "dup v17.4s, v30.s[3] \n"
443 "dup v18.4s, v30.s[3] \n"
444 "dup v19.4s, v30.s[3] \n"
445 "dup v20.4s, v31.s[0] \n"
446 "dup v21.4s, v31.s[0] \n"
447 "dup v22.4s, v31.s[0] \n"
448 "dup v23.4s, v31.s[1] \n"
449 "dup v24.4s, v31.s[1] \n"
450 "dup v25.4s, v31.s[1] \n"
451 "dup v26.4s, v31.s[2] \n"
452 "dup v27.4s, v31.s[2] \n"
453 "dup v28.4s, v31.s[2] \n"
454 "dup v29.4s, v31.s[3] \n"
455 "dup v30.4s, v31.s[3] \n"
456 "dup v31.4s, v31.s[3] \n"
457
458 "0: \n"
459
460 "prfm pldl1keep, [%9, #512] \n"
461 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n"
462
463 "prfm pldl1keep, [%10, #512] \n"
464 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%10], #64 \n"
465
466 "subs %w0, %w0, #1 \n"
467
468 "fmla v8.4s, v0.4s, v4.s[0] \n"
469 "fmla v11.4s, v0.4s, v4.s[1] \n"
470 "fmla v14.4s, v0.4s, v4.s[2] \n"
471 "fmla v17.4s, v0.4s, v4.s[3] \n"
472 "fmla v20.4s, v0.4s, v5.s[0] \n"
473 "fmla v23.4s, v0.4s, v5.s[1] \n"
474 "fmla v26.4s, v0.4s, v5.s[2] \n"
475 "fmla v29.4s, v0.4s, v5.s[3] \n"
476
477 "fmla v9.4s, v1.4s, v4.s[0] \n"
478 "fmla v12.4s, v1.4s, v4.s[1] \n"
479 "fmla v15.4s, v1.4s, v4.s[2] \n"
480 "fmla v18.4s, v1.4s, v4.s[3] \n"
481 "fmla v21.4s, v1.4s, v5.s[0] \n"
482 "fmla v24.4s, v1.4s, v5.s[1] \n"
483 "fmla v27.4s, v1.4s, v5.s[2] \n"
484 "fmla v30.4s, v1.4s, v5.s[3] \n"
485
486 "fmla v10.4s, v2.4s, v4.s[0] \n"
487 "fmla v13.4s, v2.4s, v4.s[1] \n"
488 "fmla v16.4s, v2.4s, v4.s[2] \n"
489 "fmla v19.4s, v2.4s, v4.s[3] \n"
490 "fmla v22.4s, v2.4s, v5.s[0] \n"
491 "fmla v25.4s, v2.4s, v5.s[1] \n"
492 "fmla v28.4s, v2.4s, v5.s[2] \n"
493 "fmla v31.4s, v2.4s, v5.s[3] \n"
494
495 "fmla v8.4s, v3.4s, v6.s[0] \n"
496 "fmla v11.4s, v3.4s, v6.s[1] \n"
497 "fmla v14.4s, v3.4s, v6.s[2] \n"
498 "fmla v17.4s, v3.4s, v6.s[3] \n"
499 "fmla v20.4s, v3.4s, v7.s[0] \n"
500 "fmla v23.4s, v3.4s, v7.s[1] \n"
501 "fmla v26.4s, v3.4s, v7.s[2] \n"
502 "fmla v29.4s, v3.4s, v7.s[3] \n"
503
504 "prfm pldl1keep, [%9, #512] \n"
505 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n"
506
507 "fmla v9.4s, v0.4s, v6.s[0] \n"
508 "fmla v12.4s, v0.4s, v6.s[1] \n"
509 "fmla v15.4s, v0.4s, v6.s[2] \n"
510 "fmla v18.4s, v0.4s, v6.s[3] \n"
511 "fmla v21.4s, v0.4s, v7.s[0] \n"
512 "fmla v24.4s, v0.4s, v7.s[1] \n"
513 "fmla v27.4s, v0.4s, v7.s[2] \n"
514 "fmla v30.4s, v0.4s, v7.s[3] \n"
515
516 "fmla v10.4s, v1.4s, v6.s[0] \n"
517 "fmla v13.4s, v1.4s, v6.s[1] \n"
518 "fmla v16.4s, v1.4s, v6.s[2] \n"
519 "fmla v19.4s, v1.4s, v6.s[3] \n"
520 "fmla v22.4s, v1.4s, v7.s[0] \n"
521 "fmla v25.4s, v1.4s, v7.s[1] \n"
522 "fmla v28.4s, v1.4s, v7.s[2] \n"
523 "fmla v31.4s, v1.4s, v7.s[3] \n"
524
525 "prfm pldl1keep, [%10, #512] \n"
526 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%10], #64 \n"
527
528 "fmla v8.4s, v2.4s, v4.s[0] \n"
529 "fmla v11.4s, v2.4s, v4.s[1] \n"
530 "fmla v14.4s, v2.4s, v4.s[2] \n"
531 "fmla v17.4s, v2.4s, v4.s[3] \n"
532 "fmla v20.4s, v2.4s, v5.s[0] \n"
533 "fmla v23.4s, v2.4s, v5.s[1] \n"
534 "fmla v26.4s, v2.4s, v5.s[2] \n"
535 "fmla v29.4s, v2.4s, v5.s[3] \n"
536
537 "fmla v9.4s, v3.4s, v4.s[0] \n"
538 "fmla v12.4s, v3.4s, v4.s[1] \n"
539 "fmla v15.4s, v3.4s, v4.s[2] \n"
540 "fmla v18.4s, v3.4s, v4.s[3] \n"
541 "fmla v21.4s, v3.4s, v5.s[0] \n"
542 "fmla v24.4s, v3.4s, v5.s[1] \n"
543 "fmla v27.4s, v3.4s, v5.s[2] \n"
544 "fmla v30.4s, v3.4s, v5.s[3] \n"
545
546 "prfm pldl1keep, [%9, #512] \n"
547 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n"
548
549 "fmla v10.4s, v0.4s, v4.s[0] \n"
550 "fmla v13.4s, v0.4s, v4.s[1] \n"
551 "fmla v16.4s, v0.4s, v4.s[2] \n"
552 "fmla v19.4s, v0.4s, v4.s[3] \n"
553 "fmla v22.4s, v0.4s, v5.s[0] \n"
554 "fmla v25.4s, v0.4s, v5.s[1] \n"
555 "fmla v28.4s, v0.4s, v5.s[2] \n"
556 "fmla v31.4s, v0.4s, v5.s[3] \n"
557
558 "fmla v8.4s, v1.4s, v6.s[0] \n"
559 "fmla v11.4s, v1.4s, v6.s[1] \n"
560 "fmla v14.4s, v1.4s, v6.s[2] \n"
561 "fmla v17.4s, v1.4s, v6.s[3] \n"
562 "fmla v20.4s, v1.4s, v7.s[0] \n"
563 "fmla v23.4s, v1.4s, v7.s[1] \n"
564 "fmla v26.4s, v1.4s, v7.s[2] \n"
565 "fmla v29.4s, v1.4s, v7.s[3] \n"
566
567 "fmla v9.4s, v2.4s, v6.s[0] \n"
568 "fmla v12.4s, v2.4s, v6.s[1] \n"
569 "fmla v15.4s, v2.4s, v6.s[2] \n"
570 "fmla v18.4s, v2.4s, v6.s[3] \n"
571 "fmla v21.4s, v2.4s, v7.s[0] \n"
572 "fmla v24.4s, v2.4s, v7.s[1] \n"
573 "fmla v27.4s, v2.4s, v7.s[2] \n"
574 "fmla v30.4s, v2.4s, v7.s[3] \n"
575
576 "fmla v10.4s, v3.4s, v6.s[0] \n"
577 "fmla v13.4s, v3.4s, v6.s[1] \n"
578 "fmla v16.4s, v3.4s, v6.s[2] \n"
579 "fmla v19.4s, v3.4s, v6.s[3] \n"
580 "fmla v22.4s, v3.4s, v7.s[0] \n"
581 "fmla v25.4s, v3.4s, v7.s[1] \n"
582 "fmla v28.4s, v3.4s, v7.s[2] \n"
583 "fmla v31.4s, v3.4s, v7.s[3] \n"
584
585 "bne 0b \n"
586
587 "st1 {v8.4s, v9.4s, v10.4s}, [%1], #48 \n"
588 "st1 {v11.4s, v12.4s, v13.4s}, [%2], #48 \n"
589 "st1 {v14.4s, v15.4s, v16.4s}, [%3], #48 \n"
590 "st1 {v17.4s, v18.4s, v19.4s}, [%4], #48 \n"
591 "st1 {v20.4s, v21.4s, v22.4s}, [%5], #48 \n"
592 "st1 {v23.4s, v24.4s, v25.4s}, [%6], #48 \n"
593 "st1 {v26.4s, v27.4s, v28.4s}, [%7], #48 \n"
594 "st1 {v29.4s, v30.4s, v31.4s}, [%8], #48 \n"
595
596 : "=r"(nn), // %0
597 "=r"(outptr0), // %1
598 "=r"(outptr1), // %2
599 "=r"(outptr2), // %3
600 "=r"(outptr3), // %4
601 "=r"(outptr4), // %5
602 "=r"(outptr5), // %6
603 "=r"(outptr6), // %7
604 "=r"(outptr7), // %8
605 "=r"(tmpptr), // %9
606 "=r"(kptr) // %10
607 : "0"(nn),
608 "1"(outptr0),
609 "2"(outptr1),
610 "3"(outptr2),
611 "4"(outptr3),
612 "5"(outptr4),
613 "6"(outptr5),
614 "7"(outptr6),
615 "8"(outptr7),
616 "9"(tmpptr),
617 "10"(kptr),
618 "r"(biasptr) // %22
619 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
620 }
621 for (; i + 7 < size; i += 8)
622 {
623 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
624 const float* kptr = (const float*)kernel.channel(p / 8);
625
626 int nn = inch; // inch always > 0
627
628 asm volatile(
629 "ld1 {v30.4s, v31.4s}, [%22] \n"
630 "dup v16.4s, v30.s[0] \n"
631 "dup v17.4s, v30.s[0] \n"
632 "dup v18.4s, v30.s[1] \n"
633 "dup v19.4s, v30.s[1] \n"
634 "dup v20.4s, v30.s[2] \n"
635 "dup v21.4s, v30.s[2] \n"
636 "dup v22.4s, v30.s[3] \n"
637 "dup v23.4s, v30.s[3] \n"
638 "dup v24.4s, v31.s[0] \n"
639 "dup v25.4s, v31.s[0] \n"
640 "dup v26.4s, v31.s[1] \n"
641 "dup v27.4s, v31.s[1] \n"
642 "dup v28.4s, v31.s[2] \n"
643 "dup v29.4s, v31.s[2] \n"
644 "dup v30.4s, v31.s[3] \n"
645 "dup v31.4s, v31.s[3] \n"
646
647 "0: \n"
648
649 "prfm pldl1keep, [%9, #512] \n"
650 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n"
651
652 "prfm pldl1keep, [%10, #512] \n"
653 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%10], #64 \n"
654
655 "subs %w0, %w0, #1 \n"
656
657 "fmla v16.4s, v0.4s, v4.s[0] \n"
658 "fmla v18.4s, v0.4s, v4.s[1] \n"
659 "fmla v20.4s, v0.4s, v4.s[2] \n"
660 "fmla v22.4s, v0.4s, v4.s[3] \n"
661 "fmla v24.4s, v0.4s, v5.s[0] \n"
662 "fmla v26.4s, v0.4s, v5.s[1] \n"
663 "fmla v28.4s, v0.4s, v5.s[2] \n"
664 "fmla v30.4s, v0.4s, v5.s[3] \n"
665 "fmla v17.4s, v1.4s, v4.s[0] \n"
666 "fmla v19.4s, v1.4s, v4.s[1] \n"
667 "fmla v21.4s, v1.4s, v4.s[2] \n"
668 "fmla v23.4s, v1.4s, v4.s[3] \n"
669 "fmla v25.4s, v1.4s, v5.s[0] \n"
670 "fmla v27.4s, v1.4s, v5.s[1] \n"
671 "fmla v29.4s, v1.4s, v5.s[2] \n"
672 "fmla v31.4s, v1.4s, v5.s[3] \n"
673
674 "fmla v16.4s, v2.4s, v6.s[0] \n"
675 "fmla v18.4s, v2.4s, v6.s[1] \n"
676 "fmla v20.4s, v2.4s, v6.s[2] \n"
677 "fmla v22.4s, v2.4s, v6.s[3] \n"
678 "fmla v24.4s, v2.4s, v7.s[0] \n"
679 "fmla v26.4s, v2.4s, v7.s[1] \n"
680 "fmla v28.4s, v2.4s, v7.s[2] \n"
681 "fmla v30.4s, v2.4s, v7.s[3] \n"
682 "fmla v17.4s, v3.4s, v6.s[0] \n"
683 "fmla v19.4s, v3.4s, v6.s[1] \n"
684 "fmla v21.4s, v3.4s, v6.s[2] \n"
685 "fmla v23.4s, v3.4s, v6.s[3] \n"
686 "fmla v25.4s, v3.4s, v7.s[0] \n"
687 "fmla v27.4s, v3.4s, v7.s[1] \n"
688 "fmla v29.4s, v3.4s, v7.s[2] \n"
689 "fmla v31.4s, v3.4s, v7.s[3] \n"
690
691 "prfm pldl1keep, [%9, #512] \n"
692 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%9], #64 \n"
693
694 "prfm pldl1keep, [%10, #512] \n"
695 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%10], #64 \n"
696
697 "fmla v16.4s, v12.4s, v8.s[0] \n"
698 "fmla v18.4s, v12.4s, v8.s[1] \n"
699 "fmla v20.4s, v12.4s, v8.s[2] \n"
700 "fmla v22.4s, v12.4s, v8.s[3] \n"
701 "fmla v24.4s, v12.4s, v9.s[0] \n"
702 "fmla v26.4s, v12.4s, v9.s[1] \n"
703 "fmla v28.4s, v12.4s, v9.s[2] \n"
704 "fmla v30.4s, v12.4s, v9.s[3] \n"
705 "fmla v17.4s, v13.4s, v8.s[0] \n"
706 "fmla v19.4s, v13.4s, v8.s[1] \n"
707 "fmla v21.4s, v13.4s, v8.s[2] \n"
708 "fmla v23.4s, v13.4s, v8.s[3] \n"
709 "fmla v25.4s, v13.4s, v9.s[0] \n"
710 "fmla v27.4s, v13.4s, v9.s[1] \n"
711 "fmla v29.4s, v13.4s, v9.s[2] \n"
712 "fmla v31.4s, v13.4s, v9.s[3] \n"
713
714 "fmla v16.4s, v14.4s, v10.s[0] \n"
715 "fmla v18.4s, v14.4s, v10.s[1] \n"
716 "fmla v20.4s, v14.4s, v10.s[2] \n"
717 "fmla v22.4s, v14.4s, v10.s[3] \n"
718 "fmla v24.4s, v14.4s, v11.s[0] \n"
719 "fmla v26.4s, v14.4s, v11.s[1] \n"
720 "fmla v28.4s, v14.4s, v11.s[2] \n"
721 "fmla v30.4s, v14.4s, v11.s[3] \n"
722 "fmla v17.4s, v15.4s, v10.s[0] \n"
723 "fmla v19.4s, v15.4s, v10.s[1] \n"
724 "fmla v21.4s, v15.4s, v10.s[2] \n"
725 "fmla v23.4s, v15.4s, v10.s[3] \n"
726 "fmla v25.4s, v15.4s, v11.s[0] \n"
727 "fmla v27.4s, v15.4s, v11.s[1] \n"
728 "fmla v29.4s, v15.4s, v11.s[2] \n"
729 "fmla v31.4s, v15.4s, v11.s[3] \n"
730
731 "bne 0b \n"
732
733 "st1 {v16.4s, v17.4s}, [%1], #32 \n"
734 "st1 {v18.4s, v19.4s}, [%2], #32 \n"
735 "st1 {v20.4s, v21.4s}, [%3], #32 \n"
736 "st1 {v22.4s, v23.4s}, [%4], #32 \n"
737 "st1 {v24.4s, v25.4s}, [%5], #32 \n"
738 "st1 {v26.4s, v27.4s}, [%6], #32 \n"
739 "st1 {v28.4s, v29.4s}, [%7], #32 \n"
740 "st1 {v30.4s, v31.4s}, [%8], #32 \n"
741
742 : "=r"(nn), // %0
743 "=r"(outptr0), // %1
744 "=r"(outptr1), // %2
745 "=r"(outptr2), // %3
746 "=r"(outptr3), // %4
747 "=r"(outptr4), // %5
748 "=r"(outptr5), // %6
749 "=r"(outptr6), // %7
750 "=r"(outptr7), // %8
751 "=r"(tmpptr), // %9
752 "=r"(kptr) // %10
753 : "0"(nn),
754 "1"(outptr0),
755 "2"(outptr1),
756 "3"(outptr2),
757 "4"(outptr3),
758 "5"(outptr4),
759 "6"(outptr5),
760 "7"(outptr6),
761 "8"(outptr7),
762 "9"(tmpptr),
763 "10"(kptr),
764 "r"(biasptr) // %22
765 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
766 }
767 for (; i + 3 < size; i += 4)
768 {
769 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
770 const float* kptr = (const float*)kernel.channel(p / 8);
771
772 int nn = inch; // inch always > 0
773
774 asm volatile(
775 "ld1 {v22.4s, v23.4s}, [%22] \n"
776 "dup v16.4s, v22.s[0] \n"
777 "dup v17.4s, v22.s[1] \n"
778 "dup v18.4s, v22.s[2] \n"
779 "dup v19.4s, v22.s[3] \n"
780 "dup v20.4s, v23.s[0] \n"
781 "dup v21.4s, v23.s[1] \n"
782 "dup v22.4s, v23.s[2] \n"
783 "dup v23.4s, v23.s[3] \n"
784
785 "0: \n"
786
787 "prfm pldl1keep, [%9, #512] \n"
788 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n"
789
790 "prfm pldl1keep, [%10, #512] \n"
791 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%10], #64 \n"
792
793 "subs %w0, %w0, #1 \n"
794
795 "fmla v16.4s, v0.4s, v4.s[0] \n"
796 "fmla v17.4s, v0.4s, v4.s[1] \n"
797 "fmla v18.4s, v0.4s, v4.s[2] \n"
798 "fmla v19.4s, v0.4s, v4.s[3] \n"
799 "fmla v20.4s, v0.4s, v5.s[0] \n"
800 "fmla v21.4s, v0.4s, v5.s[1] \n"
801 "fmla v22.4s, v0.4s, v5.s[2] \n"
802 "fmla v23.4s, v0.4s, v5.s[3] \n"
803
804 "prfm pldl1keep, [%10, #512] \n"
805 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%10], #64 \n"
806
807 "fmla v16.4s, v1.4s, v6.s[0] \n"
808 "fmla v17.4s, v1.4s, v6.s[1] \n"
809 "fmla v18.4s, v1.4s, v6.s[2] \n"
810 "fmla v19.4s, v1.4s, v6.s[3] \n"
811 "fmla v20.4s, v1.4s, v7.s[0] \n"
812 "fmla v21.4s, v1.4s, v7.s[1] \n"
813 "fmla v22.4s, v1.4s, v7.s[2] \n"
814 "fmla v23.4s, v1.4s, v7.s[3] \n"
815
816 "fmla v16.4s, v2.4s, v8.s[0] \n"
817 "fmla v17.4s, v2.4s, v8.s[1] \n"
818 "fmla v18.4s, v2.4s, v8.s[2] \n"
819 "fmla v19.4s, v2.4s, v8.s[3] \n"
820 "fmla v20.4s, v2.4s, v9.s[0] \n"
821 "fmla v21.4s, v2.4s, v9.s[1] \n"
822 "fmla v22.4s, v2.4s, v9.s[2] \n"
823 "fmla v23.4s, v2.4s, v9.s[3] \n"
824
825 "fmla v16.4s, v3.4s, v10.s[0] \n"
826 "fmla v17.4s, v3.4s, v10.s[1] \n"
827 "fmla v18.4s, v3.4s, v10.s[2] \n"
828 "fmla v19.4s, v3.4s, v10.s[3] \n"
829 "fmla v20.4s, v3.4s, v11.s[0] \n"
830 "fmla v21.4s, v3.4s, v11.s[1] \n"
831 "fmla v22.4s, v3.4s, v11.s[2] \n"
832 "fmla v23.4s, v3.4s, v11.s[3] \n"
833
834 "bne 0b \n"
835
836 "st1 {v16.4s}, [%1], #16 \n"
837 "st1 {v17.4s}, [%2], #16 \n"
838 "st1 {v18.4s}, [%3], #16 \n"
839 "st1 {v19.4s}, [%4], #16 \n"
840 "st1 {v20.4s}, [%5], #16 \n"
841 "st1 {v21.4s}, [%6], #16 \n"
842 "st1 {v22.4s}, [%7], #16 \n"
843 "st1 {v23.4s}, [%8], #16 \n"
844
845 : "=r"(nn), // %0
846 "=r"(outptr0), // %1
847 "=r"(outptr1), // %2
848 "=r"(outptr2), // %3
849 "=r"(outptr3), // %4
850 "=r"(outptr4), // %5
851 "=r"(outptr5), // %6
852 "=r"(outptr6), // %7
853 "=r"(outptr7), // %8
854 "=r"(tmpptr), // %9
855 "=r"(kptr) // %10
856 : "0"(nn),
857 "1"(outptr0),
858 "2"(outptr1),
859 "3"(outptr2),
860 "4"(outptr3),
861 "5"(outptr4),
862 "6"(outptr5),
863 "7"(outptr6),
864 "8"(outptr7),
865 "9"(tmpptr),
866 "10"(kptr),
867 "r"(biasptr) // %22
868 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
869 }
870 for (; i < size; i++)
871 {
872 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
873 const float* kptr = (const float*)kernel.channel(p / 8);
874
875 int nn = inch; // inch always > 0
876
877 asm volatile(
878 "ld1 {v16.4s, v17.4s}, [%22] \n"
879 "eor v18.16b, v18.16b, v18.16b \n"
880 "eor v19.16b, v19.16b, v19.16b \n"
881
882 "0: \n"
883
884 "prfm pldl1keep, [%9, #128] \n"
885 "ld1 {v0.4s}, [%9], #16 \n"
886
887 "prfm pldl1keep, [%10, #512] \n"
888 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%10], #64 \n"
889
890 "subs %w0, %w0, #1 \n"
891
892 "fmla v16.4s, v4.4s, v0.s[0] \n"
893 "fmla v17.4s, v5.4s, v0.s[0] \n"
894 "fmla v18.4s, v6.4s, v0.s[1] \n"
895 "fmla v19.4s, v7.4s, v0.s[1] \n"
896
897 "prfm pldl1keep, [%10, #512] \n"
898 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%10], #64 \n"
899
900 "fmla v16.4s, v8.4s, v0.s[2] \n"
901 "fmla v17.4s, v9.4s, v0.s[2] \n"
902 "fmla v18.4s, v10.4s, v0.s[3] \n"
903 "fmla v19.4s, v11.4s, v0.s[3] \n"
904
905 "bne 0b \n"
906
907 "fadd v16.4s, v16.4s, v18.4s \n"
908 "fadd v17.4s, v17.4s, v19.4s \n"
909
910 "st1 {v16.s}[0], [%1], #4 \n"
911 "st1 {v16.s}[1], [%2], #4 \n"
912 "st1 {v16.s}[2], [%3], #4 \n"
913 "st1 {v16.s}[3], [%4], #4 \n"
914 "st1 {v17.s}[0], [%5], #4 \n"
915 "st1 {v17.s}[1], [%6], #4 \n"
916 "st1 {v17.s}[2], [%7], #4 \n"
917 "st1 {v17.s}[3], [%8], #4 \n"
918
919 : "=r"(nn), // %0
920 "=r"(outptr0), // %1
921 "=r"(outptr1), // %2
922 "=r"(outptr2), // %3
923 "=r"(outptr3), // %4
924 "=r"(outptr4), // %5
925 "=r"(outptr5), // %6
926 "=r"(outptr6), // %7
927 "=r"(outptr7), // %8
928 "=r"(tmpptr), // %9
929 "=r"(kptr) // %10
930 : "0"(nn),
931 "1"(outptr0),
932 "2"(outptr1),
933 "3"(outptr2),
934 "4"(outptr3),
935 "5"(outptr4),
936 "6"(outptr5),
937 "7"(outptr6),
938 "8"(outptr7),
939 "9"(tmpptr),
940 "10"(kptr),
941 "r"(biasptr) // %22
942 : "cc", "memory", "v0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19");
943 }
944 }
945
946 remain_outch_start += nn_outch << 3;
947 nn_outch = (outch - remain_outch_start) >> 2;
948 #else // __aarch64__
949 nn_outch = outch >> 2;
950 #endif // __aarch64__
951
952 #pragma omp parallel for num_threads(opt.num_threads)
953 for (int pp = 0; pp < nn_outch; pp++)
954 {
955 int p = remain_outch_start + pp * 4;
956
957 float* outptr0 = top_blob.channel(p);
958 float* outptr1 = top_blob.channel(p + 1);
959 float* outptr2 = top_blob.channel(p + 2);
960 float* outptr3 = top_blob.channel(p + 3);
961
962 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
963 const float* biasptr = bias ? bias + p : zeros;
964
965 int i = 0;
966 #if __aarch64__
967 for (; i + 11 < size; i += 12)
968 {
969 float* tmpptr = tmp.channel(i / 12);
970 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4);
971
972 int nn = inch; // inch always > 0
973
974 asm volatile(
975 "ld1 {v19.4s}, [%14] \n"
976 "dup v8.4s, v19.s[0] \n"
977 "dup v9.4s, v19.s[0] \n"
978 "dup v10.4s, v19.s[0] \n"
979 "dup v11.4s, v19.s[1] \n"
980 "dup v12.4s, v19.s[1] \n"
981 "dup v13.4s, v19.s[1] \n"
982 "dup v14.4s, v19.s[2] \n"
983 "dup v15.4s, v19.s[2] \n"
984 "dup v16.4s, v19.s[2] \n"
985 "dup v17.4s, v19.s[3] \n"
986 "dup v18.4s, v19.s[3] \n"
987 "dup v19.4s, v19.s[3] \n"
988
989 "0: \n"
990
991 "prfm pldl1keep, [%5, #512] \n"
992 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64 \n"
993
994 "prfm pldl1keep, [%6, #512] \n"
995 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%6], #64 \n"
996
997 "subs %w0, %w0, #1 \n"
998
999 "fmla v8.4s, v0.4s, v4.s[0] \n"
1000 "fmla v11.4s, v0.4s, v4.s[1] \n"
1001 "fmla v14.4s, v0.4s, v4.s[2] \n"
1002 "fmla v17.4s, v0.4s, v4.s[3] \n"
1003 "fmla v9.4s, v1.4s, v4.s[0] \n"
1004 "fmla v12.4s, v1.4s, v4.s[1] \n"
1005 "fmla v15.4s, v1.4s, v4.s[2] \n"
1006 "fmla v18.4s, v1.4s, v4.s[3] \n"
1007 "fmla v10.4s, v2.4s, v4.s[0] \n"
1008 "fmla v13.4s, v2.4s, v4.s[1] \n"
1009 "fmla v16.4s, v2.4s, v4.s[2] \n"
1010 "fmla v19.4s, v2.4s, v4.s[3] \n"
1011
1012 "prfm pldl1keep, [%5, #512] \n"
1013 "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%5], #64 \n"
1014
1015 "fmla v8.4s, v3.4s, v5.s[0] \n"
1016 "fmla v11.4s, v3.4s, v5.s[1] \n"
1017 "fmla v14.4s, v3.4s, v5.s[2] \n"
1018 "fmla v17.4s, v3.4s, v5.s[3] \n"
1019 "fmla v9.4s, v20.4s, v5.s[0] \n"
1020 "fmla v12.4s, v20.4s, v5.s[1] \n"
1021 "fmla v15.4s, v20.4s, v5.s[2] \n"
1022 "fmla v18.4s, v20.4s, v5.s[3] \n"
1023 "fmla v10.4s, v21.4s, v5.s[0] \n"
1024 "fmla v13.4s, v21.4s, v5.s[1] \n"
1025 "fmla v16.4s, v21.4s, v5.s[2] \n"
1026 "fmla v19.4s, v21.4s, v5.s[3] \n"
1027
1028 "prfm pldl1keep, [%5, #512] \n"
1029 "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%5], #64 \n"
1030
1031 "fmla v8.4s, v22.4s, v6.s[0] \n"
1032 "fmla v11.4s, v22.4s, v6.s[1] \n"
1033 "fmla v14.4s, v22.4s, v6.s[2] \n"
1034 "fmla v17.4s, v22.4s, v6.s[3] \n"
1035 "fmla v9.4s, v23.4s, v6.s[0] \n"
1036 "fmla v12.4s, v23.4s, v6.s[1] \n"
1037 "fmla v15.4s, v23.4s, v6.s[2] \n"
1038 "fmla v18.4s, v23.4s, v6.s[3] \n"
1039 "fmla v10.4s, v24.4s, v6.s[0] \n"
1040 "fmla v13.4s, v24.4s, v6.s[1] \n"
1041 "fmla v16.4s, v24.4s, v6.s[2] \n"
1042 "fmla v19.4s, v24.4s, v6.s[3] \n"
1043
1044 "fmla v8.4s, v25.4s, v7.s[0] \n"
1045 "fmla v11.4s, v25.4s, v7.s[1] \n"
1046 "fmla v14.4s, v25.4s, v7.s[2] \n"
1047 "fmla v17.4s, v25.4s, v7.s[3] \n"
1048 "fmla v9.4s, v26.4s, v7.s[0] \n"
1049 "fmla v12.4s, v26.4s, v7.s[1] \n"
1050 "fmla v15.4s, v26.4s, v7.s[2] \n"
1051 "fmla v18.4s, v26.4s, v7.s[3] \n"
1052 "fmla v10.4s, v27.4s, v7.s[0] \n"
1053 "fmla v13.4s, v27.4s, v7.s[1] \n"
1054 "fmla v16.4s, v27.4s, v7.s[2] \n"
1055 "fmla v19.4s, v27.4s, v7.s[3] \n"
1056
1057 "bne 0b \n"
1058
1059 "st1 {v8.4s, v9.4s, v10.4s}, [%1], #48 \n"
1060 "st1 {v11.4s, v12.4s, v13.4s}, [%2], #48 \n"
1061 "st1 {v14.4s, v15.4s, v16.4s}, [%3], #48 \n"
1062 "st1 {v17.4s, v18.4s, v19.4s}, [%4], #48 \n"
1063
1064 : "=r"(nn), // %0
1065 "=r"(outptr0), // %1
1066 "=r"(outptr1), // %2
1067 "=r"(outptr2), // %3
1068 "=r"(outptr3), // %4
1069 "=r"(tmpptr), // %5
1070 "=r"(kptr) // %6
1071 : "0"(nn),
1072 "1"(outptr0),
1073 "2"(outptr1),
1074 "3"(outptr2),
1075 "4"(outptr3),
1076 "5"(tmpptr),
1077 "6"(kptr),
1078 "r"(biasptr) // %14
1079 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
1080 }
1081 #endif // __aarch64__
1082 for (; i + 7 < size; i += 8)
1083 {
1084 #if __aarch64__
1085 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
1086 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4);
1087 #else
1088 float* tmpptr = tmp.channel(i / 8);
1089 const float* kptr = (const float*)kernel.channel(p / 4);
1090 #endif
1091
1092 int nn = inch; // inch always > 0
1093
1094 #if __aarch64__
1095 asm volatile(
1096 "ld1 {v15.4s}, [%14] \n"
1097 "dup v8.4s, v15.s[0] \n"
1098 "dup v9.4s, v15.s[0] \n"
1099 "dup v10.4s, v15.s[1] \n"
1100 "dup v11.4s, v15.s[1] \n"
1101 "dup v12.4s, v15.s[2] \n"
1102 "dup v13.4s, v15.s[2] \n"
1103 "dup v14.4s, v15.s[3] \n"
1104 "dup v15.4s, v15.s[3] \n"
1105
1106 "0: \n"
1107
1108 "prfm pldl1keep, [%5, #512] \n"
1109 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64 \n"
1110
1111 "prfm pldl1keep, [%6, #512] \n"
1112 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%6], #64 \n"
1113
1114 "subs %w0, %w0, #1 \n"
1115
1116 "fmla v8.4s, v0.4s, v4.s[0] \n"
1117 "fmla v10.4s, v0.4s, v4.s[1] \n"
1118 "fmla v12.4s, v0.4s, v4.s[2] \n"
1119 "fmla v14.4s, v0.4s, v4.s[3] \n"
1120 "fmla v9.4s, v1.4s, v4.s[0] \n"
1121 "fmla v11.4s, v1.4s, v4.s[1] \n"
1122 "fmla v13.4s, v1.4s, v4.s[2] \n"
1123 "fmla v15.4s, v1.4s, v4.s[3] \n"
1124
1125 "fmla v8.4s, v2.4s, v5.s[0] \n"
1126 "fmla v10.4s, v2.4s, v5.s[1] \n"
1127 "fmla v12.4s, v2.4s, v5.s[2] \n"
1128 "fmla v14.4s, v2.4s, v5.s[3] \n"
1129 "fmla v9.4s, v3.4s, v5.s[0] \n"
1130 "fmla v11.4s, v3.4s, v5.s[1] \n"
1131 "fmla v13.4s, v3.4s, v5.s[2] \n"
1132 "fmla v15.4s, v3.4s, v5.s[3] \n"
1133
1134 "prfm pldl1keep, [%5, #512] \n"
1135 "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%5], #64 \n"
1136
1137 "fmla v8.4s, v16.4s, v6.s[0] \n"
1138 "fmla v10.4s, v16.4s, v6.s[1] \n"
1139 "fmla v12.4s, v16.4s, v6.s[2] \n"
1140 "fmla v14.4s, v16.4s, v6.s[3] \n"
1141 "fmla v9.4s, v17.4s, v6.s[0] \n"
1142 "fmla v11.4s, v17.4s, v6.s[1] \n"
1143 "fmla v13.4s, v17.4s, v6.s[2] \n"
1144 "fmla v15.4s, v17.4s, v6.s[3] \n"
1145
1146 "fmla v8.4s, v18.4s, v7.s[0] \n"
1147 "fmla v10.4s, v18.4s, v7.s[1] \n"
1148 "fmla v12.4s, v18.4s, v7.s[2] \n"
1149 "fmla v14.4s, v18.4s, v7.s[3] \n"
1150 "fmla v9.4s, v19.4s, v7.s[0] \n"
1151 "fmla v11.4s, v19.4s, v7.s[1] \n"
1152 "fmla v13.4s, v19.4s, v7.s[2] \n"
1153 "fmla v15.4s, v19.4s, v7.s[3] \n"
1154
1155 "bne 0b \n"
1156
1157 "st1 {v8.4s, v9.4s}, [%1], #32 \n"
1158 "st1 {v10.4s, v11.4s}, [%2], #32 \n"
1159 "st1 {v12.4s, v13.4s}, [%3], #32 \n"
1160 "st1 {v14.4s, v15.4s}, [%4], #32 \n"
1161
1162 : "=r"(nn), // %0
1163 "=r"(outptr0), // %1
1164 "=r"(outptr1), // %2
1165 "=r"(outptr2), // %3
1166 "=r"(outptr3), // %4
1167 "=r"(tmpptr), // %5
1168 "=r"(kptr) // %6
1169 : "0"(nn),
1170 "1"(outptr0),
1171 "2"(outptr1),
1172 "3"(outptr2),
1173 "4"(outptr3),
1174 "5"(tmpptr),
1175 "6"(kptr),
1176 "r"(biasptr) // %14
1177 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1178 #else // __aarch64__
1179 asm volatile(
1180 "vld1.f32 {d30-d31}, [%14] \n"
1181 "vdup.f32 q8, d30[0] \n"
1182 "vdup.f32 q9, d30[0] \n"
1183 "vdup.f32 q10, d30[1] \n"
1184 "vdup.f32 q11, d30[1] \n"
1185 "vdup.f32 q12, d31[0] \n"
1186 "vdup.f32 q13, d31[0] \n"
1187 "vdup.f32 q14, d31[1] \n"
1188 "vdup.f32 q15, d31[1] \n"
1189
1190 "0: \n"
1191
1192 "pld [%5, #512] \n"
1193 "vldm %5!, {d0-d7} \n"
1194
1195 "pld [%6, #512] \n"
1196 "vldm %6!, {d8-d15} \n"
1197
1198 "vmla.f32 q8, q0, d8[0] \n"
1199 "vmla.f32 q10, q0, d8[1] \n"
1200 "vmla.f32 q12, q0, d9[0] \n"
1201 "vmla.f32 q14, q0, d9[1] \n"
1202 "vmla.f32 q9, q1, d8[0] \n"
1203 "vmla.f32 q11, q1, d8[1] \n"
1204 "vmla.f32 q13, q1, d9[0] \n"
1205 "vmla.f32 q15, q1, d9[1] \n"
1206
1207 "vmla.f32 q8, q2, d10[0] \n"
1208 "vmla.f32 q10, q2, d10[1] \n"
1209 "vmla.f32 q12, q2, d11[0] \n"
1210 "vmla.f32 q14, q2, d11[1] \n"
1211 "vmla.f32 q9, q3, d10[0] \n"
1212 "vmla.f32 q11, q3, d10[1] \n"
1213 "vmla.f32 q13, q3, d11[0] \n"
1214 "vmla.f32 q15, q3, d11[1] \n"
1215
1216 "pld [%5, #512] \n"
1217 "vldm %5!, {d0-d7} \n"
1218
1219 "vmla.f32 q8, q0, d12[0] \n"
1220 "vmla.f32 q10, q0, d12[1] \n"
1221 "vmla.f32 q12, q0, d13[0] \n"
1222 "vmla.f32 q14, q0, d13[1] \n"
1223 "vmla.f32 q9, q1, d12[0] \n"
1224 "vmla.f32 q11, q1, d12[1] \n"
1225 "vmla.f32 q13, q1, d13[0] \n"
1226 "vmla.f32 q15, q1, d13[1] \n"
1227
1228 "subs %0, %0, #1 \n"
1229
1230 "vmla.f32 q8, q2, d14[0] \n"
1231 "vmla.f32 q10, q2, d14[1] \n"
1232 "vmla.f32 q12, q2, d15[0] \n"
1233 "vmla.f32 q14, q2, d15[1] \n"
1234 "vmla.f32 q9, q3, d14[0] \n"
1235 "vmla.f32 q11, q3, d14[1] \n"
1236 "vmla.f32 q13, q3, d15[0] \n"
1237 "vmla.f32 q15, q3, d15[1] \n"
1238
1239 "bne 0b \n"
1240
1241 "vst1.f32 {d16-d19}, [%1 :128]! \n"
1242 "vst1.f32 {d20-d23}, [%2 :128]! \n"
1243 "vst1.f32 {d24-d27}, [%3 :128]! \n"
1244 "vst1.f32 {d28-d31}, [%4 :128]! \n"
1245
1246 : "=r"(nn), // %0
1247 "=r"(outptr0), // %1
1248 "=r"(outptr1), // %2
1249 "=r"(outptr2), // %3
1250 "=r"(outptr3), // %4
1251 "=r"(tmpptr), // %5
1252 "=r"(kptr) // %6
1253 : "0"(nn),
1254 "1"(outptr0),
1255 "2"(outptr1),
1256 "3"(outptr2),
1257 "4"(outptr3),
1258 "5"(tmpptr),
1259 "6"(kptr),
1260 "r"(biasptr) // %14
1261 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1262 #endif // __aarch64__
1263 }
1264 for (; i + 3 < size; i += 4)
1265 {
1266 #if __aarch64__
1267 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1268 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4);
1269 #else
1270 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
1271 const float* kptr = (const float*)kernel.channel(p / 4);
1272 #endif
1273
1274 int nn = inch; // inch always > 0
1275
1276 #if __aarch64__
1277 asm volatile(
1278 "ld1 {v11.4s}, [%14] \n"
1279 "dup v8.4s, v11.s[0] \n"
1280 "dup v9.4s, v11.s[1] \n"
1281 "dup v10.4s, v11.s[2] \n"
1282 "dup v11.4s, v11.s[3] \n"
1283
1284 "0: \n"
1285
1286 "prfm pldl1keep, [%5, #512] \n"
1287 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64 \n"
1288
1289 "prfm pldl1keep, [%6, #512] \n"
1290 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%6], #64 \n"
1291
1292 "subs %w0, %w0, #1 \n"
1293
1294 "fmla v8.4s, v0.4s, v4.s[0] \n"
1295 "fmla v9.4s, v0.4s, v4.s[1] \n"
1296 "fmla v10.4s, v0.4s, v4.s[2] \n"
1297 "fmla v11.4s, v0.4s, v4.s[3] \n"
1298
1299 "fmla v8.4s, v1.4s, v5.s[0] \n"
1300 "fmla v9.4s, v1.4s, v5.s[1] \n"
1301 "fmla v10.4s, v1.4s, v5.s[2] \n"
1302 "fmla v11.4s, v1.4s, v5.s[3] \n"
1303
1304 "fmla v8.4s, v2.4s, v6.s[0] \n"
1305 "fmla v9.4s, v2.4s, v6.s[1] \n"
1306 "fmla v10.4s, v2.4s, v6.s[2] \n"
1307 "fmla v11.4s, v2.4s, v6.s[3] \n"
1308
1309 "fmla v8.4s, v3.4s, v7.s[0] \n"
1310 "fmla v9.4s, v3.4s, v7.s[1] \n"
1311 "fmla v10.4s, v3.4s, v7.s[2] \n"
1312 "fmla v11.4s, v3.4s, v7.s[3] \n"
1313
1314 "bne 0b \n"
1315
1316 "st1 {v8.4s}, [%1], #16 \n"
1317 "st1 {v9.4s}, [%2], #16 \n"
1318 "st1 {v10.4s}, [%3], #16 \n"
1319 "st1 {v11.4s}, [%4], #16 \n"
1320
1321 : "=r"(nn), // %0
1322 "=r"(outptr0), // %1
1323 "=r"(outptr1), // %2
1324 "=r"(outptr2), // %3
1325 "=r"(outptr3), // %4
1326 "=r"(tmpptr), // %5
1327 "=r"(kptr) // %6
1328 : "0"(nn),
1329 "1"(outptr0),
1330 "2"(outptr1),
1331 "3"(outptr2),
1332 "4"(outptr3),
1333 "5"(tmpptr),
1334 "6"(kptr),
1335 "r"(biasptr) // %14
1336 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
1337 #else // __aarch64__
1338 asm volatile(
1339 "vld1.f32 {d22-d23}, [%14] \n"
1340 "vdup.f32 q8, d22[0] \n"
1341 "vdup.f32 q9, d22[1] \n"
1342 "vdup.f32 q10, d23[0] \n"
1343 "vdup.f32 q11, d23[1] \n"
1344
1345 "0: \n"
1346
1347 "pld [%5, #512] \n"
1348 "vldm %5!, {d0-d7} \n"
1349
1350 "pld [%6, #512] \n"
1351 "vldm %6!, {d8-d15} \n"
1352
1353 "vmla.f32 q8, q0, d8[0] \n"
1354 "vmla.f32 q9, q0, d8[1] \n"
1355 "vmla.f32 q10, q0, d9[0] \n"
1356 "vmla.f32 q11, q0, d9[1] \n"
1357
1358 "vmla.f32 q8, q1, d10[0] \n"
1359 "vmla.f32 q9, q1, d10[1] \n"
1360 "vmla.f32 q10, q1, d11[0] \n"
1361 "vmla.f32 q11, q1, d11[1] \n"
1362
1363 "subs %0, %0, #1 \n"
1364
1365 "vmla.f32 q8, q2, d12[0] \n"
1366 "vmla.f32 q9, q2, d12[1] \n"
1367 "vmla.f32 q10, q2, d13[0] \n"
1368 "vmla.f32 q11, q2, d13[1] \n"
1369
1370 "vmla.f32 q8, q3, d14[0] \n"
1371 "vmla.f32 q9, q3, d14[1] \n"
1372 "vmla.f32 q10, q3, d15[0] \n"
1373 "vmla.f32 q11, q3, d15[1] \n"
1374
1375 "bne 0b \n"
1376
1377 "vst1.f32 {d16-d17}, [%1 :128]! \n"
1378 "vst1.f32 {d18-d19}, [%2 :128]! \n"
1379 "vst1.f32 {d20-d21}, [%3 :128]! \n"
1380 "vst1.f32 {d22-d23}, [%4 :128]! \n"
1381
1382 : "=r"(nn), // %0
1383 "=r"(outptr0), // %1
1384 "=r"(outptr1), // %2
1385 "=r"(outptr2), // %3
1386 "=r"(outptr3), // %4
1387 "=r"(tmpptr), // %5
1388 "=r"(kptr) // %6
1389 : "0"(nn),
1390 "1"(outptr0),
1391 "2"(outptr1),
1392 "3"(outptr2),
1393 "4"(outptr3),
1394 "5"(tmpptr),
1395 "6"(kptr),
1396 "r"(biasptr) // %14
1397 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
1398 #endif // __aarch64__
1399 }
1400 for (; i < size; i++)
1401 {
1402 #if __aarch64__
1403 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
1404 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4);
1405 #else
1406 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
1407 const float* kptr = (const float*)kernel.channel(p / 4);
1408 #endif
1409
1410 int nn = inch; // inch always > 0
1411
1412 #if __aarch64__
1413 asm volatile(
1414 "ld1 {v8.4s}, [%14] \n"
1415 "eor v9.16b, v9.16b, v9.16b \n"
1416 "eor v10.16b, v10.16b, v10.16b \n"
1417 "eor v11.16b, v11.16b, v11.16b \n"
1418
1419 "0: \n"
1420
1421 "prfm pldl1keep, [%5, #128] \n"
1422 "ld1 {v0.4s}, [%5], #16 \n"
1423
1424 "prfm pldl1keep, [%6, #512] \n"
1425 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%6], #64 \n"
1426
1427 "subs %w0, %w0, #1 \n"
1428
1429 "fmla v8.4s, v4.4s, v0.s[0] \n"
1430 "fmla v9.4s, v5.4s, v0.s[1] \n"
1431 "fmla v10.4s, v6.4s, v0.s[2] \n"
1432 "fmla v11.4s, v7.4s, v0.s[3] \n"
1433
1434 "bne 0b \n"
1435
1436 "fadd v8.4s, v8.4s, v9.4s \n"
1437 "fadd v10.4s, v10.4s, v11.4s \n"
1438 "fadd v8.4s, v8.4s, v10.4s \n"
1439
1440 "st1 {v8.s}[0], [%1], #4 \n"
1441 "st1 {v8.s}[1], [%2], #4 \n"
1442 "st1 {v8.s}[2], [%3], #4 \n"
1443 "st1 {v8.s}[3], [%4], #4 \n"
1444
1445 : "=r"(nn), // %0
1446 "=r"(outptr0), // %1
1447 "=r"(outptr1), // %2
1448 "=r"(outptr2), // %3
1449 "=r"(outptr3), // %4
1450 "=r"(tmpptr), // %5
1451 "=r"(kptr) // %6
1452 : "0"(nn),
1453 "1"(outptr0),
1454 "2"(outptr1),
1455 "3"(outptr2),
1456 "4"(outptr3),
1457 "5"(tmpptr),
1458 "6"(kptr),
1459 "r"(biasptr) // %14
1460 : "cc", "memory", "v0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
1461 #else // __aarch64__
1462 asm volatile(
1463 "vld1.f32 {d16-d17}, [%14] \n"
1464 "veor q9, q9 \n"
1465 "veor q10, q10 \n"
1466 "veor q11, q11 \n"
1467
1468 "0: \n"
1469
1470 "pld [%5, #128] \n"
1471 "vld1.f32 {d0-d1}, [%5]! \n"
1472
1473 "pld [%6, #512] \n"
1474 "vldm %6!, {d8-d15} \n"
1475
1476 "subs %0, %0, #1 \n"
1477
1478 "vmla.f32 q8, q4, d0[0] \n"
1479 "vmla.f32 q9, q5, d0[1] \n"
1480 "vmla.f32 q10, q6, d1[0] \n"
1481 "vmla.f32 q11, q7, d1[1] \n"
1482
1483 "bne 0b \n"
1484
1485 "vadd.f32 q8, q8, q9 \n"
1486 "vadd.f32 q10, q10, q11 \n"
1487 "vadd.f32 q8, q8, q10 \n"
1488
1489 "vst1.f32 {d16[0]}, [%1]! \n"
1490 "vst1.f32 {d16[1]}, [%2]! \n"
1491 "vst1.f32 {d17[0]}, [%3]! \n"
1492 "vst1.f32 {d17[1]}, [%4]! \n"
1493
1494 : "=r"(nn), // %0
1495 "=r"(outptr0), // %1
1496 "=r"(outptr1), // %2
1497 "=r"(outptr2), // %3
1498 "=r"(outptr3), // %4
1499 "=r"(tmpptr), // %5
1500 "=r"(kptr) // %6
1501 : "0"(nn),
1502 "1"(outptr0),
1503 "2"(outptr1),
1504 "3"(outptr2),
1505 "4"(outptr3),
1506 "5"(tmpptr),
1507 "6"(kptr),
1508 "r"(biasptr) // %14
1509 : "cc", "memory", "q0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
1510 #endif // __aarch64__
1511 }
1512 }
1513
1514 remain_outch_start += nn_outch << 2;
1515
1516 #pragma omp parallel for num_threads(opt.num_threads)
1517 for (int p = remain_outch_start; p < outch; p++)
1518 {
1519 float* outptr0 = top_blob.channel(p);
1520
1521 const float bias0 = bias ? bias[p] : 0.f;
1522
1523 int i = 0;
1524 #if __aarch64__
1525 for (; i + 11 < size; i += 12)
1526 {
1527 float* tmpptr = tmp.channel(i / 12);
1528 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1529
1530 int nn = inch; // inch always > 0
1531
1532 asm volatile(
1533 "dup v8.4s, %w8 \n"
1534 "dup v9.4s, %w8 \n"
1535 "dup v10.4s, %w8 \n"
1536 "eor v5.16b, v5.16b, v5.16b \n"
1537 "eor v6.16b, v6.16b, v6.16b \n"
1538 "eor v7.16b, v7.16b, v7.16b \n"
1539
1540 "0: \n"
1541
1542 "prfm pldl1keep, [%2, #512] \n"
1543 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%2], #64 \n"
1544
1545 "prfm pldl1keep, [%3, #128] \n"
1546 "ld1 {v4.4s}, [%3], #16 \n"
1547
1548 "subs %w0, %w0, #1 \n"
1549
1550 "fmla v8.4s, v0.4s, v4.s[0] \n"
1551 "fmla v9.4s, v1.4s, v4.s[0] \n"
1552 "fmla v10.4s, v2.4s, v4.s[0] \n"
1553
1554 "prfm pldl1keep, [%2, #512] \n"
1555 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%2], #64 \n"
1556
1557 "fmla v5.4s, v3.4s, v4.s[1] \n"
1558 "fmla v6.4s, v12.4s, v4.s[1] \n"
1559 "fmla v7.4s, v13.4s, v4.s[1] \n"
1560
1561 "prfm pldl1keep, [%2, #512] \n"
1562 "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%2], #64 \n"
1563
1564 "fmla v8.4s, v14.4s, v4.s[2] \n"
1565 "fmla v9.4s, v15.4s, v4.s[2] \n"
1566 "fmla v10.4s, v16.4s, v4.s[2] \n"
1567
1568 "fmla v5.4s, v17.4s, v4.s[3] \n"
1569 "fmla v6.4s, v18.4s, v4.s[3] \n"
1570 "fmla v7.4s, v19.4s, v4.s[3] \n"
1571
1572 "bne 0b \n"
1573
1574 "fadd v8.4s, v8.4s, v5.4s \n"
1575 "fadd v9.4s, v9.4s, v6.4s \n"
1576 "fadd v10.4s, v10.4s, v7.4s \n"
1577
1578 "st1 {v8.4s, v9.4s, v10.4s}, [%1], #48 \n"
1579
1580 : "=r"(nn), // %0
1581 "=r"(outptr0), // %1
1582 "=r"(tmpptr), // %2
1583 "=r"(kptr) // %3
1584 : "0"(nn),
1585 "1"(outptr0),
1586 "2"(tmpptr),
1587 "3"(kptr),
1588 "r"(bias0) // %8
1589 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1590 }
1591 #endif // __aarch64__
1592 for (; i + 7 < size; i += 8)
1593 {
1594 #if __aarch64__
1595 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
1596 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1597 #else
1598 float* tmpptr = tmp.channel(i / 8);
1599 const float* kptr = (const float*)kernel.channel(p / 4 + p % 4);
1600 #endif
1601
1602 int nn = inch; // inch always > 0
1603
1604 #if __aarch64__
1605 asm volatile(
1606 "dup v8.4s, %w8 \n"
1607 "dup v9.4s, %w8 \n"
1608 "eor v10.16b, v10.16b, v10.16b \n"
1609 "eor v11.16b, v11.16b, v11.16b \n"
1610
1611 "0: \n"
1612
1613 "prfm pldl1keep, [%2, #512] \n"
1614 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%2], #64 \n"
1615
1616 "prfm pldl1keep, [%3, #128] \n"
1617 "ld1 {v4.4s}, [%3], #16 \n"
1618
1619 "subs %w0, %w0, #1 \n"
1620
1621 "fmla v8.4s, v0.4s, v4.s[0] \n"
1622 "fmla v9.4s, v1.4s, v4.s[0] \n"
1623 "fmla v10.4s, v2.4s, v4.s[1] \n"
1624 "fmla v11.4s, v3.4s, v4.s[1] \n"
1625
1626 "prfm pldl1keep, [%2, #512] \n"
1627 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%2], #64 \n"
1628
1629 "fmla v8.4s, v12.4s, v4.s[2] \n"
1630 "fmla v9.4s, v13.4s, v4.s[2] \n"
1631 "fmla v10.4s, v14.4s, v4.s[3] \n"
1632 "fmla v11.4s, v15.4s, v4.s[3] \n"
1633
1634 "bne 0b \n"
1635
1636 "fadd v8.4s, v8.4s, v10.4s \n"
1637 "fadd v9.4s, v9.4s, v11.4s \n"
1638
1639 "st1 {v8.4s, v9.4s}, [%1], #32 \n"
1640
1641 : "=r"(nn), // %0
1642 "=r"(outptr0), // %1
1643 "=r"(tmpptr), // %2
1644 "=r"(kptr) // %3
1645 : "0"(nn),
1646 "1"(outptr0),
1647 "2"(tmpptr),
1648 "3"(kptr),
1649 "r"(bias0) // %8
1650 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15");
1651 #else // __aarch64__
1652 asm volatile(
1653 "vdup.f32 q8, %8 \n"
1654 "vdup.f32 q9, %8 \n"
1655 "veor q10, q10 \n"
1656 "veor q11, q11 \n"
1657
1658 "0: \n"
1659
1660 "pld [%2, #512] \n"
1661 "vldm %2!, {d0-d7} \n"
1662
1663 "pld [%3, #128] \n"
1664 "vld1.f32 {d8-d9}, [%3]! \n"
1665
1666 "vmla.f32 q8, q0, d8[0] \n"
1667 "vmla.f32 q9, q1, d8[0] \n"
1668 "vmla.f32 q10, q2, d8[1] \n"
1669 "vmla.f32 q11, q3, d8[1] \n"
1670
1671 "pld [%2, #512] \n"
1672 "vldm %2!, {d24-d31} \n"
1673
1674 "subs %0, %0, #1 \n"
1675
1676 "vmla.f32 q8, q12, d9[0] \n"
1677 "vmla.f32 q9, q13, d9[0] \n"
1678 "vmla.f32 q10, q14, d9[1] \n"
1679 "vmla.f32 q11, q15, d9[1] \n"
1680
1681 "bne 0b \n"
1682
1683 "vadd.f32 q8, q8, q10 \n"
1684 "vadd.f32 q9, q9, q11 \n"
1685
1686 "vst1.f32 {d16-d19}, [%1 :128]! \n"
1687
1688 : "=r"(nn), // %0
1689 "=r"(outptr0), // %1
1690 "=r"(tmpptr), // %2
1691 "=r"(kptr) // %3
1692 : "0"(nn),
1693 "1"(outptr0),
1694 "2"(tmpptr),
1695 "3"(kptr),
1696 "r"(bias0) // %8
1697 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1698 #endif // __aarch64__
1699 }
1700 for (; i + 3 < size; i += 4)
1701 {
1702 #if __aarch64__
1703 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1704 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1705 #else
1706 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
1707 const float* kptr = (const float*)kernel.channel(p / 4 + p % 4);
1708 #endif
1709
1710 int nn = inch; // inch always > 0
1711
1712 #if __aarch64__
1713 asm volatile(
1714 "dup v8.4s, %w8 \n"
1715 "eor v9.16b, v9.16b, v9.16b \n"
1716 "eor v10.16b, v10.16b, v10.16b \n"
1717 "eor v11.16b, v11.16b, v11.16b \n"
1718
1719 "0: \n"
1720
1721 "prfm pldl1keep, [%2, #512] \n"
1722 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%2], #64 \n"
1723
1724 "prfm pldl1keep, [%3, #128] \n"
1725 "ld1 {v4.4s}, [%3], #16 \n"
1726
1727 "subs %w0, %w0, #1 \n"
1728
1729 "fmla v8.4s, v0.4s, v4.s[0] \n"
1730 "fmla v9.4s, v1.4s, v4.s[1] \n"
1731 "fmla v10.4s, v2.4s, v4.s[2] \n"
1732 "fmla v11.4s, v3.4s, v4.s[3] \n"
1733
1734 "bne 0b \n"
1735
1736 "fadd v8.4s, v8.4s, v9.4s \n"
1737 "fadd v10.4s, v10.4s, v11.4s \n"
1738 "fadd v8.4s, v8.4s, v10.4s \n"
1739
1740 "st1 {v8.4s}, [%1], #16 \n"
1741
1742 : "=r"(nn), // %0
1743 "=r"(outptr0), // %1
1744 "=r"(tmpptr), // %2
1745 "=r"(kptr) // %3
1746 : "0"(nn),
1747 "1"(outptr0),
1748 "2"(tmpptr),
1749 "3"(kptr),
1750 "r"(bias0) // %8
1751 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v8", "v9", "v10", "v11");
1752 #else // __aarch64__
1753 asm volatile(
1754 "vdup.f32 q8, %8 \n"
1755 "veor q9, q9 \n"
1756 "veor q10, q10 \n"
1757 "veor q11, q11 \n"
1758
1759 "0: \n"
1760
1761 "pld [%2, #512] \n"
1762 "vldm %2!, {d0-d7} \n"
1763
1764 "pld [%3, #128] \n"
1765 "vld1.f32 {d8-d9}, [%3]! \n"
1766
1767 "subs %0, %0, #1 \n"
1768
1769 "vmla.f32 q8, q0, d8[0] \n"
1770 "vmla.f32 q9, q1, d8[1] \n"
1771 "vmla.f32 q10, q2, d9[0] \n"
1772 "vmla.f32 q11, q3, d9[1] \n"
1773
1774 "bne 0b \n"
1775
1776 "vadd.f32 q8, q8, q9 \n"
1777 "vadd.f32 q10, q10, q11 \n"
1778 "vadd.f32 q8, q8, q10 \n"
1779
1780 "vst1.f32 {d16-d17}, [%1]! \n"
1781
1782 : "=r"(nn), // %0
1783 "=r"(outptr0), // %1
1784 "=r"(tmpptr), // %2
1785 "=r"(kptr) // %3
1786 : "0"(nn),
1787 "1"(outptr0),
1788 "2"(tmpptr),
1789 "3"(kptr),
1790 "r"(bias0) // %8
1791 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q8", "q9", "q10", "q11");
1792 #endif // __aarch64__
1793 }
1794 for (; i < size; i++)
1795 {
1796 #if __aarch64__
1797 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
1798 const float* kptr = (const float*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1799 #else
1800 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
1801 const float* kptr = (const float*)kernel.channel(p / 4 + p % 4);
1802 #endif
1803
1804 float32x4_t _sum0 = vdupq_n_f32(0.f);
1805
1806 for (int q = 0; q < inch; q++)
1807 {
1808 float32x4_t _r0 = vld1q_f32(tmpptr);
1809
1810 float32x4_t _k0 = vld1q_f32(kptr);
1811
1812 _sum0 = vmlaq_f32(_sum0, _r0, _k0);
1813
1814 kptr += 4;
1815 tmpptr += 4;
1816 }
1817
1818 #if __aarch64__
1819 float sum0 = vaddvq_f32(_sum0);
1820 #else
1821 float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
1822 float32x2_t _ss2 = vpadd_f32(_ss, _ss);
1823 float sum0 = vget_lane_f32(_ss2, 0);
1824 #endif
1825
1826 outptr0[0] = bias0 + sum0;
1827
1828 outptr0++;
1829 }
1830 }
1831
1832 // // NOTE sgemm
1833 // for (; p<outch; p++)
1834 // {
1835 // Mat out0 = top_blob.channel(p);
1836 //
1837 // const float bias0 = bias ? bias[p] : 0.f;
1838 //
1839 // float* outptr0 = out0;
1840 //
1841 // for (int i=0; i<size; i++)
1842 // {
1843 // float sum = bias0;
1844 //
1845 // const float* kptr = _kernel.channel(p);
1846 //
1847 // for (int q=0; q<inch; q++)
1848 // {
1849 // const float* img0 = bottom_blob.channel(q);
1850 //
1851 // sum += img0[i] * kptr[0];
1852 // kptr ++;
1853 // }
1854 //
1855 // outptr0[i] = sum;
1856 // }
1857 // }
1858 }
1859
conv1x1s2_pack4to1_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)1860 static void conv1x1s2_pack4to1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
1861 {
1862 int w = bottom_blob.w;
1863 int channels = bottom_blob.c;
1864 size_t elemsize = bottom_blob.elemsize;
1865 int elempack = bottom_blob.elempack;
1866
1867 int outw = top_blob.w;
1868 int outh = top_blob.h;
1869
1870 const int tailstep = (w - 2 * outw + w) * 4;
1871
1872 Mat bottom_blob_shrinked;
1873 bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
1874
1875 #pragma omp parallel for num_threads(opt.num_threads)
1876 for (int p = 0; p < channels; p++)
1877 {
1878 const float* r0 = bottom_blob.channel(p);
1879 float* outptr = bottom_blob_shrinked.channel(p);
1880
1881 for (int i = 0; i < outh; i++)
1882 {
1883 for (int j = 0; j < outw; j++)
1884 {
1885 float32x4_t _v = vld1q_f32(r0);
1886 vst1q_f32(outptr, _v);
1887
1888 r0 += 8;
1889 outptr += 4;
1890 }
1891
1892 r0 += tailstep;
1893 }
1894 }
1895
1896 conv1x1s1_sgemm_pack4to1_neon(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
1897 }
1898