1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv3x3s1_winograd64_transform_kernel_pack4_fp16sa_neon(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch)15 static void conv3x3s1_winograd64_transform_kernel_pack4_fp16sa_neon(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch)
16 {
17 // winograd63 transform kernel
18 Mat kernel_tm;
19 kernel_tm.create(8 * 8, inch, outch);
20
21 const float ktm[8][3] = {
22 {1.0f, 0.0f, 0.0f},
23 {-2.0f / 9, -2.0f / 9, -2.0f / 9},
24 {-2.0f / 9, 2.0f / 9, -2.0f / 9},
25 {1.0f / 90, 1.0f / 45, 2.0f / 45},
26 {1.0f / 90, -1.0f / 45, 2.0f / 45},
27 {1.0f / 45, 1.0f / 90, 1.0f / 180},
28 {1.0f / 45, -1.0f / 90, 1.0f / 180},
29 {0.0f, 0.0f, 1.0f}
30 };
31
32 #pragma omp parallel for
33 for (int p = 0; p < outch; p++)
34 {
35 for (int q = 0; q < inch; q++)
36 {
37 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
38 float* kernel_tm0 = kernel_tm.channel(p).row(q);
39
40 // transform kernel, transposed
41 const float* k0 = kernel0;
42 const float* k1 = kernel0 + 3;
43 const float* k2 = kernel0 + 6;
44
45 // h
46 float tmp[8][3];
47 for (int i = 0; i < 8; i++)
48 {
49 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
50 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
51 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
52 }
53
54 // v
55 for (int j = 0; j < 8; j++)
56 {
57 float* tmpp = &tmp[j][0];
58
59 for (int i = 0; i < 8; i++)
60 {
61 kernel_tm0[j * 8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
62 }
63 }
64 }
65 }
66
67 // interleave
68 // src = 64-inch-outch
69 // dst = 4b-4a-inch/4a-64-outch/4b;
70 kernel_tm_pack4.create(2 * inch / 4, 64, (outch / 4) / 2 + (outch / 4) % 2, (size_t)2u * 16, 16);
71
72 int q = 0;
73 for (; q + 7 < outch; q += 8)
74 {
75 const Mat k0 = kernel_tm.channel(q);
76 const Mat k1 = kernel_tm.channel(q + 1);
77 const Mat k2 = kernel_tm.channel(q + 2);
78 const Mat k3 = kernel_tm.channel(q + 3);
79 const Mat k4 = kernel_tm.channel(q + 4);
80 const Mat k5 = kernel_tm.channel(q + 5);
81 const Mat k6 = kernel_tm.channel(q + 6);
82 const Mat k7 = kernel_tm.channel(q + 7);
83
84 Mat g0 = kernel_tm_pack4.channel(q / 8);
85
86 for (int k = 0; k < 64; k++)
87 {
88 __fp16* g00 = g0.row<__fp16>(k);
89
90 for (int p = 0; p + 3 < inch; p += 4)
91 {
92 const float* k00 = k0.row(p);
93 const float* k01 = k0.row(p + 1);
94 const float* k02 = k0.row(p + 2);
95 const float* k03 = k0.row(p + 3);
96
97 const float* k10 = k1.row(p);
98 const float* k11 = k1.row(p + 1);
99 const float* k12 = k1.row(p + 2);
100 const float* k13 = k1.row(p + 3);
101
102 const float* k20 = k2.row(p);
103 const float* k21 = k2.row(p + 1);
104 const float* k22 = k2.row(p + 2);
105 const float* k23 = k2.row(p + 3);
106
107 const float* k30 = k3.row(p);
108 const float* k31 = k3.row(p + 1);
109 const float* k32 = k3.row(p + 2);
110 const float* k33 = k3.row(p + 3);
111
112 const float* k40 = k4.row(p);
113 const float* k41 = k4.row(p + 1);
114 const float* k42 = k4.row(p + 2);
115 const float* k43 = k4.row(p + 3);
116
117 const float* k50 = k5.row(p);
118 const float* k51 = k5.row(p + 1);
119 const float* k52 = k5.row(p + 2);
120 const float* k53 = k5.row(p + 3);
121
122 const float* k60 = k6.row(p);
123 const float* k61 = k6.row(p + 1);
124 const float* k62 = k6.row(p + 2);
125 const float* k63 = k6.row(p + 3);
126
127 const float* k70 = k7.row(p);
128 const float* k71 = k7.row(p + 1);
129 const float* k72 = k7.row(p + 2);
130 const float* k73 = k7.row(p + 3);
131
132 g00[0] = (__fp16)k00[k];
133 g00[1] = (__fp16)k10[k];
134 g00[2] = (__fp16)k20[k];
135 g00[3] = (__fp16)k30[k];
136
137 g00[4] = (__fp16)k40[k];
138 g00[5] = (__fp16)k50[k];
139 g00[6] = (__fp16)k60[k];
140 g00[7] = (__fp16)k70[k];
141
142 g00[8] = (__fp16)k01[k];
143 g00[9] = (__fp16)k11[k];
144 g00[10] = (__fp16)k21[k];
145 g00[11] = (__fp16)k31[k];
146
147 g00[12] = (__fp16)k41[k];
148 g00[13] = (__fp16)k51[k];
149 g00[14] = (__fp16)k61[k];
150 g00[15] = (__fp16)k71[k];
151
152 g00[16] = (__fp16)k02[k];
153 g00[17] = (__fp16)k12[k];
154 g00[18] = (__fp16)k22[k];
155 g00[19] = (__fp16)k32[k];
156
157 g00[20] = (__fp16)k42[k];
158 g00[21] = (__fp16)k52[k];
159 g00[22] = (__fp16)k62[k];
160 g00[23] = (__fp16)k72[k];
161
162 g00[24] = (__fp16)k03[k];
163 g00[25] = (__fp16)k13[k];
164 g00[26] = (__fp16)k23[k];
165 g00[27] = (__fp16)k33[k];
166
167 g00[28] = (__fp16)k43[k];
168 g00[29] = (__fp16)k53[k];
169 g00[30] = (__fp16)k63[k];
170 g00[31] = (__fp16)k73[k];
171
172 g00 += 32;
173 }
174 }
175 }
176 for (; q + 3 < outch; q += 4)
177 {
178 const Mat k0 = kernel_tm.channel(q);
179 const Mat k1 = kernel_tm.channel(q + 1);
180 const Mat k2 = kernel_tm.channel(q + 2);
181 const Mat k3 = kernel_tm.channel(q + 3);
182
183 Mat g0 = kernel_tm_pack4.channel(q / 8 + (q % 8) / 4);
184
185 for (int k = 0; k < 64; k++)
186 {
187 __fp16* g00 = g0.row<__fp16>(k);
188
189 for (int p = 0; p + 3 < inch; p += 4)
190 {
191 const float* k00 = k0.row(p);
192 const float* k01 = k0.row(p + 1);
193 const float* k02 = k0.row(p + 2);
194 const float* k03 = k0.row(p + 3);
195
196 const float* k10 = k1.row(p);
197 const float* k11 = k1.row(p + 1);
198 const float* k12 = k1.row(p + 2);
199 const float* k13 = k1.row(p + 3);
200
201 const float* k20 = k2.row(p);
202 const float* k21 = k2.row(p + 1);
203 const float* k22 = k2.row(p + 2);
204 const float* k23 = k2.row(p + 3);
205
206 const float* k30 = k3.row(p);
207 const float* k31 = k3.row(p + 1);
208 const float* k32 = k3.row(p + 2);
209 const float* k33 = k3.row(p + 3);
210
211 g00[0] = (__fp16)k00[k];
212 g00[1] = (__fp16)k10[k];
213 g00[2] = (__fp16)k20[k];
214 g00[3] = (__fp16)k30[k];
215
216 g00[4] = (__fp16)k01[k];
217 g00[5] = (__fp16)k11[k];
218 g00[6] = (__fp16)k21[k];
219 g00[7] = (__fp16)k31[k];
220
221 g00[8] = (__fp16)k02[k];
222 g00[9] = (__fp16)k12[k];
223 g00[10] = (__fp16)k22[k];
224 g00[11] = (__fp16)k32[k];
225
226 g00[12] = (__fp16)k03[k];
227 g00[13] = (__fp16)k13[k];
228 g00[14] = (__fp16)k23[k];
229 g00[15] = (__fp16)k33[k];
230
231 g00 += 16;
232 }
233 }
234 }
235 }
236
conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)237 static void conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
238 {
239 int w = bottom_blob.w;
240 int h = bottom_blob.h;
241 int inch = bottom_blob.c;
242 //size_t elemsize = bottom_blob.elemsize;
243 int elempack = bottom_blob.elempack;
244
245 int outw = top_blob.w;
246 int outh = top_blob.h;
247 int outch = top_blob.c;
248
249 // pad to 6n+2
250 Mat bottom_blob_bordered = bottom_blob;
251
252 outw = (outw + 5) / 6 * 6;
253 outh = (outh + 5) / 6 * 6;
254
255 w = outw + 2;
256 h = outh + 2;
257 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
258
259 const float* bias = _bias;
260
261 // BEGIN transform input
262 Mat bottom_blob_tm;
263 {
264 int w_tm = outw / 6 * 8;
265 int h_tm = outh / 6 * 8;
266
267 const int tiles = w_tm / 8 * h_tm / 8;
268
269 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
270 bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator);
271
272 // const float itm[8][8] = {
273 // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f},
274 //
275 // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f},
276 // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f},
277 //
278 // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f},
279 // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f},
280 //
281 // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f},
282 // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f},
283 //
284 // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f}
285 // };
286
287 // 0 = r00 - r06 + (r04 - r02) * 5.25
288 // 7 = r07 - r01 + (r03 - r05) * 5.25
289
290 // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
291 // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
292
293 // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
294 // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
295
296 // reuse r04 * 1.25
297 // reuse r03 * 2.5
298 // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
299 // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
300
301 #pragma omp parallel for num_threads(opt.num_threads)
302 for (int q = 0; q < inch; q++)
303 {
304 const Mat img0 = bottom_blob_bordered.channel(q);
305 Mat img0_tm = bottom_blob_tm.channel(q);
306
307 __fp16 tmp[8][8][4];
308
309 // tile
310 for (int i = 0; i < h_tm / 8; i++)
311 {
312 for (int j = 0; j < w_tm / 8; j++)
313 {
314 const __fp16* r0 = img0.row<const __fp16>(i * 6) + (j * 6) * 4;
315
316 for (int m = 0; m < 8; m++)
317 {
318 float16x4_t _r00 = vld1_f16(r0);
319 float16x4_t _r01 = vld1_f16(r0 + 4);
320 float16x4_t _r02 = vld1_f16(r0 + 8);
321 float16x4_t _r03 = vld1_f16(r0 + 12);
322 float16x4_t _r04 = vld1_f16(r0 + 16);
323 float16x4_t _r05 = vld1_f16(r0 + 20);
324 float16x4_t _r06 = vld1_f16(r0 + 24);
325 float16x4_t _r07 = vld1_f16(r0 + 28);
326
327 float16x4_t _tmp0m = vfma_n_f16(vsub_f16(_r00, _r06), vsub_f16(_r04, _r02), 5.25f);
328 float16x4_t _tmp7m = vfma_n_f16(vsub_f16(_r07, _r01), vsub_f16(_r03, _r05), 5.25f);
329 vst1_f16(tmp[0][m], _tmp0m);
330 vst1_f16(tmp[7][m], _tmp7m);
331
332 // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25;
333 // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25;
334
335 float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_r02, _r06), _r04, 4.25f);
336 float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_r01, _r05), _r03, 4.25f);
337
338 // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25);
339 // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25);
340
341 float16x4_t _tmp1m = vadd_f16(_tmp12a, _tmp12b);
342 float16x4_t _tmp2m = vsub_f16(_tmp12a, _tmp12b);
343 vst1_f16(tmp[1][m], _tmp1m);
344 vst1_f16(tmp[2][m], _tmp2m);
345
346 // tmp[1][m] = tmp12a + tmp12b;
347 // tmp[2][m] = tmp12a - tmp12b;
348
349 float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_r06, _r02, 0.25f), _r04, 1.25f);
350 float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f);
351
352 // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25);
353 // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2);
354
355 float16x4_t _tmp3m = vadd_f16(_tmp34a, _tmp34b);
356 float16x4_t _tmp4m = vsub_f16(_tmp34a, _tmp34b);
357 vst1_f16(tmp[3][m], _tmp3m);
358 vst1_f16(tmp[4][m], _tmp4m);
359
360 // tmp[3][m] = tmp34a + tmp34b;
361 // tmp[4][m] = tmp34a - tmp34b;
362
363 float16x4_t _tmp56a = vfma_n_f16(_r06, vfms_n_f16(_r02, _r04, 1.25f), 4.f);
364 float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f);
365
366 // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4);
367 // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5);
368
369 float16x4_t _tmp5m = vadd_f16(_tmp56a, _tmp56b);
370 float16x4_t _tmp6m = vsub_f16(_tmp56a, _tmp56b);
371 vst1_f16(tmp[5][m], _tmp5m);
372 vst1_f16(tmp[6][m], _tmp6m);
373
374 // tmp[5][m] = tmp56a + tmp56b;
375 // tmp[6][m] = tmp56a - tmp56b;
376
377 r0 += w * 4;
378 }
379
380 __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 4;
381 __fp16* r0_tm_1 = r0_tm_0 + tiles * 4;
382 __fp16* r0_tm_2 = r0_tm_0 + tiles * 8;
383 __fp16* r0_tm_3 = r0_tm_0 + tiles * 12;
384 __fp16* r0_tm_4 = r0_tm_0 + tiles * 16;
385 __fp16* r0_tm_5 = r0_tm_0 + tiles * 20;
386 __fp16* r0_tm_6 = r0_tm_0 + tiles * 24;
387 __fp16* r0_tm_7 = r0_tm_0 + tiles * 28;
388
389 for (int m = 0; m < 8; m++)
390 {
391 float16x4_t _tmp00 = vld1_f16(tmp[m][0]);
392 float16x4_t _tmp01 = vld1_f16(tmp[m][1]);
393 float16x4_t _tmp02 = vld1_f16(tmp[m][2]);
394 float16x4_t _tmp03 = vld1_f16(tmp[m][3]);
395 float16x4_t _tmp04 = vld1_f16(tmp[m][4]);
396 float16x4_t _tmp05 = vld1_f16(tmp[m][5]);
397 float16x4_t _tmp06 = vld1_f16(tmp[m][6]);
398 float16x4_t _tmp07 = vld1_f16(tmp[m][7]);
399
400 float16x4_t _r0tm0 = vfma_n_f16(vsub_f16(_tmp00, _tmp06), vsub_f16(_tmp04, _tmp02), 5.25f);
401 float16x4_t _r0tm7 = vfma_n_f16(vsub_f16(_tmp07, _tmp01), vsub_f16(_tmp03, _tmp05), 5.25f);
402
403 // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25;
404 // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25;
405
406 float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_tmp02, _tmp06), _tmp04, 4.25f);
407 float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_tmp01, _tmp05), _tmp03, 4.25f);
408
409 // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25);
410 // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25);
411
412 float16x4_t _r0tm1 = vadd_f16(_tmp12a, _tmp12b);
413 float16x4_t _r0tm2 = vsub_f16(_tmp12a, _tmp12b);
414
415 // r0_tm[1] = tmp12a + tmp12b;
416 // r0_tm[2] = tmp12a - tmp12b;
417
418 float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f);
419 float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f);
420
421 // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25);
422 // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2);
423
424 float16x4_t _r0tm3 = vadd_f16(_tmp34a, _tmp34b);
425 float16x4_t _r0tm4 = vsub_f16(_tmp34a, _tmp34b);
426
427 // r0_tm[3] = tmp34a + tmp34b;
428 // r0_tm[4] = tmp34a - tmp34b;
429
430 float16x4_t _tmp56a = vfma_n_f16(_tmp06, vfms_n_f16(_tmp02, _tmp04, 1.25f), 4.f);
431 float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f);
432
433 // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4);
434 // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5);
435
436 float16x4_t _r0tm5 = vadd_f16(_tmp56a, _tmp56b);
437 float16x4_t _r0tm6 = vsub_f16(_tmp56a, _tmp56b);
438
439 // r0_tm[5] = tmp56a + tmp56b;
440 // r0_tm[6] = tmp56a - tmp56b;
441
442 vst1_f16(r0_tm_0, _r0tm0);
443 vst1_f16(r0_tm_1, _r0tm1);
444 vst1_f16(r0_tm_2, _r0tm2);
445 vst1_f16(r0_tm_3, _r0tm3);
446 vst1_f16(r0_tm_4, _r0tm4);
447 vst1_f16(r0_tm_5, _r0tm5);
448 vst1_f16(r0_tm_6, _r0tm6);
449 vst1_f16(r0_tm_7, _r0tm7);
450
451 r0_tm_0 += tiles * 32;
452 r0_tm_1 += tiles * 32;
453 r0_tm_2 += tiles * 32;
454 r0_tm_3 += tiles * 32;
455 r0_tm_4 += tiles * 32;
456 r0_tm_5 += tiles * 32;
457 r0_tm_6 += tiles * 32;
458 r0_tm_7 += tiles * 32;
459 }
460 }
461 }
462 }
463 }
464 bottom_blob_bordered = Mat();
465 // END transform input
466
467 // BEGIN dot
468 Mat top_blob_tm;
469 {
470 int w_tm = outw / 6 * 8;
471 int h_tm = outh / 6 * 8;
472
473 const int tiles = h_tm / 8 * w_tm / 8;
474
475 // permute
476 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
477 Mat bottom_blob_tm2;
478 if (tiles >= 8)
479 bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
480 else if (tiles >= 4)
481 bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
482 else // if (tiles >= 1)
483 bottom_blob_tm2.create(1 * inch, tiles, 64, 2u * elempack, elempack, opt.workspace_allocator);
484
485 #pragma omp parallel for num_threads(opt.num_threads)
486 for (int r = 0; r < 64; r++)
487 {
488 Mat tm2 = bottom_blob_tm2.channel(r);
489
490 // tile
491 int i = 0;
492 for (; i + 7 < tiles; i += 8)
493 {
494 __fp16* tm2p = tm2.row<__fp16>(i / 8);
495
496 const __fp16* r0 = bottom_blob_tm;
497
498 r0 += (r * tiles + i) * 4;
499
500 for (int q = 0; q < inch; q++)
501 {
502 // transpose 4x8
503 asm volatile(
504 "prfm pldl1keep, [%0, #512] \n"
505 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
506 "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
507 : "=r"(r0), // %0
508 "=r"(tm2p) // %1
509 : "0"(r0),
510 "1"(tm2p)
511 : "memory", "v0", "v1", "v2", "v3");
512
513 r0 += bottom_blob_tm.cstep * 4;
514 }
515 }
516 for (; i + 3 < tiles; i += 4)
517 {
518 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4);
519
520 const __fp16* r0 = bottom_blob_tm;
521
522 r0 += (r * tiles + i) * 4;
523
524 for (int q = 0; q < inch; q++)
525 {
526 // transpose 4x4
527 asm volatile(
528 "prfm pldl1keep, [%0, #256] \n"
529 "ld4 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n"
530 "st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
531 : "=r"(r0), // %0
532 "=r"(tm2p) // %1
533 : "0"(r0),
534 "1"(tm2p)
535 : "memory", "v0", "v1", "v2", "v3");
536
537 r0 += bottom_blob_tm.cstep * 4;
538 }
539 }
540 for (; i < tiles; i++)
541 {
542 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4 + i % 4);
543
544 const __fp16* r0 = bottom_blob_tm;
545
546 r0 += (r * tiles + i) * 4;
547
548 for (int q = 0; q < inch; q++)
549 {
550 asm volatile(
551 "prfm pldl1keep, [%0, #64] \n"
552 "ld1 {v0.4h}, [%0] \n"
553 "st1 {v0.4h}, [%1], #8 \n"
554 : "=r"(r0), // %0
555 "=r"(tm2p) // %1
556 : "0"(r0),
557 "1"(tm2p)
558 : "memory", "v0");
559
560 r0 += bottom_blob_tm.cstep * 4;
561 }
562 }
563 }
564
565 bottom_blob_tm = Mat();
566 // permute end
567
568 top_blob_tm.create(tiles, 64, outch, 2u * elempack, elempack, opt.workspace_allocator);
569
570 int nn_outch = 0;
571 int remain_outch_start = 0;
572
573 nn_outch = outch >> 1;
574 remain_outch_start = nn_outch << 1;
575
576 #pragma omp parallel for num_threads(opt.num_threads)
577 for (int pp = 0; pp < nn_outch; pp++)
578 {
579 int p = pp * 2;
580
581 __fp16* output0_tm = top_blob_tm.channel(p);
582 __fp16* output1_tm = top_blob_tm.channel(p + 1);
583
584 const Mat kernel01_tm = kernel_tm.channel(pp);
585
586 for (int r = 0; r < 64; r++)
587 {
588 const Mat bb2 = bottom_blob_tm2.channel(r);
589
590 int i = 0;
591 for (; i + 7 < tiles; i += 8)
592 {
593 const __fp16* r0 = bb2.row<const __fp16>(i / 8);
594
595 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
596
597 int nn = inch; // inch always > 0
598
599 asm volatile(
600 "eor v24.16b, v24.16b, v24.16b \n"
601 "eor v25.16b, v25.16b, v25.16b \n"
602 "eor v26.16b, v26.16b, v26.16b \n"
603 "eor v27.16b, v27.16b, v27.16b \n"
604 "eor v28.16b, v28.16b, v28.16b \n"
605 "eor v29.16b, v29.16b, v29.16b \n"
606 "eor v30.16b, v30.16b, v30.16b \n"
607 "eor v31.16b, v31.16b, v31.16b \n"
608
609 "0: \n"
610
611 "prfm pldl1keep, [%3, #512] \n"
612 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%3], #64 \n" // r01 r23 r45 r67
613
614 "prfm pldl1keep, [%4, #512] \n"
615 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
616
617 "fmla v24.8h, v4.8h, v0.h[0] \n"
618 "fmla v25.8h, v4.8h, v0.h[1] \n"
619 "fmla v26.8h, v4.8h, v0.h[2] \n"
620 "fmla v27.8h, v4.8h, v0.h[3] \n"
621 "fmla v28.8h, v4.8h, v0.h[4] \n"
622 "fmla v29.8h, v4.8h, v0.h[5] \n"
623 "fmla v30.8h, v4.8h, v0.h[6] \n"
624 "fmla v31.8h, v4.8h, v0.h[7] \n"
625
626 "fmla v24.8h, v5.8h, v1.h[0] \n"
627 "fmla v25.8h, v5.8h, v1.h[1] \n"
628 "fmla v26.8h, v5.8h, v1.h[2] \n"
629 "fmla v27.8h, v5.8h, v1.h[3] \n"
630 "fmla v28.8h, v5.8h, v1.h[4] \n"
631 "fmla v29.8h, v5.8h, v1.h[5] \n"
632 "fmla v30.8h, v5.8h, v1.h[6] \n"
633 "fmla v31.8h, v5.8h, v1.h[7] \n"
634
635 "fmla v24.8h, v6.8h, v2.h[0] \n"
636 "fmla v25.8h, v6.8h, v2.h[1] \n"
637 "fmla v26.8h, v6.8h, v2.h[2] \n"
638 "fmla v27.8h, v6.8h, v2.h[3] \n"
639 "fmla v28.8h, v6.8h, v2.h[4] \n"
640 "fmla v29.8h, v6.8h, v2.h[5] \n"
641 "fmla v30.8h, v6.8h, v2.h[6] \n"
642 "fmla v31.8h, v6.8h, v2.h[7] \n"
643
644 "subs %w0, %w0, #1 \n"
645
646 "fmla v24.8h, v7.8h, v3.h[0] \n"
647 "fmla v25.8h, v7.8h, v3.h[1] \n"
648 "fmla v26.8h, v7.8h, v3.h[2] \n"
649 "fmla v27.8h, v7.8h, v3.h[3] \n"
650 "fmla v28.8h, v7.8h, v3.h[4] \n"
651 "fmla v29.8h, v7.8h, v3.h[5] \n"
652 "fmla v30.8h, v7.8h, v3.h[6] \n"
653 "fmla v31.8h, v7.8h, v3.h[7] \n"
654
655 "bne 0b \n"
656
657 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
658 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
659
660 "ext v24.16b, v24.16b, v24.16b, #8 \n"
661 "ext v25.16b, v25.16b, v25.16b, #8 \n"
662 "ext v26.16b, v26.16b, v26.16b, #8 \n"
663 "ext v27.16b, v27.16b, v27.16b, #8 \n"
664 "ext v28.16b, v28.16b, v28.16b, #8 \n"
665 "ext v29.16b, v29.16b, v29.16b, #8 \n"
666 "ext v30.16b, v30.16b, v30.16b, #8 \n"
667 "ext v31.16b, v31.16b, v31.16b, #8 \n"
668
669 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
670 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%2], #32 \n"
671
672 : "=r"(nn), // %0
673 "=r"(output0_tm), // %1
674 "=r"(output1_tm), // %2
675 "=r"(r0), // %3
676 "=r"(kptr) // %4
677 : "0"(nn),
678 "1"(output0_tm),
679 "2"(output1_tm),
680 "3"(r0),
681 "4"(kptr)
682 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
683 }
684 for (; i + 3 < tiles; i += 4)
685 {
686 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
687
688 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
689
690 int nn = inch; // inch always > 0
691
692 asm volatile(
693 "eor v24.16b, v24.16b, v24.16b \n"
694 "eor v25.16b, v25.16b, v25.16b \n"
695 "eor v26.16b, v26.16b, v26.16b \n"
696 "eor v27.16b, v27.16b, v27.16b \n"
697
698 "0: \n"
699
700 "prfm pldl1keep, [%3, #256] \n"
701 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%3], #32 \n" // r01 r23 r45 r67
702
703 "prfm pldl1keep, [%4, #512] \n"
704 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
705
706 "fmla v24.8h, v4.8h, v0.h[0] \n"
707 "fmla v25.8h, v4.8h, v0.h[1] \n"
708 "fmla v26.8h, v4.8h, v0.h[2] \n"
709 "fmla v27.8h, v4.8h, v0.h[3] \n"
710
711 "fmla v24.8h, v5.8h, v1.h[0] \n"
712 "fmla v25.8h, v5.8h, v1.h[1] \n"
713 "fmla v26.8h, v5.8h, v1.h[2] \n"
714 "fmla v27.8h, v5.8h, v1.h[3] \n"
715
716 "fmla v24.8h, v6.8h, v2.h[0] \n"
717 "fmla v25.8h, v6.8h, v2.h[1] \n"
718 "fmla v26.8h, v6.8h, v2.h[2] \n"
719 "fmla v27.8h, v6.8h, v2.h[3] \n"
720
721 "subs %w0, %w0, #1 \n"
722
723 "fmla v24.8h, v7.8h, v3.h[0] \n"
724 "fmla v25.8h, v7.8h, v3.h[1] \n"
725 "fmla v26.8h, v7.8h, v3.h[2] \n"
726 "fmla v27.8h, v7.8h, v3.h[3] \n"
727
728 "bne 0b \n"
729
730 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
731
732 "ext v24.16b, v24.16b, v24.16b, #8 \n"
733 "ext v25.16b, v25.16b, v25.16b, #8 \n"
734 "ext v26.16b, v26.16b, v26.16b, #8 \n"
735 "ext v27.16b, v27.16b, v27.16b, #8 \n"
736
737 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
738
739 : "=r"(nn), // %0
740 "=r"(output0_tm), // %1
741 "=r"(output1_tm), // %2
742 "=r"(r0), // %3
743 "=r"(kptr) // %4
744 : "0"(nn),
745 "1"(output0_tm),
746 "2"(output1_tm),
747 "3"(r0),
748 "4"(kptr)
749 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
750 }
751 for (; i < tiles; i++)
752 {
753 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
754
755 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
756
757 float16x8_t _sum0 = vdupq_n_f16(0.f);
758
759 for (int q = 0; q < inch; q++)
760 {
761 float16x4_t _r0 = vld1_f16(r0);
762
763 float16x8_t _k0 = vld1q_f16(kptr);
764 float16x8_t _k1 = vld1q_f16(kptr + 8);
765 float16x8_t _k2 = vld1q_f16(kptr + 16);
766 float16x8_t _k3 = vld1q_f16(kptr + 24);
767
768 _sum0 = vfmaq_lane_f16(_sum0, _k0, _r0, 0);
769 _sum0 = vfmaq_lane_f16(_sum0, _k1, _r0, 1);
770 _sum0 = vfmaq_lane_f16(_sum0, _k2, _r0, 2);
771 _sum0 = vfmaq_lane_f16(_sum0, _k3, _r0, 3);
772
773 kptr += 32;
774 r0 += 4;
775 }
776
777 vst1_f16(output0_tm, vget_low_f16(_sum0));
778 vst1_f16(output1_tm, vget_high_f16(_sum0));
779
780 output0_tm += 4;
781 output1_tm += 4;
782 }
783 }
784 }
785
786 #pragma omp parallel for num_threads(opt.num_threads)
787 for (int p = remain_outch_start; p < outch; p++)
788 {
789 __fp16* output0_tm = top_blob_tm.channel(p);
790
791 const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2);
792
793 for (int r = 0; r < 64; r++)
794 {
795 const Mat bb2 = bottom_blob_tm2.channel(r);
796
797 int i = 0;
798 for (; i + 7 < tiles; i += 8)
799 {
800 const __fp16* r0 = bb2.row<const __fp16>(i / 8);
801
802 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
803
804 int nn = inch; // inch always > 0
805
806 asm volatile(
807 "eor v24.16b, v24.16b, v24.16b \n"
808 "eor v25.16b, v25.16b, v25.16b \n"
809 "eor v26.16b, v26.16b, v26.16b \n"
810 "eor v27.16b, v27.16b, v27.16b \n"
811 "eor v28.16b, v28.16b, v28.16b \n"
812 "eor v29.16b, v29.16b, v29.16b \n"
813 "eor v30.16b, v30.16b, v30.16b \n"
814 "eor v31.16b, v31.16b, v31.16b \n"
815
816 "0: \n"
817
818 "prfm pldl1keep, [%2, #512] \n"
819 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%2], #64 \n" // r01 r23 r45 r67
820
821 "prfm pldl1keep, [%3, #256] \n"
822 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
823
824 "fmla v24.4h, v4.4h, v0.h[0] \n"
825 "fmla v25.4h, v4.4h, v0.h[1] \n"
826 "fmla v26.4h, v4.4h, v0.h[2] \n"
827 "fmla v27.4h, v4.4h, v0.h[3] \n"
828 "fmla v28.4h, v4.4h, v0.h[4] \n"
829 "fmla v29.4h, v4.4h, v0.h[5] \n"
830 "fmla v30.4h, v4.4h, v0.h[6] \n"
831 "fmla v31.4h, v4.4h, v0.h[7] \n"
832
833 "fmla v24.4h, v5.4h, v1.h[0] \n"
834 "fmla v25.4h, v5.4h, v1.h[1] \n"
835 "fmla v26.4h, v5.4h, v1.h[2] \n"
836 "fmla v27.4h, v5.4h, v1.h[3] \n"
837 "fmla v28.4h, v5.4h, v1.h[4] \n"
838 "fmla v29.4h, v5.4h, v1.h[5] \n"
839 "fmla v30.4h, v5.4h, v1.h[6] \n"
840 "fmla v31.4h, v5.4h, v1.h[7] \n"
841
842 "fmla v24.4h, v6.4h, v2.h[0] \n"
843 "fmla v25.4h, v6.4h, v2.h[1] \n"
844 "fmla v26.4h, v6.4h, v2.h[2] \n"
845 "fmla v27.4h, v6.4h, v2.h[3] \n"
846 "fmla v28.4h, v6.4h, v2.h[4] \n"
847 "fmla v29.4h, v6.4h, v2.h[5] \n"
848 "fmla v30.4h, v6.4h, v2.h[6] \n"
849 "fmla v31.4h, v6.4h, v2.h[7] \n"
850
851 "subs %w0, %w0, #1 \n"
852
853 "fmla v24.4h, v7.4h, v3.h[0] \n"
854 "fmla v25.4h, v7.4h, v3.h[1] \n"
855 "fmla v26.4h, v7.4h, v3.h[2] \n"
856 "fmla v27.4h, v7.4h, v3.h[3] \n"
857 "fmla v28.4h, v7.4h, v3.h[4] \n"
858 "fmla v29.4h, v7.4h, v3.h[5] \n"
859 "fmla v30.4h, v7.4h, v3.h[6] \n"
860 "fmla v31.4h, v7.4h, v3.h[7] \n"
861
862 "bne 0b \n"
863
864 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
865 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
866
867 : "=r"(nn), // %0
868 "=r"(output0_tm), // %1
869 "=r"(r0), // %2
870 "=r"(kptr) // %3
871 : "0"(nn),
872 "1"(output0_tm),
873 "2"(r0),
874 "3"(kptr)
875 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
876 }
877 for (; i + 3 < tiles; i += 4)
878 {
879 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
880
881 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
882
883 int nn = inch; // inch always > 0
884
885 asm volatile(
886 "eor v24.16b, v24.16b, v24.16b \n"
887 "eor v25.16b, v25.16b, v25.16b \n"
888 "eor v26.16b, v26.16b, v26.16b \n"
889 "eor v27.16b, v27.16b, v27.16b \n"
890
891 "0: \n"
892
893 "prfm pldl1keep, [%2, #256] \n"
894 "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n" // r01 r23 r45 r67
895
896 "prfm pldl1keep, [%3, #256] \n"
897 "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
898
899 "fmla v24.4h, v4.4h, v0.h[0] \n"
900 "fmla v25.4h, v4.4h, v0.h[1] \n"
901 "fmla v26.4h, v4.4h, v0.h[2] \n"
902 "fmla v27.4h, v4.4h, v0.h[3] \n"
903
904 "fmla v24.4h, v5.4h, v1.h[0] \n"
905 "fmla v25.4h, v5.4h, v1.h[1] \n"
906 "fmla v26.4h, v5.4h, v1.h[2] \n"
907 "fmla v27.4h, v5.4h, v1.h[3] \n"
908
909 "fmla v24.4h, v6.4h, v2.h[0] \n"
910 "fmla v25.4h, v6.4h, v2.h[1] \n"
911 "fmla v26.4h, v6.4h, v2.h[2] \n"
912 "fmla v27.4h, v6.4h, v2.h[3] \n"
913
914 "subs %w0, %w0, #1 \n"
915
916 "fmla v24.4h, v7.4h, v3.h[0] \n"
917 "fmla v25.4h, v7.4h, v3.h[1] \n"
918 "fmla v26.4h, v7.4h, v3.h[2] \n"
919 "fmla v27.4h, v7.4h, v3.h[3] \n"
920
921 "bne 0b \n"
922
923 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
924
925 : "=r"(nn), // %0
926 "=r"(output0_tm), // %1
927 "=r"(r0), // %2
928 "=r"(kptr) // %3
929 : "0"(nn),
930 "1"(output0_tm),
931 "2"(r0),
932 "3"(kptr)
933 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
934 }
935 for (; i < tiles; i++)
936 {
937 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
938
939 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
940
941 float16x4_t _sum0 = vdup_n_f16(0.f);
942
943 for (int q = 0; q < inch; q++)
944 {
945 float16x4_t _r0 = vld1_f16(r0);
946
947 float16x4_t _k0 = vld1_f16(kptr);
948 float16x4_t _k1 = vld1_f16(kptr + 4);
949 float16x4_t _k2 = vld1_f16(kptr + 8);
950 float16x4_t _k3 = vld1_f16(kptr + 12);
951
952 _sum0 = vfma_lane_f16(_sum0, _k0, _r0, 0);
953 _sum0 = vfma_lane_f16(_sum0, _k1, _r0, 1);
954 _sum0 = vfma_lane_f16(_sum0, _k2, _r0, 2);
955 _sum0 = vfma_lane_f16(_sum0, _k3, _r0, 3);
956
957 kptr += 16;
958 r0 += 4;
959 }
960
961 vst1_f16(output0_tm, _sum0);
962
963 output0_tm += 4;
964 }
965 }
966 }
967 }
968 bottom_blob_tm = Mat();
969 // END dot
970
971 // BEGIN transform output
972 Mat top_blob_bordered;
973 if (outw == top_blob.w && outh == top_blob.h)
974 {
975 top_blob_bordered = top_blob;
976 }
977 else
978 {
979 top_blob_bordered.create(outw, outh, outch, 2u * 4, 4, opt.workspace_allocator);
980 }
981 {
982 // const float otm[6][8] = {
983 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f},
984 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f},
985 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f},
986 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f},
987 // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f},
988 // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f}
989 // };
990
991 // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32
992 // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
993 // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
994 // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
995 // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
996 // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
997
998 int w_tm = outw / 6 * 8;
999 int h_tm = outh / 6 * 8;
1000 const int tiles = w_tm / 8 * h_tm / 8;
1001
1002 #pragma omp parallel for num_threads(opt.num_threads)
1003 for (int p = 0; p < outch; p++)
1004 {
1005 const Mat out0_tm = top_blob_tm.channel(p);
1006 Mat out0 = top_blob_bordered.channel(p);
1007
1008 // const float bias0 = bias ? bias[p] : 0.f;
1009 float16x4_t _bias0 = bias ? vld1_f16((const __fp16*)bias + p * 4) : vdup_n_f16(0.f);
1010
1011 __fp16 tmp[6][8][4];
1012
1013 // tile
1014 for (int i = 0; i < outh / 6; i++)
1015 {
1016 for (int j = 0; j < outw / 6; j++)
1017 {
1018 // top_blob_tm.create(tiles, 64, outch, elemsize, elempack);
1019
1020 const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 4;
1021 const __fp16* output0_tm_1 = output0_tm_0 + tiles * 4;
1022 const __fp16* output0_tm_2 = output0_tm_0 + tiles * 8;
1023 const __fp16* output0_tm_3 = output0_tm_0 + tiles * 12;
1024 const __fp16* output0_tm_4 = output0_tm_0 + tiles * 16;
1025 const __fp16* output0_tm_5 = output0_tm_0 + tiles * 20;
1026 const __fp16* output0_tm_6 = output0_tm_0 + tiles * 24;
1027 const __fp16* output0_tm_7 = output0_tm_0 + tiles * 28;
1028
1029 __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 4;
1030
1031 // TODO neon optimize
1032 for (int m = 0; m < 8; m++)
1033 {
1034 float16x4_t _out0tm0 = vld1_f16(output0_tm_0);
1035 float16x4_t _out0tm1 = vld1_f16(output0_tm_1);
1036 float16x4_t _out0tm2 = vld1_f16(output0_tm_2);
1037 float16x4_t _out0tm3 = vld1_f16(output0_tm_3);
1038 float16x4_t _out0tm4 = vld1_f16(output0_tm_4);
1039 float16x4_t _out0tm5 = vld1_f16(output0_tm_5);
1040 float16x4_t _out0tm6 = vld1_f16(output0_tm_6);
1041 float16x4_t _out0tm7 = vld1_f16(output0_tm_7);
1042
1043 float16x4_t _tmp024a = vadd_f16(_out0tm1, _out0tm2);
1044 float16x4_t _tmp135a = vsub_f16(_out0tm1, _out0tm2);
1045
1046 // float tmp024a = output0_tm[1] + output0_tm[2];
1047 // float tmp135a = output0_tm[1] - output0_tm[2];
1048
1049 float16x4_t _tmp024b = vadd_f16(_out0tm3, _out0tm4);
1050 float16x4_t _tmp135b = vsub_f16(_out0tm3, _out0tm4);
1051
1052 // float tmp024b = output0_tm[3] + output0_tm[4];
1053 // float tmp135b = output0_tm[3] - output0_tm[4];
1054
1055 float16x4_t _tmp024c = vadd_f16(_out0tm5, _out0tm6);
1056 float16x4_t _tmp135c = vsub_f16(_out0tm5, _out0tm6);
1057
1058 // float tmp024c = output0_tm[5] + output0_tm[6];
1059 // float tmp135c = output0_tm[5] - output0_tm[6];
1060
1061 float16x4_t _tmp0m = vadd_f16(vadd_f16(_out0tm0, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f));
1062 float16x4_t _tmp2m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f);
1063 float16x4_t _tmp4m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f);
1064 vst1_f16(tmp[0][m], _tmp0m);
1065 vst1_f16(tmp[2][m], _tmp2m);
1066 vst1_f16(tmp[4][m], _tmp4m);
1067
1068 // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32;
1069 // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
1070 // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1071
1072 float16x4_t _tmp1m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f);
1073 float16x4_t _tmp3m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f);
1074 float16x4_t _tmp5m = vadd_f16(vadd_f16(_out0tm7, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f));
1075 vst1_f16(tmp[1][m], _tmp1m);
1076 vst1_f16(tmp[3][m], _tmp3m);
1077 vst1_f16(tmp[5][m], _tmp5m);
1078
1079 // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
1080 // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
1081 // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c;
1082
1083 output0_tm_0 += tiles * 32;
1084 output0_tm_1 += tiles * 32;
1085 output0_tm_2 += tiles * 32;
1086 output0_tm_3 += tiles * 32;
1087 output0_tm_4 += tiles * 32;
1088 output0_tm_5 += tiles * 32;
1089 output0_tm_6 += tiles * 32;
1090 output0_tm_7 += tiles * 32;
1091 }
1092
1093 for (int m = 0; m < 6; m++)
1094 {
1095 float16x4_t _tmp00 = vld1_f16(tmp[m][0]);
1096 float16x4_t _tmp01 = vld1_f16(tmp[m][1]);
1097 float16x4_t _tmp02 = vld1_f16(tmp[m][2]);
1098 float16x4_t _tmp03 = vld1_f16(tmp[m][3]);
1099 float16x4_t _tmp04 = vld1_f16(tmp[m][4]);
1100 float16x4_t _tmp05 = vld1_f16(tmp[m][5]);
1101 float16x4_t _tmp06 = vld1_f16(tmp[m][6]);
1102 float16x4_t _tmp07 = vld1_f16(tmp[m][7]);
1103
1104 float16x4_t _tmp024a = vadd_f16(_tmp01, _tmp02);
1105 float16x4_t _tmp135a = vsub_f16(_tmp01, _tmp02);
1106
1107 // float tmp024a = tmp0[1] + tmp0[2];
1108 // float tmp135a = tmp0[1] - tmp0[2];
1109
1110 float16x4_t _tmp024b = vadd_f16(_tmp03, _tmp04);
1111 float16x4_t _tmp135b = vsub_f16(_tmp03, _tmp04);
1112
1113 // float tmp024b = tmp0[3] + tmp0[4];
1114 // float tmp135b = tmp0[3] - tmp0[4];
1115
1116 float16x4_t _tmp024c = vadd_f16(_tmp05, _tmp06);
1117 float16x4_t _tmp135c = vsub_f16(_tmp05, _tmp06);
1118
1119 // float tmp024c = tmp0[5] + tmp0[6];
1120 // float tmp135c = tmp0[5] - tmp0[6];
1121
1122 float16x4_t _out00 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp00, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f)));
1123 float16x4_t _out02 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f));
1124 float16x4_t _out04 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f));
1125 vst1_f16(output0, _out00);
1126 vst1_f16(output0 + 8, _out02);
1127 vst1_f16(output0 + 16, _out04);
1128
1129 // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
1130 // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8;
1131 // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1132
1133 float16x4_t _out01 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f));
1134 float16x4_t _out03 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f));
1135 float16x4_t _out05 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp07, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f)));
1136 vst1_f16(output0 + 4, _out01);
1137 vst1_f16(output0 + 12, _out03);
1138 vst1_f16(output0 + 20, _out05);
1139
1140 // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16;
1141 // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4;
1142 // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
1143
1144 output0 += outw * 4;
1145 }
1146 }
1147 }
1148 }
1149 }
1150 // END transform output
1151
1152 // cut result pad
1153 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1154 }
1155
conv3x3s1_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)1156 static void conv3x3s1_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
1157 {
1158 int w = bottom_blob.w;
1159 int inch = bottom_blob.c;
1160 int outw = top_blob.w;
1161 int outh = top_blob.h;
1162 int outch = top_blob.c;
1163
1164 const __fp16* bias = _bias;
1165
1166 #pragma omp parallel for num_threads(opt.num_threads)
1167 for (int p = 0; p < outch; p++)
1168 {
1169 Mat out0 = top_blob.channel(p);
1170
1171 float16x4_t _bias0 = bias ? vld1_f16(bias + p * 4) : vdup_n_f16((__fp16)0.f);
1172 out0.fill(_bias0);
1173
1174 int q = 0;
1175 for (; q < inch; q++)
1176 {
1177 __fp16* outptr0 = out0.row<__fp16>(0);
1178
1179 const Mat img0 = bottom_blob.channel(q);
1180
1181 const __fp16* r0 = img0.row<const __fp16>(0);
1182 const __fp16* r1 = img0.row<const __fp16>(1);
1183 const __fp16* r2 = img0.row<const __fp16>(2);
1184
1185 const __fp16* kptr = kernel.channel(p).row<const __fp16>(q);
1186
1187 // 16 * 9
1188 float16x8_t _k00_01 = vld1q_f16(kptr);
1189 float16x8_t _k00_23 = vld1q_f16(kptr + 8);
1190 float16x8_t _k01_01 = vld1q_f16(kptr + 16);
1191 float16x8_t _k01_23 = vld1q_f16(kptr + 24);
1192 float16x8_t _k02_01 = vld1q_f16(kptr + 32);
1193 float16x8_t _k02_23 = vld1q_f16(kptr + 40);
1194 float16x8_t _k10_01 = vld1q_f16(kptr + 48);
1195 float16x8_t _k10_23 = vld1q_f16(kptr + 56);
1196 float16x8_t _k11_01 = vld1q_f16(kptr + 64);
1197 float16x8_t _k11_23 = vld1q_f16(kptr + 72);
1198 float16x8_t _k12_01 = vld1q_f16(kptr + 80);
1199 float16x8_t _k12_23 = vld1q_f16(kptr + 88);
1200 float16x8_t _k20_01 = vld1q_f16(kptr + 96);
1201 float16x8_t _k20_23 = vld1q_f16(kptr + 104);
1202 float16x8_t _k21_01 = vld1q_f16(kptr + 112);
1203 float16x8_t _k21_23 = vld1q_f16(kptr + 120);
1204 float16x8_t _k22_01 = vld1q_f16(kptr + 128);
1205 float16x8_t _k22_23 = vld1q_f16(kptr + 136);
1206
1207 int i = 0;
1208 for (; i < outh; i++)
1209 {
1210 int j = 0;
1211 for (; j + 3 < outw; j += 4)
1212 {
1213 asm volatile(
1214 "prfm pldl1keep, [%0, #256] \n"
1215 "ld1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%0] \n" // sum0 sum1 sum2 sum3
1216
1217 "prfm pldl1keep, [%1, #384] \n"
1218 "ld1 {v0.8h, v1.8h, v2.8h}, [%1] \n" // r00 r01 r02 r03 r04 r05
1219
1220 "ext v6.16b, %8.16b, %8.16b, #8 \n"
1221 "fmla v10.4h, %8.4h, v0.h[0] \n"
1222 "fmla v11.4h, %8.4h, v0.h[4] \n"
1223 "fmla v12.4h, %8.4h, v1.h[0] \n"
1224 "fmla v13.4h, %8.4h, v1.h[4] \n"
1225 "fmla v10.4h, v6.4h, v0.h[1] \n"
1226 "fmla v11.4h, v6.4h, v0.h[5] \n"
1227 "fmla v12.4h, v6.4h, v1.h[1] \n"
1228 "fmla v13.4h, v6.4h, v1.h[5] \n"
1229 "ext v7.16b, %9.16b, %9.16b, #8 \n"
1230 "fmla v10.4h, %9.4h, v0.h[2] \n"
1231 "fmla v11.4h, %9.4h, v0.h[6] \n"
1232 "fmla v12.4h, %9.4h, v1.h[2] \n"
1233 "fmla v13.4h, %9.4h, v1.h[6] \n"
1234 "fmla v10.4h, v7.4h, v0.h[3] \n"
1235 "fmla v11.4h, v7.4h, v0.h[7] \n"
1236 "fmla v12.4h, v7.4h, v1.h[3] \n"
1237 "fmla v13.4h, v7.4h, v1.h[7] \n"
1238
1239 "ext v8.16b, %10.16b, %10.16b, #8 \n"
1240 "fmla v10.4h, %10.4h, v0.h[4] \n"
1241 "fmla v11.4h, %10.4h, v1.h[0] \n"
1242 "fmla v12.4h, %10.4h, v1.h[4] \n"
1243 "fmla v13.4h, %10.4h, v2.h[0] \n"
1244 "fmla v10.4h, v8.4h, v0.h[5] \n"
1245 "fmla v11.4h, v8.4h, v1.h[1] \n"
1246 "fmla v12.4h, v8.4h, v1.h[5] \n"
1247 "fmla v13.4h, v8.4h, v2.h[1] \n"
1248 "ext v9.16b, %11.16b, %11.16b, #8 \n"
1249 "fmla v10.4h, %11.4h, v0.h[6] \n"
1250 "fmla v11.4h, %11.4h, v1.h[2] \n"
1251 "fmla v12.4h, %11.4h, v1.h[6] \n"
1252 "fmla v13.4h, %11.4h, v2.h[2] \n"
1253 "fmla v10.4h, v9.4h, v0.h[7] \n"
1254 "fmla v11.4h, v9.4h, v1.h[3] \n"
1255 "fmla v12.4h, v9.4h, v1.h[7] \n"
1256 "fmla v13.4h, v9.4h, v2.h[3] \n"
1257
1258 "prfm pldl1keep, [%2, #384] \n"
1259 "ld1 {v3.8h, v4.8h, v5.8h}, [%2] \n" // r10 r11 r12 r13 r14 r15
1260
1261 "ext v6.16b, %12.16b, %12.16b, #8 \n"
1262 "fmla v10.4h, %12.4h, v1.h[0] \n"
1263 "fmla v11.4h, %12.4h, v1.h[4] \n"
1264 "fmla v12.4h, %12.4h, v2.h[0] \n"
1265 "fmla v13.4h, %12.4h, v2.h[4] \n"
1266 "fmla v10.4h, v6.4h, v1.h[1] \n"
1267 "fmla v11.4h, v6.4h, v1.h[5] \n"
1268 "fmla v12.4h, v6.4h, v2.h[1] \n"
1269 "fmla v13.4h, v6.4h, v2.h[5] \n"
1270 "ext v7.16b, %13.16b, %13.16b, #8 \n"
1271 "fmla v10.4h, %13.4h, v1.h[2] \n"
1272 "fmla v11.4h, %13.4h, v1.h[6] \n"
1273 "fmla v12.4h, %13.4h, v2.h[2] \n"
1274 "fmla v13.4h, %13.4h, v2.h[6] \n"
1275 "fmla v10.4h, v7.4h, v1.h[3] \n"
1276 "fmla v11.4h, v7.4h, v1.h[7] \n"
1277 "fmla v12.4h, v7.4h, v2.h[3] \n"
1278 "fmla v13.4h, v7.4h, v2.h[7] \n"
1279
1280 "ext v8.16b, %14.16b, %14.16b, #8 \n"
1281 "fmla v10.4h, %14.4h, v3.h[0] \n"
1282 "fmla v11.4h, %14.4h, v3.h[4] \n"
1283 "fmla v12.4h, %14.4h, v4.h[0] \n"
1284 "fmla v13.4h, %14.4h, v4.h[4] \n"
1285 "fmla v10.4h, v8.4h, v3.h[1] \n"
1286 "fmla v11.4h, v8.4h, v3.h[5] \n"
1287 "fmla v12.4h, v8.4h, v4.h[1] \n"
1288 "fmla v13.4h, v8.4h, v4.h[5] \n"
1289 "ext v9.16b, %15.16b, %15.16b, #8 \n"
1290 "fmla v10.4h, %15.4h, v3.h[2] \n"
1291 "fmla v11.4h, %15.4h, v3.h[6] \n"
1292 "fmla v12.4h, %15.4h, v4.h[2] \n"
1293 "fmla v13.4h, %15.4h, v4.h[6] \n"
1294 "fmla v10.4h, v9.4h, v3.h[3] \n"
1295 "fmla v11.4h, v9.4h, v3.h[7] \n"
1296 "fmla v12.4h, v9.4h, v4.h[3] \n"
1297 "fmla v13.4h, v9.4h, v4.h[7] \n"
1298
1299 "ext v6.16b, %16.16b, %16.16b, #8 \n"
1300 "fmla v10.4h, %16.4h, v3.h[4] \n"
1301 "fmla v11.4h, %16.4h, v4.h[0] \n"
1302 "fmla v12.4h, %16.4h, v4.h[4] \n"
1303 "fmla v13.4h, %16.4h, v5.h[0] \n"
1304 "fmla v10.4h, v6.4h, v3.h[5] \n"
1305 "fmla v11.4h, v6.4h, v4.h[1] \n"
1306 "fmla v12.4h, v6.4h, v4.h[5] \n"
1307 "fmla v13.4h, v6.4h, v5.h[1] \n"
1308 "ext v7.16b, %17.16b, %17.16b, #8 \n"
1309 "fmla v10.4h, %17.4h, v3.h[6] \n"
1310 "fmla v11.4h, %17.4h, v4.h[2] \n"
1311 "fmla v12.4h, %17.4h, v4.h[6] \n"
1312 "fmla v13.4h, %17.4h, v5.h[2] \n"
1313 "fmla v10.4h, v7.4h, v3.h[7] \n"
1314 "fmla v11.4h, v7.4h, v4.h[3] \n"
1315 "fmla v12.4h, v7.4h, v4.h[7] \n"
1316 "fmla v13.4h, v7.4h, v5.h[3] \n"
1317
1318 "prfm pldl1keep, [%3, #384] \n"
1319 "ld1 {v0.8h, v1.8h, v2.8h}, [%3] \n" // r20 r21 r22 r23 r24 r25
1320
1321 "ext v8.16b, %18.16b, %18.16b, #8 \n"
1322 "fmla v10.4h, %18.4h, v4.h[0] \n"
1323 "fmla v11.4h, %18.4h, v4.h[4] \n"
1324 "fmla v12.4h, %18.4h, v5.h[0] \n"
1325 "fmla v13.4h, %18.4h, v5.h[4] \n"
1326 "fmla v10.4h, v8.4h, v4.h[1] \n"
1327 "fmla v11.4h, v8.4h, v4.h[5] \n"
1328 "fmla v12.4h, v8.4h, v5.h[1] \n"
1329 "fmla v13.4h, v8.4h, v5.h[5] \n"
1330 "ext v9.16b, %19.16b, %19.16b, #8 \n"
1331 "fmla v10.4h, %19.4h, v4.h[2] \n"
1332 "fmla v11.4h, %19.4h, v4.h[6] \n"
1333 "fmla v12.4h, %19.4h, v5.h[2] \n"
1334 "fmla v13.4h, %19.4h, v5.h[6] \n"
1335 "fmla v10.4h, v9.4h, v4.h[3] \n"
1336 "fmla v11.4h, v9.4h, v4.h[7] \n"
1337 "fmla v12.4h, v9.4h, v5.h[3] \n"
1338 "fmla v13.4h, v9.4h, v5.h[7] \n"
1339
1340 "ext v6.16b, %20.16b, %20.16b, #8 \n"
1341 "fmla v10.4h, %20.4h, v0.h[0] \n"
1342 "fmla v11.4h, %20.4h, v0.h[4] \n"
1343 "fmla v12.4h, %20.4h, v1.h[0] \n"
1344 "fmla v13.4h, %20.4h, v1.h[4] \n"
1345 "fmla v10.4h, v6.4h, v0.h[1] \n"
1346 "fmla v11.4h, v6.4h, v0.h[5] \n"
1347 "fmla v12.4h, v6.4h, v1.h[1] \n"
1348 "fmla v13.4h, v6.4h, v1.h[5] \n"
1349 "ext v7.16b, %21.16b, %21.16b, #8 \n"
1350 "fmla v10.4h, %21.4h, v0.h[2] \n"
1351 "fmla v11.4h, %21.4h, v0.h[6] \n"
1352 "fmla v12.4h, %21.4h, v1.h[2] \n"
1353 "fmla v13.4h, %21.4h, v1.h[6] \n"
1354 "fmla v10.4h, v7.4h, v0.h[3] \n"
1355 "fmla v11.4h, v7.4h, v0.h[7] \n"
1356 "fmla v12.4h, v7.4h, v1.h[3] \n"
1357 "fmla v13.4h, v7.4h, v1.h[7] \n"
1358
1359 "ext v8.16b, %22.16b, %22.16b, #8 \n"
1360 "fmla v10.4h, %22.4h, v0.h[4] \n"
1361 "fmla v11.4h, %22.4h, v1.h[0] \n"
1362 "fmla v12.4h, %22.4h, v1.h[4] \n"
1363 "fmla v13.4h, %22.4h, v2.h[0] \n"
1364 "fmla v10.4h, v8.4h, v0.h[5] \n"
1365 "fmla v11.4h, v8.4h, v1.h[1] \n"
1366 "fmla v12.4h, v8.4h, v1.h[5] \n"
1367 "fmla v13.4h, v8.4h, v2.h[1] \n"
1368 "ext v9.16b, %23.16b, %23.16b, #8 \n"
1369 "fmla v10.4h, %23.4h, v0.h[6] \n"
1370 "fmla v11.4h, %23.4h, v1.h[2] \n"
1371 "fmla v12.4h, %23.4h, v1.h[6] \n"
1372 "fmla v13.4h, %23.4h, v2.h[2] \n"
1373 "fmla v10.4h, v9.4h, v0.h[7] \n"
1374 "fmla v11.4h, v9.4h, v1.h[3] \n"
1375 "fmla v12.4h, v9.4h, v1.h[7] \n"
1376 "fmla v13.4h, v9.4h, v2.h[3] \n"
1377
1378 "ext v6.16b, %24.16b, %24.16b, #8 \n"
1379 "fmla v10.4h, %24.4h, v1.h[0] \n"
1380 "fmla v11.4h, %24.4h, v1.h[4] \n"
1381 "fmla v12.4h, %24.4h, v2.h[0] \n"
1382 "fmla v13.4h, %24.4h, v2.h[4] \n"
1383
1384 "add %1, %1, #32 \n"
1385
1386 "fmla v10.4h, v6.4h, v1.h[1] \n"
1387 "fmla v11.4h, v6.4h, v1.h[5] \n"
1388 "fmla v12.4h, v6.4h, v2.h[1] \n"
1389 "fmla v13.4h, v6.4h, v2.h[5] \n"
1390 "ext v7.16b, %25.16b, %25.16b, #8 \n"
1391 "fmla v10.4h, %25.4h, v1.h[2] \n"
1392 "fmla v11.4h, %25.4h, v1.h[6] \n"
1393 "fmla v12.4h, %25.4h, v2.h[2] \n"
1394 "fmla v13.4h, %25.4h, v2.h[6] \n"
1395
1396 "add %2, %2, #32 \n"
1397
1398 "fmla v10.4h, v7.4h, v1.h[3] \n"
1399 "fmla v11.4h, v7.4h, v1.h[7] \n"
1400 "fmla v12.4h, v7.4h, v2.h[3] \n"
1401 "fmla v13.4h, v7.4h, v2.h[7] \n"
1402
1403 "add %3, %3, #32 \n"
1404
1405 "st1 {v10.4h, v11.4h, v12.4h, v13.4h}, [%0], #32 \n"
1406
1407 : "=r"(outptr0), // %0
1408 "=r"(r0), // %1
1409 "=r"(r1), // %2
1410 "=r"(r2) // %3
1411 : "0"(outptr0),
1412 "1"(r0),
1413 "2"(r1),
1414 "3"(r2),
1415 "w"(_k00_01), // %8
1416 "w"(_k00_23), // %9
1417 "w"(_k01_01), // %10
1418 "w"(_k01_23), // %11
1419 "w"(_k02_01), // %12
1420 "w"(_k02_23), // %13
1421 "w"(_k10_01), // %14
1422 "w"(_k10_23), // %15
1423 "w"(_k11_01), // %16
1424 "w"(_k11_23), // %17
1425 "w"(_k12_01), // %18
1426 "w"(_k12_23), // %19
1427 "w"(_k20_01), // %20
1428 "w"(_k20_23), // %21
1429 "w"(_k21_01), // %22
1430 "w"(_k21_23), // %23
1431 "w"(_k22_01), // %24
1432 "w"(_k22_23) // %25
1433 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13");
1434 }
1435 for (; j + 1 < outw; j += 2)
1436 {
1437 asm volatile(
1438 "prfm pldl1keep, [%1, #256] \n"
1439 "ld1 {v0.8h, v1.8h}, [%1] \n" // r00 r01 r02 r03
1440
1441 "prfm pldl1keep, [%0, #128] \n"
1442 "ld1 {v12.4h, v13.4h}, [%0] \n" // sum0 sum1
1443
1444 "ext v4.16b, %8.16b, %8.16b, #8 \n"
1445 "fmul v10.4h, %8.4h, v0.h[0] \n"
1446 "fmul v11.4h, %8.4h, v0.h[4] \n"
1447 "fmla v12.4h, v4.4h, v0.h[1] \n"
1448 "fmla v13.4h, v4.4h, v0.h[5] \n"
1449 "ext v5.16b, %9.16b, %9.16b, #8 \n"
1450 "fmla v10.4h, %9.4h, v0.h[2] \n"
1451 "fmla v11.4h, %9.4h, v0.h[6] \n"
1452 "fmla v12.4h, v5.4h, v0.h[3] \n"
1453 "fmla v13.4h, v5.4h, v0.h[7] \n"
1454
1455 "ext v6.16b, %10.16b, %10.16b, #8 \n"
1456 "fmla v10.4h, %10.4h, v0.h[4] \n"
1457 "fmla v11.4h, %10.4h, v1.h[0] \n"
1458 "fmla v12.4h, v6.4h, v0.h[5] \n"
1459 "fmla v13.4h, v6.4h, v1.h[1] \n"
1460 "ext v7.16b, %11.16b, %11.16b, #8 \n"
1461 "fmla v10.4h, %11.4h, v0.h[6] \n"
1462 "fmla v11.4h, %11.4h, v1.h[2] \n"
1463 "fmla v12.4h, v7.4h, v0.h[7] \n"
1464 "fmla v13.4h, v7.4h, v1.h[3] \n"
1465
1466 "prfm pldl1keep, [%2, #256] \n"
1467 "ld1 {v2.8h, v3.8h}, [%2] \n" // r10 r11 r12 r13
1468
1469 "ext v8.16b, %12.16b, %12.16b, #8 \n"
1470 "fmla v10.4h, %12.4h, v1.h[0] \n"
1471 "fmla v11.4h, %12.4h, v1.h[4] \n"
1472 "fmla v12.4h, v8.4h, v1.h[1] \n"
1473 "fmla v13.4h, v8.4h, v1.h[5] \n"
1474 "ext v9.16b, %13.16b, %13.16b, #8 \n"
1475 "fmla v10.4h, %13.4h, v1.h[2] \n"
1476 "fmla v11.4h, %13.4h, v1.h[6] \n"
1477 "fmla v12.4h, v9.4h, v1.h[3] \n"
1478 "fmla v13.4h, v9.4h, v1.h[7] \n"
1479
1480 "ext v4.16b, %14.16b, %14.16b, #8 \n"
1481 "fmla v10.4h, %14.4h, v2.h[0] \n"
1482 "fmla v11.4h, %14.4h, v2.h[4] \n"
1483 "fmla v12.4h, v4.4h, v2.h[1] \n"
1484 "fmla v13.4h, v4.4h, v2.h[5] \n"
1485 "ext v5.16b, %15.16b, %15.16b, #8 \n"
1486 "fmla v10.4h, %15.4h, v2.h[2] \n"
1487 "fmla v11.4h, %15.4h, v2.h[6] \n"
1488 "fmla v12.4h, v5.4h, v2.h[3] \n"
1489 "fmla v13.4h, v5.4h, v2.h[7] \n"
1490
1491 "ext v6.16b, %16.16b, %16.16b, #8 \n"
1492 "fmla v10.4h, %16.4h, v2.h[4] \n"
1493 "fmla v11.4h, %16.4h, v3.h[0] \n"
1494 "fmla v12.4h, v6.4h, v2.h[5] \n"
1495 "fmla v13.4h, v6.4h, v3.h[1] \n"
1496 "ext v7.16b, %17.16b, %17.16b, #8 \n"
1497 "fmla v10.4h, %17.4h, v2.h[6] \n"
1498 "fmla v11.4h, %17.4h, v3.h[2] \n"
1499 "fmla v12.4h, v7.4h, v2.h[7] \n"
1500 "fmla v13.4h, v7.4h, v3.h[3] \n"
1501
1502 "prfm pldl1keep, [%3, #256] \n"
1503 "ld1 {v0.8h, v1.8h}, [%3] \n" // r20 r21 r22 r23
1504
1505 "ext v8.16b, %18.16b, %18.16b, #8 \n"
1506 "fmla v10.4h, %18.4h, v3.h[0] \n"
1507 "fmla v11.4h, %18.4h, v3.h[4] \n"
1508 "fmla v12.4h, v8.4h, v3.h[1] \n"
1509 "fmla v13.4h, v8.4h, v3.h[5] \n"
1510 "ext v9.16b, %19.16b, %19.16b, #8 \n"
1511 "fmla v10.4h, %19.4h, v3.h[2] \n"
1512 "fmla v11.4h, %19.4h, v3.h[6] \n"
1513 "fmla v12.4h, v9.4h, v3.h[3] \n"
1514 "fmla v13.4h, v9.4h, v3.h[7] \n"
1515
1516 "ext v4.16b, %20.16b, %20.16b, #8 \n"
1517 "fmla v10.4h, %20.4h, v0.h[0] \n"
1518 "fmla v11.4h, %20.4h, v0.h[4] \n"
1519 "fmla v12.4h, v4.4h, v0.h[1] \n"
1520 "fmla v13.4h, v4.4h, v0.h[5] \n"
1521 "ext v5.16b, %21.16b, %21.16b, #8 \n"
1522 "fmla v10.4h, %21.4h, v0.h[2] \n"
1523 "fmla v11.4h, %21.4h, v0.h[6] \n"
1524 "fmla v12.4h, v5.4h, v0.h[3] \n"
1525 "fmla v13.4h, v5.4h, v0.h[7] \n"
1526
1527 "ext v6.16b, %22.16b, %22.16b, #8 \n"
1528 "fmla v10.4h, %22.4h, v0.h[4] \n"
1529 "fmla v11.4h, %22.4h, v1.h[0] \n"
1530 "fmla v12.4h, v6.4h, v0.h[5] \n"
1531 "fmla v13.4h, v6.4h, v1.h[1] \n"
1532 "ext v7.16b, %23.16b, %23.16b, #8 \n"
1533 "fmla v10.4h, %23.4h, v0.h[6] \n"
1534 "fmla v11.4h, %23.4h, v1.h[2] \n"
1535 "fmla v12.4h, v7.4h, v0.h[7] \n"
1536 "fmla v13.4h, v7.4h, v1.h[3] \n"
1537
1538 "ext v8.16b, %24.16b, %24.16b, #8 \n"
1539 "fmla v10.4h, %24.4h, v1.h[0] \n"
1540 "fmla v11.4h, %24.4h, v1.h[4] \n"
1541 "fmla v12.4h, v8.4h, v1.h[1] \n"
1542 "fmla v13.4h, v8.4h, v1.h[5] \n"
1543 "ext v9.16b, %25.16b, %25.16b, #8 \n"
1544 "fmla v10.4h, %25.4h, v1.h[2] \n"
1545 "fmla v11.4h, %25.4h, v1.h[6] \n"
1546 "fmla v12.4h, v9.4h, v1.h[3] \n"
1547 "fmla v13.4h, v9.4h, v1.h[7] \n"
1548
1549 "add %1, %1, #16 \n"
1550
1551 "fadd v10.4h, v10.4h, v12.4h \n"
1552
1553 "add %2, %2, #16 \n"
1554
1555 "fadd v11.4h, v11.4h, v13.4h \n"
1556
1557 "add %3, %3, #16 \n"
1558
1559 "st1 {v10.4h, v11.4h}, [%0], #16 \n"
1560
1561 : "=r"(outptr0), // %0
1562 "=r"(r0), // %1
1563 "=r"(r1), // %2
1564 "=r"(r2) // %3
1565 : "0"(outptr0),
1566 "1"(r0),
1567 "2"(r1),
1568 "3"(r2),
1569 "w"(_k00_01), // %8
1570 "w"(_k00_23), // %9
1571 "w"(_k01_01), // %10
1572 "w"(_k01_23), // %11
1573 "w"(_k02_01), // %12
1574 "w"(_k02_23), // %13
1575 "w"(_k10_01), // %14
1576 "w"(_k10_23), // %15
1577 "w"(_k11_01), // %16
1578 "w"(_k11_23), // %17
1579 "w"(_k12_01), // %18
1580 "w"(_k12_23), // %19
1581 "w"(_k20_01), // %20
1582 "w"(_k20_23), // %21
1583 "w"(_k21_01), // %22
1584 "w"(_k21_23), // %23
1585 "w"(_k22_01), // %24
1586 "w"(_k22_23) // %25
1587 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13");
1588 }
1589 for (; j < outw; j++)
1590 {
1591 asm volatile(
1592 "prfm pldl1keep, [%1, #192] \n"
1593 "ld1 {v0.4h, v1.4h, v2.4h}, [%1] \n" // r00 r01 r02
1594
1595 "prfm pldl1keep, [%0, #64] \n"
1596 "ld1 {v13.4h}, [%0] \n" // sum0
1597
1598 "ext v6.16b, %8.16b, %8.16b, #8 \n"
1599 "fmul v10.4h, %8.4h, v0.h[0] \n"
1600 "fmul v11.4h, v6.4h, v0.h[1] \n"
1601 "ext v7.16b, %9.16b, %9.16b, #8 \n"
1602 "fmul v12.4h, %9.4h, v0.h[2] \n"
1603 "fmla v13.4h, v7.4h, v0.h[3] \n"
1604
1605 "ext v8.16b, %10.16b, %10.16b, #8 \n"
1606 "fmla v10.4h, %10.4h, v1.h[0] \n"
1607 "fmla v11.4h, v8.4h, v1.h[1] \n"
1608 "ext v9.16b, %11.16b, %11.16b, #8 \n"
1609 "fmla v12.4h, %11.4h, v1.h[2] \n"
1610 "fmla v13.4h, v9.4h, v1.h[3] \n"
1611
1612 "prfm pldl1keep, [%2, #192] \n"
1613 "ld1 {v3.4h, v4.4h, v5.4h}, [%2] \n" // r10 r11 r12
1614
1615 "ext v6.16b, %12.16b, %12.16b, #8 \n"
1616 "fmla v10.4h, %12.4h, v2.h[0] \n"
1617 "fmla v11.4h, v6.4h, v2.h[1] \n"
1618 "ext v7.16b, %13.16b, %13.16b, #8 \n"
1619 "fmla v12.4h, %13.4h, v2.h[2] \n"
1620 "fmla v13.4h, v7.4h, v2.h[3] \n"
1621
1622 "ext v8.16b, %14.16b, %14.16b, #8 \n"
1623 "fmla v10.4h, %14.4h, v3.h[0] \n"
1624 "fmla v11.4h, v8.4h, v3.h[1] \n"
1625 "ext v9.16b, %15.16b, %15.16b, #8 \n"
1626 "fmla v12.4h, %15.4h, v3.h[2] \n"
1627 "fmla v13.4h, v9.4h, v3.h[3] \n"
1628
1629 "ext v6.16b, %16.16b, %16.16b, #8 \n"
1630 "fmla v10.4h, %16.4h, v4.h[0] \n"
1631 "fmla v11.4h, v6.4h, v4.h[1] \n"
1632 "ext v7.16b, %17.16b, %17.16b, #8 \n"
1633 "fmla v12.4h, %17.4h, v4.h[2] \n"
1634 "fmla v13.4h, v7.4h, v4.h[3] \n"
1635
1636 "prfm pldl1keep, [%3, #192] \n"
1637 "ld1 {v0.4h, v1.4h, v2.4h}, [%3] \n" // r20 r21 r22
1638
1639 "ext v8.16b, %18.16b, %18.16b, #8 \n"
1640 "fmla v10.4h, %18.4h, v5.h[0] \n"
1641 "fmla v11.4h, v8.4h, v5.h[1] \n"
1642 "ext v9.16b, %19.16b, %19.16b, #8 \n"
1643 "fmla v12.4h, %19.4h, v5.h[2] \n"
1644 "fmla v13.4h, v9.4h, v5.h[3] \n"
1645
1646 "ext v6.16b, %20.16b, %20.16b, #8 \n"
1647 "fmla v10.4h, %20.4h, v0.h[0] \n"
1648 "fmla v11.4h, v6.4h, v0.h[1] \n"
1649 "ext v7.16b, %21.16b, %21.16b, #8 \n"
1650 "fmla v12.4h, %21.4h, v0.h[2] \n"
1651 "fmla v13.4h, v7.4h, v0.h[3] \n"
1652
1653 "ext v8.16b, %22.16b, %22.16b, #8 \n"
1654 "fmla v10.4h, %22.4h, v1.h[0] \n"
1655 "fmla v11.4h, v8.4h, v1.h[1] \n"
1656 "ext v9.16b, %23.16b, %23.16b, #8 \n"
1657 "fmla v12.4h, %23.4h, v1.h[2] \n"
1658 "fmla v13.4h, v9.4h, v1.h[3] \n"
1659
1660 "ext v6.16b, %24.16b, %24.16b, #8 \n"
1661 "fmla v10.4h, %24.4h, v2.h[0] \n"
1662 "fmla v11.4h, v6.4h, v2.h[1] \n"
1663 "ext v7.16b, %25.16b, %25.16b, #8 \n"
1664 "fmla v12.4h, %25.4h, v2.h[2] \n"
1665 "fmla v13.4h, v7.4h, v2.h[3] \n"
1666
1667 "fadd v10.4h, v10.4h, v11.4h \n"
1668
1669 "add %1, %1, #8 \n"
1670
1671 "fadd v12.4h, v12.4h, v13.4h \n"
1672
1673 "add %2, %2, #8 \n"
1674
1675 "fadd v10.4h, v10.4h, v12.4h \n"
1676
1677 "add %3, %3, #8 \n"
1678
1679 "st1 {v10.4h}, [%0], #8 \n"
1680
1681 : "=r"(outptr0), // %0
1682 "=r"(r0), // %1
1683 "=r"(r1), // %2
1684 "=r"(r2) // %3
1685 : "0"(outptr0),
1686 "1"(r0),
1687 "2"(r1),
1688 "3"(r2),
1689 "w"(_k00_01), // %8
1690 "w"(_k00_23), // %9
1691 "w"(_k01_01), // %10
1692 "w"(_k01_23), // %11
1693 "w"(_k02_01), // %12
1694 "w"(_k02_23), // %13
1695 "w"(_k10_01), // %14
1696 "w"(_k10_23), // %15
1697 "w"(_k11_01), // %16
1698 "w"(_k11_23), // %17
1699 "w"(_k12_01), // %18
1700 "w"(_k12_23), // %19
1701 "w"(_k20_01), // %20
1702 "w"(_k20_23), // %21
1703 "w"(_k21_01), // %22
1704 "w"(_k21_23), // %23
1705 "w"(_k22_01), // %24
1706 "w"(_k22_23) // %25
1707 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13");
1708 }
1709
1710 r0 += 8;
1711 r1 += 8;
1712 r2 += 8;
1713 }
1714 }
1715 }
1716 }
1717