1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv1x1s1_sgemm_transform_kernel_pack4_fp16sa_neon(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch)15 static void conv1x1s1_sgemm_transform_kernel_pack4_fp16sa_neon(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch)
16 {
17 // interleave
18 // src = inch-outch
19 // dst = 4b-4a-inch/4a-outch/4b
20 kernel_tm_pack4.create(2 * 1, inch / 4, (outch / 4) / 2 + (outch / 4) % 2, (size_t)2u * 16, 16);
21
22 int q = 0;
23 for (; q + 7 < outch; q += 8)
24 {
25 const float* k0 = (const float*)kernel + (q + 0) * inch;
26 const float* k1 = (const float*)kernel + (q + 1) * inch;
27 const float* k2 = (const float*)kernel + (q + 2) * inch;
28 const float* k3 = (const float*)kernel + (q + 3) * inch;
29 const float* k4 = (const float*)kernel + (q + 4) * inch;
30 const float* k5 = (const float*)kernel + (q + 5) * inch;
31 const float* k6 = (const float*)kernel + (q + 6) * inch;
32 const float* k7 = (const float*)kernel + (q + 7) * inch;
33
34 __fp16* g0 = kernel_tm_pack4.channel(q / 8);
35
36 for (int p = 0; p + 3 < inch; p += 4)
37 {
38 g0[0] = (__fp16)k0[0];
39 g0[1] = (__fp16)k1[0];
40 g0[2] = (__fp16)k2[0];
41 g0[3] = (__fp16)k3[0];
42 g0[4] = (__fp16)k4[0];
43 g0[5] = (__fp16)k5[0];
44 g0[6] = (__fp16)k6[0];
45 g0[7] = (__fp16)k7[0];
46
47 g0[8] = (__fp16)k0[1];
48 g0[9] = (__fp16)k1[1];
49 g0[10] = (__fp16)k2[1];
50 g0[11] = (__fp16)k3[1];
51 g0[12] = (__fp16)k4[1];
52 g0[13] = (__fp16)k5[1];
53 g0[14] = (__fp16)k6[1];
54 g0[15] = (__fp16)k7[1];
55
56 g0[16] = (__fp16)k0[2];
57 g0[17] = (__fp16)k1[2];
58 g0[18] = (__fp16)k2[2];
59 g0[19] = (__fp16)k3[2];
60 g0[20] = (__fp16)k4[2];
61 g0[21] = (__fp16)k5[2];
62 g0[22] = (__fp16)k6[2];
63 g0[23] = (__fp16)k7[2];
64
65 g0[24] = (__fp16)k0[3];
66 g0[25] = (__fp16)k1[3];
67 g0[26] = (__fp16)k2[3];
68 g0[27] = (__fp16)k3[3];
69 g0[28] = (__fp16)k4[3];
70 g0[29] = (__fp16)k5[3];
71 g0[30] = (__fp16)k6[3];
72 g0[31] = (__fp16)k7[3];
73
74 k0 += 4;
75 k1 += 4;
76 k2 += 4;
77 k3 += 4;
78 k4 += 4;
79 k5 += 4;
80 k6 += 4;
81 k7 += 4;
82 g0 += 32;
83 }
84 }
85 for (; q + 3 < outch; q += 4)
86 {
87 const float* k0 = (const float*)kernel + (q + 0) * inch;
88 const float* k1 = (const float*)kernel + (q + 1) * inch;
89 const float* k2 = (const float*)kernel + (q + 2) * inch;
90 const float* k3 = (const float*)kernel + (q + 3) * inch;
91
92 __fp16* g0 = kernel_tm_pack4.channel(q / 8 + (q % 8) / 4);
93
94 for (int p = 0; p + 3 < inch; p += 4)
95 {
96 g0[0] = (__fp16)k0[0];
97 g0[1] = (__fp16)k1[0];
98 g0[2] = (__fp16)k2[0];
99 g0[3] = (__fp16)k3[0];
100
101 g0[4] = (__fp16)k0[1];
102 g0[5] = (__fp16)k1[1];
103 g0[6] = (__fp16)k2[1];
104 g0[7] = (__fp16)k3[1];
105
106 g0[8] = (__fp16)k0[2];
107 g0[9] = (__fp16)k1[2];
108 g0[10] = (__fp16)k2[2];
109 g0[11] = (__fp16)k3[2];
110
111 g0[12] = (__fp16)k0[3];
112 g0[13] = (__fp16)k1[3];
113 g0[14] = (__fp16)k2[3];
114 g0[15] = (__fp16)k3[3];
115
116 k0 += 4;
117 k1 += 4;
118 k2 += 4;
119 k3 += 4;
120 g0 += 16;
121 }
122 }
123 }
124
conv1x1s1_sgemm_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)125 static void conv1x1s1_sgemm_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
126 {
127 int w = bottom_blob.w;
128 int h = bottom_blob.h;
129 int inch = bottom_blob.c;
130 int outch = top_blob.c;
131
132 size_t elemsize = bottom_blob.elemsize;
133 int elempack = bottom_blob.elempack;
134
135 const int size = w * h;
136
137 const __fp16* bias = _bias;
138
139 // interleave
140 Mat tmp;
141 if (size >= 8)
142 tmp.create(8, inch, size / 8 + (size % 8) / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
143 else if (size >= 4)
144 tmp.create(4, inch, size / 4 + size % 4, elemsize, elempack, opt.workspace_allocator);
145 else // if (size >= 1)
146 tmp.create(1, inch, size, elemsize, elempack, opt.workspace_allocator);
147 {
148 int nn_size;
149 int remain_size_start = 0;
150
151 nn_size = (size - remain_size_start) >> 3;
152
153 #pragma omp parallel for num_threads(opt.num_threads)
154 for (int ii = 0; ii < nn_size; ii++)
155 {
156 int i = remain_size_start + ii * 8;
157
158 const __fp16* img0 = bottom_blob.channel(0);
159 img0 += i * 4;
160
161 __fp16* tmpptr = tmp.channel(i / 8);
162
163 for (int q = 0; q < inch; q++)
164 {
165 // transpose 4x8
166 asm volatile(
167 "prfm pldl1keep, [%0, #512] \n"
168 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
169 "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
170 : "=r"(img0), // %0
171 "=r"(tmpptr) // %1
172 : "0"(img0),
173 "1"(tmpptr)
174 : "memory", "v0", "v1", "v2", "v3");
175
176 img0 += bottom_blob.cstep * 4;
177 }
178 }
179
180 remain_size_start += nn_size << 3;
181 nn_size = (size - remain_size_start) >> 2;
182
183 #pragma omp parallel for num_threads(opt.num_threads)
184 for (int ii = 0; ii < nn_size; ii++)
185 {
186 int i = remain_size_start + ii * 4;
187
188 const __fp16* img0 = bottom_blob.channel(0);
189 img0 += i * 4;
190
191 __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
192
193 for (int q = 0; q < inch; q++)
194 {
195 // transpose 4x4
196 asm volatile(
197 "prfm pldl1keep, [%0, #256] \n"
198 "ld4 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n"
199 "st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
200 : "=r"(img0), // %0
201 "=r"(tmpptr) // %1
202 : "0"(img0),
203 "1"(tmpptr)
204 : "memory", "v0", "v1", "v2", "v3");
205
206 img0 += bottom_blob.cstep * 4;
207 }
208 }
209
210 remain_size_start += nn_size << 2;
211
212 #pragma omp parallel for num_threads(opt.num_threads)
213 for (int i = remain_size_start; i < size; i++)
214 {
215 const __fp16* img0 = bottom_blob.channel(0);
216 img0 += i * 4;
217
218 __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
219
220 for (int q = 0; q < inch; q++)
221 {
222 asm volatile(
223 "prfm pldl1keep, [%0, #64] \n"
224 "ld1 {v0.4h}, [%0] \n"
225 "st1 {v0.4h}, [%1], #8 \n"
226 : "=r"(img0), // %0
227 "=r"(tmpptr) // %1
228 : "0"(img0),
229 "1"(tmpptr)
230 : "memory", "v0");
231
232 img0 += bottom_blob.cstep * 4;
233 }
234 }
235 }
236
237 int nn_outch = 0;
238 int remain_outch_start = 0;
239
240 nn_outch = outch >> 1;
241 remain_outch_start = nn_outch << 1;
242
243 #pragma omp parallel for num_threads(opt.num_threads)
244 for (int pp = 0; pp < nn_outch; pp++)
245 {
246 int p = pp * 2;
247
248 __fp16* outptr0 = top_blob.channel(p);
249 __fp16* outptr1 = top_blob.channel(p + 1);
250
251 const __fp16 zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
252 const __fp16* biasptr = bias ? bias + p * 4 : zeros;
253 float16x8_t _bias0 = vld1q_f16(biasptr);
254
255 int i = 0;
256 for (; i + 7 < size; i += 8)
257 {
258 __fp16* tmpptr = tmp.channel(i / 8);
259 const __fp16* kptr = kernel.channel(pp);
260
261 int nn = inch; // inch always > 0
262
263 asm volatile(
264 "mov v24.16b, %10.16b \n"
265 "mov v25.16b, %10.16b \n"
266 "mov v26.16b, %10.16b \n"
267 "mov v27.16b, %10.16b \n"
268 "mov v28.16b, %10.16b \n"
269 "mov v29.16b, %10.16b \n"
270 "mov v30.16b, %10.16b \n"
271 "mov v31.16b, %10.16b \n"
272
273 "0: \n"
274
275 "prfm pldl1keep, [%3, #512] \n"
276 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%3], #64 \n" // r01 r23 r45 r67
277
278 "prfm pldl1keep, [%4, #512] \n"
279 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
280
281 "fmla v24.8h, v4.8h, v0.h[0] \n"
282 "fmla v25.8h, v4.8h, v0.h[1] \n"
283 "fmla v26.8h, v4.8h, v0.h[2] \n"
284 "fmla v27.8h, v4.8h, v0.h[3] \n"
285 "fmla v28.8h, v4.8h, v0.h[4] \n"
286 "fmla v29.8h, v4.8h, v0.h[5] \n"
287 "fmla v30.8h, v4.8h, v0.h[6] \n"
288 "fmla v31.8h, v4.8h, v0.h[7] \n"
289
290 "fmla v24.8h, v5.8h, v1.h[0] \n"
291 "fmla v25.8h, v5.8h, v1.h[1] \n"
292 "fmla v26.8h, v5.8h, v1.h[2] \n"
293 "fmla v27.8h, v5.8h, v1.h[3] \n"
294 "fmla v28.8h, v5.8h, v1.h[4] \n"
295 "fmla v29.8h, v5.8h, v1.h[5] \n"
296 "fmla v30.8h, v5.8h, v1.h[6] \n"
297 "fmla v31.8h, v5.8h, v1.h[7] \n"
298
299 "fmla v24.8h, v6.8h, v2.h[0] \n"
300 "fmla v25.8h, v6.8h, v2.h[1] \n"
301 "fmla v26.8h, v6.8h, v2.h[2] \n"
302 "fmla v27.8h, v6.8h, v2.h[3] \n"
303 "fmla v28.8h, v6.8h, v2.h[4] \n"
304 "fmla v29.8h, v6.8h, v2.h[5] \n"
305 "fmla v30.8h, v6.8h, v2.h[6] \n"
306 "fmla v31.8h, v6.8h, v2.h[7] \n"
307
308 "subs %w0, %w0, #1 \n"
309
310 "fmla v24.8h, v7.8h, v3.h[0] \n"
311 "fmla v25.8h, v7.8h, v3.h[1] \n"
312 "fmla v26.8h, v7.8h, v3.h[2] \n"
313 "fmla v27.8h, v7.8h, v3.h[3] \n"
314 "fmla v28.8h, v7.8h, v3.h[4] \n"
315 "fmla v29.8h, v7.8h, v3.h[5] \n"
316 "fmla v30.8h, v7.8h, v3.h[6] \n"
317 "fmla v31.8h, v7.8h, v3.h[7] \n"
318
319 "bne 0b \n"
320
321 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
322 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
323
324 "ext v24.16b, v24.16b, v24.16b, #8 \n"
325 "ext v25.16b, v25.16b, v25.16b, #8 \n"
326 "ext v26.16b, v26.16b, v26.16b, #8 \n"
327 "ext v27.16b, v27.16b, v27.16b, #8 \n"
328 "ext v28.16b, v28.16b, v28.16b, #8 \n"
329 "ext v29.16b, v29.16b, v29.16b, #8 \n"
330 "ext v30.16b, v30.16b, v30.16b, #8 \n"
331 "ext v31.16b, v31.16b, v31.16b, #8 \n"
332
333 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
334 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%2], #32 \n"
335
336 : "=r"(nn), // %0
337 "=r"(outptr0), // %1
338 "=r"(outptr1), // %2
339 "=r"(tmpptr), // %3
340 "=r"(kptr) // %4
341 : "0"(nn),
342 "1"(outptr0),
343 "2"(outptr1),
344 "3"(tmpptr),
345 "4"(kptr),
346 "w"(_bias0) // %10
347 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
348 }
349 for (; i + 3 < size; i += 4)
350 {
351 __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
352 const __fp16* kptr = kernel.channel(pp);
353
354 int nn = inch; // inch always > 0
355
356 asm volatile(
357 "mov v24.16b, %10.16b \n"
358 "mov v25.16b, %10.16b \n"
359 "mov v26.16b, %10.16b \n"
360 "mov v27.16b, %10.16b \n"
361
362 "0: \n"
363
364 "prfm pldl1keep, [%3, #256] \n"
365 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%3], #32 \n" // r01 r23 r45 r67
366
367 "prfm pldl1keep, [%4, #512] \n"
368 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
369
370 "fmla v24.8h, v4.8h, v0.h[0] \n"
371 "fmla v25.8h, v4.8h, v0.h[1] \n"
372 "fmla v26.8h, v4.8h, v0.h[2] \n"
373 "fmla v27.8h, v4.8h, v0.h[3] \n"
374
375 "fmla v24.8h, v5.8h, v1.h[0] \n"
376 "fmla v25.8h, v5.8h, v1.h[1] \n"
377 "fmla v26.8h, v5.8h, v1.h[2] \n"
378 "fmla v27.8h, v5.8h, v1.h[3] \n"
379
380 "fmla v24.8h, v6.8h, v2.h[0] \n"
381 "fmla v25.8h, v6.8h, v2.h[1] \n"
382 "fmla v26.8h, v6.8h, v2.h[2] \n"
383 "fmla v27.8h, v6.8h, v2.h[3] \n"
384
385 "subs %w0, %w0, #1 \n"
386
387 "fmla v24.8h, v7.8h, v3.h[0] \n"
388 "fmla v25.8h, v7.8h, v3.h[1] \n"
389 "fmla v26.8h, v7.8h, v3.h[2] \n"
390 "fmla v27.8h, v7.8h, v3.h[3] \n"
391
392 "bne 0b \n"
393
394 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
395
396 "ext v24.16b, v24.16b, v24.16b, #8 \n"
397 "ext v25.16b, v25.16b, v25.16b, #8 \n"
398 "ext v26.16b, v26.16b, v26.16b, #8 \n"
399 "ext v27.16b, v27.16b, v27.16b, #8 \n"
400
401 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
402
403 : "=r"(nn), // %0
404 "=r"(outptr0), // %1
405 "=r"(outptr1), // %2
406 "=r"(tmpptr), // %3
407 "=r"(kptr) // %4
408 : "0"(nn),
409 "1"(outptr0),
410 "2"(outptr1),
411 "3"(tmpptr),
412 "4"(kptr),
413 "w"(_bias0) // %10
414 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
415 }
416 for (; i < size; i++)
417 {
418 __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
419 const __fp16* kptr = kernel.channel(pp);
420
421 float16x8_t _sum0 = _bias0;
422
423 for (int q = 0; q < inch; q++)
424 {
425 float16x4_t _r0 = vld1_f16(tmpptr);
426
427 float16x8_t _k0 = vld1q_f16(kptr);
428 float16x8_t _k1 = vld1q_f16(kptr + 8);
429 float16x8_t _k2 = vld1q_f16(kptr + 16);
430 float16x8_t _k3 = vld1q_f16(kptr + 24);
431
432 _sum0 = vfmaq_lane_f16(_sum0, _k0, _r0, 0);
433 _sum0 = vfmaq_lane_f16(_sum0, _k1, _r0, 1);
434 _sum0 = vfmaq_lane_f16(_sum0, _k2, _r0, 2);
435 _sum0 = vfmaq_lane_f16(_sum0, _k3, _r0, 3);
436
437 kptr += 32;
438 tmpptr += 4;
439 }
440
441 vst1_f16(outptr0, vget_low_f16(_sum0));
442 vst1_f16(outptr1, vget_high_f16(_sum0));
443
444 outptr0 += 4;
445 outptr1 += 4;
446 }
447 }
448
449 #pragma omp parallel for num_threads(opt.num_threads)
450 for (int p = remain_outch_start; p < outch; p++)
451 {
452 __fp16* outptr0 = top_blob.channel(p);
453
454 const __fp16 zeros[4] = {0.f, 0.f, 0.f, 0.f};
455 const __fp16* biasptr = bias ? bias + p * 4 : zeros;
456 float16x4_t _bias0 = vld1_f16(biasptr);
457
458 int i = 0;
459 for (; i + 7 < size; i += 8)
460 {
461 __fp16* tmpptr = tmp.channel(i / 8);
462 const __fp16* kptr = kernel.channel(p / 2 + p % 2);
463
464 int nn = inch; // inch always > 0
465
466 asm volatile(
467 "mov v24.16b, %8.16b \n"
468 "mov v25.16b, %8.16b \n"
469 "mov v26.16b, %8.16b \n"
470 "mov v27.16b, %8.16b \n"
471 "mov v28.16b, %8.16b \n"
472 "mov v29.16b, %8.16b \n"
473 "mov v30.16b, %8.16b \n"
474 "mov v31.16b, %8.16b \n"
475
476 "0: \n"
477
478 "prfm pldl1keep, [%2, #512] \n"
479 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%2], #64 \n" // r01 r23 r45 r67
480
481 "prfm pldl1keep, [%3, #256] \n"
482 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
483
484 "fmla v24.4h, v4.4h, v0.h[0] \n"
485 "fmla v25.4h, v4.4h, v0.h[1] \n"
486 "fmla v26.4h, v4.4h, v0.h[2] \n"
487 "fmla v27.4h, v4.4h, v0.h[3] \n"
488 "fmla v28.4h, v4.4h, v0.h[4] \n"
489 "fmla v29.4h, v4.4h, v0.h[5] \n"
490 "fmla v30.4h, v4.4h, v0.h[6] \n"
491 "fmla v31.4h, v4.4h, v0.h[7] \n"
492
493 "fmla v24.4h, v5.4h, v1.h[0] \n"
494 "fmla v25.4h, v5.4h, v1.h[1] \n"
495 "fmla v26.4h, v5.4h, v1.h[2] \n"
496 "fmla v27.4h, v5.4h, v1.h[3] \n"
497 "fmla v28.4h, v5.4h, v1.h[4] \n"
498 "fmla v29.4h, v5.4h, v1.h[5] \n"
499 "fmla v30.4h, v5.4h, v1.h[6] \n"
500 "fmla v31.4h, v5.4h, v1.h[7] \n"
501
502 "fmla v24.4h, v6.4h, v2.h[0] \n"
503 "fmla v25.4h, v6.4h, v2.h[1] \n"
504 "fmla v26.4h, v6.4h, v2.h[2] \n"
505 "fmla v27.4h, v6.4h, v2.h[3] \n"
506 "fmla v28.4h, v6.4h, v2.h[4] \n"
507 "fmla v29.4h, v6.4h, v2.h[5] \n"
508 "fmla v30.4h, v6.4h, v2.h[6] \n"
509 "fmla v31.4h, v6.4h, v2.h[7] \n"
510
511 "subs %w0, %w0, #1 \n"
512
513 "fmla v24.4h, v7.4h, v3.h[0] \n"
514 "fmla v25.4h, v7.4h, v3.h[1] \n"
515 "fmla v26.4h, v7.4h, v3.h[2] \n"
516 "fmla v27.4h, v7.4h, v3.h[3] \n"
517 "fmla v28.4h, v7.4h, v3.h[4] \n"
518 "fmla v29.4h, v7.4h, v3.h[5] \n"
519 "fmla v30.4h, v7.4h, v3.h[6] \n"
520 "fmla v31.4h, v7.4h, v3.h[7] \n"
521
522 "bne 0b \n"
523
524 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
525 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
526
527 : "=r"(nn), // %0
528 "=r"(outptr0), // %1
529 "=r"(tmpptr), // %2
530 "=r"(kptr) // %3
531 : "0"(nn),
532 "1"(outptr0),
533 "2"(tmpptr),
534 "3"(kptr),
535 "w"(_bias0) // %8
536 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
537 }
538 for (; i + 3 < size; i += 4)
539 {
540 __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
541 const __fp16* kptr = kernel.channel(p / 2 + p % 2);
542
543 int nn = inch; // inch always > 0
544
545 asm volatile(
546 "mov v24.16b, %8.16b \n"
547 "mov v25.16b, %8.16b \n"
548 "mov v26.16b, %8.16b \n"
549 "mov v27.16b, %8.16b \n"
550
551 "0: \n"
552
553 "prfm pldl1keep, [%2, #256] \n"
554 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n" // r01 r23 r45 r67
555
556 "prfm pldl1keep, [%3, #256] \n"
557 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
558
559 "fmla v24.4h, v4.4h, v0.h[0] \n"
560 "fmla v25.4h, v4.4h, v0.h[1] \n"
561 "fmla v26.4h, v4.4h, v0.h[2] \n"
562 "fmla v27.4h, v4.4h, v0.h[3] \n"
563
564 "fmla v24.4h, v5.4h, v1.h[0] \n"
565 "fmla v25.4h, v5.4h, v1.h[1] \n"
566 "fmla v26.4h, v5.4h, v1.h[2] \n"
567 "fmla v27.4h, v5.4h, v1.h[3] \n"
568
569 "fmla v24.4h, v6.4h, v2.h[0] \n"
570 "fmla v25.4h, v6.4h, v2.h[1] \n"
571 "fmla v26.4h, v6.4h, v2.h[2] \n"
572 "fmla v27.4h, v6.4h, v2.h[3] \n"
573
574 "subs %w0, %w0, #1 \n"
575
576 "fmla v24.4h, v7.4h, v3.h[0] \n"
577 "fmla v25.4h, v7.4h, v3.h[1] \n"
578 "fmla v26.4h, v7.4h, v3.h[2] \n"
579 "fmla v27.4h, v7.4h, v3.h[3] \n"
580
581 "bne 0b \n"
582
583 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
584
585 : "=r"(nn), // %0
586 "=r"(outptr0), // %1
587 "=r"(tmpptr), // %2
588 "=r"(kptr) // %3
589 : "0"(nn),
590 "1"(outptr0),
591 "2"(tmpptr),
592 "3"(kptr),
593 "w"(_bias0) // %8
594 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
595 }
596 for (; i < size; i++)
597 {
598 __fp16* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + i % 4);
599 const __fp16* kptr = kernel.channel(p / 2 + p % 2);
600
601 float16x4_t _sum0 = _bias0;
602
603 for (int q = 0; q < inch; q++)
604 {
605 float16x4_t _r0 = vld1_f16(tmpptr);
606
607 float16x4_t _k0 = vld1_f16(kptr);
608 float16x4_t _k1 = vld1_f16(kptr + 4);
609 float16x4_t _k2 = vld1_f16(kptr + 8);
610 float16x4_t _k3 = vld1_f16(kptr + 12);
611
612 _sum0 = vfma_lane_f16(_sum0, _k0, _r0, 0);
613 _sum0 = vfma_lane_f16(_sum0, _k1, _r0, 1);
614 _sum0 = vfma_lane_f16(_sum0, _k2, _r0, 2);
615 _sum0 = vfma_lane_f16(_sum0, _k3, _r0, 3);
616
617 kptr += 16;
618 tmpptr += 4;
619 }
620
621 vst1_f16(outptr0, _sum0);
622
623 outptr0 += 4;
624 }
625 }
626
627 // // NOTE sgemm
628 // for (; p<outch; p++)
629 // {
630 // Mat out0 = top_blob.channel(p);
631 //
632 // const short bias0 = bias ? bias[p] : 0.f;
633 //
634 // __fp16* outptr0 = out0;
635 //
636 // for (int i=0; i<size; i++)
637 // {
638 // short sum = bias0;
639 //
640 // const __fp16* kptr = _kernel.channel(p);
641 //
642 // for (int q=0; q<inch; q++)
643 // {
644 // const __fp16* img0 = bottom_blob.channel(q);
645 //
646 // sum += img0[i] * kptr[0];
647 // kptr ++;
648 // }
649 //
650 // outptr0[i] = sum;
651 // }
652 // }
653 }
654
conv1x1s2_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)655 static void conv1x1s2_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
656 {
657 int w = bottom_blob.w;
658 int channels = bottom_blob.c;
659 size_t elemsize = bottom_blob.elemsize;
660 int elempack = bottom_blob.elempack;
661
662 int outw = top_blob.w;
663 int outh = top_blob.h;
664
665 const int tailstep = (w - 2 * outw + w) * 4;
666
667 Mat bottom_blob_shrinked;
668 bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
669
670 #pragma omp parallel for num_threads(opt.num_threads)
671 for (int p = 0; p < channels; p++)
672 {
673 const __fp16* r0 = bottom_blob.channel(p);
674 __fp16* outptr = bottom_blob_shrinked.channel(p);
675
676 for (int i = 0; i < outh; i++)
677 {
678 int j = 0;
679 for (; j + 3 < outw; j += 4)
680 {
681 float16x4_t _v0 = vld1_f16(r0);
682 float16x4_t _v1 = vld1_f16(r0 + 8);
683 float16x4_t _v2 = vld1_f16(r0 + 16);
684 float16x4_t _v3 = vld1_f16(r0 + 24);
685 float16x8_t _v01 = vcombine_f16(_v0, _v1);
686 float16x8_t _v23 = vcombine_f16(_v2, _v3);
687 vst1q_f16(outptr, _v01);
688 vst1q_f16(outptr + 8, _v23);
689
690 r0 += 32;
691 outptr += 16;
692 }
693 for (; j + 1 < outw; j += 2)
694 {
695 float16x4_t _v0 = vld1_f16(r0);
696 float16x4_t _v1 = vld1_f16(r0 + 8);
697 float16x8_t _v = vcombine_f16(_v0, _v1);
698 vst1q_f16(outptr, _v);
699
700 r0 += 16;
701 outptr += 8;
702 }
703 for (; j < outw; j++)
704 {
705 float16x4_t _v = vld1_f16(r0);
706 vst1_f16(outptr, _v);
707
708 r0 += 8;
709 outptr += 4;
710 }
711
712 r0 += tailstep;
713 }
714 }
715
716 conv1x1s1_sgemm_pack4_fp16sa_neon(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
717 }
718