1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
convdw3x3s1_pack4_bf16s_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void convdw3x3s1_pack4_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 #if __aarch64__
18 const int w = bottom_blob.w;
19 #endif
20
21 const int outw = top_blob.w;
22 const int outh = top_blob.h;
23
24 const int group = bottom_blob.c;
25
26 const float* bias = _bias;
27
28 #pragma omp parallel for num_threads(opt.num_threads)
29 for (int g = 0; g < group; g++)
30 {
31 Mat out = top_blob.channel(g);
32
33 float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + g * 4) : vdupq_n_f32(0.f);
34
35 const unsigned short* k0 = kernel.row<const unsigned short>(g);
36
37 unsigned short* outptr0 = out.row<unsigned short>(0);
38
39 const Mat img0 = bottom_blob.channel(g);
40
41 const unsigned short* r0 = img0.row<const unsigned short>(0);
42 const unsigned short* r1 = img0.row<const unsigned short>(1);
43 const unsigned short* r2 = img0.row<const unsigned short>(2);
44
45 float32x4_t _k00 = vcvt_f32_bf16(vld1_u16(k0));
46 float32x4_t _k01 = vcvt_f32_bf16(vld1_u16(k0 + 4));
47 float32x4_t _k02 = vcvt_f32_bf16(vld1_u16(k0 + 8));
48 float32x4_t _k10 = vcvt_f32_bf16(vld1_u16(k0 + 12));
49 float32x4_t _k11 = vcvt_f32_bf16(vld1_u16(k0 + 16));
50 float32x4_t _k12 = vcvt_f32_bf16(vld1_u16(k0 + 20));
51 float32x4_t _k20 = vcvt_f32_bf16(vld1_u16(k0 + 24));
52 float32x4_t _k21 = vcvt_f32_bf16(vld1_u16(k0 + 28));
53 float32x4_t _k22 = vcvt_f32_bf16(vld1_u16(k0 + 32));
54
55 int i = 0;
56
57 #if __aarch64__
58 unsigned short* outptr1 = out.row<unsigned short>(1);
59 const unsigned short* r3 = img0.row<const unsigned short>(3);
60
61 for (; i + 1 < outh; i += 2)
62 {
63 int j = 0;
64
65 for (; j + 3 < outw; j += 4)
66 {
67 asm volatile(
68 "prfm pldl1keep, [%3, #256] \n"
69 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%3], #32 \n" // r10 r11 r12 r13
70
71 "mov v16.16b, %21.16b \n" // sum00
72 "mov v17.16b, %21.16b \n" // sum01
73
74 "prfm pldl1keep, [%3, #128] \n"
75 "ld1 {v28.4h, v29.4h}, [%3] \n" // r14 r15
76
77 "shll v10.4s, v10.4h, #16 \n"
78 "shll v11.4s, v11.4h, #16 \n"
79
80 "mov v18.16b, %21.16b \n" // sum02
81 "mov v19.16b, %21.16b \n" // sum03
82
83 "shll v12.4s, v12.4h, #16 \n"
84 "shll v13.4s, v13.4h, #16 \n"
85
86 "mov v20.16b, %21.16b \n" // sum10
87
88 "fmla v16.4s, %15.4s, v10.4s \n"
89 "fmla v17.4s, %15.4s, v11.4s \n"
90
91 "mov v21.16b, %21.16b \n" // sum11
92
93 "fmla v18.4s, %15.4s, v12.4s \n"
94 "fmla v19.4s, %15.4s, v13.4s \n"
95
96 "mov v22.16b, %21.16b \n" // sum12
97
98 "fmla v20.4s, %12.4s, v10.4s \n"
99 "fmla v21.4s, %12.4s, v11.4s \n"
100
101 "mov v23.16b, %21.16b \n" // sum13
102
103 "fmla v22.4s, %12.4s, v12.4s \n"
104 "fmla v23.4s, %12.4s, v13.4s \n"
105
106 "shll v28.4s, v28.4h, #16 \n"
107
108 "fmla v16.4s, %16.4s, v11.4s \n"
109 "fmla v17.4s, %16.4s, v12.4s \n"
110
111 "shll v29.4s, v29.4h, #16 \n"
112
113 "fmla v18.4s, %16.4s, v13.4s \n"
114 "fmla v19.4s, %16.4s, v28.4s \n"
115
116 "prfm pldl1keep, [%4, #256] \n"
117 "ld1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%4], #32 \n" // r20 r21 r22 r23
118
119 "fmla v20.4s, %13.4s, v11.4s \n"
120 "fmla v21.4s, %13.4s, v12.4s \n"
121 "fmla v22.4s, %13.4s, v13.4s \n"
122 "fmla v23.4s, %13.4s, v28.4s \n"
123
124 "prfm pldl1keep, [%4, #128] \n"
125 "ld1 {v14.4h, v15.4h}, [%4] \n" // r24 r25
126
127 "fmla v16.4s, %17.4s, v12.4s \n"
128 "fmla v17.4s, %17.4s, v13.4s \n"
129
130 "shll v24.4s, v24.4h, #16 \n"
131
132 "fmla v18.4s, %17.4s, v28.4s \n"
133 "fmla v19.4s, %17.4s, v29.4s \n"
134
135 "shll v25.4s, v25.4h, #16 \n"
136
137 "fmla v20.4s, %14.4s, v12.4s \n"
138 "fmla v21.4s, %14.4s, v13.4s \n"
139
140 "prfm pldl1keep, [%2, #256] \n"
141 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%2], #32 \n" // r00 r01 r02 r03
142
143 "fmla v22.4s, %14.4s, v28.4s \n"
144 "fmla v23.4s, %14.4s, v29.4s \n"
145
146 "shll v26.4s, v26.4h, #16 \n"
147
148 "fmla v16.4s, %18.4s, v24.4s \n"
149 "fmla v17.4s, %18.4s, v25.4s \n"
150
151 "shll v27.4s, v27.4h, #16 \n"
152
153 "fmla v18.4s, %18.4s, v26.4s \n"
154 "fmla v19.4s, %18.4s, v27.4s \n"
155
156 "prfm pldl1keep, [%5, #256] \n"
157 "ld1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%5], #32 \n" // r30 r31 r32 r33
158
159 "fmla v20.4s, %15.4s, v24.4s \n"
160 "fmla v21.4s, %15.4s, v25.4s \n"
161
162 "shll v14.4s, v14.4h, #16 \n"
163
164 "fmla v22.4s, %15.4s, v26.4s \n"
165 "fmla v23.4s, %15.4s, v27.4s \n"
166
167 "shll v15.4s, v15.4h, #16 \n"
168
169 "fmla v16.4s, %19.4s, v25.4s \n"
170 "fmla v17.4s, %19.4s, v26.4s \n"
171
172 "fmla v18.4s, %19.4s, v27.4s \n"
173 "fmla v19.4s, %19.4s, v14.4s \n"
174
175 "fmla v20.4s, %16.4s, v25.4s \n"
176 "fmla v21.4s, %16.4s, v26.4s \n"
177
178 "prfm pldl1keep, [%2, #128] \n"
179 "ld1 {v24.4h, v25.4h}, [%2] \n" // r04 r05
180
181 "fmla v22.4s, %16.4s, v27.4s \n"
182 "fmla v23.4s, %16.4s, v14.4s \n"
183
184 "shll v10.4s, v10.4h, #16 \n"
185 "shll v11.4s, v11.4h, #16 \n"
186
187 "fmla v16.4s, %20.4s, v26.4s \n"
188 "fmla v17.4s, %20.4s, v27.4s \n"
189
190 "shll v12.4s, v12.4h, #16 \n"
191
192 "fmla v18.4s, %20.4s, v14.4s \n"
193 "fmla v19.4s, %20.4s, v15.4s \n"
194
195 "shll v13.4s, v13.4h, #16 \n"
196
197 "fmla v20.4s, %17.4s, v26.4s \n"
198 "fmla v21.4s, %17.4s, v27.4s \n"
199
200 "prfm pldl1keep, [%5, #128] \n"
201 "ld1 {v26.4h, v27.4h}, [%5] \n" // r34 r35
202
203 "fmla v22.4s, %17.4s, v14.4s \n"
204 "fmla v23.4s, %17.4s, v15.4s \n"
205
206 "shll v28.4s, v28.4h, #16 \n"
207
208 "fmla v16.4s, %12.4s, v10.4s \n"
209 "fmla v17.4s, %12.4s, v11.4s \n"
210
211 "shll v29.4s, v29.4h, #16 \n"
212
213 "fmla v18.4s, %12.4s, v12.4s \n"
214 "fmla v19.4s, %12.4s, v13.4s \n"
215
216 "shll v30.4s, v30.4h, #16 \n"
217
218 "fmla v20.4s, %18.4s, v28.4s \n"
219 "fmla v21.4s, %18.4s, v29.4s \n"
220
221 "shll v31.4s, v31.4h, #16 \n"
222
223 "fmla v22.4s, %18.4s, v30.4s \n"
224 "fmla v23.4s, %18.4s, v31.4s \n"
225
226 "shll v24.4s, v24.4h, #16 \n"
227
228 "fmla v16.4s, %13.4s, v11.4s \n"
229 "fmla v17.4s, %13.4s, v12.4s \n"
230 "fmla v18.4s, %13.4s, v13.4s \n"
231 "fmla v19.4s, %13.4s, v24.4s \n"
232
233 "shll v26.4s, v26.4h, #16 \n"
234
235 "fmla v20.4s, %19.4s, v29.4s \n"
236 "fmla v21.4s, %19.4s, v30.4s \n"
237 "fmla v22.4s, %19.4s, v31.4s \n"
238 "fmla v23.4s, %19.4s, v26.4s \n"
239
240 "shll v25.4s, v25.4h, #16 \n"
241
242 "fmla v16.4s, %14.4s, v12.4s \n"
243 "fmla v17.4s, %14.4s, v13.4s \n"
244 "fmla v18.4s, %14.4s, v24.4s \n"
245 "fmla v19.4s, %14.4s, v25.4s \n"
246
247 "shll v27.4s, v27.4h, #16 \n"
248
249 "fmla v20.4s, %20.4s, v30.4s \n"
250 "fmla v21.4s, %20.4s, v31.4s \n"
251 "fmla v22.4s, %20.4s, v26.4s \n"
252 "fmla v23.4s, %20.4s, v27.4s \n"
253
254 "shrn v16.4h, v16.4s, #16 \n"
255 "shrn v17.4h, v17.4s, #16 \n"
256 "shrn v18.4h, v18.4s, #16 \n"
257 "shrn v19.4h, v19.4s, #16 \n"
258 "shrn v20.4h, v20.4s, #16 \n"
259 "shrn v21.4h, v21.4s, #16 \n"
260 "shrn v22.4h, v22.4s, #16 \n"
261 "shrn v23.4h, v23.4s, #16 \n"
262
263 "st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%0], #32 \n"
264 "st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%1], #32 \n"
265
266 : "=r"(outptr0), // %0
267 "=r"(outptr1), // %1
268 "=r"(r0), // %2
269 "=r"(r1), // %3
270 "=r"(r2), // %4
271 "=r"(r3) // %5
272 : "0"(outptr0),
273 "1"(outptr1),
274 "2"(r0),
275 "3"(r1),
276 "4"(r2),
277 "5"(r3),
278 "w"(_k00), // %12
279 "w"(_k01), // %13
280 "w"(_k02), // %14
281 "w"(_k10), // %15
282 "w"(_k11), // %16
283 "w"(_k12), // %17
284 "w"(_k20), // %18
285 "w"(_k21), // %19
286 "w"(_k22), // %20
287 "w"(_bias0) // %21
288 : "memory", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
289 }
290 for (; j + 1 < outw; j += 2)
291 {
292 asm volatile(
293 "prfm pldl1keep, [%3, #256] \n"
294 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%3] \n" // r10 r11 r12 r13
295
296 "mov v16.16b, %21.16b \n" // sum00
297 "mov v17.16b, %21.16b \n" // sum01
298
299 "shll v10.4s, v10.4h, #16 \n"
300 "shll v11.4s, v11.4h, #16 \n"
301
302 "mov v18.16b, %21.16b \n" // sum10
303 "mov v19.16b, %21.16b \n" // sum11
304
305 "fmla v16.4s, %15.4s, v10.4s \n"
306 "fmla v17.4s, %15.4s, v11.4s \n"
307
308 "shll v12.4s, v12.4h, #16 \n"
309
310 "fmla v18.4s, %12.4s, v10.4s \n"
311 "fmla v19.4s, %12.4s, v11.4s \n"
312
313 "shll v13.4s, v13.4h, #16 \n"
314
315 "fmla v16.4s, %16.4s, v11.4s \n"
316 "fmla v17.4s, %16.4s, v12.4s \n"
317
318 "prfm pldl1keep, [%4, #256] \n"
319 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%4] \n" // r20 r21 r22 r23
320
321 "fmla v18.4s, %13.4s, v11.4s \n"
322 "fmla v19.4s, %13.4s, v12.4s \n"
323
324 "shll v20.4s, v20.4h, #16 \n"
325
326 "fmla v16.4s, %17.4s, v12.4s \n"
327 "fmla v17.4s, %17.4s, v13.4s \n"
328
329 "shll v21.4s, v21.4h, #16 \n"
330
331 "fmla v18.4s, %14.4s, v12.4s \n"
332 "fmla v19.4s, %14.4s, v13.4s \n"
333
334 "shll v22.4s, v22.4h, #16 \n"
335
336 "fmla v16.4s, %18.4s, v20.4s \n"
337 "fmla v17.4s, %18.4s, v21.4s \n"
338
339 "shll v23.4s, v23.4h, #16 \n"
340
341 "fmla v18.4s, %15.4s, v20.4s \n"
342 "fmla v19.4s, %15.4s, v21.4s \n"
343
344 "prfm pldl1keep, [%2, #256] \n"
345 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%2] \n" // r00 r01 r02 r03
346
347 "fmla v16.4s, %19.4s, v21.4s \n"
348 "fmla v17.4s, %19.4s, v22.4s \n"
349
350 "prfm pldl1keep, [%5, #256] \n"
351 "ld1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%5] \n" // r30 r31 r32 r33
352
353 "fmla v18.4s, %16.4s, v21.4s \n"
354 "fmla v19.4s, %16.4s, v22.4s \n"
355
356 "shll v10.4s, v10.4h, #16 \n"
357
358 "fmla v16.4s, %20.4s, v22.4s \n"
359 "fmla v17.4s, %20.4s, v23.4s \n"
360
361 "shll v24.4s, v24.4h, #16 \n"
362
363 "fmla v18.4s, %17.4s, v22.4s \n"
364 "fmla v19.4s, %17.4s, v23.4s \n"
365
366 "shll v11.4s, v11.4h, #16 \n"
367 "shll v25.4s, v25.4h, #16 \n"
368
369 "fmla v16.4s, %12.4s, v10.4s \n"
370 "fmla v17.4s, %12.4s, v11.4s \n"
371
372 "shll v12.4s, v12.4h, #16 \n"
373
374 "fmla v18.4s, %18.4s, v24.4s \n"
375 "fmla v19.4s, %18.4s, v25.4s \n"
376
377 "shll v26.4s, v26.4h, #16 \n"
378
379 "fmla v16.4s, %13.4s, v11.4s \n"
380 "fmla v17.4s, %13.4s, v12.4s \n"
381
382 "shll v13.4s, v13.4h, #16 \n"
383
384 "fmla v18.4s, %19.4s, v25.4s \n"
385 "fmla v19.4s, %19.4s, v26.4s \n"
386
387 "shll v27.4s, v27.4h, #16 \n"
388
389 "fmla v16.4s, %14.4s, v12.4s \n"
390 "fmla v17.4s, %14.4s, v13.4s \n"
391
392 "add %3, %3, #16 \n"
393
394 "fmla v18.4s, %20.4s, v26.4s \n"
395 "fmla v19.4s, %20.4s, v27.4s \n"
396
397 "add %4, %4, #16 \n"
398
399 "shrn v16.4h, v16.4s, #16 \n"
400 "shrn v17.4h, v17.4s, #16 \n"
401
402 "add %2, %2, #16 \n"
403
404 "shrn v18.4h, v18.4s, #16 \n"
405 "shrn v19.4h, v19.4s, #16 \n"
406
407 "add %5, %5, #16 \n"
408
409 "st1 {v16.4h, v17.4h}, [%0], #16 \n"
410 "st1 {v18.4h, v19.4h}, [%1], #16 \n"
411
412 : "=r"(outptr0), // %0
413 "=r"(outptr1), // %1
414 "=r"(r0), // %2
415 "=r"(r1), // %3
416 "=r"(r2), // %4
417 "=r"(r3) // %5
418 : "0"(outptr0),
419 "1"(outptr1),
420 "2"(r0),
421 "3"(r1),
422 "4"(r2),
423 "5"(r3),
424 "w"(_k00), // %12
425 "w"(_k01), // %13
426 "w"(_k02), // %14
427 "w"(_k10), // %15
428 "w"(_k11), // %16
429 "w"(_k12), // %17
430 "w"(_k20), // %18
431 "w"(_k21), // %19
432 "w"(_k22), // %20
433 "w"(_bias0) // %21
434 : "memory", "v10", "v11", "v12", "v13", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
435 }
436 for (; j < outw; j++)
437 {
438 asm volatile(
439 "prfm pldl1keep, [%3, #192] \n"
440 "ld1 {v10.4h, v11.4h, v12.4h}, [%3] \n" // r10 r11 r12
441
442 "mov v18.16b, %21.16b \n" // sum0
443 "mov v19.16b, %21.16b \n" // sum1
444
445 "shll v10.4s, v10.4h, #16 \n"
446 "shll v11.4s, v11.4h, #16 \n"
447
448 "fmul v16.4s, %15.4s, v10.4s \n"
449 "fmul v17.4s, %12.4s, v10.4s \n"
450
451 "shll v12.4s, v12.4h, #16 \n"
452
453 "fmla v18.4s, %16.4s, v11.4s \n"
454 "fmla v19.4s, %13.4s, v11.4s \n"
455
456 "prfm pldl1keep, [%4, #192] \n"
457 "ld1 {v20.4h, v21.4h, v22.4h}, [%4] \n" // r20 r21 r22
458
459 "fmla v16.4s, %17.4s, v12.4s \n"
460 "fmla v17.4s, %14.4s, v12.4s \n"
461
462 "shll v20.4s, v20.4h, #16 \n"
463 "shll v21.4s, v21.4h, #16 \n"
464
465 "fmla v18.4s, %18.4s, v20.4s \n"
466 "fmla v19.4s, %15.4s, v20.4s \n"
467
468 "prfm pldl1keep, [%2, #192] \n"
469 "ld1 {v10.4h, v11.4h, v12.4h}, [%2] \n" // r00 r01 r02
470
471 "shll v22.4s, v22.4h, #16 \n"
472
473 "prfm pldl1keep, [%5, #192] \n"
474 "ld1 {v24.4h, v25.4h, v26.4h}, [%5] \n" // r30 r31 r32
475
476 "fmla v16.4s, %19.4s, v21.4s \n"
477 "fmla v17.4s, %16.4s, v21.4s \n"
478
479 "shll v10.4s, v10.4h, #16 \n"
480 "shll v24.4s, v24.4h, #16 \n"
481
482 "fmla v18.4s, %20.4s, v22.4s \n"
483 "fmla v19.4s, %17.4s, v22.4s \n"
484
485 "shll v11.4s, v11.4h, #16 \n"
486 "shll v25.4s, v25.4h, #16 \n"
487
488 "fmla v16.4s, %12.4s, v10.4s \n"
489 "fmla v17.4s, %18.4s, v24.4s \n"
490
491 "shll v12.4s, v12.4h, #16 \n"
492 "shll v26.4s, v26.4h, #16 \n"
493
494 "fmla v18.4s, %13.4s, v11.4s \n"
495 "fmla v19.4s, %19.4s, v25.4s \n"
496
497 "add %3, %3, #8 \n"
498
499 "fmla v16.4s, %14.4s, v12.4s \n"
500 "fmla v17.4s, %20.4s, v26.4s \n"
501
502 "add %4, %4, #8 \n"
503
504 "fadd v18.4s, v18.4s, v16.4s \n"
505 "fadd v19.4s, v19.4s, v17.4s \n"
506
507 "add %2, %2, #8 \n"
508
509 "shrn v18.4h, v18.4s, #16 \n"
510 "shrn v19.4h, v19.4s, #16 \n"
511
512 "add %5, %5, #8 \n"
513
514 "st1 {v18.4h}, [%0], #8 \n"
515 "st1 {v19.4h}, [%1], #8 \n"
516
517 : "=r"(outptr0), // %0
518 "=r"(outptr1), // %1
519 "=r"(r0), // %2
520 "=r"(r1), // %3
521 "=r"(r2), // %4
522 "=r"(r3) // %5
523 : "0"(outptr0),
524 "1"(outptr1),
525 "2"(r0),
526 "3"(r1),
527 "4"(r2),
528 "5"(r3),
529 "w"(_k00), // %12
530 "w"(_k01), // %13
531 "w"(_k02), // %14
532 "w"(_k10), // %15
533 "w"(_k11), // %16
534 "w"(_k12), // %17
535 "w"(_k20), // %18
536 "w"(_k21), // %19
537 "w"(_k22), // %20
538 "w"(_bias0) // %21
539 : "memory", "v10", "v11", "v12", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26");
540 }
541
542 r0 += 2 * 4 + w * 4;
543 r1 += 2 * 4 + w * 4;
544 r2 += 2 * 4 + w * 4;
545 r3 += 2 * 4 + w * 4;
546
547 outptr0 += outw * 4;
548 outptr1 += outw * 4;
549 }
550 #endif // __aarch64__
551 for (; i < outh; i++)
552 {
553 int j = 0;
554
555 for (; j + 3 < outw; j += 4)
556 {
557 #if __aarch64__
558 asm volatile(
559 "prfm pldl1keep, [%1, #256] \n"
560 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%1], #32 \n" // r00 r01 r02 r03
561
562 "mov v16.16b, %17.16b \n" // sum00
563 "mov v17.16b, %17.16b \n" // sum01
564 "mov v18.16b, %17.16b \n" // sum02
565 "mov v19.16b, %17.16b \n" // sum03
566
567 "shll v10.4s, v10.4h, #16 \n"
568 "shll v11.4s, v11.4h, #16 \n"
569
570 "fmla v16.4s, %8.4s, v10.4s \n"
571 "fmla v17.4s, %8.4s, v11.4s \n"
572
573 "shll v12.4s, v12.4h, #16 \n"
574 "shll v13.4s, v13.4h, #16 \n"
575
576 "fmla v18.4s, %8.4s, v12.4s \n"
577 "fmla v19.4s, %8.4s, v13.4s \n"
578
579 "prfm pldl1keep, [%1, #128] \n"
580 "ld1 {v14.4h, v15.4h}, [%1] \n" // r04 r05
581
582 "fmla v16.4s, %9.4s, v11.4s \n"
583 "fmla v17.4s, %9.4s, v12.4s \n"
584
585 "shll v14.4s, v14.4h, #16 \n"
586
587 "fmla v18.4s, %9.4s, v13.4s \n"
588 "fmla v19.4s, %9.4s, v14.4s \n"
589
590 "prfm pldl1keep, [%2, #256] \n"
591 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%2], #32 \n" // r10 r11 r12 r13
592
593 "fmla v16.4s, %10.4s, v12.4s \n"
594 "fmla v17.4s, %10.4s, v13.4s \n"
595
596 "shll v15.4s, v15.4h, #16 \n"
597
598 "fmla v18.4s, %10.4s, v14.4s \n"
599 "fmla v19.4s, %10.4s, v15.4s \n"
600
601 "shll v20.4s, v20.4h, #16 \n"
602 "shll v21.4s, v21.4h, #16 \n"
603
604 "fmla v16.4s, %11.4s, v20.4s \n"
605 "fmla v17.4s, %11.4s, v21.4s \n"
606
607 "shll v22.4s, v22.4h, #16 \n"
608 "shll v23.4s, v23.4h, #16 \n"
609
610 "fmla v18.4s, %11.4s, v22.4s \n"
611 "fmla v19.4s, %11.4s, v23.4s \n"
612
613 "prfm pldl1keep, [%2, #128] \n"
614 "ld1 {v14.4h, v15.4h}, [%2] \n" // r14 r15
615
616 "fmla v16.4s, %12.4s, v21.4s \n"
617 "fmla v17.4s, %12.4s, v22.4s \n"
618
619 "shll v14.4s, v14.4h, #16 \n"
620
621 "fmla v18.4s, %12.4s, v23.4s \n"
622 "fmla v19.4s, %12.4s, v14.4s \n"
623
624 "prfm pldl1keep, [%3, #256] \n"
625 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%3], #32 \n" // r20 r21 r22 r23
626
627 "fmla v16.4s, %13.4s, v22.4s \n"
628 "fmla v17.4s, %13.4s, v23.4s \n"
629
630 "shll v15.4s, v15.4h, #16 \n"
631
632 "fmla v18.4s, %13.4s, v14.4s \n"
633 "fmla v19.4s, %13.4s, v15.4s \n"
634
635 "shll v10.4s, v10.4h, #16 \n"
636 "shll v11.4s, v11.4h, #16 \n"
637
638 "fmla v16.4s, %14.4s, v10.4s \n"
639 "fmla v17.4s, %14.4s, v11.4s \n"
640
641 "shll v12.4s, v12.4h, #16 \n"
642 "shll v13.4s, v13.4h, #16 \n"
643
644 "fmla v18.4s, %14.4s, v12.4s \n"
645 "fmla v19.4s, %14.4s, v13.4s \n"
646
647 "prfm pldl1keep, [%3, #128] \n"
648 "ld1 {v14.4h, v15.4h}, [%3] \n" // r24 r25
649
650 "fmla v16.4s, %15.4s, v11.4s \n"
651 "fmla v17.4s, %15.4s, v12.4s \n"
652
653 "shll v14.4s, v14.4h, #16 \n"
654
655 "fmla v18.4s, %15.4s, v13.4s \n"
656 "fmla v19.4s, %15.4s, v14.4s \n"
657
658 "fmla v16.4s, %16.4s, v12.4s \n"
659 "fmla v17.4s, %16.4s, v13.4s \n"
660
661 "shll v15.4s, v15.4h, #16 \n"
662
663 "fmla v18.4s, %16.4s, v14.4s \n"
664 "fmla v19.4s, %16.4s, v15.4s \n"
665
666 "shrn v16.4h, v16.4s, #16 \n"
667 "shrn v17.4h, v17.4s, #16 \n"
668 "shrn v18.4h, v18.4s, #16 \n"
669 "shrn v19.4h, v19.4s, #16 \n"
670
671 "st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%0], #32 \n"
672
673 : "=r"(outptr0), // %0
674 "=r"(r0), // %1
675 "=r"(r1), // %2
676 "=r"(r2) // %3
677 : "0"(outptr0),
678 "1"(r0),
679 "2"(r1),
680 "3"(r2),
681 "w"(_k00), // %8
682 "w"(_k01), // %9
683 "w"(_k02), // %10
684 "w"(_k10), // %11
685 "w"(_k11), // %12
686 "w"(_k12), // %13
687 "w"(_k20), // %14
688 "w"(_k21), // %15
689 "w"(_k22), // %16
690 "w"(_bias0) // %17
691 : "memory", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
692 #else
693 asm volatile(
694 "pld [%1, #128] \n"
695 "vld1.u16 {d30-d31}, [%1 :64]! \n" // r00 r01
696
697 "vmov q10, %q17 \n" // sum00
698 "vmov q11, %q17 \n" // sum01
699
700 "vshll.u16 q14, d30, #16 \n"
701 "vshll.u16 q15, d31, #16 \n"
702
703 "vmla.f32 q10, %q8, q14 \n"
704 "vmla.f32 q11, %q8, q15 \n"
705 "vmla.f32 q10, %q9, q15 \n"
706
707 "pld [%1, #128] \n"
708 "vld1.u16 {d30-d31}, [%1 :64]! \n" // r02 r03
709
710 "vmov q12, %q17 \n" // sum02
711 "vmov q13, %q17 \n" // sum03
712
713 "vshll.u16 q14, d30, #16 \n"
714 "vshll.u16 q15, d31, #16 \n"
715
716 "vmla.f32 q12, %q8, q14 \n"
717 "vmla.f32 q11, %q9, q14 \n"
718 "vmla.f32 q13, %q8, q15 \n"
719 "vmla.f32 q10, %q10, q14 \n"
720 "vmla.f32 q12, %q9, q15 \n"
721 "vmla.f32 q11, %q10, q15 \n"
722
723 // "pld [%1, #128] \n"
724 "vld1.u16 {d30-d31}, [%1 :64] \n" // r04 r05
725
726 "vshll.u16 q14, d30, #16 \n"
727 "vshll.u16 q15, d31, #16 \n"
728
729 "vmla.f32 q13, %q9, q14 \n"
730 "vmla.f32 q12, %q10, q14 \n"
731 "vmla.f32 q13, %q10, q15 \n"
732
733 "pld [%2, #128] \n"
734 "vld1.u16 {d30-d31}, [%2 :64]! \n" // r10 r11
735
736 "vshll.u16 q14, d30, #16 \n"
737 "vshll.u16 q15, d31, #16 \n"
738
739 "vmla.f32 q10, %q11, q14 \n"
740 "vmla.f32 q11, %q11, q15 \n"
741 "vmla.f32 q10, %q12, q15 \n"
742
743 "pld [%2, #128] \n"
744 "vld1.u16 {d30-d31}, [%2 :64]! \n" // r12 r13
745
746 "vshll.u16 q14, d30, #16 \n"
747 "vshll.u16 q15, d31, #16 \n"
748
749 "vmla.f32 q12, %q11, q14 \n"
750 "vmla.f32 q11, %q12, q14 \n"
751 "vmla.f32 q13, %q11, q15 \n"
752 "vmla.f32 q10, %q13, q14 \n"
753 "vmla.f32 q12, %q12, q15 \n"
754 "vmla.f32 q11, %q13, q15 \n"
755
756 // "pld [%2, #128] \n"
757 "vld1.u16 {d30-d31}, [%2 :64] \n" // r14 r15
758
759 "vshll.u16 q14, d30, #16 \n"
760 "vshll.u16 q15, d31, #16 \n"
761
762 "vmla.f32 q13, %q12, q14 \n"
763 "vmla.f32 q12, %q13, q14 \n"
764 "vmla.f32 q13, %q13, q15 \n"
765
766 "pld [%3, #128] \n"
767 "vld1.u16 {d30-d31}, [%3 :64]! \n" // r20 r21
768
769 "vshll.u16 q14, d30, #16 \n"
770 "vshll.u16 q15, d31, #16 \n"
771
772 "vmla.f32 q10, %q14, q14 \n"
773 "vmla.f32 q11, %q14, q15 \n"
774 "vmla.f32 q10, %q15, q15 \n"
775
776 "pld [%3, #128] \n"
777 "vld1.u16 {d30-d31}, [%3 :64]! \n" // r22 r23
778
779 "vshll.u16 q14, d30, #16 \n"
780 "vshll.u16 q15, d31, #16 \n"
781
782 "vmla.f32 q12, %q14, q14 \n"
783 "vmla.f32 q11, %q15, q14 \n"
784 "vmla.f32 q13, %q14, q15 \n"
785 "vmla.f32 q10, %q16, q14 \n"
786 "vmla.f32 q12, %q15, q15 \n"
787 "vmla.f32 q11, %q16, q15 \n"
788
789 // "pld [%3, #128] \n"
790 "vld1.u16 {d30-d31}, [%3 :64] \n" // r24 r25
791
792 "vshll.u16 q14, d30, #16 \n"
793 "vshll.u16 q15, d31, #16 \n"
794
795 "vmla.f32 q13, %q15, q14 \n"
796 "vmla.f32 q12, %q16, q14 \n"
797 "vmla.f32 q13, %q16, q15 \n"
798
799 "vshrn.u32 d20, q10, #16 \n"
800 "vshrn.u32 d21, q11, #16 \n"
801 "vshrn.u32 d22, q12, #16 \n"
802 "vshrn.u32 d23, q13, #16 \n"
803
804 "vst1.u16 {d20-d23}, [%0 :64]! \n"
805
806 : "=r"(outptr0), // %0
807 "=r"(r0), // %1
808 "=r"(r1), // %2
809 "=r"(r2) // %3
810 : "0"(outptr0),
811 "1"(r0),
812 "2"(r1),
813 "3"(r2),
814 "w"(_k00), // %8
815 "w"(_k01), // %9
816 "w"(_k02), // %10
817 "w"(_k10), // %11
818 "w"(_k11), // %12
819 "w"(_k12), // %13
820 "w"(_k20), // %14
821 "w"(_k21), // %15
822 "w"(_k22), // %16
823 "w"(_bias0) // %17
824 : "memory", "q10", "q11", "q12", "q13", "q14", "q15");
825 #endif
826 }
827 for (; j + 1 < outw; j += 2)
828 {
829 #if __aarch64__
830 asm volatile(
831 "prfm pldl1keep, [%1, #256] \n"
832 "ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [%1] \n" // r00 r01 r02 r03
833
834 "mov v18.16b, %17.16b \n" // sum00
835 "mov v19.16b, %17.16b \n" // sum01
836
837 "shll v12.4s, v12.4h, #16 \n"
838 "shll v13.4s, v13.4h, #16 \n"
839
840 "fmul v16.4s, %8.4s, v12.4s \n"
841 "fmul v17.4s, %8.4s, v13.4s \n"
842
843 "shll v14.4s, v14.4h, #16 \n"
844 "shll v15.4s, v15.4h, #16 \n"
845
846 "fmla v18.4s, %9.4s, v13.4s \n"
847 "fmla v19.4s, %9.4s, v14.4s \n"
848
849 "prfm pldl1keep, [%2, #256] \n"
850 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%2] \n" // r10 r11 r12 r13
851
852 "fmla v16.4s, %10.4s, v14.4s \n"
853 "fmla v17.4s, %10.4s, v15.4s \n"
854
855 "shll v20.4s, v20.4h, #16 \n"
856 "shll v21.4s, v21.4h, #16 \n"
857
858 "fmla v18.4s, %11.4s, v20.4s \n"
859 "fmla v19.4s, %11.4s, v21.4s \n"
860
861 "shll v22.4s, v22.4h, #16 \n"
862 "shll v23.4s, v23.4h, #16 \n"
863
864 "fmla v16.4s, %12.4s, v21.4s \n"
865 "fmla v17.4s, %12.4s, v22.4s \n"
866
867 "prfm pldl1keep, [%3, #256] \n"
868 "ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [%3] \n" // r20 r21 r22 r23
869
870 "fmla v18.4s, %13.4s, v22.4s \n"
871 "fmla v19.4s, %13.4s, v23.4s \n"
872
873 "shll v12.4s, v12.4h, #16 \n"
874 "shll v13.4s, v13.4h, #16 \n"
875
876 "fmla v16.4s, %14.4s, v12.4s \n"
877 "fmla v17.4s, %14.4s, v13.4s \n"
878
879 "shll v14.4s, v14.4h, #16 \n"
880 "shll v15.4s, v15.4h, #16 \n"
881
882 "fmla v18.4s, %15.4s, v13.4s \n"
883 "fmla v19.4s, %15.4s, v14.4s \n"
884
885 "add %1, %1, #16 \n"
886
887 "fmla v16.4s, %16.4s, v14.4s \n"
888 "fmla v17.4s, %16.4s, v15.4s \n"
889
890 "add %2, %2, #16 \n"
891
892 "fadd v18.4s, v18.4s, v16.4s \n"
893 "fadd v19.4s, v19.4s, v17.4s \n"
894
895 "add %3, %3, #16 \n"
896
897 "shrn v18.4h, v18.4s, #16 \n"
898 "shrn v19.4h, v19.4s, #16 \n"
899
900 "st1 {v18.4h, v19.4h}, [%0], #16 \n"
901
902 : "=r"(outptr0), // %0
903 "=r"(r0), // %1
904 "=r"(r1), // %2
905 "=r"(r2) // %3
906 : "0"(outptr0),
907 "1"(r0),
908 "2"(r1),
909 "3"(r2),
910 "w"(_k00), // %8
911 "w"(_k01), // %9
912 "w"(_k02), // %10
913 "w"(_k10), // %11
914 "w"(_k11), // %12
915 "w"(_k12), // %13
916 "w"(_k20), // %14
917 "w"(_k21), // %15
918 "w"(_k22), // %16
919 "w"(_bias0) // %17
920 : "memory", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
921 #else
922 asm volatile(
923 "pld [%1, #256] \n"
924 "vld1.u16 {d28-d31}, [%1 :64] \n" // r00 r01 r02 r03
925
926 "vmov q10, %q17 \n" // sum00
927 "vmov q11, %q17 \n" // sum01
928
929 "vshll.u16 q12, d28, #16 \n"
930 "vshll.u16 q13, d29, #16 \n"
931
932 "vmla.f32 q10, %q8, q12 \n"
933 "vmla.f32 q11, %q8, q13 \n"
934
935 "vshll.u16 q14, d30, #16 \n"
936
937 "vmla.f32 q10, %q9, q13 \n"
938 "vmla.f32 q11, %q9, q14 \n"
939
940 "vshll.u16 q15, d31, #16 \n"
941
942 "vmla.f32 q10, %q10, q14 \n"
943 "vmla.f32 q11, %q10, q15 \n"
944
945 "pld [%2, #256] \n"
946 "vld1.u16 {d28-d31}, [%2 :64] \n" // r10 r11 r12 r13
947
948 "vshll.u16 q12, d28, #16 \n"
949 "vshll.u16 q13, d29, #16 \n"
950
951 "vmla.f32 q10, %q11, q12 \n"
952 "vmla.f32 q11, %q11, q13 \n"
953
954 "vshll.u16 q14, d30, #16 \n"
955
956 "vmla.f32 q10, %q12, q13 \n"
957 "vmla.f32 q11, %q12, q14 \n"
958
959 "vshll.u16 q15, d31, #16 \n"
960
961 "vmla.f32 q10, %q13, q14 \n"
962 "vmla.f32 q11, %q13, q15 \n"
963
964 "pld [%3, #256] \n"
965 "vld1.u16 {d28-d31}, [%3 :64] \n" // r20 r21 r22 r23
966
967 "vshll.u16 q12, d28, #16 \n"
968 "vshll.u16 q13, d29, #16 \n"
969
970 "vmla.f32 q10, %q14, q12 \n"
971 "vmla.f32 q11, %q14, q13 \n"
972
973 "vshll.u16 q14, d30, #16 \n"
974
975 "vmla.f32 q10, %q15, q13 \n"
976 "vmla.f32 q11, %q15, q14 \n"
977
978 "vshll.u16 q15, d31, #16 \n"
979
980 "vmla.f32 q10, %q16, q14 \n"
981 "vmla.f32 q11, %q16, q15 \n"
982
983 "add %1, %1, #16 \n"
984 "add %2, %2, #16 \n"
985
986 "vshrn.u32 d20, q10, #16 \n"
987 "vshrn.u32 d21, q11, #16 \n"
988
989 "add %3, %3, #16 \n"
990
991 "vst1.u16 {d20-d21}, [%0 :64]! \n"
992
993 : "=r"(outptr0), // %0
994 "=r"(r0), // %1
995 "=r"(r1), // %2
996 "=r"(r2) // %3
997 : "0"(outptr0),
998 "1"(r0),
999 "2"(r1),
1000 "3"(r2),
1001 "w"(_k00), // %8
1002 "w"(_k01), // %9
1003 "w"(_k02), // %10
1004 "w"(_k10), // %11
1005 "w"(_k11), // %12
1006 "w"(_k12), // %13
1007 "w"(_k20), // %14
1008 "w"(_k21), // %15
1009 "w"(_k22), // %16
1010 "w"(_bias0) // %17
1011 : "memory", "q10", "q11", "q12", "q13", "q14", "q15");
1012 #endif
1013 }
1014 for (; j < outw; j++)
1015 {
1016 float32x4_t _sum0 = _bias0;
1017
1018 float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0));
1019 float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4));
1020 float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8));
1021 float32x4_t _r10 = vcvt_f32_bf16(vld1_u16(r1));
1022 float32x4_t _r11 = vcvt_f32_bf16(vld1_u16(r1 + 4));
1023 float32x4_t _r12 = vcvt_f32_bf16(vld1_u16(r1 + 8));
1024 float32x4_t _r20 = vcvt_f32_bf16(vld1_u16(r2));
1025 float32x4_t _r21 = vcvt_f32_bf16(vld1_u16(r2 + 4));
1026 float32x4_t _r22 = vcvt_f32_bf16(vld1_u16(r2 + 8));
1027
1028 _sum0 = vmlaq_f32(_sum0, _k00, _r00);
1029 _sum0 = vmlaq_f32(_sum0, _k01, _r01);
1030 _sum0 = vmlaq_f32(_sum0, _k02, _r02);
1031 _sum0 = vmlaq_f32(_sum0, _k10, _r10);
1032 _sum0 = vmlaq_f32(_sum0, _k11, _r11);
1033 _sum0 = vmlaq_f32(_sum0, _k12, _r12);
1034 _sum0 = vmlaq_f32(_sum0, _k20, _r20);
1035 _sum0 = vmlaq_f32(_sum0, _k21, _r21);
1036 _sum0 = vmlaq_f32(_sum0, _k22, _r22);
1037
1038 vst1_u16(outptr0, vcvt_bf16_f32(_sum0));
1039
1040 r0 += 4;
1041 r1 += 4;
1042 r2 += 4;
1043 outptr0 += 4;
1044 }
1045
1046 r0 += 2 * 4;
1047 r1 += 2 * 4;
1048 r2 += 2 * 4;
1049 }
1050 }
1051 }
1052
convdw3x3s2_pack4_bf16s_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)1053 static void convdw3x3s2_pack4_bf16s_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
1054 {
1055 int w = bottom_blob.w;
1056
1057 int outw = top_blob.w;
1058 int outh = top_blob.h;
1059
1060 const int group = bottom_blob.c;
1061
1062 const int tailstep = (w - 2 * outw + w) * 4;
1063
1064 const float* bias = _bias;
1065
1066 #pragma omp parallel for num_threads(opt.num_threads)
1067 for (int g = 0; g < group; g++)
1068 {
1069 Mat out = top_blob.channel(g);
1070
1071 float32x4_t _bias0 = bias ? vld1q_f32((const float*)bias + g * 4) : vdupq_n_f32(0.f);
1072
1073 const unsigned short* k0 = kernel.row<const unsigned short>(g);
1074
1075 unsigned short* outptr0 = out;
1076
1077 const Mat img0 = bottom_blob.channel(g);
1078
1079 const unsigned short* r0 = img0.row<const unsigned short>(0);
1080 const unsigned short* r1 = img0.row<const unsigned short>(1);
1081 const unsigned short* r2 = img0.row<const unsigned short>(2);
1082
1083 float32x4_t _k00 = vcvt_f32_bf16(vld1_u16(k0));
1084 float32x4_t _k01 = vcvt_f32_bf16(vld1_u16(k0 + 4));
1085 float32x4_t _k02 = vcvt_f32_bf16(vld1_u16(k0 + 8));
1086 float32x4_t _k10 = vcvt_f32_bf16(vld1_u16(k0 + 12));
1087 float32x4_t _k11 = vcvt_f32_bf16(vld1_u16(k0 + 16));
1088 float32x4_t _k12 = vcvt_f32_bf16(vld1_u16(k0 + 20));
1089 float32x4_t _k20 = vcvt_f32_bf16(vld1_u16(k0 + 24));
1090 float32x4_t _k21 = vcvt_f32_bf16(vld1_u16(k0 + 28));
1091 float32x4_t _k22 = vcvt_f32_bf16(vld1_u16(k0 + 32));
1092
1093 int i = 0;
1094
1095 for (; i < outh; i++)
1096 {
1097 int j = 0;
1098
1099 #if __aarch64__
1100 for (; j + 3 < outw; j += 4)
1101 {
1102 asm volatile(
1103 "prfm pldl1keep, [%1, #256] \n"
1104 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%1], #32 \n" // r00 r01 r02 r03
1105
1106 "mov v28.16b, %17.16b \n" // sum00
1107 "mov v29.16b, %17.16b \n" // sum01
1108 "mov v30.16b, %17.16b \n" // sum02
1109 "mov v31.16b, %17.16b \n" // sum03
1110
1111 "prfm pldl1keep, [%1, #256] \n"
1112 "ld1 {v14.4h, v15.4h, v16.4h, v17.4h}, [%1], #32 \n" // r04 r05 r06 r07
1113
1114 "shll v10.4s, v10.4h, #16 \n"
1115 "shll v11.4s, v11.4h, #16 \n"
1116 "shll v12.4s, v12.4h, #16 \n"
1117 "shll v13.4s, v13.4h, #16 \n"
1118
1119 "prfm pldl1keep, [%1, #64] \n"
1120 "ld1 {v18.4h}, [%1] \n" // r08
1121
1122 "shll v14.4s, v14.4h, #16 \n"
1123 "shll v15.4s, v15.4h, #16 \n"
1124
1125 "fmla v28.4s, %8.4s, v10.4s \n"
1126 "fmla v29.4s, %8.4s, v12.4s \n"
1127
1128 "shll v16.4s, v16.4h, #16 \n"
1129
1130 "fmla v30.4s, %8.4s, v14.4s \n"
1131 "fmla v31.4s, %8.4s, v16.4s \n"
1132
1133 "shll v17.4s, v17.4h, #16 \n"
1134
1135 "fmla v28.4s, %9.4s, v11.4s \n"
1136 "fmla v29.4s, %9.4s, v13.4s \n"
1137 "fmla v30.4s, %9.4s, v15.4s \n"
1138 "fmla v31.4s, %9.4s, v17.4s \n"
1139
1140 "prfm pldl1keep, [%2, #256] \n"
1141 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%2], #32 \n" // r10 r11 r12 r13
1142
1143 "fmla v28.4s, %10.4s, v12.4s \n"
1144 "fmla v29.4s, %10.4s, v14.4s \n"
1145
1146 "shll v18.4s, v18.4h, #16 \n"
1147
1148 "fmla v30.4s, %10.4s, v16.4s \n"
1149 "fmla v31.4s, %10.4s, v18.4s \n"
1150
1151 "prfm pldl1keep, [%2, #256] \n"
1152 "ld1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n" // r14 r15 r16 r17
1153
1154 "shll v20.4s, v20.4h, #16 \n"
1155 "shll v21.4s, v21.4h, #16 \n"
1156 "shll v22.4s, v22.4h, #16 \n"
1157 "shll v23.4s, v23.4h, #16 \n"
1158
1159 "prfm pldl1keep, [%2, #64] \n"
1160 "ld1 {v19.4h}, [%2] \n" // r18
1161
1162 "shll v24.4s, v24.4h, #16 \n"
1163 "shll v25.4s, v25.4h, #16 \n"
1164
1165 "fmla v28.4s, %11.4s, v20.4s \n"
1166 "fmla v29.4s, %11.4s, v22.4s \n"
1167
1168 "shll v26.4s, v26.4h, #16 \n"
1169
1170 "fmla v30.4s, %11.4s, v24.4s \n"
1171 "fmla v31.4s, %11.4s, v26.4s \n"
1172
1173 "shll v27.4s, v27.4h, #16 \n"
1174
1175 "fmla v28.4s, %12.4s, v21.4s \n"
1176 "fmla v29.4s, %12.4s, v23.4s \n"
1177 "fmla v30.4s, %12.4s, v25.4s \n"
1178 "fmla v31.4s, %12.4s, v27.4s \n"
1179
1180 "prfm pldl1keep, [%3, #256] \n"
1181 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%3], #32 \n" // r20 r21 r22 r23
1182
1183 "fmla v28.4s, %13.4s, v22.4s \n"
1184 "fmla v29.4s, %13.4s, v24.4s \n"
1185
1186 "shll v19.4s, v19.4h, #16 \n"
1187
1188 "fmla v30.4s, %13.4s, v26.4s \n"
1189 "fmla v31.4s, %13.4s, v19.4s \n"
1190
1191 "prfm pldl1keep, [%3, #256] \n"
1192 "ld1 {v14.4h, v15.4h, v16.4h, v17.4h}, [%3], #32 \n" // r24 r25 r26 r27
1193
1194 "shll v10.4s, v10.4h, #16 \n"
1195 "shll v11.4s, v11.4h, #16 \n"
1196 "shll v12.4s, v12.4h, #16 \n"
1197 "shll v13.4s, v13.4h, #16 \n"
1198
1199 "prfm pldl1keep, [%3, #64] \n"
1200 "ld1 {v18.4h}, [%3] \n" // r28
1201
1202 "shll v14.4s, v14.4h, #16 \n"
1203 "shll v15.4s, v15.4h, #16 \n"
1204
1205 "fmla v28.4s, %14.4s, v10.4s \n"
1206 "fmla v29.4s, %14.4s, v12.4s \n"
1207
1208 "shll v16.4s, v16.4h, #16 \n"
1209
1210 "fmla v30.4s, %14.4s, v14.4s \n"
1211 "fmla v31.4s, %14.4s, v16.4s \n"
1212
1213 "shll v17.4s, v17.4h, #16 \n"
1214
1215 "fmla v28.4s, %15.4s, v11.4s \n"
1216 "fmla v29.4s, %15.4s, v13.4s \n"
1217 "fmla v30.4s, %15.4s, v15.4s \n"
1218 "fmla v31.4s, %15.4s, v17.4s \n"
1219
1220 "fmla v28.4s, %16.4s, v12.4s \n"
1221 "fmla v29.4s, %16.4s, v14.4s \n"
1222
1223 "shll v18.4s, v18.4h, #16 \n"
1224
1225 "fmla v30.4s, %16.4s, v16.4s \n"
1226 "fmla v31.4s, %16.4s, v18.4s \n"
1227
1228 "shrn v28.4h, v28.4s, #16 \n"
1229 "shrn v29.4h, v29.4s, #16 \n"
1230 "shrn v30.4h, v30.4s, #16 \n"
1231 "shrn v31.4h, v31.4s, #16 \n"
1232
1233 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%0], #32 \n"
1234
1235 : "=r"(outptr0), // %0
1236 "=r"(r0), // %1
1237 "=r"(r1), // %2
1238 "=r"(r2) // %3
1239 : "0"(outptr0),
1240 "1"(r0),
1241 "2"(r1),
1242 "3"(r2),
1243 "w"(_k00), // %8
1244 "w"(_k01), // %9
1245 "w"(_k02), // %10
1246 "w"(_k10), // %11
1247 "w"(_k11), // %12
1248 "w"(_k12), // %13
1249 "w"(_k20), // %14
1250 "w"(_k21), // %15
1251 "w"(_k22), // %16
1252 "w"(_bias0) // %17
1253 : "memory", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1254 }
1255 #endif // __aarch64__
1256 for (; j + 1 < outw; j += 2)
1257 {
1258 #if __aarch64__
1259 asm volatile(
1260 "prfm pldl1keep, [%1, #256] \n"
1261 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%1], #32 \n" // r00 r01 r02 r03
1262
1263 "mov v22.16b, %17.16b \n" // sum00
1264 "mov v23.16b, %17.16b \n" // sum01
1265
1266 "shll v10.4s, v10.4h, #16 \n"
1267 "shll v11.4s, v11.4h, #16 \n"
1268
1269 "fmul v20.4s, %8.4s, v10.4s \n"
1270
1271 "shll v12.4s, v12.4h, #16 \n"
1272 "shll v13.4s, v13.4h, #16 \n"
1273
1274 "fmul v21.4s, %8.4s, v12.4s \n"
1275
1276 "prfm pldl1keep, [%1, #64] \n"
1277 "ld1 {v14.4h}, [%1] \n" // r04
1278
1279 "fmla v22.4s, %9.4s, v11.4s \n"
1280 "fmla v23.4s, %9.4s, v13.4s \n"
1281
1282 "prfm pldl1keep, [%2, #256] \n"
1283 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%2], #32 \n" // r10 r11 r12 r13
1284
1285 "shll v14.4s, v14.4h, #16 \n"
1286
1287 "fmla v20.4s, %10.4s, v12.4s \n"
1288 "fmla v21.4s, %10.4s, v14.4s \n"
1289
1290 "shll v16.4s, v16.4h, #16 \n"
1291 "shll v17.4s, v17.4h, #16 \n"
1292
1293 "fmla v22.4s, %11.4s, v16.4s \n"
1294
1295 "shll v18.4s, v18.4h, #16 \n"
1296 "shll v19.4s, v19.4h, #16 \n"
1297
1298 "fmla v23.4s, %11.4s, v18.4s \n"
1299
1300 "prfm pldl1keep, [%2, #64] \n"
1301 "ld1 {v15.4h}, [%2] \n" // r14
1302
1303 "fmla v20.4s, %12.4s, v17.4s \n"
1304 "fmla v21.4s, %12.4s, v19.4s \n"
1305
1306 "prfm pldl1keep, [%3, #256] \n"
1307 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%3], #32 \n" // r20 r21 r22 r23
1308
1309 "shll v15.4s, v15.4h, #16 \n"
1310
1311 "fmla v22.4s, %13.4s, v18.4s \n"
1312 "fmla v23.4s, %13.4s, v15.4s \n"
1313
1314 "shll v10.4s, v10.4h, #16 \n"
1315 "shll v11.4s, v11.4h, #16 \n"
1316
1317 "fmla v20.4s, %14.4s, v10.4s \n"
1318
1319 "shll v12.4s, v12.4h, #16 \n"
1320 "shll v13.4s, v13.4h, #16 \n"
1321
1322 "fmla v21.4s, %14.4s, v12.4s \n"
1323
1324 "prfm pldl1keep, [%3, #64] \n"
1325 "ld1 {v14.4h}, [%3] \n" // r24
1326
1327 "fmla v22.4s, %15.4s, v11.4s \n"
1328 "fmla v23.4s, %15.4s, v13.4s \n"
1329
1330 "shll v14.4s, v14.4h, #16 \n"
1331
1332 "fmla v20.4s, %16.4s, v12.4s \n"
1333 "fmla v21.4s, %16.4s, v14.4s \n"
1334
1335 "fadd v22.4s, v20.4s, v22.4s \n"
1336 "fadd v23.4s, v21.4s, v23.4s \n"
1337
1338 "shrn v22.4h, v22.4s, #16 \n"
1339 "shrn v23.4h, v23.4s, #16 \n"
1340
1341 "st1 {v22.4h, v23.4h}, [%0], #16 \n"
1342
1343 : "=r"(outptr0), // %0
1344 "=r"(r0), // %1
1345 "=r"(r1), // %2
1346 "=r"(r2) // %3
1347 : "0"(outptr0),
1348 "1"(r0),
1349 "2"(r1),
1350 "3"(r2),
1351 "w"(_k00), // %8
1352 "w"(_k01), // %9
1353 "w"(_k02), // %10
1354 "w"(_k10), // %11
1355 "w"(_k11), // %12
1356 "w"(_k12), // %13
1357 "w"(_k20), // %14
1358 "w"(_k21), // %15
1359 "w"(_k22), // %16
1360 "w"(_bias0) // %17
1361 : "memory", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1362 #else
1363 asm volatile(
1364 "pld [%1, #256] \n"
1365 "vld1.u16 {d28-d31}, [%1 :64]! \n" // r00 r01 r02 r03
1366
1367 "vmov q10, %q17 \n" // sum00
1368 "vmov q11, %q17 \n" // sum01
1369
1370 "vshll.u16 q12, d28, #16 \n"
1371 "vshll.u16 q13, d29, #16 \n"
1372
1373 "vmla.f32 q10, %q8, q12 \n"
1374
1375 "vshll.u16 q14, d30, #16 \n"
1376 "vshll.u16 q15, d31, #16 \n"
1377
1378 "vmla.f32 q11, %q8, q14 \n"
1379
1380 "vld1.u16 {d25}, [%1] \n" // r04
1381
1382 "vmla.f32 q10, %q9, q13 \n"
1383 "vmla.f32 q11, %q9, q15 \n"
1384
1385 "vshll.u16 q12, d25, #16 \n"
1386
1387 "vmla.f32 q10, %q10, q14 \n"
1388
1389 "pld [%2, #256] \n"
1390 "vld1.u16 {d28-d31}, [%2 :64]! \n" // r10 r11 r12 r13
1391
1392 "vmla.f32 q11, %q10, q12 \n"
1393
1394 "vshll.u16 q12, d28, #16 \n"
1395 "vshll.u16 q13, d29, #16 \n"
1396
1397 "vmla.f32 q10, %q11, q12 \n"
1398
1399 "vshll.u16 q14, d30, #16 \n"
1400 "vshll.u16 q15, d31, #16 \n"
1401
1402 "vmla.f32 q11, %q11, q14 \n"
1403
1404 "vld1.u16 {d25}, [%2] \n" // r14
1405
1406 "vmla.f32 q10, %q12, q13 \n"
1407 "vmla.f32 q11, %q12, q15 \n"
1408
1409 "vshll.u16 q12, d25, #16 \n"
1410
1411 "vmla.f32 q10, %q13, q14 \n"
1412
1413 "pld [%3, #256] \n"
1414 "vld1.u16 {d28-d31}, [%3 :64]! \n" // r20 r21 r22 r23
1415
1416 "vmla.f32 q11, %q13, q12 \n"
1417
1418 "vshll.u16 q12, d28, #16 \n"
1419 "vshll.u16 q13, d29, #16 \n"
1420
1421 "vmla.f32 q10, %q14, q12 \n"
1422
1423 "vshll.u16 q14, d30, #16 \n"
1424 "vshll.u16 q15, d31, #16 \n"
1425
1426 "vmla.f32 q11, %q14, q14 \n"
1427
1428 "vld1.u16 {d25}, [%3] \n" // r24
1429
1430 "vmla.f32 q10, %q15, q13 \n"
1431 "vmla.f32 q11, %q15, q15 \n"
1432
1433 "vshll.u16 q12, d25, #16 \n"
1434
1435 "vmla.f32 q10, %q16, q14 \n"
1436 "vmla.f32 q11, %q16, q12 \n"
1437
1438 "vshrn.u32 d20, q10, #16 \n"
1439 "vshrn.u32 d21, q11, #16 \n"
1440
1441 "vst1.u16 {d20-d21}, [%0 :64]! \n"
1442
1443 : "=r"(outptr0), // %0
1444 "=r"(r0), // %1
1445 "=r"(r1), // %2
1446 "=r"(r2) // %3
1447 : "0"(outptr0),
1448 "1"(r0),
1449 "2"(r1),
1450 "3"(r2),
1451 "w"(_k00), // %8
1452 "w"(_k01), // %9
1453 "w"(_k02), // %10
1454 "w"(_k10), // %11
1455 "w"(_k11), // %12
1456 "w"(_k12), // %13
1457 "w"(_k20), // %14
1458 "w"(_k21), // %15
1459 "w"(_k22), // %16
1460 "w"(_bias0) // %17
1461 : "memory", "q10", "q11", "q12", "q13", "q14", "q15");
1462 #endif
1463 }
1464 for (; j < outw; j++)
1465 {
1466 float32x4_t _sum0 = _bias0;
1467
1468 float32x4_t _r00 = vcvt_f32_bf16(vld1_u16(r0));
1469 float32x4_t _r01 = vcvt_f32_bf16(vld1_u16(r0 + 4));
1470 float32x4_t _r02 = vcvt_f32_bf16(vld1_u16(r0 + 8));
1471 float32x4_t _r10 = vcvt_f32_bf16(vld1_u16(r1));
1472 float32x4_t _r11 = vcvt_f32_bf16(vld1_u16(r1 + 4));
1473 float32x4_t _r12 = vcvt_f32_bf16(vld1_u16(r1 + 8));
1474 float32x4_t _r20 = vcvt_f32_bf16(vld1_u16(r2));
1475 float32x4_t _r21 = vcvt_f32_bf16(vld1_u16(r2 + 4));
1476 float32x4_t _r22 = vcvt_f32_bf16(vld1_u16(r2 + 8));
1477
1478 _sum0 = vmlaq_f32(_sum0, _k00, _r00);
1479 _sum0 = vmlaq_f32(_sum0, _k01, _r01);
1480 _sum0 = vmlaq_f32(_sum0, _k02, _r02);
1481 _sum0 = vmlaq_f32(_sum0, _k10, _r10);
1482 _sum0 = vmlaq_f32(_sum0, _k11, _r11);
1483 _sum0 = vmlaq_f32(_sum0, _k12, _r12);
1484 _sum0 = vmlaq_f32(_sum0, _k20, _r20);
1485 _sum0 = vmlaq_f32(_sum0, _k21, _r21);
1486 _sum0 = vmlaq_f32(_sum0, _k22, _r22);
1487
1488 vst1_u16(outptr0, vcvt_bf16_f32(_sum0));
1489
1490 r0 += 2 * 4;
1491 r1 += 2 * 4;
1492 r2 += 2 * 4;
1493 outptr0 += 4;
1494 }
1495
1496 r0 += tailstep;
1497 r1 += tailstep;
1498 r2 += tailstep;
1499 }
1500 }
1501 }
1502