1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv1x1s1_sgemm_transform_kernel_pack4to1_bf16s_neon(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch)15 static void conv1x1s1_sgemm_transform_kernel_pack4to1_bf16s_neon(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch)
16 {
17 // interleave
18 // src = inch-outch
19 // dst = 4a-inch/4a-outch
20 #if __aarch64__
21 kernel_tm_pack4.create(8, inch / 4, outch / 8 + (outch % 8) / 4 + outch % 4, (size_t)2u * 4, 4);
22 #else
23 kernel_tm_pack4.create(4, inch / 4, outch / 4 + outch % 4, (size_t)2u * 4, 4);
24 #endif
25
26 int p = 0;
27 #if __aarch64__
28 for (; p + 7 < outch; p += 8)
29 {
30 const float* k0 = (const float*)kernel + (p + 0) * inch;
31 const float* k1 = (const float*)kernel + (p + 1) * inch;
32 const float* k2 = (const float*)kernel + (p + 2) * inch;
33 const float* k3 = (const float*)kernel + (p + 3) * inch;
34 const float* k4 = (const float*)kernel + (p + 4) * inch;
35 const float* k5 = (const float*)kernel + (p + 5) * inch;
36 const float* k6 = (const float*)kernel + (p + 6) * inch;
37 const float* k7 = (const float*)kernel + (p + 7) * inch;
38
39 unsigned short* ktmp = kernel_tm_pack4.channel(p / 8);
40
41 for (int q = 0; q + 3 < inch; q += 4)
42 {
43 ktmp[0] = float32_to_bfloat16(k0[0]);
44 ktmp[1] = float32_to_bfloat16(k1[0]);
45 ktmp[2] = float32_to_bfloat16(k2[0]);
46 ktmp[3] = float32_to_bfloat16(k3[0]);
47 ktmp[4] = float32_to_bfloat16(k4[0]);
48 ktmp[5] = float32_to_bfloat16(k5[0]);
49 ktmp[6] = float32_to_bfloat16(k6[0]);
50 ktmp[7] = float32_to_bfloat16(k7[0]);
51
52 ktmp[8] = float32_to_bfloat16(k0[1]);
53 ktmp[9] = float32_to_bfloat16(k1[1]);
54 ktmp[10] = float32_to_bfloat16(k2[1]);
55 ktmp[11] = float32_to_bfloat16(k3[1]);
56 ktmp[12] = float32_to_bfloat16(k4[1]);
57 ktmp[13] = float32_to_bfloat16(k5[1]);
58 ktmp[14] = float32_to_bfloat16(k6[1]);
59 ktmp[15] = float32_to_bfloat16(k7[1]);
60
61 ktmp[16] = float32_to_bfloat16(k0[2]);
62 ktmp[17] = float32_to_bfloat16(k1[2]);
63 ktmp[18] = float32_to_bfloat16(k2[2]);
64 ktmp[19] = float32_to_bfloat16(k3[2]);
65 ktmp[20] = float32_to_bfloat16(k4[2]);
66 ktmp[21] = float32_to_bfloat16(k5[2]);
67 ktmp[22] = float32_to_bfloat16(k6[2]);
68 ktmp[23] = float32_to_bfloat16(k7[2]);
69
70 ktmp[24] = float32_to_bfloat16(k0[3]);
71 ktmp[25] = float32_to_bfloat16(k1[3]);
72 ktmp[26] = float32_to_bfloat16(k2[3]);
73 ktmp[27] = float32_to_bfloat16(k3[3]);
74 ktmp[28] = float32_to_bfloat16(k4[3]);
75 ktmp[29] = float32_to_bfloat16(k5[3]);
76 ktmp[30] = float32_to_bfloat16(k6[3]);
77 ktmp[31] = float32_to_bfloat16(k7[3]);
78
79 k0 += 4;
80 k1 += 4;
81 k2 += 4;
82 k3 += 4;
83 k4 += 4;
84 k5 += 4;
85 k6 += 4;
86 k7 += 4;
87 ktmp += 32;
88 }
89 }
90 #endif
91 for (; p + 3 < outch; p += 4)
92 {
93 const float* k0 = (const float*)kernel + (p + 0) * inch;
94 const float* k1 = (const float*)kernel + (p + 1) * inch;
95 const float* k2 = (const float*)kernel + (p + 2) * inch;
96 const float* k3 = (const float*)kernel + (p + 3) * inch;
97
98 #if __aarch64__
99 unsigned short* ktmp = kernel_tm_pack4.channel(p / 8 + (p % 8) / 4);
100 #else
101 unsigned short* ktmp = kernel_tm_pack4.channel(p / 4);
102 #endif
103
104 for (int q = 0; q + 3 < inch; q += 4)
105 {
106 ktmp[0] = float32_to_bfloat16(k0[0]);
107 ktmp[1] = float32_to_bfloat16(k1[0]);
108 ktmp[2] = float32_to_bfloat16(k2[0]);
109 ktmp[3] = float32_to_bfloat16(k3[0]);
110
111 ktmp[4] = float32_to_bfloat16(k0[1]);
112 ktmp[5] = float32_to_bfloat16(k1[1]);
113 ktmp[6] = float32_to_bfloat16(k2[1]);
114 ktmp[7] = float32_to_bfloat16(k3[1]);
115
116 ktmp[8] = float32_to_bfloat16(k0[2]);
117 ktmp[9] = float32_to_bfloat16(k1[2]);
118 ktmp[10] = float32_to_bfloat16(k2[2]);
119 ktmp[11] = float32_to_bfloat16(k3[2]);
120
121 ktmp[12] = float32_to_bfloat16(k0[3]);
122 ktmp[13] = float32_to_bfloat16(k1[3]);
123 ktmp[14] = float32_to_bfloat16(k2[3]);
124 ktmp[15] = float32_to_bfloat16(k3[3]);
125
126 k0 += 4;
127 k1 += 4;
128 k2 += 4;
129 k3 += 4;
130 ktmp += 16;
131 }
132 }
133 for (; p < outch; p++)
134 {
135 const float* k0 = (const float*)kernel + p * inch;
136
137 #if __aarch64__
138 unsigned short* ktmp = kernel_tm_pack4.channel(p / 8 + (p % 8) / 4 + p % 4);
139 #else
140 unsigned short* ktmp = kernel_tm_pack4.channel(p / 4 + p % 4);
141 #endif
142
143 for (int q = 0; q + 3 < inch; q += 4)
144 {
145 ktmp[0] = float32_to_bfloat16(k0[0]);
146 ktmp[1] = float32_to_bfloat16(k0[1]);
147 ktmp[2] = float32_to_bfloat16(k0[2]);
148 ktmp[3] = float32_to_bfloat16(k0[3]);
149
150 k0 += 4;
151 ktmp += 4;
152 }
153 }
154 }
155
conv1x1s1_sgemm_pack4to1_bf16s_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)156 static void conv1x1s1_sgemm_pack4to1_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
157 {
158 int w = bottom_blob.w;
159 int h = bottom_blob.h;
160 int inch = bottom_blob.c;
161 int outch = top_blob.c;
162
163 size_t elemsize = bottom_blob.elemsize;
164 int elempack = bottom_blob.elempack;
165
166 const int size = w * h;
167
168 const float* bias = _bias;
169
170 // interleave
171 Mat tmp;
172 #if __aarch64__
173 if (size >= 12)
174 tmp.create(12, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + size % 12 % 4, elemsize, elempack, opt.workspace_allocator);
175 else if (size >= 8)
176 tmp.create(8, inch, size / 8 + (size % 8) / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
177 else if (size >= 4)
178 tmp.create(4, inch, size / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
179 else // if (size >= 1)
180 tmp.create(1, inch, size, elemsize, elempack, opt.workspace_allocator);
181 #else
182 if (size >= 8)
183 tmp.create(8, inch, size / 8 + (size % 8) / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
184 else if (size >= 4)
185 tmp.create(4, inch, size / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
186 else // if (size >= 1)
187 tmp.create(1, inch, size, elemsize, elempack, opt.workspace_allocator);
188 #endif
189 {
190 int nn_size;
191 int remain_size_start;
192
193 #if __aarch64__
194 nn_size = size / 12;
195 remain_size_start = nn_size * 12;
196
197 #pragma omp parallel for num_threads(opt.num_threads)
198 for (int ii = 0; ii < nn_size; ii++)
199 {
200 int i = ii * 12;
201
202 const unsigned short* img0 = bottom_blob.channel(0);
203 img0 += i * 4;
204
205 unsigned short* tmpptr = tmp.channel(i / 12);
206
207 for (int q = 0; q < inch; q++)
208 {
209 // transpose 4x12
210 asm volatile(
211 "prfm pldl1keep, [%0, #512] \n"
212 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
213 "ld4 {v4.4h, v5.4h, v6.4h, v7.4h}, [%0] \n"
214 "st1 {v0.8h}, [%1], #16 \n"
215 "st1 {v4.4h}, [%1], #8 \n"
216 "st1 {v1.8h}, [%1], #16 \n"
217 "st1 {v5.4h}, [%1], #8 \n"
218 "sub %0, %0, #64 \n"
219 "st1 {v2.8h}, [%1], #16 \n"
220 "st1 {v6.4h}, [%1], #8 \n"
221 "st1 {v3.8h}, [%1], #16 \n"
222 "st1 {v7.4h}, [%1], #8 \n"
223 : "=r"(img0), // %0
224 "=r"(tmpptr) // %1
225 : "0"(img0),
226 "1"(tmpptr)
227 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
228 img0 += bottom_blob.cstep * 4;
229 }
230 }
231 #else
232 remain_size_start = 0;
233 #endif
234 nn_size = (size - remain_size_start) >> 3;
235
236 #pragma omp parallel for num_threads(opt.num_threads)
237 for (int ii = 0; ii < nn_size; ii++)
238 {
239 int i = remain_size_start + ii * 8;
240
241 const unsigned short* img0 = bottom_blob.channel(0);
242 img0 += i * 4;
243
244 #if __aarch64__
245 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
246 #else
247 unsigned short* tmpptr = tmp.channel(i / 8);
248 #endif
249
250 for (int q = 0; q < inch; q++)
251 {
252 // transpose 4x8
253 #if __aarch64__
254 asm volatile(
255 "prfm pldl1keep, [%0, #512] \n"
256 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
257 "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
258 : "=r"(img0), // %0
259 "=r"(tmpptr) // %1
260 : "0"(img0),
261 "1"(tmpptr)
262 : "memory", "v0", "v1", "v2", "v3");
263 #else
264 asm volatile(
265 "pld [%0, #256] \n"
266 "vld4.u16 {d0-d3}, [%0]! \n"
267 "pld [%0, #256] \n"
268 "vld4.u16 {d4-d7}, [%0] \n"
269 "sub %0, %0, #32 \n"
270 "vst1.u16 {d0}, [%1 :64]! \n"
271 "vst1.u16 {d4}, [%1 :64]! \n"
272 "vst1.u16 {d1}, [%1 :64]! \n"
273 "vst1.u16 {d5}, [%1 :64]! \n"
274 "vst1.u16 {d2}, [%1 :64]! \n"
275 "vst1.u16 {d6}, [%1 :64]! \n"
276 "vst1.u16 {d3}, [%1 :64]! \n"
277 "vst1.u16 {d7}, [%1 :64]! \n"
278 : "=r"(img0), // %0
279 "=r"(tmpptr) // %1
280 : "0"(img0),
281 "1"(tmpptr)
282 : "memory", "q0", "q1", "q2", "q3");
283 #endif // __aarch64__
284 img0 += bottom_blob.cstep * 4;
285 }
286 }
287
288 remain_size_start += nn_size << 3;
289 nn_size = (size - remain_size_start) >> 2;
290
291 #pragma omp parallel for num_threads(opt.num_threads)
292 for (int ii = 0; ii < nn_size; ii++)
293 {
294 int i = remain_size_start + ii * 4;
295
296 const unsigned short* img0 = bottom_blob.channel(0);
297 img0 += i * 4;
298
299 #if __aarch64__
300 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
301 #else
302 unsigned short* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
303 #endif
304
305 for (int q = 0; q < inch; q++)
306 {
307 // transpose 4x4
308 #if __aarch64__
309 asm volatile(
310 "prfm pldl1keep, [%0, #256] \n"
311 "ld4 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n"
312 "st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
313 : "=r"(img0), // %0
314 "=r"(tmpptr) // %1
315 : "0"(img0),
316 "1"(tmpptr)
317 : "memory", "v0", "v1");
318 #else
319 asm volatile(
320 "pld [%0, #256] \n"
321 "vld4.u16 {d0-d3}, [%0 :128] \n"
322 "vst1.u16 {d0-d3}, [%1 :128]! \n"
323 : "=r"(img0), // %0
324 "=r"(tmpptr) // %1
325 : "0"(img0),
326 "1"(tmpptr)
327 : "memory", "q0", "q1");
328 #endif // __aarch64__
329 img0 += bottom_blob.cstep * 4;
330 }
331 }
332
333 remain_size_start += nn_size << 2;
334
335 #pragma omp parallel for num_threads(opt.num_threads)
336 for (int i = remain_size_start; i < size; i++)
337 {
338 const unsigned short* img0 = bottom_blob.channel(0);
339 img0 += i * 4;
340
341 #if __aarch64__
342 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
343 #else
344 unsigned short* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
345 #endif
346
347 for (int q = 0; q < inch; q++)
348 {
349 #if __aarch64__
350 asm volatile(
351 "prfm pldl1keep, [%0, #64] \n"
352 "ld1 {v0.4h}, [%0] \n"
353 "st1 {v0.4h}, [%1], #8 \n"
354 : "=r"(img0), // %0
355 "=r"(tmpptr) // %1
356 : "0"(img0),
357 "1"(tmpptr)
358 : "memory", "v0");
359 #else
360 asm volatile(
361 "pld [%0, #64] \n"
362 "vld1.u16 {d0}, [%0 :64] \n"
363 "vst1.u16 {d0}, [%1 :64]! \n"
364 : "=r"(img0), // %0
365 "=r"(tmpptr) // %1
366 : "0"(img0),
367 "1"(tmpptr)
368 : "memory", "q0");
369 #endif // __aarch64__
370 img0 += bottom_blob.cstep * 4;
371 }
372 }
373 }
374
375 int nn_outch = 0;
376 int remain_outch_start = 0;
377
378 #if __aarch64__
379 nn_outch = outch >> 3;
380
381 #pragma omp parallel for num_threads(opt.num_threads)
382 for (int pp = 0; pp < nn_outch; pp++)
383 {
384 int p = pp * 8;
385
386 unsigned short* outptr0 = top_blob.channel(p);
387 unsigned short* outptr1 = top_blob.channel(p + 1);
388 unsigned short* outptr2 = top_blob.channel(p + 2);
389 unsigned short* outptr3 = top_blob.channel(p + 3);
390 unsigned short* outptr4 = top_blob.channel(p + 4);
391 unsigned short* outptr5 = top_blob.channel(p + 5);
392 unsigned short* outptr6 = top_blob.channel(p + 6);
393 unsigned short* outptr7 = top_blob.channel(p + 7);
394
395 const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
396 const float* biasptr = bias ? bias + p : zeros;
397
398 int i = 0;
399 for (; i + 11 < size; i += 12)
400 {
401 unsigned short* tmpptr = tmp.channel(i / 12);
402 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8);
403
404 int nn = inch; // inch always > 0
405
406 asm volatile(
407 "ld1 {v30.4s, v31.4s}, [%22] \n"
408 "dup v8.4s, v30.s[0] \n"
409 "dup v9.4s, v30.s[0] \n"
410 "dup v10.4s, v30.s[0] \n"
411 "dup v11.4s, v30.s[1] \n"
412 "dup v12.4s, v30.s[1] \n"
413 "dup v13.4s, v30.s[1] \n"
414 "dup v14.4s, v30.s[2] \n"
415 "dup v15.4s, v30.s[2] \n"
416 "dup v16.4s, v30.s[2] \n"
417 "dup v17.4s, v30.s[3] \n"
418 "dup v18.4s, v30.s[3] \n"
419 "dup v19.4s, v30.s[3] \n"
420 "dup v20.4s, v31.s[0] \n"
421 "dup v21.4s, v31.s[0] \n"
422 "dup v22.4s, v31.s[0] \n"
423 "dup v23.4s, v31.s[1] \n"
424 "dup v24.4s, v31.s[1] \n"
425 "dup v25.4s, v31.s[1] \n"
426 "dup v26.4s, v31.s[2] \n"
427 "dup v27.4s, v31.s[2] \n"
428 "dup v28.4s, v31.s[2] \n"
429 "dup v29.4s, v31.s[3] \n"
430 "dup v30.4s, v31.s[3] \n"
431 "dup v31.4s, v31.s[3] \n"
432
433 "0: \n"
434
435 "prfm pldl1keep, [%9, #256] \n"
436 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%9], #32 \n"
437
438 "prfm pldl1keep, [%10, #256] \n"
439 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%10], #32 \n"
440
441 "shll v0.4s, v0.4h, #16 \n"
442 "shll v1.4s, v1.4h, #16 \n"
443 "shll v2.4s, v2.4h, #16 \n"
444 "shll v3.4s, v3.4h, #16 \n"
445
446 "shll v4.4s, v4.4h, #16 \n"
447 "shll v5.4s, v5.4h, #16 \n"
448 "shll v6.4s, v6.4h, #16 \n"
449 "shll v7.4s, v7.4h, #16 \n"
450
451 "fmla v8.4s, v0.4s, v4.s[0] \n"
452 "fmla v11.4s, v0.4s, v4.s[1] \n"
453 "fmla v14.4s, v0.4s, v4.s[2] \n"
454 "fmla v17.4s, v0.4s, v4.s[3] \n"
455 "fmla v20.4s, v0.4s, v5.s[0] \n"
456 "fmla v23.4s, v0.4s, v5.s[1] \n"
457 "fmla v26.4s, v0.4s, v5.s[2] \n"
458 "fmla v29.4s, v0.4s, v5.s[3] \n"
459
460 "fmla v9.4s, v1.4s, v4.s[0] \n"
461 "fmla v12.4s, v1.4s, v4.s[1] \n"
462 "fmla v15.4s, v1.4s, v4.s[2] \n"
463 "fmla v18.4s, v1.4s, v4.s[3] \n"
464 "fmla v21.4s, v1.4s, v5.s[0] \n"
465 "fmla v24.4s, v1.4s, v5.s[1] \n"
466 "fmla v27.4s, v1.4s, v5.s[2] \n"
467 "fmla v30.4s, v1.4s, v5.s[3] \n"
468
469 "fmla v10.4s, v2.4s, v4.s[0] \n"
470 "fmla v13.4s, v2.4s, v4.s[1] \n"
471 "fmla v16.4s, v2.4s, v4.s[2] \n"
472 "fmla v19.4s, v2.4s, v4.s[3] \n"
473 "fmla v22.4s, v2.4s, v5.s[0] \n"
474 "fmla v25.4s, v2.4s, v5.s[1] \n"
475 "fmla v28.4s, v2.4s, v5.s[2] \n"
476 "fmla v31.4s, v2.4s, v5.s[3] \n"
477
478 "fmla v8.4s, v3.4s, v6.s[0] \n"
479 "fmla v11.4s, v3.4s, v6.s[1] \n"
480 "fmla v14.4s, v3.4s, v6.s[2] \n"
481 "fmla v17.4s, v3.4s, v6.s[3] \n"
482 "fmla v20.4s, v3.4s, v7.s[0] \n"
483 "fmla v23.4s, v3.4s, v7.s[1] \n"
484 "fmla v26.4s, v3.4s, v7.s[2] \n"
485 "fmla v29.4s, v3.4s, v7.s[3] \n"
486
487 "prfm pldl1keep, [%9, #256] \n"
488 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%9], #32 \n"
489
490 "shll v0.4s, v0.4h, #16 \n"
491 "shll v1.4s, v1.4h, #16 \n"
492 "shll v2.4s, v2.4h, #16 \n"
493 "shll v3.4s, v3.4h, #16 \n"
494
495 "fmla v9.4s, v0.4s, v6.s[0] \n"
496 "fmla v12.4s, v0.4s, v6.s[1] \n"
497 "fmla v15.4s, v0.4s, v6.s[2] \n"
498 "fmla v18.4s, v0.4s, v6.s[3] \n"
499 "fmla v21.4s, v0.4s, v7.s[0] \n"
500 "fmla v24.4s, v0.4s, v7.s[1] \n"
501 "fmla v27.4s, v0.4s, v7.s[2] \n"
502 "fmla v30.4s, v0.4s, v7.s[3] \n"
503
504 "fmla v10.4s, v1.4s, v6.s[0] \n"
505 "fmla v13.4s, v1.4s, v6.s[1] \n"
506 "fmla v16.4s, v1.4s, v6.s[2] \n"
507 "fmla v19.4s, v1.4s, v6.s[3] \n"
508 "fmla v22.4s, v1.4s, v7.s[0] \n"
509 "fmla v25.4s, v1.4s, v7.s[1] \n"
510 "fmla v28.4s, v1.4s, v7.s[2] \n"
511 "fmla v31.4s, v1.4s, v7.s[3] \n"
512
513 "prfm pldl1keep, [%10, #256] \n"
514 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%10], #32 \n"
515
516 "shll v4.4s, v4.4h, #16 \n"
517 "shll v5.4s, v5.4h, #16 \n"
518 "shll v6.4s, v6.4h, #16 \n"
519 "shll v7.4s, v7.4h, #16 \n"
520
521 "fmla v8.4s, v2.4s, v4.s[0] \n"
522 "fmla v11.4s, v2.4s, v4.s[1] \n"
523 "fmla v14.4s, v2.4s, v4.s[2] \n"
524 "fmla v17.4s, v2.4s, v4.s[3] \n"
525 "fmla v20.4s, v2.4s, v5.s[0] \n"
526 "fmla v23.4s, v2.4s, v5.s[1] \n"
527 "fmla v26.4s, v2.4s, v5.s[2] \n"
528 "fmla v29.4s, v2.4s, v5.s[3] \n"
529
530 "fmla v9.4s, v3.4s, v4.s[0] \n"
531 "fmla v12.4s, v3.4s, v4.s[1] \n"
532 "fmla v15.4s, v3.4s, v4.s[2] \n"
533 "fmla v18.4s, v3.4s, v4.s[3] \n"
534 "fmla v21.4s, v3.4s, v5.s[0] \n"
535 "fmla v24.4s, v3.4s, v5.s[1] \n"
536 "fmla v27.4s, v3.4s, v5.s[2] \n"
537 "fmla v30.4s, v3.4s, v5.s[3] \n"
538
539 "prfm pldl1keep, [%9, #256] \n"
540 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%9], #32 \n"
541
542 "shll v0.4s, v0.4h, #16 \n"
543 "shll v1.4s, v1.4h, #16 \n"
544 "shll v2.4s, v2.4h, #16 \n"
545 "shll v3.4s, v3.4h, #16 \n"
546
547 "fmla v10.4s, v0.4s, v4.s[0] \n"
548 "fmla v13.4s, v0.4s, v4.s[1] \n"
549 "fmla v16.4s, v0.4s, v4.s[2] \n"
550 "fmla v19.4s, v0.4s, v4.s[3] \n"
551 "fmla v22.4s, v0.4s, v5.s[0] \n"
552 "fmla v25.4s, v0.4s, v5.s[1] \n"
553 "fmla v28.4s, v0.4s, v5.s[2] \n"
554 "fmla v31.4s, v0.4s, v5.s[3] \n"
555
556 "fmla v8.4s, v1.4s, v6.s[0] \n"
557 "fmla v11.4s, v1.4s, v6.s[1] \n"
558 "fmla v14.4s, v1.4s, v6.s[2] \n"
559 "fmla v17.4s, v1.4s, v6.s[3] \n"
560 "fmla v20.4s, v1.4s, v7.s[0] \n"
561 "fmla v23.4s, v1.4s, v7.s[1] \n"
562 "fmla v26.4s, v1.4s, v7.s[2] \n"
563 "fmla v29.4s, v1.4s, v7.s[3] \n"
564
565 "fmla v9.4s, v2.4s, v6.s[0] \n"
566 "fmla v12.4s, v2.4s, v6.s[1] \n"
567 "fmla v15.4s, v2.4s, v6.s[2] \n"
568 "fmla v18.4s, v2.4s, v6.s[3] \n"
569 "fmla v21.4s, v2.4s, v7.s[0] \n"
570 "fmla v24.4s, v2.4s, v7.s[1] \n"
571 "fmla v27.4s, v2.4s, v7.s[2] \n"
572 "fmla v30.4s, v2.4s, v7.s[3] \n"
573
574 "subs %w0, %w0, #1 \n"
575
576 "fmla v10.4s, v3.4s, v6.s[0] \n"
577 "fmla v13.4s, v3.4s, v6.s[1] \n"
578 "fmla v16.4s, v3.4s, v6.s[2] \n"
579 "fmla v19.4s, v3.4s, v6.s[3] \n"
580 "fmla v22.4s, v3.4s, v7.s[0] \n"
581 "fmla v25.4s, v3.4s, v7.s[1] \n"
582 "fmla v28.4s, v3.4s, v7.s[2] \n"
583 "fmla v31.4s, v3.4s, v7.s[3] \n"
584
585 "bne 0b \n"
586
587 "shrn v8.4h, v8.4s, #16 \n"
588 "shrn v9.4h, v9.4s, #16 \n"
589 "shrn v10.4h, v10.4s, #16 \n"
590 "shrn v11.4h, v11.4s, #16 \n"
591
592 "shrn v12.4h, v12.4s, #16 \n"
593 "shrn v13.4h, v13.4s, #16 \n"
594 "shrn v14.4h, v14.4s, #16 \n"
595 "shrn v15.4h, v15.4s, #16 \n"
596
597 "shrn v16.4h, v16.4s, #16 \n"
598 "shrn v17.4h, v17.4s, #16 \n"
599 "shrn v18.4h, v18.4s, #16 \n"
600 "shrn v19.4h, v19.4s, #16 \n"
601
602 "shrn v20.4h, v20.4s, #16 \n"
603 "shrn v21.4h, v21.4s, #16 \n"
604 "shrn v22.4h, v22.4s, #16 \n"
605 "shrn v23.4h, v23.4s, #16 \n"
606
607 "shrn v24.4h, v24.4s, #16 \n"
608 "shrn v25.4h, v25.4s, #16 \n"
609 "shrn v26.4h, v26.4s, #16 \n"
610 "shrn v27.4h, v27.4s, #16 \n"
611
612 "shrn v28.4h, v28.4s, #16 \n"
613 "shrn v29.4h, v29.4s, #16 \n"
614 "shrn v30.4h, v30.4s, #16 \n"
615 "shrn v31.4h, v31.4s, #16 \n"
616
617 "st1 {v8.4h, v9.4h, v10.4h}, [%1], #24 \n"
618 "st1 {v11.4h, v12.4h, v13.4h}, [%2], #24 \n"
619 "st1 {v14.4h, v15.4h, v16.4h}, [%3], #24 \n"
620 "st1 {v17.4h, v18.4h, v19.4h}, [%4], #24 \n"
621 "st1 {v20.4h, v21.4h, v22.4h}, [%5], #24 \n"
622 "st1 {v23.4h, v24.4h, v25.4h}, [%6], #24 \n"
623 "st1 {v26.4h, v27.4h, v28.4h}, [%7], #24 \n"
624 "st1 {v29.4h, v30.4h, v31.4h}, [%8], #24 \n"
625
626 : "=r"(nn), // %0
627 "=r"(outptr0), // %1
628 "=r"(outptr1), // %2
629 "=r"(outptr2), // %3
630 "=r"(outptr3), // %4
631 "=r"(outptr4), // %5
632 "=r"(outptr5), // %6
633 "=r"(outptr6), // %7
634 "=r"(outptr7), // %8
635 "=r"(tmpptr), // %9
636 "=r"(kptr) // %10
637 : "0"(nn),
638 "1"(outptr0),
639 "2"(outptr1),
640 "3"(outptr2),
641 "4"(outptr3),
642 "5"(outptr4),
643 "6"(outptr5),
644 "7"(outptr6),
645 "8"(outptr7),
646 "9"(tmpptr),
647 "10"(kptr),
648 "r"(biasptr) // %22
649 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
650 }
651 for (; i + 7 < size; i += 8)
652 {
653 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
654 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8);
655
656 int nn = inch; // inch always > 0
657
658 asm volatile(
659 "ld1 {v30.4s, v31.4s}, [%22] \n"
660 "dup v16.4s, v30.s[0] \n"
661 "dup v17.4s, v30.s[0] \n"
662 "dup v18.4s, v30.s[1] \n"
663 "dup v19.4s, v30.s[1] \n"
664 "dup v20.4s, v30.s[2] \n"
665 "dup v21.4s, v30.s[2] \n"
666 "dup v22.4s, v30.s[3] \n"
667 "dup v23.4s, v30.s[3] \n"
668 "dup v24.4s, v31.s[0] \n"
669 "dup v25.4s, v31.s[0] \n"
670 "dup v26.4s, v31.s[1] \n"
671 "dup v27.4s, v31.s[1] \n"
672 "dup v28.4s, v31.s[2] \n"
673 "dup v29.4s, v31.s[2] \n"
674 "dup v30.4s, v31.s[3] \n"
675 "dup v31.4s, v31.s[3] \n"
676
677 "0: \n"
678
679 "prfm pldl1keep, [%9, #256] \n"
680 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%9], #32 \n"
681
682 "prfm pldl1keep, [%10, #256] \n"
683 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%10], #32 \n"
684
685 "shll v0.4s, v0.4h, #16 \n"
686 "shll v1.4s, v1.4h, #16 \n"
687 "shll v2.4s, v2.4h, #16 \n"
688 "shll v3.4s, v3.4h, #16 \n"
689
690 "shll v4.4s, v4.4h, #16 \n"
691 "shll v5.4s, v5.4h, #16 \n"
692 "shll v6.4s, v6.4h, #16 \n"
693 "shll v7.4s, v7.4h, #16 \n"
694
695 "fmla v16.4s, v0.4s, v4.s[0] \n"
696 "fmla v18.4s, v0.4s, v4.s[1] \n"
697 "fmla v20.4s, v0.4s, v4.s[2] \n"
698 "fmla v22.4s, v0.4s, v4.s[3] \n"
699 "fmla v24.4s, v0.4s, v5.s[0] \n"
700 "fmla v26.4s, v0.4s, v5.s[1] \n"
701 "fmla v28.4s, v0.4s, v5.s[2] \n"
702 "fmla v30.4s, v0.4s, v5.s[3] \n"
703 "fmla v17.4s, v1.4s, v4.s[0] \n"
704 "fmla v19.4s, v1.4s, v4.s[1] \n"
705 "fmla v21.4s, v1.4s, v4.s[2] \n"
706 "fmla v23.4s, v1.4s, v4.s[3] \n"
707 "fmla v25.4s, v1.4s, v5.s[0] \n"
708 "fmla v27.4s, v1.4s, v5.s[1] \n"
709 "fmla v29.4s, v1.4s, v5.s[2] \n"
710 "fmla v31.4s, v1.4s, v5.s[3] \n"
711
712 "fmla v16.4s, v2.4s, v6.s[0] \n"
713 "fmla v18.4s, v2.4s, v6.s[1] \n"
714 "fmla v20.4s, v2.4s, v6.s[2] \n"
715 "fmla v22.4s, v2.4s, v6.s[3] \n"
716 "fmla v24.4s, v2.4s, v7.s[0] \n"
717 "fmla v26.4s, v2.4s, v7.s[1] \n"
718 "fmla v28.4s, v2.4s, v7.s[2] \n"
719 "fmla v30.4s, v2.4s, v7.s[3] \n"
720 "fmla v17.4s, v3.4s, v6.s[0] \n"
721 "fmla v19.4s, v3.4s, v6.s[1] \n"
722 "fmla v21.4s, v3.4s, v6.s[2] \n"
723 "fmla v23.4s, v3.4s, v6.s[3] \n"
724 "fmla v25.4s, v3.4s, v7.s[0] \n"
725 "fmla v27.4s, v3.4s, v7.s[1] \n"
726 "fmla v29.4s, v3.4s, v7.s[2] \n"
727 "fmla v31.4s, v3.4s, v7.s[3] \n"
728
729 "prfm pldl1keep, [%9, #256] \n"
730 "ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [%9], #32 \n"
731
732 "prfm pldl1keep, [%10, #256] \n"
733 "ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [%10], #32 \n"
734
735 "shll v12.4s, v12.4h, #16 \n"
736 "shll v13.4s, v13.4h, #16 \n"
737 "shll v14.4s, v14.4h, #16 \n"
738 "shll v15.4s, v15.4h, #16 \n"
739
740 "shll v8.4s, v8.4h, #16 \n"
741 "shll v9.4s, v9.4h, #16 \n"
742 "shll v10.4s, v10.4h, #16 \n"
743 "shll v11.4s, v11.4h, #16 \n"
744
745 "fmla v16.4s, v12.4s, v8.s[0] \n"
746 "fmla v18.4s, v12.4s, v8.s[1] \n"
747 "fmla v20.4s, v12.4s, v8.s[2] \n"
748 "fmla v22.4s, v12.4s, v8.s[3] \n"
749 "fmla v24.4s, v12.4s, v9.s[0] \n"
750 "fmla v26.4s, v12.4s, v9.s[1] \n"
751 "fmla v28.4s, v12.4s, v9.s[2] \n"
752 "fmla v30.4s, v12.4s, v9.s[3] \n"
753 "fmla v17.4s, v13.4s, v8.s[0] \n"
754 "fmla v19.4s, v13.4s, v8.s[1] \n"
755 "fmla v21.4s, v13.4s, v8.s[2] \n"
756 "fmla v23.4s, v13.4s, v8.s[3] \n"
757 "fmla v25.4s, v13.4s, v9.s[0] \n"
758 "fmla v27.4s, v13.4s, v9.s[1] \n"
759 "fmla v29.4s, v13.4s, v9.s[2] \n"
760 "fmla v31.4s, v13.4s, v9.s[3] \n"
761
762 "subs %w0, %w0, #1 \n"
763
764 "fmla v16.4s, v14.4s, v10.s[0] \n"
765 "fmla v18.4s, v14.4s, v10.s[1] \n"
766 "fmla v20.4s, v14.4s, v10.s[2] \n"
767 "fmla v22.4s, v14.4s, v10.s[3] \n"
768 "fmla v24.4s, v14.4s, v11.s[0] \n"
769 "fmla v26.4s, v14.4s, v11.s[1] \n"
770 "fmla v28.4s, v14.4s, v11.s[2] \n"
771 "fmla v30.4s, v14.4s, v11.s[3] \n"
772 "fmla v17.4s, v15.4s, v10.s[0] \n"
773 "fmla v19.4s, v15.4s, v10.s[1] \n"
774 "fmla v21.4s, v15.4s, v10.s[2] \n"
775 "fmla v23.4s, v15.4s, v10.s[3] \n"
776 "fmla v25.4s, v15.4s, v11.s[0] \n"
777 "fmla v27.4s, v15.4s, v11.s[1] \n"
778 "fmla v29.4s, v15.4s, v11.s[2] \n"
779 "fmla v31.4s, v15.4s, v11.s[3] \n"
780
781 "bne 0b \n"
782
783 "shrn v16.4h, v16.4s, #16 \n"
784 "shrn v17.4h, v17.4s, #16 \n"
785 "shrn v18.4h, v18.4s, #16 \n"
786 "shrn v19.4h, v19.4s, #16 \n"
787
788 "shrn v20.4h, v20.4s, #16 \n"
789 "shrn v21.4h, v21.4s, #16 \n"
790 "shrn v22.4h, v22.4s, #16 \n"
791 "shrn v23.4h, v23.4s, #16 \n"
792
793 "shrn v24.4h, v24.4s, #16 \n"
794 "shrn v25.4h, v25.4s, #16 \n"
795 "shrn v26.4h, v26.4s, #16 \n"
796 "shrn v27.4h, v27.4s, #16 \n"
797
798 "shrn v28.4h, v28.4s, #16 \n"
799 "shrn v29.4h, v29.4s, #16 \n"
800 "shrn v30.4h, v30.4s, #16 \n"
801 "shrn v31.4h, v31.4s, #16 \n"
802
803 "st1 {v16.4h, v17.4h}, [%1], #16 \n"
804 "st1 {v18.4h, v19.4h}, [%2], #16 \n"
805 "st1 {v20.4h, v21.4h}, [%3], #16 \n"
806 "st1 {v22.4h, v23.4h}, [%4], #16 \n"
807 "st1 {v24.4h, v25.4h}, [%5], #16 \n"
808 "st1 {v26.4h, v27.4h}, [%6], #16 \n"
809 "st1 {v28.4h, v29.4h}, [%7], #16 \n"
810 "st1 {v30.4h, v31.4h}, [%8], #16 \n"
811
812 : "=r"(nn), // %0
813 "=r"(outptr0), // %1
814 "=r"(outptr1), // %2
815 "=r"(outptr2), // %3
816 "=r"(outptr3), // %4
817 "=r"(outptr4), // %5
818 "=r"(outptr5), // %6
819 "=r"(outptr6), // %7
820 "=r"(outptr7), // %8
821 "=r"(tmpptr), // %9
822 "=r"(kptr) // %10
823 : "0"(nn),
824 "1"(outptr0),
825 "2"(outptr1),
826 "3"(outptr2),
827 "4"(outptr3),
828 "5"(outptr4),
829 "6"(outptr5),
830 "7"(outptr6),
831 "8"(outptr7),
832 "9"(tmpptr),
833 "10"(kptr),
834 "r"(biasptr) // %22
835 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
836 }
837 for (; i + 3 < size; i += 4)
838 {
839 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
840 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8);
841
842 int nn = inch; // inch always > 0
843
844 asm volatile(
845 "ld1 {v22.4s, v23.4s}, [%22] \n"
846 "dup v16.4s, v22.s[0] \n"
847 "dup v17.4s, v22.s[1] \n"
848 "dup v18.4s, v22.s[2] \n"
849 "dup v19.4s, v22.s[3] \n"
850 "dup v20.4s, v23.s[0] \n"
851 "dup v21.4s, v23.s[1] \n"
852 "dup v22.4s, v23.s[2] \n"
853 "dup v23.4s, v23.s[3] \n"
854
855 "0: \n"
856
857 "prfm pldl1keep, [%9, #256] \n"
858 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%9], #32 \n"
859
860 "prfm pldl1keep, [%10, #256] \n"
861 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%10], #32 \n"
862
863 "shll v0.4s, v0.4h, #16 \n"
864 "shll v1.4s, v1.4h, #16 \n"
865 "shll v2.4s, v2.4h, #16 \n"
866 "shll v3.4s, v3.4h, #16 \n"
867
868 "shll v4.4s, v4.4h, #16 \n"
869 "shll v5.4s, v5.4h, #16 \n"
870 "shll v6.4s, v6.4h, #16 \n"
871 "shll v7.4s, v7.4h, #16 \n"
872
873 "fmla v16.4s, v0.4s, v4.s[0] \n"
874 "fmla v17.4s, v0.4s, v4.s[1] \n"
875 "fmla v18.4s, v0.4s, v4.s[2] \n"
876 "fmla v19.4s, v0.4s, v4.s[3] \n"
877 "fmla v20.4s, v0.4s, v5.s[0] \n"
878 "fmla v21.4s, v0.4s, v5.s[1] \n"
879 "fmla v22.4s, v0.4s, v5.s[2] \n"
880 "fmla v23.4s, v0.4s, v5.s[3] \n"
881
882 "prfm pldl1keep, [%10, #256] \n"
883 "ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [%10], #32 \n"
884
885 "shll v8.4s, v8.4h, #16 \n"
886 "shll v9.4s, v9.4h, #16 \n"
887 "shll v10.4s, v10.4h, #16 \n"
888 "shll v11.4s, v11.4h, #16 \n"
889
890 "fmla v16.4s, v1.4s, v6.s[0] \n"
891 "fmla v17.4s, v1.4s, v6.s[1] \n"
892 "fmla v18.4s, v1.4s, v6.s[2] \n"
893 "fmla v19.4s, v1.4s, v6.s[3] \n"
894 "fmla v20.4s, v1.4s, v7.s[0] \n"
895 "fmla v21.4s, v1.4s, v7.s[1] \n"
896 "fmla v22.4s, v1.4s, v7.s[2] \n"
897 "fmla v23.4s, v1.4s, v7.s[3] \n"
898
899 "fmla v16.4s, v2.4s, v8.s[0] \n"
900 "fmla v17.4s, v2.4s, v8.s[1] \n"
901 "fmla v18.4s, v2.4s, v8.s[2] \n"
902 "fmla v19.4s, v2.4s, v8.s[3] \n"
903 "fmla v20.4s, v2.4s, v9.s[0] \n"
904 "fmla v21.4s, v2.4s, v9.s[1] \n"
905 "fmla v22.4s, v2.4s, v9.s[2] \n"
906 "fmla v23.4s, v2.4s, v9.s[3] \n"
907
908 "subs %w0, %w0, #1 \n"
909
910 "fmla v16.4s, v3.4s, v10.s[0] \n"
911 "fmla v17.4s, v3.4s, v10.s[1] \n"
912 "fmla v18.4s, v3.4s, v10.s[2] \n"
913 "fmla v19.4s, v3.4s, v10.s[3] \n"
914 "fmla v20.4s, v3.4s, v11.s[0] \n"
915 "fmla v21.4s, v3.4s, v11.s[1] \n"
916 "fmla v22.4s, v3.4s, v11.s[2] \n"
917 "fmla v23.4s, v3.4s, v11.s[3] \n"
918
919 "bne 0b \n"
920
921 "shrn v16.4h, v16.4s, #16 \n"
922 "shrn v17.4h, v17.4s, #16 \n"
923 "shrn v18.4h, v18.4s, #16 \n"
924 "shrn v19.4h, v19.4s, #16 \n"
925
926 "shrn v20.4h, v20.4s, #16 \n"
927 "shrn v21.4h, v21.4s, #16 \n"
928 "shrn v22.4h, v22.4s, #16 \n"
929 "shrn v23.4h, v23.4s, #16 \n"
930
931 "st1 {v16.4h}, [%1], #8 \n"
932 "st1 {v17.4h}, [%2], #8 \n"
933 "st1 {v18.4h}, [%3], #8 \n"
934 "st1 {v19.4h}, [%4], #8 \n"
935 "st1 {v20.4h}, [%5], #8 \n"
936 "st1 {v21.4h}, [%6], #8 \n"
937 "st1 {v22.4h}, [%7], #8 \n"
938 "st1 {v23.4h}, [%8], #8 \n"
939
940 : "=r"(nn), // %0
941 "=r"(outptr0), // %1
942 "=r"(outptr1), // %2
943 "=r"(outptr2), // %3
944 "=r"(outptr3), // %4
945 "=r"(outptr4), // %5
946 "=r"(outptr5), // %6
947 "=r"(outptr6), // %7
948 "=r"(outptr7), // %8
949 "=r"(tmpptr), // %9
950 "=r"(kptr) // %10
951 : "0"(nn),
952 "1"(outptr0),
953 "2"(outptr1),
954 "3"(outptr2),
955 "4"(outptr3),
956 "5"(outptr4),
957 "6"(outptr5),
958 "7"(outptr6),
959 "8"(outptr7),
960 "9"(tmpptr),
961 "10"(kptr),
962 "r"(biasptr) // %22
963 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
964 }
965 for (; i < size; i++)
966 {
967 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
968 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8);
969
970 int nn = inch; // inch always > 0
971
972 asm volatile(
973 "ld1 {v16.4s, v17.4s}, [%22] \n"
974 "eor v18.16b, v18.16b, v18.16b \n"
975 "eor v19.16b, v19.16b, v19.16b \n"
976
977 "0: \n"
978
979 "prfm pldl1keep, [%9, #64] \n"
980 "ld1 {v0.4h}, [%9], #8 \n"
981
982 "prfm pldl1keep, [%10, #256] \n"
983 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%10], #32 \n"
984
985 "shll v0.4s, v0.4h, #16 \n"
986
987 "shll v4.4s, v4.4h, #16 \n"
988 "shll v5.4s, v5.4h, #16 \n"
989 "shll v6.4s, v6.4h, #16 \n"
990 "shll v7.4s, v7.4h, #16 \n"
991
992 "fmla v16.4s, v4.4s, v0.s[0] \n"
993 "fmla v17.4s, v5.4s, v0.s[0] \n"
994 "fmla v18.4s, v6.4s, v0.s[1] \n"
995 "fmla v19.4s, v7.4s, v0.s[1] \n"
996
997 "prfm pldl1keep, [%10, #256] \n"
998 "ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [%10], #32 \n"
999
1000 "shll v8.4s, v8.4h, #16 \n"
1001 "shll v9.4s, v9.4h, #16 \n"
1002 "shll v10.4s, v10.4h, #16 \n"
1003 "shll v11.4s, v11.4h, #16 \n"
1004
1005 "fmla v16.4s, v8.4s, v0.s[2] \n"
1006 "fmla v17.4s, v9.4s, v0.s[2] \n"
1007
1008 "subs %w0, %w0, #1 \n"
1009
1010 "fmla v18.4s, v10.4s, v0.s[3] \n"
1011 "fmla v19.4s, v11.4s, v0.s[3] \n"
1012
1013 "bne 0b \n"
1014
1015 "fadd v16.4s, v16.4s, v18.4s \n"
1016 "fadd v17.4s, v17.4s, v19.4s \n"
1017
1018 "shrn v16.4h, v16.4s, #16 \n"
1019 "shrn v17.4h, v17.4s, #16 \n"
1020
1021 "st1 {v16.h}[0], [%1], #2 \n"
1022 "st1 {v16.h}[1], [%2], #2 \n"
1023 "st1 {v16.h}[2], [%3], #2 \n"
1024 "st1 {v16.h}[3], [%4], #2 \n"
1025 "st1 {v17.h}[0], [%5], #2 \n"
1026 "st1 {v17.h}[1], [%6], #2 \n"
1027 "st1 {v17.h}[2], [%7], #2 \n"
1028 "st1 {v17.h}[3], [%8], #2 \n"
1029
1030 : "=r"(nn), // %0
1031 "=r"(outptr0), // %1
1032 "=r"(outptr1), // %2
1033 "=r"(outptr2), // %3
1034 "=r"(outptr3), // %4
1035 "=r"(outptr4), // %5
1036 "=r"(outptr5), // %6
1037 "=r"(outptr6), // %7
1038 "=r"(outptr7), // %8
1039 "=r"(tmpptr), // %9
1040 "=r"(kptr) // %10
1041 : "0"(nn),
1042 "1"(outptr0),
1043 "2"(outptr1),
1044 "3"(outptr2),
1045 "4"(outptr3),
1046 "5"(outptr4),
1047 "6"(outptr5),
1048 "7"(outptr6),
1049 "8"(outptr7),
1050 "9"(tmpptr),
1051 "10"(kptr),
1052 "r"(biasptr) // %22
1053 : "cc", "memory", "v0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19");
1054 }
1055 }
1056
1057 remain_outch_start += nn_outch << 3;
1058 nn_outch = (outch - remain_outch_start) >> 2;
1059 #else // __aarch64__
1060 nn_outch = outch >> 2;
1061 #endif // __aarch64__
1062
1063 #pragma omp parallel for num_threads(opt.num_threads)
1064 for (int pp = 0; pp < nn_outch; pp++)
1065 {
1066 int p = remain_outch_start + pp * 4;
1067
1068 unsigned short* outptr0 = top_blob.channel(p);
1069 unsigned short* outptr1 = top_blob.channel(p + 1);
1070 unsigned short* outptr2 = top_blob.channel(p + 2);
1071 unsigned short* outptr3 = top_blob.channel(p + 3);
1072
1073 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
1074 const float* biasptr = bias ? bias + p : zeros;
1075
1076 int i = 0;
1077 #if __aarch64__
1078 for (; i + 11 < size; i += 12)
1079 {
1080 unsigned short* tmpptr = tmp.channel(i / 12);
1081 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4);
1082
1083 int nn = inch; // inch always > 0
1084
1085 asm volatile(
1086 "ld1 {v19.4s}, [%14] \n"
1087 "dup v8.4s, v19.s[0] \n"
1088 "dup v9.4s, v19.s[0] \n"
1089 "dup v10.4s, v19.s[0] \n"
1090 "dup v11.4s, v19.s[1] \n"
1091 "dup v12.4s, v19.s[1] \n"
1092 "dup v13.4s, v19.s[1] \n"
1093 "dup v14.4s, v19.s[2] \n"
1094 "dup v15.4s, v19.s[2] \n"
1095 "dup v16.4s, v19.s[2] \n"
1096 "dup v17.4s, v19.s[3] \n"
1097 "dup v18.4s, v19.s[3] \n"
1098 "dup v19.4s, v19.s[3] \n"
1099
1100 "0: \n"
1101
1102 "prfm pldl1keep, [%5, #256] \n"
1103 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%5], #32 \n"
1104
1105 "prfm pldl1keep, [%6, #256] \n"
1106 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%6], #32 \n"
1107
1108 "shll v0.4s, v0.4h, #16 \n"
1109 "shll v1.4s, v1.4h, #16 \n"
1110 "shll v2.4s, v2.4h, #16 \n"
1111 "shll v3.4s, v3.4h, #16 \n"
1112
1113 "shll v4.4s, v4.4h, #16 \n"
1114 "shll v5.4s, v5.4h, #16 \n"
1115 "shll v6.4s, v6.4h, #16 \n"
1116 "shll v7.4s, v7.4h, #16 \n"
1117
1118 "fmla v8.4s, v0.4s, v4.s[0] \n"
1119 "fmla v11.4s, v0.4s, v4.s[1] \n"
1120 "fmla v14.4s, v0.4s, v4.s[2] \n"
1121 "fmla v17.4s, v0.4s, v4.s[3] \n"
1122 "fmla v9.4s, v1.4s, v4.s[0] \n"
1123 "fmla v12.4s, v1.4s, v4.s[1] \n"
1124 "fmla v15.4s, v1.4s, v4.s[2] \n"
1125 "fmla v18.4s, v1.4s, v4.s[3] \n"
1126 "fmla v10.4s, v2.4s, v4.s[0] \n"
1127 "fmla v13.4s, v2.4s, v4.s[1] \n"
1128 "fmla v16.4s, v2.4s, v4.s[2] \n"
1129 "fmla v19.4s, v2.4s, v4.s[3] \n"
1130
1131 "prfm pldl1keep, [%5, #256] \n"
1132 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%5], #32 \n"
1133
1134 "shll v20.4s, v20.4h, #16 \n"
1135 "shll v21.4s, v21.4h, #16 \n"
1136 "shll v22.4s, v22.4h, #16 \n"
1137 "shll v23.4s, v23.4h, #16 \n"
1138
1139 "fmla v8.4s, v3.4s, v5.s[0] \n"
1140 "fmla v11.4s, v3.4s, v5.s[1] \n"
1141 "fmla v14.4s, v3.4s, v5.s[2] \n"
1142 "fmla v17.4s, v3.4s, v5.s[3] \n"
1143 "fmla v9.4s, v20.4s, v5.s[0] \n"
1144 "fmla v12.4s, v20.4s, v5.s[1] \n"
1145 "fmla v15.4s, v20.4s, v5.s[2] \n"
1146 "fmla v18.4s, v20.4s, v5.s[3] \n"
1147 "fmla v10.4s, v21.4s, v5.s[0] \n"
1148 "fmla v13.4s, v21.4s, v5.s[1] \n"
1149 "fmla v16.4s, v21.4s, v5.s[2] \n"
1150 "fmla v19.4s, v21.4s, v5.s[3] \n"
1151
1152 "prfm pldl1keep, [%5, #256] \n"
1153 "ld1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%5], #32 \n"
1154
1155 "shll v24.4s, v24.4h, #16 \n"
1156 "shll v25.4s, v25.4h, #16 \n"
1157 "shll v26.4s, v26.4h, #16 \n"
1158 "shll v27.4s, v27.4h, #16 \n"
1159
1160 "fmla v8.4s, v22.4s, v6.s[0] \n"
1161 "fmla v11.4s, v22.4s, v6.s[1] \n"
1162 "fmla v14.4s, v22.4s, v6.s[2] \n"
1163 "fmla v17.4s, v22.4s, v6.s[3] \n"
1164 "fmla v9.4s, v23.4s, v6.s[0] \n"
1165 "fmla v12.4s, v23.4s, v6.s[1] \n"
1166 "fmla v15.4s, v23.4s, v6.s[2] \n"
1167 "fmla v18.4s, v23.4s, v6.s[3] \n"
1168 "fmla v10.4s, v24.4s, v6.s[0] \n"
1169 "fmla v13.4s, v24.4s, v6.s[1] \n"
1170 "fmla v16.4s, v24.4s, v6.s[2] \n"
1171 "fmla v19.4s, v24.4s, v6.s[3] \n"
1172
1173 "subs %w0, %w0, #1 \n"
1174
1175 "fmla v8.4s, v25.4s, v7.s[0] \n"
1176 "fmla v11.4s, v25.4s, v7.s[1] \n"
1177 "fmla v14.4s, v25.4s, v7.s[2] \n"
1178 "fmla v17.4s, v25.4s, v7.s[3] \n"
1179 "fmla v9.4s, v26.4s, v7.s[0] \n"
1180 "fmla v12.4s, v26.4s, v7.s[1] \n"
1181 "fmla v15.4s, v26.4s, v7.s[2] \n"
1182 "fmla v18.4s, v26.4s, v7.s[3] \n"
1183 "fmla v10.4s, v27.4s, v7.s[0] \n"
1184 "fmla v13.4s, v27.4s, v7.s[1] \n"
1185 "fmla v16.4s, v27.4s, v7.s[2] \n"
1186 "fmla v19.4s, v27.4s, v7.s[3] \n"
1187
1188 "bne 0b \n"
1189
1190 "shrn v8.4h, v8.4s, #16 \n"
1191 "shrn v9.4h, v9.4s, #16 \n"
1192 "shrn v10.4h, v10.4s, #16 \n"
1193 "shrn v11.4h, v11.4s, #16 \n"
1194
1195 "shrn v12.4h, v12.4s, #16 \n"
1196 "shrn v13.4h, v13.4s, #16 \n"
1197 "shrn v14.4h, v14.4s, #16 \n"
1198 "shrn v15.4h, v15.4s, #16 \n"
1199
1200 "shrn v16.4h, v16.4s, #16 \n"
1201 "shrn v17.4h, v17.4s, #16 \n"
1202 "shrn v18.4h, v18.4s, #16 \n"
1203 "shrn v19.4h, v19.4s, #16 \n"
1204
1205 "st1 {v8.4h, v9.4h, v10.4h}, [%1], #24 \n"
1206 "st1 {v11.4h, v12.4h, v13.4h}, [%2], #24 \n"
1207 "st1 {v14.4h, v15.4h, v16.4h}, [%3], #24 \n"
1208 "st1 {v17.4h, v18.4h, v19.4h}, [%4], #24 \n"
1209
1210 : "=r"(nn), // %0
1211 "=r"(outptr0), // %1
1212 "=r"(outptr1), // %2
1213 "=r"(outptr2), // %3
1214 "=r"(outptr3), // %4
1215 "=r"(tmpptr), // %5
1216 "=r"(kptr) // %6
1217 : "0"(nn),
1218 "1"(outptr0),
1219 "2"(outptr1),
1220 "3"(outptr2),
1221 "4"(outptr3),
1222 "5"(tmpptr),
1223 "6"(kptr),
1224 "r"(biasptr) // %14
1225 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
1226 }
1227 #endif // __aarch64__
1228 for (; i + 7 < size; i += 8)
1229 {
1230 #if __aarch64__
1231 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
1232 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4);
1233 #else
1234 unsigned short* tmpptr = tmp.channel(i / 8);
1235 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 4);
1236 #endif
1237
1238 int nn = inch; // inch always > 0
1239
1240 #if __aarch64__
1241 asm volatile(
1242 "ld1 {v15.4s}, [%14] \n"
1243 "dup v8.4s, v15.s[0] \n"
1244 "dup v9.4s, v15.s[0] \n"
1245 "dup v10.4s, v15.s[1] \n"
1246 "dup v11.4s, v15.s[1] \n"
1247 "dup v12.4s, v15.s[2] \n"
1248 "dup v13.4s, v15.s[2] \n"
1249 "dup v14.4s, v15.s[3] \n"
1250 "dup v15.4s, v15.s[3] \n"
1251
1252 "0: \n"
1253
1254 "prfm pldl1keep, [%5, #256] \n"
1255 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%5], #32 \n"
1256
1257 "prfm pldl1keep, [%6, #256] \n"
1258 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%6], #32 \n"
1259
1260 "shll v0.4s, v0.4h, #16 \n"
1261 "shll v1.4s, v1.4h, #16 \n"
1262 "shll v2.4s, v2.4h, #16 \n"
1263 "shll v3.4s, v3.4h, #16 \n"
1264
1265 "shll v4.4s, v4.4h, #16 \n"
1266 "shll v5.4s, v5.4h, #16 \n"
1267 "shll v6.4s, v6.4h, #16 \n"
1268 "shll v7.4s, v7.4h, #16 \n"
1269
1270 "fmla v8.4s, v0.4s, v4.s[0] \n"
1271 "fmla v10.4s, v0.4s, v4.s[1] \n"
1272 "fmla v12.4s, v0.4s, v4.s[2] \n"
1273 "fmla v14.4s, v0.4s, v4.s[3] \n"
1274 "fmla v9.4s, v1.4s, v4.s[0] \n"
1275 "fmla v11.4s, v1.4s, v4.s[1] \n"
1276 "fmla v13.4s, v1.4s, v4.s[2] \n"
1277 "fmla v15.4s, v1.4s, v4.s[3] \n"
1278
1279 "fmla v8.4s, v2.4s, v5.s[0] \n"
1280 "fmla v10.4s, v2.4s, v5.s[1] \n"
1281 "fmla v12.4s, v2.4s, v5.s[2] \n"
1282 "fmla v14.4s, v2.4s, v5.s[3] \n"
1283 "fmla v9.4s, v3.4s, v5.s[0] \n"
1284 "fmla v11.4s, v3.4s, v5.s[1] \n"
1285 "fmla v13.4s, v3.4s, v5.s[2] \n"
1286 "fmla v15.4s, v3.4s, v5.s[3] \n"
1287
1288 "prfm pldl1keep, [%5, #256] \n"
1289 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%5], #32 \n"
1290
1291 "shll v16.4s, v16.4h, #16 \n"
1292 "shll v17.4s, v17.4h, #16 \n"
1293 "shll v18.4s, v18.4h, #16 \n"
1294 "shll v19.4s, v19.4h, #16 \n"
1295
1296 "fmla v8.4s, v16.4s, v6.s[0] \n"
1297 "fmla v10.4s, v16.4s, v6.s[1] \n"
1298 "fmla v12.4s, v16.4s, v6.s[2] \n"
1299 "fmla v14.4s, v16.4s, v6.s[3] \n"
1300 "fmla v9.4s, v17.4s, v6.s[0] \n"
1301 "fmla v11.4s, v17.4s, v6.s[1] \n"
1302 "fmla v13.4s, v17.4s, v6.s[2] \n"
1303 "fmla v15.4s, v17.4s, v6.s[3] \n"
1304
1305 "subs %w0, %w0, #1 \n"
1306
1307 "fmla v8.4s, v18.4s, v7.s[0] \n"
1308 "fmla v10.4s, v18.4s, v7.s[1] \n"
1309 "fmla v12.4s, v18.4s, v7.s[2] \n"
1310 "fmla v14.4s, v18.4s, v7.s[3] \n"
1311 "fmla v9.4s, v19.4s, v7.s[0] \n"
1312 "fmla v11.4s, v19.4s, v7.s[1] \n"
1313 "fmla v13.4s, v19.4s, v7.s[2] \n"
1314 "fmla v15.4s, v19.4s, v7.s[3] \n"
1315
1316 "bne 0b \n"
1317
1318 "shrn v8.4h, v8.4s, #16 \n"
1319 "shrn v9.4h, v9.4s, #16 \n"
1320 "shrn v10.4h, v10.4s, #16 \n"
1321 "shrn v11.4h, v11.4s, #16 \n"
1322
1323 "shrn v12.4h, v12.4s, #16 \n"
1324 "shrn v13.4h, v13.4s, #16 \n"
1325 "shrn v14.4h, v14.4s, #16 \n"
1326 "shrn v15.4h, v15.4s, #16 \n"
1327
1328 "st1 {v8.4h, v9.4h}, [%1], #16 \n"
1329 "st1 {v10.4h, v11.4h}, [%2], #16 \n"
1330 "st1 {v12.4h, v13.4h}, [%3], #16 \n"
1331 "st1 {v14.4h, v15.4h}, [%4], #16 \n"
1332
1333 : "=r"(nn), // %0
1334 "=r"(outptr0), // %1
1335 "=r"(outptr1), // %2
1336 "=r"(outptr2), // %3
1337 "=r"(outptr3), // %4
1338 "=r"(tmpptr), // %5
1339 "=r"(kptr) // %6
1340 : "0"(nn),
1341 "1"(outptr0),
1342 "2"(outptr1),
1343 "3"(outptr2),
1344 "4"(outptr3),
1345 "5"(tmpptr),
1346 "6"(kptr),
1347 "r"(biasptr) // %14
1348 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1349 #else // __aarch64__
1350 asm volatile(
1351 "vld1.f32 {d30-d31}, [%14] \n"
1352 "vdup.f32 q8, d30[0] \n"
1353 "vdup.f32 q9, d30[0] \n"
1354 "vdup.f32 q10, d30[1] \n"
1355 "vdup.f32 q11, d30[1] \n"
1356 "vdup.f32 q12, d31[0] \n"
1357 "vdup.f32 q13, d31[0] \n"
1358 "vdup.f32 q14, d31[1] \n"
1359 "vdup.f32 q15, d31[1] \n"
1360
1361 "0: \n"
1362
1363 "pld [%5, #256] \n"
1364 "vld1.u16 {d4-d7}, [%5]! \n"
1365
1366 "pld [%6, #256] \n"
1367 "vld1.u16 {d12-d15}, [%6]! \n"
1368
1369 "vshll.u16 q0, d4, #16 \n"
1370 "vshll.u16 q1, d5, #16 \n"
1371 "vshll.u16 q2, d6, #16 \n"
1372 "vshll.u16 q3, d7, #16 \n"
1373
1374 "vshll.u16 q4, d12, #16 \n"
1375 "vshll.u16 q5, d13, #16 \n"
1376 "vshll.u16 q6, d14, #16 \n"
1377 "vshll.u16 q7, d15, #16 \n"
1378
1379 "vmla.f32 q8, q0, d8[0] \n"
1380 "vmla.f32 q10, q0, d8[1] \n"
1381 "vmla.f32 q12, q0, d9[0] \n"
1382 "vmla.f32 q14, q0, d9[1] \n"
1383 "vmla.f32 q9, q1, d8[0] \n"
1384 "vmla.f32 q11, q1, d8[1] \n"
1385 "vmla.f32 q13, q1, d9[0] \n"
1386 "vmla.f32 q15, q1, d9[1] \n"
1387
1388 "vmla.f32 q8, q2, d10[0] \n"
1389 "vmla.f32 q10, q2, d10[1] \n"
1390 "vmla.f32 q12, q2, d11[0] \n"
1391 "vmla.f32 q14, q2, d11[1] \n"
1392 "vmla.f32 q9, q3, d10[0] \n"
1393 "vmla.f32 q11, q3, d10[1] \n"
1394 "vmla.f32 q13, q3, d11[0] \n"
1395 "vmla.f32 q15, q3, d11[1] \n"
1396
1397 "pld [%5, #256] \n"
1398 "vld1.u16 {d4-d7}, [%5]! \n"
1399
1400 "vshll.u16 q0, d4, #16 \n"
1401 "vshll.u16 q1, d5, #16 \n"
1402 "vshll.u16 q2, d6, #16 \n"
1403 "vshll.u16 q3, d7, #16 \n"
1404
1405 "vmla.f32 q8, q0, d12[0] \n"
1406 "vmla.f32 q10, q0, d12[1] \n"
1407 "vmla.f32 q12, q0, d13[0] \n"
1408 "vmla.f32 q14, q0, d13[1] \n"
1409 "vmla.f32 q9, q1, d12[0] \n"
1410 "vmla.f32 q11, q1, d12[1] \n"
1411 "vmla.f32 q13, q1, d13[0] \n"
1412 "vmla.f32 q15, q1, d13[1] \n"
1413
1414 "subs %0, %0, #1 \n"
1415
1416 "vmla.f32 q8, q2, d14[0] \n"
1417 "vmla.f32 q10, q2, d14[1] \n"
1418 "vmla.f32 q12, q2, d15[0] \n"
1419 "vmla.f32 q14, q2, d15[1] \n"
1420 "vmla.f32 q9, q3, d14[0] \n"
1421 "vmla.f32 q11, q3, d14[1] \n"
1422 "vmla.f32 q13, q3, d15[0] \n"
1423 "vmla.f32 q15, q3, d15[1] \n"
1424
1425 "bne 0b \n"
1426
1427 "vshrn.u32 d16, q8, #16 \n"
1428 "vshrn.u32 d17, q9, #16 \n"
1429 "vshrn.u32 d20, q10, #16 \n"
1430 "vshrn.u32 d21, q11, #16 \n"
1431
1432 "vshrn.u32 d24, q12, #16 \n"
1433 "vshrn.u32 d25, q13, #16 \n"
1434 "vshrn.u32 d28, q14, #16 \n"
1435 "vshrn.u32 d29, q15, #16 \n"
1436
1437 "vst1.u16 {d16-d17}, [%1 :64]! \n"
1438 "vst1.u16 {d20-d21}, [%2 :64]! \n"
1439 "vst1.u16 {d24-d25}, [%3 :64]! \n"
1440 "vst1.u16 {d28-d29}, [%4 :64]! \n"
1441
1442 : "=r"(nn), // %0
1443 "=r"(outptr0), // %1
1444 "=r"(outptr1), // %2
1445 "=r"(outptr2), // %3
1446 "=r"(outptr3), // %4
1447 "=r"(tmpptr), // %5
1448 "=r"(kptr) // %6
1449 : "0"(nn),
1450 "1"(outptr0),
1451 "2"(outptr1),
1452 "3"(outptr2),
1453 "4"(outptr3),
1454 "5"(tmpptr),
1455 "6"(kptr),
1456 "r"(biasptr) // %14
1457 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1458 #endif // __aarch64__
1459 }
1460 for (; i + 3 < size; i += 4)
1461 {
1462 #if __aarch64__
1463 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1464 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4);
1465 #else
1466 unsigned short* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
1467 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 4);
1468 #endif
1469
1470 int nn = inch; // inch always > 0
1471
1472 #if __aarch64__
1473 asm volatile(
1474 "ld1 {v11.4s}, [%14] \n"
1475 "dup v8.4s, v11.s[0] \n"
1476 "dup v9.4s, v11.s[1] \n"
1477 "dup v10.4s, v11.s[2] \n"
1478 "dup v11.4s, v11.s[3] \n"
1479
1480 "0: \n"
1481
1482 "prfm pldl1keep, [%5, #256] \n"
1483 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%5], #32 \n"
1484
1485 "prfm pldl1keep, [%6, #256] \n"
1486 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%6], #32 \n"
1487
1488 "shll v0.4s, v0.4h, #16 \n"
1489 "shll v1.4s, v1.4h, #16 \n"
1490 "shll v2.4s, v2.4h, #16 \n"
1491 "shll v3.4s, v3.4h, #16 \n"
1492
1493 "shll v4.4s, v4.4h, #16 \n"
1494 "shll v5.4s, v5.4h, #16 \n"
1495 "shll v6.4s, v6.4h, #16 \n"
1496 "shll v7.4s, v7.4h, #16 \n"
1497
1498 "fmla v8.4s, v0.4s, v4.s[0] \n"
1499 "fmla v9.4s, v0.4s, v4.s[1] \n"
1500 "fmla v10.4s, v0.4s, v4.s[2] \n"
1501 "fmla v11.4s, v0.4s, v4.s[3] \n"
1502
1503 "fmla v8.4s, v1.4s, v5.s[0] \n"
1504 "fmla v9.4s, v1.4s, v5.s[1] \n"
1505 "fmla v10.4s, v1.4s, v5.s[2] \n"
1506 "fmla v11.4s, v1.4s, v5.s[3] \n"
1507
1508 "subs %w0, %w0, #1 \n"
1509
1510 "fmla v8.4s, v2.4s, v6.s[0] \n"
1511 "fmla v9.4s, v2.4s, v6.s[1] \n"
1512 "fmla v10.4s, v2.4s, v6.s[2] \n"
1513 "fmla v11.4s, v2.4s, v6.s[3] \n"
1514
1515 "fmla v8.4s, v3.4s, v7.s[0] \n"
1516 "fmla v9.4s, v3.4s, v7.s[1] \n"
1517 "fmla v10.4s, v3.4s, v7.s[2] \n"
1518 "fmla v11.4s, v3.4s, v7.s[3] \n"
1519
1520 "bne 0b \n"
1521
1522 "shrn v8.4h, v8.4s, #16 \n"
1523 "shrn v9.4h, v9.4s, #16 \n"
1524 "shrn v10.4h, v10.4s, #16 \n"
1525 "shrn v11.4h, v11.4s, #16 \n"
1526
1527 "st1 {v8.4h}, [%1], #8 \n"
1528 "st1 {v9.4h}, [%2], #8 \n"
1529 "st1 {v10.4h}, [%3], #8 \n"
1530 "st1 {v11.4h}, [%4], #8 \n"
1531
1532 : "=r"(nn), // %0
1533 "=r"(outptr0), // %1
1534 "=r"(outptr1), // %2
1535 "=r"(outptr2), // %3
1536 "=r"(outptr3), // %4
1537 "=r"(tmpptr), // %5
1538 "=r"(kptr) // %6
1539 : "0"(nn),
1540 "1"(outptr0),
1541 "2"(outptr1),
1542 "3"(outptr2),
1543 "4"(outptr3),
1544 "5"(tmpptr),
1545 "6"(kptr),
1546 "r"(biasptr) // %14
1547 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
1548 #else // __aarch64__
1549 asm volatile(
1550 "vld1.f32 {d22-d23}, [%14] \n"
1551 "vdup.f32 q8, d22[0] \n"
1552 "vdup.f32 q9, d22[1] \n"
1553 "vdup.f32 q10, d23[0] \n"
1554 "vdup.f32 q11, d23[1] \n"
1555
1556 "0: \n"
1557
1558 "pld [%5, #256] \n"
1559 "vld1.u16 {d4-d7}, [%5]! \n"
1560
1561 "pld [%6, #256] \n"
1562 "vld1.u16 {d12-d15}, [%6]! \n"
1563
1564 "vshll.u16 q0, d4, #16 \n"
1565 "vshll.u16 q1, d5, #16 \n"
1566 "vshll.u16 q2, d6, #16 \n"
1567 "vshll.u16 q3, d7, #16 \n"
1568
1569 "vshll.u16 q4, d12, #16 \n"
1570 "vshll.u16 q5, d13, #16 \n"
1571 "vshll.u16 q6, d14, #16 \n"
1572 "vshll.u16 q7, d15, #16 \n"
1573
1574 "vmla.f32 q8, q0, d8[0] \n"
1575 "vmla.f32 q9, q0, d8[1] \n"
1576 "vmla.f32 q10, q0, d9[0] \n"
1577 "vmla.f32 q11, q0, d9[1] \n"
1578
1579 "vmla.f32 q8, q1, d10[0] \n"
1580 "vmla.f32 q9, q1, d10[1] \n"
1581 "vmla.f32 q10, q1, d11[0] \n"
1582 "vmla.f32 q11, q1, d11[1] \n"
1583
1584 "subs %0, %0, #1 \n"
1585
1586 "vmla.f32 q8, q2, d12[0] \n"
1587 "vmla.f32 q9, q2, d12[1] \n"
1588 "vmla.f32 q10, q2, d13[0] \n"
1589 "vmla.f32 q11, q2, d13[1] \n"
1590
1591 "vmla.f32 q8, q3, d14[0] \n"
1592 "vmla.f32 q9, q3, d14[1] \n"
1593 "vmla.f32 q10, q3, d15[0] \n"
1594 "vmla.f32 q11, q3, d15[1] \n"
1595
1596 "bne 0b \n"
1597
1598 "vshrn.u32 d16, q8, #16 \n"
1599 "vshrn.u32 d18, q9, #16 \n"
1600 "vshrn.u32 d20, q10, #16 \n"
1601 "vshrn.u32 d22, q11, #16 \n"
1602
1603 "vst1.u16 {d16}, [%1 :64]! \n"
1604 "vst1.u16 {d18}, [%2 :64]! \n"
1605 "vst1.u16 {d20}, [%3 :64]! \n"
1606 "vst1.u16 {d22}, [%4 :64]! \n"
1607
1608 : "=r"(nn), // %0
1609 "=r"(outptr0), // %1
1610 "=r"(outptr1), // %2
1611 "=r"(outptr2), // %3
1612 "=r"(outptr3), // %4
1613 "=r"(tmpptr), // %5
1614 "=r"(kptr) // %6
1615 : "0"(nn),
1616 "1"(outptr0),
1617 "2"(outptr1),
1618 "3"(outptr2),
1619 "4"(outptr3),
1620 "5"(tmpptr),
1621 "6"(kptr),
1622 "r"(biasptr) // %14
1623 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
1624 #endif // __aarch64__
1625 }
1626 for (; i < size; i++)
1627 {
1628 #if __aarch64__
1629 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
1630 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4);
1631 #else
1632 unsigned short* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
1633 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 4);
1634 #endif
1635
1636 int nn = inch; // inch always > 0
1637
1638 #if __aarch64__
1639 asm volatile(
1640 "ld1 {v8.4s}, [%14] \n"
1641 "eor v9.16b, v9.16b, v9.16b \n"
1642 "eor v10.16b, v10.16b, v10.16b \n"
1643 "eor v11.16b, v11.16b, v11.16b \n"
1644
1645 "0: \n"
1646
1647 "prfm pldl1keep, [%5, #64] \n"
1648 "ld1 {v0.4h}, [%5], #8 \n"
1649
1650 "prfm pldl1keep, [%6, #256] \n"
1651 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%6], #32 \n"
1652
1653 "shll v0.4s, v0.4h, #16 \n"
1654
1655 "shll v4.4s, v4.4h, #16 \n"
1656 "shll v5.4s, v5.4h, #16 \n"
1657 "shll v6.4s, v6.4h, #16 \n"
1658 "shll v7.4s, v7.4h, #16 \n"
1659
1660 "fmla v8.4s, v4.4s, v0.s[0] \n"
1661 "fmla v9.4s, v5.4s, v0.s[1] \n"
1662
1663 "subs %w0, %w0, #1 \n"
1664
1665 "fmla v10.4s, v6.4s, v0.s[2] \n"
1666 "fmla v11.4s, v7.4s, v0.s[3] \n"
1667
1668 "bne 0b \n"
1669
1670 "fadd v8.4s, v8.4s, v9.4s \n"
1671 "fadd v10.4s, v10.4s, v11.4s \n"
1672 "fadd v8.4s, v8.4s, v10.4s \n"
1673
1674 "shrn v8.4h, v8.4s, #16 \n"
1675
1676 "st1 {v8.h}[0], [%1], #2 \n"
1677 "st1 {v8.h}[1], [%2], #2 \n"
1678 "st1 {v8.h}[2], [%3], #2 \n"
1679 "st1 {v8.h}[3], [%4], #2 \n"
1680
1681 : "=r"(nn), // %0
1682 "=r"(outptr0), // %1
1683 "=r"(outptr1), // %2
1684 "=r"(outptr2), // %3
1685 "=r"(outptr3), // %4
1686 "=r"(tmpptr), // %5
1687 "=r"(kptr) // %6
1688 : "0"(nn),
1689 "1"(outptr0),
1690 "2"(outptr1),
1691 "3"(outptr2),
1692 "4"(outptr3),
1693 "5"(tmpptr),
1694 "6"(kptr),
1695 "r"(biasptr) // %14
1696 : "cc", "memory", "v0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
1697 #else // __aarch64__
1698 asm volatile(
1699 "vld1.f32 {d16-d17}, [%14] \n"
1700 "veor q9, q9 \n"
1701 "veor q10, q10 \n"
1702 "veor q11, q11 \n"
1703
1704 "0: \n"
1705
1706 "pld [%5, #64] \n"
1707 "vld1.u16 {d1}, [%5]! \n"
1708
1709 "pld [%6, #256] \n"
1710 "vld1.u16 {d12-d15}, [%6]! \n"
1711
1712 "vshll.u16 q0, d1, #16 \n"
1713
1714 "vshll.u16 q4, d12, #16 \n"
1715 "vshll.u16 q5, d13, #16 \n"
1716 "vshll.u16 q6, d14, #16 \n"
1717 "vshll.u16 q7, d15, #16 \n"
1718
1719 "vmla.f32 q8, q4, d0[0] \n"
1720 "vmla.f32 q9, q5, d0[1] \n"
1721
1722 "subs %0, %0, #1 \n"
1723
1724 "vmla.f32 q10, q6, d1[0] \n"
1725 "vmla.f32 q11, q7, d1[1] \n"
1726
1727 "bne 0b \n"
1728
1729 "vadd.f32 q8, q8, q9 \n"
1730 "vadd.f32 q10, q10, q11 \n"
1731 "vadd.f32 q8, q8, q10 \n"
1732
1733 "vshrn.u32 d16, q8, #16 \n"
1734
1735 "vst1.u16 {d16[0]}, [%1]! \n"
1736 "vst1.u16 {d16[1]}, [%2]! \n"
1737 "vst1.u16 {d16[2]}, [%3]! \n"
1738 "vst1.u16 {d16[3]}, [%4]! \n"
1739
1740 : "=r"(nn), // %0
1741 "=r"(outptr0), // %1
1742 "=r"(outptr1), // %2
1743 "=r"(outptr2), // %3
1744 "=r"(outptr3), // %4
1745 "=r"(tmpptr), // %5
1746 "=r"(kptr) // %6
1747 : "0"(nn),
1748 "1"(outptr0),
1749 "2"(outptr1),
1750 "3"(outptr2),
1751 "4"(outptr3),
1752 "5"(tmpptr),
1753 "6"(kptr),
1754 "r"(biasptr) // %14
1755 : "cc", "memory", "q0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
1756 #endif // __aarch64__
1757 }
1758 }
1759
1760 remain_outch_start += nn_outch << 2;
1761
1762 #pragma omp parallel for num_threads(opt.num_threads)
1763 for (int p = remain_outch_start; p < outch; p++)
1764 {
1765 unsigned short* outptr0 = top_blob.channel(p);
1766
1767 const float bias0 = bias ? bias[p] : 0.f;
1768
1769 int i = 0;
1770 #if __aarch64__
1771 for (; i + 11 < size; i += 12)
1772 {
1773 unsigned short* tmpptr = tmp.channel(i / 12);
1774 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1775
1776 int nn = inch; // inch always > 0
1777
1778 asm volatile(
1779 "dup v8.4s, %w8 \n"
1780 "dup v9.4s, %w8 \n"
1781 "dup v10.4s, %w8 \n"
1782 "eor v5.16b, v5.16b, v5.16b \n"
1783 "eor v6.16b, v6.16b, v6.16b \n"
1784 "eor v7.16b, v7.16b, v7.16b \n"
1785
1786 "0: \n"
1787
1788 "prfm pldl1keep, [%2, #256] \n"
1789 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n"
1790
1791 "prfm pldl1keep, [%3, #64] \n"
1792 "ld1 {v4.4h}, [%3], #8 \n"
1793
1794 "shll v0.4s, v0.4h, #16 \n"
1795 "shll v1.4s, v1.4h, #16 \n"
1796 "shll v2.4s, v2.4h, #16 \n"
1797 "shll v3.4s, v3.4h, #16 \n"
1798
1799 "shll v4.4s, v4.4h, #16 \n"
1800
1801 "fmla v8.4s, v0.4s, v4.s[0] \n"
1802 "fmla v9.4s, v1.4s, v4.s[0] \n"
1803 "fmla v10.4s, v2.4s, v4.s[0] \n"
1804
1805 "prfm pldl1keep, [%2, #256] \n"
1806 "ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [%2], #32 \n"
1807
1808 "shll v12.4s, v12.4h, #16 \n"
1809 "shll v13.4s, v13.4h, #16 \n"
1810 "shll v14.4s, v14.4h, #16 \n"
1811 "shll v15.4s, v15.4h, #16 \n"
1812
1813 "fmla v5.4s, v3.4s, v4.s[1] \n"
1814 "fmla v6.4s, v12.4s, v4.s[1] \n"
1815 "fmla v7.4s, v13.4s, v4.s[1] \n"
1816
1817 "prfm pldl1keep, [%2, #256] \n"
1818 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%2], #32 \n"
1819
1820 "shll v16.4s, v16.4h, #16 \n"
1821 "shll v17.4s, v17.4h, #16 \n"
1822 "shll v18.4s, v18.4h, #16 \n"
1823 "shll v19.4s, v19.4h, #16 \n"
1824
1825 "fmla v8.4s, v14.4s, v4.s[2] \n"
1826 "fmla v9.4s, v15.4s, v4.s[2] \n"
1827 "fmla v10.4s, v16.4s, v4.s[2] \n"
1828
1829 "subs %w0, %w0, #1 \n"
1830
1831 "fmla v5.4s, v17.4s, v4.s[3] \n"
1832 "fmla v6.4s, v18.4s, v4.s[3] \n"
1833 "fmla v7.4s, v19.4s, v4.s[3] \n"
1834
1835 "bne 0b \n"
1836
1837 "fadd v8.4s, v8.4s, v5.4s \n"
1838 "fadd v9.4s, v9.4s, v6.4s \n"
1839 "fadd v10.4s, v10.4s, v7.4s \n"
1840
1841 "shrn v8.4h, v8.4s, #16 \n"
1842 "shrn v9.4h, v9.4s, #16 \n"
1843 "shrn v10.4h, v10.4s, #16 \n"
1844
1845 "st1 {v8.4h, v9.4h, v10.4h}, [%1], #24 \n"
1846
1847 : "=r"(nn), // %0
1848 "=r"(outptr0), // %1
1849 "=r"(tmpptr), // %2
1850 "=r"(kptr) // %3
1851 : "0"(nn),
1852 "1"(outptr0),
1853 "2"(tmpptr),
1854 "3"(kptr),
1855 "r"(bias0) // %8
1856 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1857 }
1858 #endif // __aarch64__
1859 for (; i + 7 < size; i += 8)
1860 {
1861 #if __aarch64__
1862 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
1863 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
1864 #else
1865 unsigned short* tmpptr = tmp.channel(i / 8);
1866 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 4 + p % 4);
1867 #endif
1868
1869 int nn = inch; // inch always > 0
1870
1871 #if __aarch64__
1872 asm volatile(
1873 "dup v8.4s, %w8 \n"
1874 "dup v9.4s, %w8 \n"
1875 "eor v10.16b, v10.16b, v10.16b \n"
1876 "eor v11.16b, v11.16b, v11.16b \n"
1877
1878 "0: \n"
1879
1880 "prfm pldl1keep, [%2, #256] \n"
1881 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n"
1882
1883 "prfm pldl1keep, [%3, #64] \n"
1884 "ld1 {v4.4h}, [%3], #8 \n"
1885
1886 "shll v0.4s, v0.4h, #16 \n"
1887 "shll v1.4s, v1.4h, #16 \n"
1888 "shll v2.4s, v2.4h, #16 \n"
1889 "shll v3.4s, v3.4h, #16 \n"
1890
1891 "shll v4.4s, v4.4h, #16 \n"
1892
1893 "fmla v8.4s, v0.4s, v4.s[0] \n"
1894 "fmla v9.4s, v1.4s, v4.s[0] \n"
1895 "fmla v10.4s, v2.4s, v4.s[1] \n"
1896 "fmla v11.4s, v3.4s, v4.s[1] \n"
1897
1898 "prfm pldl1keep, [%2, #256] \n"
1899 "ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [%2], #32 \n"
1900
1901 "shll v12.4s, v12.4h, #16 \n"
1902 "shll v13.4s, v13.4h, #16 \n"
1903 "shll v14.4s, v14.4h, #16 \n"
1904 "shll v15.4s, v15.4h, #16 \n"
1905
1906 "fmla v8.4s, v12.4s, v4.s[2] \n"
1907 "fmla v9.4s, v13.4s, v4.s[2] \n"
1908
1909 "subs %w0, %w0, #1 \n"
1910
1911 "fmla v10.4s, v14.4s, v4.s[3] \n"
1912 "fmla v11.4s, v15.4s, v4.s[3] \n"
1913
1914 "bne 0b \n"
1915
1916 "fadd v8.4s, v8.4s, v10.4s \n"
1917 "fadd v9.4s, v9.4s, v11.4s \n"
1918
1919 "shrn v8.4h, v8.4s, #16 \n"
1920 "shrn v9.4h, v9.4s, #16 \n"
1921
1922 "st1 {v8.4h, v9.4h}, [%1], #16 \n"
1923
1924 : "=r"(nn), // %0
1925 "=r"(outptr0), // %1
1926 "=r"(tmpptr), // %2
1927 "=r"(kptr) // %3
1928 : "0"(nn),
1929 "1"(outptr0),
1930 "2"(tmpptr),
1931 "3"(kptr),
1932 "r"(bias0) // %8
1933 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15");
1934 #else // __aarch64__
1935 asm volatile(
1936 "vdup.f32 q8, %8 \n"
1937 "vdup.f32 q9, %8 \n"
1938 "veor q10, q10 \n"
1939 "veor q11, q11 \n"
1940
1941 "0: \n"
1942
1943 "pld [%2, #256] \n"
1944 "vld1.u16 {d4-d7}, [%2]! \n"
1945
1946 "pld [%3, #64] \n"
1947 "vld1.u16 {d9}, [%3]! \n"
1948
1949 "vshll.u16 q0, d4, #16 \n"
1950 "vshll.u16 q1, d5, #16 \n"
1951 "vshll.u16 q2, d6, #16 \n"
1952 "vshll.u16 q3, d7, #16 \n"
1953
1954 "vshll.u16 q4, d9, #16 \n"
1955
1956 "vmla.f32 q8, q0, d8[0] \n"
1957 "vmla.f32 q9, q1, d8[0] \n"
1958 "vmla.f32 q10, q2, d8[1] \n"
1959 "vmla.f32 q11, q3, d8[1] \n"
1960
1961 "pld [%2, #256] \n"
1962 "vld1.u16 {d28-d31}, [%2]! \n"
1963
1964 "vshll.u16 q12, d28, #16 \n"
1965 "vshll.u16 q13, d29, #16 \n"
1966 "vshll.u16 q14, d30, #16 \n"
1967 "vshll.u16 q15, d31, #16 \n"
1968
1969 "vmla.f32 q8, q12, d9[0] \n"
1970 "vmla.f32 q9, q13, d9[0] \n"
1971
1972 "subs %0, %0, #1 \n"
1973
1974 "vmla.f32 q10, q14, d9[1] \n"
1975 "vmla.f32 q11, q15, d9[1] \n"
1976
1977 "bne 0b \n"
1978
1979 "vadd.f32 q8, q8, q10 \n"
1980 "vadd.f32 q9, q9, q11 \n"
1981
1982 "vshrn.u32 d16, q8, #16 \n"
1983 "vshrn.u32 d17, q9, #16 \n"
1984
1985 "vst1.u16 {d16-d17}, [%1 :64]! \n"
1986
1987 : "=r"(nn), // %0
1988 "=r"(outptr0), // %1
1989 "=r"(tmpptr), // %2
1990 "=r"(kptr) // %3
1991 : "0"(nn),
1992 "1"(outptr0),
1993 "2"(tmpptr),
1994 "3"(kptr),
1995 "r"(bias0) // %8
1996 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1997 #endif // __aarch64__
1998 }
1999 for (; i + 3 < size; i += 4)
2000 {
2001 #if __aarch64__
2002 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
2003 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
2004 #else
2005 unsigned short* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
2006 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 4 + p % 4);
2007 #endif
2008
2009 int nn = inch; // inch always > 0
2010
2011 #if __aarch64__
2012 asm volatile(
2013 "dup v8.4s, %w8 \n"
2014 "eor v9.16b, v9.16b, v9.16b \n"
2015 "eor v10.16b, v10.16b, v10.16b \n"
2016 "eor v11.16b, v11.16b, v11.16b \n"
2017
2018 "0: \n"
2019
2020 "prfm pldl1keep, [%2, #256] \n"
2021 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n"
2022
2023 "prfm pldl1keep, [%3, #64] \n"
2024 "ld1 {v4.4h}, [%3], #8 \n"
2025
2026 "shll v0.4s, v0.4h, #16 \n"
2027 "shll v1.4s, v1.4h, #16 \n"
2028 "shll v2.4s, v2.4h, #16 \n"
2029 "shll v3.4s, v3.4h, #16 \n"
2030
2031 "shll v4.4s, v4.4h, #16 \n"
2032
2033 "fmla v8.4s, v0.4s, v4.s[0] \n"
2034 "fmla v9.4s, v1.4s, v4.s[1] \n"
2035
2036 "subs %w0, %w0, #1 \n"
2037
2038 "fmla v10.4s, v2.4s, v4.s[2] \n"
2039 "fmla v11.4s, v3.4s, v4.s[3] \n"
2040
2041 "bne 0b \n"
2042
2043 "fadd v8.4s, v8.4s, v9.4s \n"
2044 "fadd v10.4s, v10.4s, v11.4s \n"
2045 "fadd v8.4s, v8.4s, v10.4s \n"
2046
2047 "shrn v8.4h, v8.4s, #16 \n"
2048
2049 "st1 {v8.4h}, [%1], #8 \n"
2050
2051 : "=r"(nn), // %0
2052 "=r"(outptr0), // %1
2053 "=r"(tmpptr), // %2
2054 "=r"(kptr) // %3
2055 : "0"(nn),
2056 "1"(outptr0),
2057 "2"(tmpptr),
2058 "3"(kptr),
2059 "r"(bias0) // %8
2060 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v8", "v9", "v10", "v11");
2061 #else // __aarch64__
2062 asm volatile(
2063 "vdup.f32 q8, %8 \n"
2064 "veor q9, q9 \n"
2065 "veor q10, q10 \n"
2066 "veor q11, q11 \n"
2067
2068 "0: \n"
2069
2070 "pld [%2, #256] \n"
2071 "vld1.u16 {d4-d7}, [%2]! \n"
2072
2073 "pld [%3, #64] \n"
2074 "vld1.u16 {d9}, [%3]! \n"
2075
2076 "vshll.u16 q0, d4, #16 \n"
2077 "vshll.u16 q1, d5, #16 \n"
2078 "vshll.u16 q2, d6, #16 \n"
2079 "vshll.u16 q3, d7, #16 \n"
2080
2081 "vshll.u16 q4, d9, #16 \n"
2082
2083 "vmla.f32 q8, q0, d8[0] \n"
2084 "vmla.f32 q9, q1, d8[1] \n"
2085
2086 "subs %0, %0, #1 \n"
2087
2088 "vmla.f32 q10, q2, d9[0] \n"
2089 "vmla.f32 q11, q3, d9[1] \n"
2090
2091 "bne 0b \n"
2092
2093 "vadd.f32 q8, q8, q9 \n"
2094 "vadd.f32 q10, q10, q11 \n"
2095 "vadd.f32 q8, q8, q10 \n"
2096
2097 "vshrn.u32 d16, q8, #16 \n"
2098
2099 "vst1.u16 {d16}, [%1]! \n"
2100
2101 : "=r"(nn), // %0
2102 "=r"(outptr0), // %1
2103 "=r"(tmpptr), // %2
2104 "=r"(kptr) // %3
2105 : "0"(nn),
2106 "1"(outptr0),
2107 "2"(tmpptr),
2108 "3"(kptr),
2109 "r"(bias0) // %8
2110 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q8", "q9", "q10", "q11");
2111 #endif // __aarch64__
2112 }
2113 for (; i < size; i++)
2114 {
2115 #if __aarch64__
2116 unsigned short* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + i % 12 % 4);
2117 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 8 + (p % 8) / 4 + p % 4);
2118 #else
2119 unsigned short* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
2120 const unsigned short* kptr = (const unsigned short*)kernel.channel(p / 4 + p % 4);
2121 #endif
2122
2123 float32x4_t _sum0 = vdupq_n_f32(0.f);
2124
2125 for (int q = 0; q < inch; q++)
2126 {
2127 float32x4_t _r0 = vcvt_f32_bf16(vld1_u16(tmpptr));
2128
2129 float32x4_t _k0 = vcvt_f32_bf16(vld1_u16(kptr));
2130
2131 _sum0 = vmlaq_f32(_sum0, _r0, _k0);
2132
2133 kptr += 4;
2134 tmpptr += 4;
2135 }
2136
2137 #if __aarch64__
2138 float sum0 = vaddvq_f32(_sum0);
2139 #else
2140 float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
2141 float32x2_t _ss2 = vpadd_f32(_ss, _ss);
2142 float sum0 = vget_lane_f32(_ss2, 0);
2143 #endif
2144
2145 outptr0[0] = float32_to_bfloat16(bias0 + sum0);
2146
2147 outptr0++;
2148 }
2149 }
2150
2151 // // NOTE sgemm
2152 // for (; p<outch; p++)
2153 // {
2154 // Mat out0 = top_blob.channel(p);
2155 //
2156 // const float bias0 = bias ? bias[p] : 0.f;
2157 //
2158 // unsigned short* outptr0 = out0;
2159 //
2160 // for (int i=0; i<size; i++)
2161 // {
2162 // float sum = bias0;
2163 //
2164 // const unsigned short* kptr = _kernel.channel(p);
2165 //
2166 // for (int q=0; q<inch; q++)
2167 // {
2168 // const unsigned short* img0 = bottom_blob.channel(q);
2169 //
2170 // sum += img0[i] * kptr[0];
2171 // kptr ++;
2172 // }
2173 //
2174 // outptr0[i] = sum;
2175 // }
2176 // }
2177 }
2178
conv1x1s2_pack4to1_bf16s_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)2179 static void conv1x1s2_pack4to1_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
2180 {
2181 int w = bottom_blob.w;
2182 int channels = bottom_blob.c;
2183 size_t elemsize = bottom_blob.elemsize;
2184 int elempack = bottom_blob.elempack;
2185
2186 int outw = top_blob.w;
2187 int outh = top_blob.h;
2188
2189 const int tailstep = (w - 2 * outw + w) * 4;
2190
2191 Mat bottom_blob_shrinked;
2192 bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
2193
2194 #pragma omp parallel for num_threads(opt.num_threads)
2195 for (int p = 0; p < channels; p++)
2196 {
2197 const unsigned short* r0 = bottom_blob.channel(p);
2198 unsigned short* outptr = bottom_blob_shrinked.channel(p);
2199
2200 for (int i = 0; i < outh; i++)
2201 {
2202 int j = 0;
2203 for (; j + 3 < outw; j += 4)
2204 {
2205 uint16x4_t _v0 = vld1_u16(r0);
2206 uint16x4_t _v1 = vld1_u16(r0 + 8);
2207 uint16x4_t _v2 = vld1_u16(r0 + 16);
2208 uint16x4_t _v3 = vld1_u16(r0 + 24);
2209 uint16x8_t _v01 = vcombine_u16(_v0, _v1);
2210 uint16x8_t _v23 = vcombine_u16(_v2, _v3);
2211 vst1q_u16(outptr, _v01);
2212 vst1q_u16(outptr + 8, _v23);
2213
2214 r0 += 32;
2215 outptr += 16;
2216 }
2217 for (; j + 1 < outw; j += 2)
2218 {
2219 uint16x4_t _v0 = vld1_u16(r0);
2220 uint16x4_t _v1 = vld1_u16(r0 + 8);
2221 uint16x8_t _v = vcombine_u16(_v0, _v1);
2222 vst1q_u16(outptr, _v);
2223
2224 r0 += 16;
2225 outptr += 8;
2226 }
2227 for (; j < outw; j++)
2228 {
2229 uint16x4_t _v = vld1_u16(r0);
2230 vst1_u16(outptr, _v);
2231
2232 r0 += 8;
2233 outptr += 4;
2234 }
2235
2236 r0 += tailstep;
2237 }
2238 }
2239
2240 conv1x1s1_sgemm_pack4to1_bf16s_neon(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
2241 }
2242