1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv3x3s1_winograd64_transform_kernel_pack8to4_fp16sa_neon(const Mat & kernel,Mat & kernel_tm_pack8to4,int inch,int outch,const Option & opt)15 static void conv3x3s1_winograd64_transform_kernel_pack8to4_fp16sa_neon(const Mat& kernel, Mat& kernel_tm_pack8to4, int inch, int outch, const Option& opt)
16 {
17 // winograd63 transform kernel
18 Mat kernel_tm;
19 kernel_tm.create(8 * 8, inch, outch);
20
21 const float ktm[8][3] = {
22 {1.0f, 0.0f, 0.0f},
23 {-2.0f / 9, -2.0f / 9, -2.0f / 9},
24 {-2.0f / 9, 2.0f / 9, -2.0f / 9},
25 {1.0f / 90, 1.0f / 45, 2.0f / 45},
26 {1.0f / 90, -1.0f / 45, 2.0f / 45},
27 {1.0f / 45, 1.0f / 90, 1.0f / 180},
28 {1.0f / 45, -1.0f / 90, 1.0f / 180},
29 {0.0f, 0.0f, 1.0f}
30 };
31
32 #pragma omp parallel for num_threads(opt.num_threads)
33 for (int p = 0; p < outch; p++)
34 {
35 for (int q = 0; q < inch; q++)
36 {
37 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
38 float* kernel_tm0 = kernel_tm.channel(p).row(q);
39
40 // transform kernel, transposed
41 const float* k0 = kernel0;
42 const float* k1 = kernel0 + 3;
43 const float* k2 = kernel0 + 6;
44
45 // h
46 float tmp[8][3];
47 for (int i = 0; i < 8; i++)
48 {
49 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
50 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
51 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
52 }
53
54 // v
55 for (int j = 0; j < 8; j++)
56 {
57 float* tmpp = &tmp[j][0];
58
59 for (int i = 0; i < 8; i++)
60 {
61 kernel_tm0[j * 8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
62 }
63 }
64 }
65 }
66
67 // interleave
68 // src = 64-inch-outch
69 // dst = 4b-8a-inch/8a-64-outch/4b
70 kernel_tm_pack8to4.create(2 * inch / 8, 64, outch / 8 + (outch % 8) / 4, (size_t)2u * 32, 32);
71
72 int p = 0;
73 for (; p + 7 < outch; p += 8)
74 {
75 const Mat k0 = kernel_tm.channel(p);
76 const Mat k1 = kernel_tm.channel(p + 1);
77 const Mat k2 = kernel_tm.channel(p + 2);
78 const Mat k3 = kernel_tm.channel(p + 3);
79 const Mat k4 = kernel_tm.channel(p + 4);
80 const Mat k5 = kernel_tm.channel(p + 5);
81 const Mat k6 = kernel_tm.channel(p + 6);
82 const Mat k7 = kernel_tm.channel(p + 7);
83
84 Mat g0 = kernel_tm_pack8to4.channel(p / 8);
85
86 for (int k = 0; k < 64; k++)
87 {
88 __fp16* g00 = g0.row<__fp16>(k);
89
90 for (int q = 0; q + 7 < inch; q += 8)
91 {
92 for (int i = 0; i < 8; i++)
93 {
94 g00[0] = (__fp16)k0.row(q + i)[k];
95 g00[1] = (__fp16)k1.row(q + i)[k];
96 g00[2] = (__fp16)k2.row(q + i)[k];
97 g00[3] = (__fp16)k3.row(q + i)[k];
98 g00[4] = (__fp16)k4.row(q + i)[k];
99 g00[5] = (__fp16)k5.row(q + i)[k];
100 g00[6] = (__fp16)k6.row(q + i)[k];
101 g00[7] = (__fp16)k7.row(q + i)[k];
102
103 g00 += 8;
104 }
105 }
106 }
107 }
108 for (; p + 3 < outch; p += 4)
109 {
110 const Mat k0 = kernel_tm.channel(p);
111 const Mat k1 = kernel_tm.channel(p + 1);
112 const Mat k2 = kernel_tm.channel(p + 2);
113 const Mat k3 = kernel_tm.channel(p + 3);
114
115 Mat g0 = kernel_tm_pack8to4.channel(p / 8 + (p % 8) / 4);
116
117 for (int k = 0; k < 64; k++)
118 {
119 __fp16* g00 = g0.row<__fp16>(k);
120
121 for (int q = 0; q + 7 < inch; q += 8)
122 {
123 for (int i = 0; i < 8; i++)
124 {
125 g00[0] = (__fp16)k0.row(q + i)[k];
126 g00[1] = (__fp16)k1.row(q + i)[k];
127 g00[2] = (__fp16)k2.row(q + i)[k];
128 g00[3] = (__fp16)k3.row(q + i)[k];
129
130 g00 += 4;
131 }
132 }
133 }
134 }
135 }
136
conv3x3s1_winograd64_pack8to4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)137 static void conv3x3s1_winograd64_pack8to4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
138 {
139 int w = bottom_blob.w;
140 int h = bottom_blob.h;
141 int inch = bottom_blob.c;
142 //size_t elemsize = bottom_blob.elemsize;
143 int elempack = bottom_blob.elempack;
144
145 int outw = top_blob.w;
146 int outh = top_blob.h;
147 int outch = top_blob.c;
148
149 // pad to 6n+2
150 Mat bottom_blob_bordered = bottom_blob;
151
152 outw = (outw + 5) / 6 * 6;
153 outh = (outh + 5) / 6 * 6;
154
155 w = outw + 2;
156 h = outh + 2;
157 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
158
159 const __fp16* bias = _bias;
160
161 // BEGIN transform input
162 Mat bottom_blob_tm;
163 {
164 int w_tm = outw / 6 * 8;
165 int h_tm = outh / 6 * 8;
166
167 const int tiles = w_tm / 8 * h_tm / 8;
168
169 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
170 bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator);
171
172 // const float itm[8][8] = {
173 // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f},
174 //
175 // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f},
176 // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f},
177 //
178 // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f},
179 // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f},
180 //
181 // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f},
182 // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f},
183 //
184 // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f}
185 // };
186
187 // 0 = r00 - r06 + (r04 - r02) * 5.25
188 // 7 = r07 - r01 + (r03 - r05) * 5.25
189
190 // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
191 // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
192
193 // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
194 // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
195
196 // reuse r04 * 1.25
197 // reuse r03 * 2.5
198 // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
199 // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
200
201 #pragma omp parallel for num_threads(opt.num_threads)
202 for (int q = 0; q < inch; q++)
203 {
204 const Mat img0 = bottom_blob_bordered.channel(q);
205 Mat img0_tm = bottom_blob_tm.channel(q);
206
207 __fp16 tmp[8][8][8];
208
209 // tile
210 for (int i = 0; i < h_tm / 8; i++)
211 {
212 for (int j = 0; j < w_tm / 8; j++)
213 {
214 const __fp16* r0 = img0.row<const __fp16>(i * 6) + (j * 6) * 8;
215
216 for (int m = 0; m < 8; m++)
217 {
218 float16x8_t _r00 = vld1q_f16(r0);
219 float16x8_t _r01 = vld1q_f16(r0 + 8);
220 float16x8_t _r02 = vld1q_f16(r0 + 16);
221 float16x8_t _r03 = vld1q_f16(r0 + 24);
222 float16x8_t _r04 = vld1q_f16(r0 + 32);
223 float16x8_t _r05 = vld1q_f16(r0 + 40);
224 float16x8_t _r06 = vld1q_f16(r0 + 48);
225 float16x8_t _r07 = vld1q_f16(r0 + 56);
226
227 float16x8_t _tmp0m = vfmaq_n_f16(vsubq_f16(_r00, _r06), vsubq_f16(_r04, _r02), 5.25f);
228 float16x8_t _tmp7m = vfmaq_n_f16(vsubq_f16(_r07, _r01), vsubq_f16(_r03, _r05), 5.25f);
229 vst1q_f16(tmp[0][m], _tmp0m);
230 vst1q_f16(tmp[7][m], _tmp7m);
231
232 // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25;
233 // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25;
234
235 float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_r02, _r06), _r04, 4.25f);
236 float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_r01, _r05), _r03, 4.25f);
237
238 // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25);
239 // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25);
240
241 float16x8_t _tmp1m = vaddq_f16(_tmp12a, _tmp12b);
242 float16x8_t _tmp2m = vsubq_f16(_tmp12a, _tmp12b);
243 vst1q_f16(tmp[1][m], _tmp1m);
244 vst1q_f16(tmp[2][m], _tmp2m);
245
246 // tmp[1][m] = tmp12a + tmp12b;
247 // tmp[2][m] = tmp12a - tmp12b;
248
249 float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_r06, _r02, 0.25f), _r04, 1.25f);
250 float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f);
251
252 // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25);
253 // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2);
254
255 float16x8_t _tmp3m = vaddq_f16(_tmp34a, _tmp34b);
256 float16x8_t _tmp4m = vsubq_f16(_tmp34a, _tmp34b);
257 vst1q_f16(tmp[3][m], _tmp3m);
258 vst1q_f16(tmp[4][m], _tmp4m);
259
260 // tmp[3][m] = tmp34a + tmp34b;
261 // tmp[4][m] = tmp34a - tmp34b;
262
263 float16x8_t _tmp56a = vfmaq_n_f16(_r06, vfmsq_n_f16(_r02, _r04, 1.25f), 4.f);
264 float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f);
265
266 // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4);
267 // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5);
268
269 float16x8_t _tmp5m = vaddq_f16(_tmp56a, _tmp56b);
270 float16x8_t _tmp6m = vsubq_f16(_tmp56a, _tmp56b);
271 vst1q_f16(tmp[5][m], _tmp5m);
272 vst1q_f16(tmp[6][m], _tmp6m);
273
274 // tmp[5][m] = tmp56a + tmp56b;
275 // tmp[6][m] = tmp56a - tmp56b;
276
277 r0 += w * 8;
278 }
279
280 __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 8;
281 __fp16* r0_tm_1 = r0_tm_0 + tiles * 8;
282 __fp16* r0_tm_2 = r0_tm_0 + tiles * 16;
283 __fp16* r0_tm_3 = r0_tm_0 + tiles * 24;
284 __fp16* r0_tm_4 = r0_tm_0 + tiles * 32;
285 __fp16* r0_tm_5 = r0_tm_0 + tiles * 40;
286 __fp16* r0_tm_6 = r0_tm_0 + tiles * 48;
287 __fp16* r0_tm_7 = r0_tm_0 + tiles * 56;
288
289 for (int m = 0; m < 8; m++)
290 {
291 float16x8_t _tmp00 = vld1q_f16(tmp[m][0]);
292 float16x8_t _tmp01 = vld1q_f16(tmp[m][1]);
293 float16x8_t _tmp02 = vld1q_f16(tmp[m][2]);
294 float16x8_t _tmp03 = vld1q_f16(tmp[m][3]);
295 float16x8_t _tmp04 = vld1q_f16(tmp[m][4]);
296 float16x8_t _tmp05 = vld1q_f16(tmp[m][5]);
297 float16x8_t _tmp06 = vld1q_f16(tmp[m][6]);
298 float16x8_t _tmp07 = vld1q_f16(tmp[m][7]);
299
300 float16x8_t _r0tm0 = vfmaq_n_f16(vsubq_f16(_tmp00, _tmp06), vsubq_f16(_tmp04, _tmp02), 5.25f);
301 float16x8_t _r0tm7 = vfmaq_n_f16(vsubq_f16(_tmp07, _tmp01), vsubq_f16(_tmp03, _tmp05), 5.25f);
302
303 // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25;
304 // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25;
305
306 float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_tmp02, _tmp06), _tmp04, 4.25f);
307 float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_tmp01, _tmp05), _tmp03, 4.25f);
308
309 // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25);
310 // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25);
311
312 float16x8_t _r0tm1 = vaddq_f16(_tmp12a, _tmp12b);
313 float16x8_t _r0tm2 = vsubq_f16(_tmp12a, _tmp12b);
314
315 // r0_tm[1] = tmp12a + tmp12b;
316 // r0_tm[2] = tmp12a - tmp12b;
317
318 float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f);
319 float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f);
320
321 // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25);
322 // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2);
323
324 float16x8_t _r0tm3 = vaddq_f16(_tmp34a, _tmp34b);
325 float16x8_t _r0tm4 = vsubq_f16(_tmp34a, _tmp34b);
326
327 // r0_tm[3] = tmp34a + tmp34b;
328 // r0_tm[4] = tmp34a - tmp34b;
329
330 float16x8_t _tmp56a = vfmaq_n_f16(_tmp06, vfmsq_n_f16(_tmp02, _tmp04, 1.25f), 4.f);
331 float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f);
332
333 // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4);
334 // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5);
335
336 float16x8_t _r0tm5 = vaddq_f16(_tmp56a, _tmp56b);
337 float16x8_t _r0tm6 = vsubq_f16(_tmp56a, _tmp56b);
338
339 // r0_tm[5] = tmp56a + tmp56b;
340 // r0_tm[6] = tmp56a - tmp56b;
341
342 vst1q_f16(r0_tm_0, _r0tm0);
343 vst1q_f16(r0_tm_1, _r0tm1);
344 vst1q_f16(r0_tm_2, _r0tm2);
345 vst1q_f16(r0_tm_3, _r0tm3);
346 vst1q_f16(r0_tm_4, _r0tm4);
347 vst1q_f16(r0_tm_5, _r0tm5);
348 vst1q_f16(r0_tm_6, _r0tm6);
349 vst1q_f16(r0_tm_7, _r0tm7);
350
351 r0_tm_0 += tiles * 64;
352 r0_tm_1 += tiles * 64;
353 r0_tm_2 += tiles * 64;
354 r0_tm_3 += tiles * 64;
355 r0_tm_4 += tiles * 64;
356 r0_tm_5 += tiles * 64;
357 r0_tm_6 += tiles * 64;
358 r0_tm_7 += tiles * 64;
359 }
360 }
361 }
362 }
363 }
364 bottom_blob_bordered = Mat();
365 // END transform input
366
367 // BEGIN dot
368 Mat top_blob_tm;
369 {
370 int w_tm = outw / 6 * 8;
371 int h_tm = outh / 6 * 8;
372
373 const int tiles = h_tm / 8 * w_tm / 8;
374
375 // permute
376 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
377 Mat bottom_blob_tm2;
378 if (tiles >= 8)
379 bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
380 else if (tiles >= 4)
381 bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
382 else // if (tiles >= 1)
383 bottom_blob_tm2.create(1 * inch, tiles, 64, 2u * elempack, elempack, opt.workspace_allocator);
384
385 #pragma omp parallel for num_threads(opt.num_threads)
386 for (int r = 0; r < 64; r++)
387 {
388 Mat tm2 = bottom_blob_tm2.channel(r);
389
390 // tile
391 int i = 0;
392 for (; i + 7 < tiles; i += 8)
393 {
394 __fp16* tm2p = tm2.row<__fp16>(i / 8);
395
396 const __fp16* r0 = bottom_blob_tm;
397
398 r0 += (r * tiles + i) * 8;
399
400 for (int q = 0; q < inch; q++)
401 {
402 // transpose 8x8
403 asm volatile(
404 "prfm pldl1keep, [%0, #512] \n"
405 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
406 "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n"
407 "sub %0, %0, #64 \n"
408
409 "uzp1 v16.8h, v0.8h, v4.8h \n"
410 "uzp2 v20.8h, v0.8h, v4.8h \n"
411 "uzp1 v17.8h, v1.8h, v5.8h \n"
412 "uzp2 v21.8h, v1.8h, v5.8h \n"
413 "uzp1 v18.8h, v2.8h, v6.8h \n"
414 "uzp2 v22.8h, v2.8h, v6.8h \n"
415 "uzp1 v19.8h, v3.8h, v7.8h \n"
416 "uzp2 v23.8h, v3.8h, v7.8h \n"
417
418 "st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n"
419 "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n"
420 : "=r"(r0), // %0
421 "=r"(tm2p) // %1
422 : "0"(r0),
423 "1"(tm2p)
424 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
425
426 r0 += bottom_blob_tm.cstep * 8;
427 }
428 }
429 for (; i + 3 < tiles; i += 4)
430 {
431 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4);
432
433 const __fp16* r0 = bottom_blob_tm;
434
435 r0 += (r * tiles + i) * 8;
436
437 for (int q = 0; q < inch; q++)
438 {
439 // transpose 8x4
440 asm volatile(
441 "prfm pldl1keep, [%0, #256] \n"
442 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
443 "st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
444 : "=r"(r0), // %0
445 "=r"(tm2p) // %1
446 : "0"(r0),
447 "1"(tm2p)
448 : "memory", "v0", "v1", "v2", "v3");
449
450 r0 += bottom_blob_tm.cstep * 8;
451 }
452 }
453 for (; i < tiles; i++)
454 {
455 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4 + i % 4);
456
457 const __fp16* r0 = bottom_blob_tm;
458
459 r0 += (r * tiles + i) * 8;
460
461 for (int q = 0; q < inch; q++)
462 {
463 asm volatile(
464 "prfm pldl1keep, [%0, #128] \n"
465 "ld1 {v0.8h}, [%0] \n"
466 "st1 {v0.8h}, [%1], #16 \n"
467 : "=r"(r0), // %0
468 "=r"(tm2p) // %1
469 : "0"(r0),
470 "1"(tm2p)
471 : "memory", "v0");
472
473 r0 += bottom_blob_tm.cstep * 8;
474 }
475 }
476 }
477
478 bottom_blob_tm = Mat();
479 // permute end
480
481 top_blob_tm.create(tiles, 64, outch, 2u * 4, 4, opt.workspace_allocator);
482
483 int nn_outch = 0;
484 int remain_outch_start = 0;
485
486 nn_outch = outch >> 1;
487
488 #pragma omp parallel for num_threads(opt.num_threads)
489 for (int pp = 0; pp < nn_outch; pp++)
490 {
491 int p = pp * 2;
492
493 __fp16* output0_tm = top_blob_tm.channel(p);
494 __fp16* output1_tm = top_blob_tm.channel(p + 1);
495
496 const Mat kernel01_tm = kernel_tm.channel(p / 2);
497
498 for (int r = 0; r < 64; r++)
499 {
500 const Mat bb2 = bottom_blob_tm2.channel(r);
501
502 int i = 0;
503 for (; i + 7 < tiles; i += 8)
504 {
505 const __fp16* r0 = bb2.row<const __fp16>(i / 8);
506
507 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
508
509 int nn = inch; // inch always > 0
510
511 asm volatile(
512 "eor v24.16b, v24.16b, v24.16b \n"
513 "eor v25.16b, v25.16b, v25.16b \n"
514 "eor v26.16b, v26.16b, v26.16b \n"
515 "eor v27.16b, v27.16b, v27.16b \n"
516 "eor v28.16b, v28.16b, v28.16b \n"
517 "eor v29.16b, v29.16b, v29.16b \n"
518 "eor v30.16b, v30.16b, v30.16b \n"
519 "eor v31.16b, v31.16b, v31.16b \n"
520
521 "0: \n"
522
523 "prfm pldl1keep, [%4, #512] \n"
524 "ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%4], #64 \n"
525
526 "prfm pldl1keep, [%3, #512] \n"
527 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%3], #64 \n"
528
529 "fmla v24.8h, v16.8h, v0.h[0] \n"
530 "fmla v25.8h, v16.8h, v0.h[1] \n"
531 "fmla v26.8h, v16.8h, v0.h[2] \n"
532 "fmla v27.8h, v16.8h, v0.h[3] \n"
533 "fmla v28.8h, v16.8h, v0.h[4] \n"
534 "fmla v29.8h, v16.8h, v0.h[5] \n"
535 "fmla v30.8h, v16.8h, v0.h[6] \n"
536 "fmla v31.8h, v16.8h, v0.h[7] \n"
537
538 "fmla v24.8h, v17.8h, v1.h[0] \n"
539 "fmla v25.8h, v17.8h, v1.h[1] \n"
540 "fmla v26.8h, v17.8h, v1.h[2] \n"
541 "fmla v27.8h, v17.8h, v1.h[3] \n"
542 "fmla v28.8h, v17.8h, v1.h[4] \n"
543 "fmla v29.8h, v17.8h, v1.h[5] \n"
544 "fmla v30.8h, v17.8h, v1.h[6] \n"
545 "fmla v31.8h, v17.8h, v1.h[7] \n"
546
547 "prfm pldl1keep, [%4, #512] \n"
548 "ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%4], #64 \n"
549
550 "fmla v24.8h, v18.8h, v2.h[0] \n"
551 "fmla v25.8h, v18.8h, v2.h[1] \n"
552 "fmla v26.8h, v18.8h, v2.h[2] \n"
553 "fmla v27.8h, v18.8h, v2.h[3] \n"
554 "fmla v28.8h, v18.8h, v2.h[4] \n"
555 "fmla v29.8h, v18.8h, v2.h[5] \n"
556 "fmla v30.8h, v18.8h, v2.h[6] \n"
557 "fmla v31.8h, v18.8h, v2.h[7] \n"
558
559 "prfm pldl1keep, [%3, #512] \n"
560 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%3], #64 \n"
561
562 "fmla v24.8h, v19.8h, v3.h[0] \n"
563 "fmla v25.8h, v19.8h, v3.h[1] \n"
564 "fmla v26.8h, v19.8h, v3.h[2] \n"
565 "fmla v27.8h, v19.8h, v3.h[3] \n"
566 "fmla v28.8h, v19.8h, v3.h[4] \n"
567 "fmla v29.8h, v19.8h, v3.h[5] \n"
568 "fmla v30.8h, v19.8h, v3.h[6] \n"
569 "fmla v31.8h, v19.8h, v3.h[7] \n"
570
571 "fmla v24.8h, v20.8h, v4.h[0] \n"
572 "fmla v25.8h, v20.8h, v4.h[1] \n"
573 "fmla v26.8h, v20.8h, v4.h[2] \n"
574 "fmla v27.8h, v20.8h, v4.h[3] \n"
575 "fmla v28.8h, v20.8h, v4.h[4] \n"
576 "fmla v29.8h, v20.8h, v4.h[5] \n"
577 "fmla v30.8h, v20.8h, v4.h[6] \n"
578 "fmla v31.8h, v20.8h, v4.h[7] \n"
579
580 "fmla v24.8h, v21.8h, v5.h[0] \n"
581 "fmla v25.8h, v21.8h, v5.h[1] \n"
582 "fmla v26.8h, v21.8h, v5.h[2] \n"
583 "fmla v27.8h, v21.8h, v5.h[3] \n"
584 "fmla v28.8h, v21.8h, v5.h[4] \n"
585 "fmla v29.8h, v21.8h, v5.h[5] \n"
586 "fmla v30.8h, v21.8h, v5.h[6] \n"
587 "fmla v31.8h, v21.8h, v5.h[7] \n"
588
589 "fmla v24.8h, v22.8h, v6.h[0] \n"
590 "fmla v25.8h, v22.8h, v6.h[1] \n"
591 "fmla v26.8h, v22.8h, v6.h[2] \n"
592 "fmla v27.8h, v22.8h, v6.h[3] \n"
593 "fmla v28.8h, v22.8h, v6.h[4] \n"
594 "fmla v29.8h, v22.8h, v6.h[5] \n"
595 "fmla v30.8h, v22.8h, v6.h[6] \n"
596 "fmla v31.8h, v22.8h, v6.h[7] \n"
597
598 "subs %w0, %w0, #1 \n"
599
600 "fmla v24.8h, v23.8h, v7.h[0] \n"
601 "fmla v25.8h, v23.8h, v7.h[1] \n"
602 "fmla v26.8h, v23.8h, v7.h[2] \n"
603 "fmla v27.8h, v23.8h, v7.h[3] \n"
604 "fmla v28.8h, v23.8h, v7.h[4] \n"
605 "fmla v29.8h, v23.8h, v7.h[5] \n"
606 "fmla v30.8h, v23.8h, v7.h[6] \n"
607 "fmla v31.8h, v23.8h, v7.h[7] \n"
608
609 "bne 0b \n"
610
611 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
612 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
613
614 "ext v24.16b, v24.16b, v24.16b, #8 \n"
615 "ext v25.16b, v25.16b, v25.16b, #8 \n"
616 "ext v26.16b, v26.16b, v26.16b, #8 \n"
617 "ext v27.16b, v27.16b, v27.16b, #8 \n"
618 "ext v28.16b, v28.16b, v28.16b, #8 \n"
619 "ext v29.16b, v29.16b, v29.16b, #8 \n"
620 "ext v30.16b, v30.16b, v30.16b, #8 \n"
621 "ext v31.16b, v31.16b, v31.16b, #8 \n"
622
623 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
624 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%2], #32 \n"
625
626 : "=r"(nn), // %0
627 "=r"(output0_tm), // %1
628 "=r"(output1_tm), // %2
629 "=r"(r0), // %3
630 "=r"(kptr) // %4
631 : "0"(nn),
632 "1"(output0_tm),
633 "2"(output1_tm),
634 "3"(r0),
635 "4"(kptr)
636 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
637 }
638 for (; i + 3 < tiles; i += 4)
639 {
640 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
641
642 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
643
644 int nn = inch; // inch always > 0
645
646 asm volatile(
647 "eor v24.16b, v24.16b, v24.16b \n"
648 "eor v25.16b, v25.16b, v25.16b \n"
649 "eor v26.16b, v26.16b, v26.16b \n"
650 "eor v27.16b, v27.16b, v27.16b \n"
651 "eor v28.16b, v28.16b, v28.16b \n"
652 "eor v29.16b, v29.16b, v29.16b \n"
653 "eor v30.16b, v30.16b, v30.16b \n"
654 "eor v31.16b, v31.16b, v31.16b \n"
655
656 "0: \n"
657
658 "prfm pldl1keep, [%4, #512] \n"
659 "ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%4], #64 \n"
660
661 "prfm pldl1keep, [%3, #512] \n"
662 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%3], #64 \n"
663
664 "fmla v24.8h, v16.8h, v0.h[0] \n"
665 "fmla v25.8h, v16.8h, v0.h[1] \n"
666 "fmla v26.8h, v16.8h, v0.h[2] \n"
667 "fmla v27.8h, v16.8h, v0.h[3] \n"
668 "fmla v24.8h, v17.8h, v0.h[4] \n"
669 "fmla v25.8h, v17.8h, v0.h[5] \n"
670 "fmla v26.8h, v17.8h, v0.h[6] \n"
671 "fmla v27.8h, v17.8h, v0.h[7] \n"
672
673 "prfm pldl1keep, [%4, #512] \n"
674 "ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%4], #64 \n"
675
676 "fmla v24.8h, v18.8h, v1.h[0] \n"
677 "fmla v25.8h, v18.8h, v1.h[1] \n"
678 "fmla v26.8h, v18.8h, v1.h[2] \n"
679 "fmla v27.8h, v18.8h, v1.h[3] \n"
680 "fmla v24.8h, v19.8h, v1.h[4] \n"
681 "fmla v25.8h, v19.8h, v1.h[5] \n"
682 "fmla v26.8h, v19.8h, v1.h[6] \n"
683 "fmla v27.8h, v19.8h, v1.h[7] \n"
684
685 "fmla v24.8h, v20.8h, v2.h[0] \n"
686 "fmla v25.8h, v20.8h, v2.h[1] \n"
687 "fmla v26.8h, v20.8h, v2.h[2] \n"
688 "fmla v27.8h, v20.8h, v2.h[3] \n"
689 "fmla v24.8h, v21.8h, v2.h[4] \n"
690 "fmla v25.8h, v21.8h, v2.h[5] \n"
691 "fmla v26.8h, v21.8h, v2.h[6] \n"
692 "fmla v27.8h, v21.8h, v2.h[7] \n"
693
694 "subs %w0, %w0, #1 \n"
695
696 "fmla v24.8h, v22.8h, v3.h[0] \n"
697 "fmla v25.8h, v22.8h, v3.h[1] \n"
698 "fmla v26.8h, v22.8h, v3.h[2] \n"
699 "fmla v27.8h, v22.8h, v3.h[3] \n"
700 "fmla v24.8h, v23.8h, v3.h[4] \n"
701 "fmla v25.8h, v23.8h, v3.h[5] \n"
702 "fmla v26.8h, v23.8h, v3.h[6] \n"
703 "fmla v27.8h, v23.8h, v3.h[7] \n"
704
705 "bne 0b \n"
706
707 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
708
709 "ext v24.16b, v24.16b, v24.16b, #8 \n"
710 "ext v25.16b, v25.16b, v25.16b, #8 \n"
711 "ext v26.16b, v26.16b, v26.16b, #8 \n"
712 "ext v27.16b, v27.16b, v27.16b, #8 \n"
713
714 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
715
716 : "=r"(nn), // %0
717 "=r"(output0_tm), // %1
718 "=r"(output1_tm), // %2
719 "=r"(r0), // %3
720 "=r"(kptr) // %4
721 : "0"(nn),
722 "1"(output0_tm),
723 "2"(output1_tm),
724 "3"(r0),
725 "4"(kptr)
726 : "cc", "memory", "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
727 }
728 for (; i < tiles; i++)
729 {
730 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
731
732 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
733
734 float16x8_t _sum0 = vdupq_n_f16((__fp16)0.f);
735
736 for (int q = 0; q < inch; q++)
737 {
738 float16x8_t _r0 = vld1q_f16(r0);
739
740 float16x8_t _k0 = vld1q_f16(kptr);
741 float16x8_t _k1 = vld1q_f16(kptr + 8);
742 float16x8_t _k2 = vld1q_f16(kptr + 16);
743 float16x8_t _k3 = vld1q_f16(kptr + 24);
744 float16x8_t _k4 = vld1q_f16(kptr + 32);
745 float16x8_t _k5 = vld1q_f16(kptr + 40);
746 float16x8_t _k6 = vld1q_f16(kptr + 48);
747 float16x8_t _k7 = vld1q_f16(kptr + 56);
748
749 _sum0 = vfmaq_laneq_f16(_sum0, _k0, _r0, 0);
750 _sum0 = vfmaq_laneq_f16(_sum0, _k1, _r0, 1);
751 _sum0 = vfmaq_laneq_f16(_sum0, _k2, _r0, 2);
752 _sum0 = vfmaq_laneq_f16(_sum0, _k3, _r0, 3);
753 _sum0 = vfmaq_laneq_f16(_sum0, _k4, _r0, 4);
754 _sum0 = vfmaq_laneq_f16(_sum0, _k5, _r0, 5);
755 _sum0 = vfmaq_laneq_f16(_sum0, _k6, _r0, 6);
756 _sum0 = vfmaq_laneq_f16(_sum0, _k7, _r0, 7);
757
758 kptr += 64;
759 r0 += 8;
760 }
761
762 vst1_f16(output0_tm, vget_low_f16(_sum0));
763 vst1_f16(output1_tm, vget_high_f16(_sum0));
764
765 output0_tm += 4;
766 output1_tm += 4;
767 }
768 }
769 }
770
771 remain_outch_start += nn_outch << 1;
772
773 #pragma omp parallel for num_threads(opt.num_threads)
774 for (int p = remain_outch_start; p < outch; p++)
775 {
776 __fp16* output0_tm = top_blob_tm.channel(p);
777
778 const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2);
779
780 for (int r = 0; r < 64; r++)
781 {
782 const Mat bb2 = bottom_blob_tm2.channel(r);
783
784 int i = 0;
785 for (; i + 7 < tiles; i += 8)
786 {
787 const __fp16* r0 = bb2.row<const __fp16>(i / 8);
788
789 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
790
791 int nn = inch; // inch always > 0
792
793 asm volatile(
794 "eor v24.16b, v24.16b, v24.16b \n"
795 "eor v25.16b, v25.16b, v25.16b \n"
796 "eor v26.16b, v26.16b, v26.16b \n"
797 "eor v27.16b, v27.16b, v27.16b \n"
798 "eor v28.16b, v28.16b, v28.16b \n"
799 "eor v29.16b, v29.16b, v29.16b \n"
800 "eor v30.16b, v30.16b, v30.16b \n"
801 "eor v31.16b, v31.16b, v31.16b \n"
802
803 "0: \n"
804
805 "prfm pldl1keep, [%3, #256] \n"
806 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%3], #32 \n"
807
808 "prfm pldl1keep, [%2, #512] \n"
809 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%2], #64 \n"
810
811 "fmla v24.4h, v16.4h, v0.h[0] \n"
812 "fmla v25.4h, v16.4h, v0.h[1] \n"
813 "fmla v26.4h, v16.4h, v0.h[2] \n"
814 "fmla v27.4h, v16.4h, v0.h[3] \n"
815 "fmla v28.4h, v16.4h, v0.h[4] \n"
816 "fmla v29.4h, v16.4h, v0.h[5] \n"
817 "fmla v30.4h, v16.4h, v0.h[6] \n"
818 "fmla v31.4h, v16.4h, v0.h[7] \n"
819
820 "fmla v24.4h, v17.4h, v1.h[0] \n"
821 "fmla v25.4h, v17.4h, v1.h[1] \n"
822 "fmla v26.4h, v17.4h, v1.h[2] \n"
823 "fmla v27.4h, v17.4h, v1.h[3] \n"
824 "fmla v28.4h, v17.4h, v1.h[4] \n"
825 "fmla v29.4h, v17.4h, v1.h[5] \n"
826 "fmla v30.4h, v17.4h, v1.h[6] \n"
827 "fmla v31.4h, v17.4h, v1.h[7] \n"
828
829 "prfm pldl1keep, [%3, #256] \n"
830 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%3], #32 \n"
831
832 "fmla v24.4h, v18.4h, v2.h[0] \n"
833 "fmla v25.4h, v18.4h, v2.h[1] \n"
834 "fmla v26.4h, v18.4h, v2.h[2] \n"
835 "fmla v27.4h, v18.4h, v2.h[3] \n"
836 "fmla v28.4h, v18.4h, v2.h[4] \n"
837 "fmla v29.4h, v18.4h, v2.h[5] \n"
838 "fmla v30.4h, v18.4h, v2.h[6] \n"
839 "fmla v31.4h, v18.4h, v2.h[7] \n"
840
841 "prfm pldl1keep, [%2, #512] \n"
842 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%2], #64 \n"
843
844 "fmla v24.4h, v19.4h, v3.h[0] \n"
845 "fmla v25.4h, v19.4h, v3.h[1] \n"
846 "fmla v26.4h, v19.4h, v3.h[2] \n"
847 "fmla v27.4h, v19.4h, v3.h[3] \n"
848 "fmla v28.4h, v19.4h, v3.h[4] \n"
849 "fmla v29.4h, v19.4h, v3.h[5] \n"
850 "fmla v30.4h, v19.4h, v3.h[6] \n"
851 "fmla v31.4h, v19.4h, v3.h[7] \n"
852
853 "fmla v24.4h, v20.4h, v4.h[0] \n"
854 "fmla v25.4h, v20.4h, v4.h[1] \n"
855 "fmla v26.4h, v20.4h, v4.h[2] \n"
856 "fmla v27.4h, v20.4h, v4.h[3] \n"
857 "fmla v28.4h, v20.4h, v4.h[4] \n"
858 "fmla v29.4h, v20.4h, v4.h[5] \n"
859 "fmla v30.4h, v20.4h, v4.h[6] \n"
860 "fmla v31.4h, v20.4h, v4.h[7] \n"
861
862 "fmla v24.4h, v21.4h, v5.h[0] \n"
863 "fmla v25.4h, v21.4h, v5.h[1] \n"
864 "fmla v26.4h, v21.4h, v5.h[2] \n"
865 "fmla v27.4h, v21.4h, v5.h[3] \n"
866 "fmla v28.4h, v21.4h, v5.h[4] \n"
867 "fmla v29.4h, v21.4h, v5.h[5] \n"
868 "fmla v30.4h, v21.4h, v5.h[6] \n"
869 "fmla v31.4h, v21.4h, v5.h[7] \n"
870
871 "fmla v24.4h, v22.4h, v6.h[0] \n"
872 "fmla v25.4h, v22.4h, v6.h[1] \n"
873 "fmla v26.4h, v22.4h, v6.h[2] \n"
874 "fmla v27.4h, v22.4h, v6.h[3] \n"
875 "fmla v28.4h, v22.4h, v6.h[4] \n"
876 "fmla v29.4h, v22.4h, v6.h[5] \n"
877 "fmla v30.4h, v22.4h, v6.h[6] \n"
878 "fmla v31.4h, v22.4h, v6.h[7] \n"
879
880 "subs %w0, %w0, #1 \n"
881
882 "fmla v24.4h, v23.4h, v7.h[0] \n"
883 "fmla v25.4h, v23.4h, v7.h[1] \n"
884 "fmla v26.4h, v23.4h, v7.h[2] \n"
885 "fmla v27.4h, v23.4h, v7.h[3] \n"
886 "fmla v28.4h, v23.4h, v7.h[4] \n"
887 "fmla v29.4h, v23.4h, v7.h[5] \n"
888 "fmla v30.4h, v23.4h, v7.h[6] \n"
889 "fmla v31.4h, v23.4h, v7.h[7] \n"
890
891 "bne 0b \n"
892
893 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
894 "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
895
896 : "=r"(nn), // %0
897 "=r"(output0_tm), // %1
898 "=r"(r0), // %2
899 "=r"(kptr) // %3
900 : "0"(nn),
901 "1"(output0_tm),
902 "2"(r0),
903 "3"(kptr)
904 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
905 }
906 for (; i + 3 < tiles; i += 4)
907 {
908 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
909
910 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
911
912 int nn = inch; // inch always > 0
913
914 asm volatile(
915 "eor v24.16b, v24.16b, v24.16b \n"
916 "eor v25.16b, v25.16b, v25.16b \n"
917 "eor v26.16b, v26.16b, v26.16b \n"
918 "eor v27.16b, v27.16b, v27.16b \n"
919
920 "0: \n"
921
922 "prfm pldl1keep, [%3, #256] \n"
923 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%3], #32 \n"
924
925 "prfm pldl1keep, [%2, #512] \n"
926 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%2], #64 \n"
927
928 "fmla v24.4h, v16.4h, v0.h[0] \n"
929 "fmla v25.4h, v16.4h, v0.h[1] \n"
930 "fmla v26.4h, v16.4h, v0.h[2] \n"
931 "fmla v27.4h, v16.4h, v0.h[3] \n"
932 "fmla v24.4h, v17.4h, v0.h[4] \n"
933 "fmla v25.4h, v17.4h, v0.h[5] \n"
934 "fmla v26.4h, v17.4h, v0.h[6] \n"
935 "fmla v27.4h, v17.4h, v0.h[7] \n"
936
937 "prfm pldl1keep, [%3, #256] \n"
938 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%3], #32 \n"
939
940 "fmla v24.4h, v18.4h, v1.h[0] \n"
941 "fmla v25.4h, v18.4h, v1.h[1] \n"
942 "fmla v26.4h, v18.4h, v1.h[2] \n"
943 "fmla v27.4h, v18.4h, v1.h[3] \n"
944 "fmla v24.4h, v19.4h, v1.h[4] \n"
945 "fmla v25.4h, v19.4h, v1.h[5] \n"
946 "fmla v26.4h, v19.4h, v1.h[6] \n"
947 "fmla v27.4h, v19.4h, v1.h[7] \n"
948
949 "fmla v24.4h, v20.4h, v2.h[0] \n"
950 "fmla v25.4h, v20.4h, v2.h[1] \n"
951 "fmla v26.4h, v20.4h, v2.h[2] \n"
952 "fmla v27.4h, v20.4h, v2.h[3] \n"
953 "fmla v24.4h, v21.4h, v2.h[4] \n"
954 "fmla v25.4h, v21.4h, v2.h[5] \n"
955 "fmla v26.4h, v21.4h, v2.h[6] \n"
956 "fmla v27.4h, v21.4h, v2.h[7] \n"
957
958 "subs %w0, %w0, #1 \n"
959
960 "fmla v24.4h, v22.4h, v3.h[0] \n"
961 "fmla v25.4h, v22.4h, v3.h[1] \n"
962 "fmla v26.4h, v22.4h, v3.h[2] \n"
963 "fmla v27.4h, v22.4h, v3.h[3] \n"
964 "fmla v24.4h, v23.4h, v3.h[4] \n"
965 "fmla v25.4h, v23.4h, v3.h[5] \n"
966 "fmla v26.4h, v23.4h, v3.h[6] \n"
967 "fmla v27.4h, v23.4h, v3.h[7] \n"
968
969 "bne 0b \n"
970
971 "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
972
973 : "=r"(nn), // %0
974 "=r"(output0_tm), // %1
975 "=r"(r0), // %2
976 "=r"(kptr) // %3
977 : "0"(nn),
978 "1"(output0_tm),
979 "2"(r0),
980 "3"(kptr)
981 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
982 }
983 for (; i < tiles; i++)
984 {
985 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
986
987 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
988
989 float16x4_t _sum0 = vdup_n_f16((__fp16)0.f);
990
991 for (int q = 0; q < inch; q++)
992 {
993 float16x8_t _r0 = vld1q_f16(r0);
994
995 float16x4_t _k0 = vld1_f16(kptr);
996 float16x4_t _k1 = vld1_f16(kptr + 4);
997 float16x4_t _k2 = vld1_f16(kptr + 8);
998 float16x4_t _k3 = vld1_f16(kptr + 12);
999 float16x4_t _k4 = vld1_f16(kptr + 16);
1000 float16x4_t _k5 = vld1_f16(kptr + 20);
1001 float16x4_t _k6 = vld1_f16(kptr + 24);
1002 float16x4_t _k7 = vld1_f16(kptr + 28);
1003
1004 _sum0 = vfma_laneq_f16(_sum0, _k0, _r0, 0);
1005 _sum0 = vfma_laneq_f16(_sum0, _k1, _r0, 1);
1006 _sum0 = vfma_laneq_f16(_sum0, _k2, _r0, 2);
1007 _sum0 = vfma_laneq_f16(_sum0, _k3, _r0, 3);
1008 _sum0 = vfma_laneq_f16(_sum0, _k4, _r0, 4);
1009 _sum0 = vfma_laneq_f16(_sum0, _k5, _r0, 5);
1010 _sum0 = vfma_laneq_f16(_sum0, _k6, _r0, 6);
1011 _sum0 = vfma_laneq_f16(_sum0, _k7, _r0, 7);
1012
1013 kptr += 32;
1014 r0 += 8;
1015 }
1016
1017 vst1_f16(output0_tm, _sum0);
1018
1019 output0_tm += 4;
1020 }
1021 }
1022 }
1023 }
1024 bottom_blob_tm = Mat();
1025 // END dot
1026
1027 // BEGIN transform output
1028 Mat top_blob_bordered;
1029 if (outw == top_blob.w && outh == top_blob.h)
1030 {
1031 top_blob_bordered = top_blob;
1032 }
1033 else
1034 {
1035 top_blob_bordered.create(outw, outh, outch, 2u * 4, 4, opt.workspace_allocator);
1036 }
1037 {
1038 // const float otm[6][8] = {
1039 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f},
1040 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f},
1041 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f},
1042 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f},
1043 // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f},
1044 // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f}
1045 // };
1046
1047 // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32
1048 // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
1049 // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
1050 // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
1051 // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
1052 // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
1053
1054 int w_tm = outw / 6 * 8;
1055 int h_tm = outh / 6 * 8;
1056 const int tiles = w_tm / 8 * h_tm / 8;
1057
1058 #pragma omp parallel for num_threads(opt.num_threads)
1059 for (int p = 0; p < outch; p++)
1060 {
1061 const Mat out0_tm = top_blob_tm.channel(p);
1062 Mat out0 = top_blob_bordered.channel(p);
1063
1064 // const float bias0 = bias ? bias[p] : 0.f;
1065 float16x4_t _bias0 = bias ? vld1_f16((const __fp16*)bias + p * 4) : vdup_n_f16(0.f);
1066
1067 __fp16 tmp[6][8][4];
1068
1069 // tile
1070 for (int i = 0; i < outh / 6; i++)
1071 {
1072 for (int j = 0; j < outw / 6; j++)
1073 {
1074 // top_blob_tm.create(tiles, 64, outch, elemsize, elempack);
1075
1076 const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 4;
1077 const __fp16* output0_tm_1 = output0_tm_0 + tiles * 4;
1078 const __fp16* output0_tm_2 = output0_tm_0 + tiles * 8;
1079 const __fp16* output0_tm_3 = output0_tm_0 + tiles * 12;
1080 const __fp16* output0_tm_4 = output0_tm_0 + tiles * 16;
1081 const __fp16* output0_tm_5 = output0_tm_0 + tiles * 20;
1082 const __fp16* output0_tm_6 = output0_tm_0 + tiles * 24;
1083 const __fp16* output0_tm_7 = output0_tm_0 + tiles * 28;
1084
1085 __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 4;
1086
1087 // TODO neon optimize
1088 for (int m = 0; m < 8; m++)
1089 {
1090 float16x4_t _out0tm0 = vld1_f16(output0_tm_0);
1091 float16x4_t _out0tm1 = vld1_f16(output0_tm_1);
1092 float16x4_t _out0tm2 = vld1_f16(output0_tm_2);
1093 float16x4_t _out0tm3 = vld1_f16(output0_tm_3);
1094 float16x4_t _out0tm4 = vld1_f16(output0_tm_4);
1095 float16x4_t _out0tm5 = vld1_f16(output0_tm_5);
1096 float16x4_t _out0tm6 = vld1_f16(output0_tm_6);
1097 float16x4_t _out0tm7 = vld1_f16(output0_tm_7);
1098
1099 float16x4_t _tmp024a = vadd_f16(_out0tm1, _out0tm2);
1100 float16x4_t _tmp135a = vsub_f16(_out0tm1, _out0tm2);
1101
1102 // float tmp024a = output0_tm[1] + output0_tm[2];
1103 // float tmp135a = output0_tm[1] - output0_tm[2];
1104
1105 float16x4_t _tmp024b = vadd_f16(_out0tm3, _out0tm4);
1106 float16x4_t _tmp135b = vsub_f16(_out0tm3, _out0tm4);
1107
1108 // float tmp024b = output0_tm[3] + output0_tm[4];
1109 // float tmp135b = output0_tm[3] - output0_tm[4];
1110
1111 float16x4_t _tmp024c = vadd_f16(_out0tm5, _out0tm6);
1112 float16x4_t _tmp135c = vsub_f16(_out0tm5, _out0tm6);
1113
1114 // float tmp024c = output0_tm[5] + output0_tm[6];
1115 // float tmp135c = output0_tm[5] - output0_tm[6];
1116
1117 float16x4_t _tmp0m = vadd_f16(vadd_f16(_out0tm0, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f));
1118 float16x4_t _tmp2m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f);
1119 float16x4_t _tmp4m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f);
1120 vst1_f16(tmp[0][m], _tmp0m);
1121 vst1_f16(tmp[2][m], _tmp2m);
1122 vst1_f16(tmp[4][m], _tmp4m);
1123
1124 // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32;
1125 // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
1126 // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1127
1128 float16x4_t _tmp1m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f);
1129 float16x4_t _tmp3m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f);
1130 float16x4_t _tmp5m = vadd_f16(vadd_f16(_out0tm7, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f));
1131 vst1_f16(tmp[1][m], _tmp1m);
1132 vst1_f16(tmp[3][m], _tmp3m);
1133 vst1_f16(tmp[5][m], _tmp5m);
1134
1135 // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
1136 // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
1137 // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c;
1138
1139 output0_tm_0 += tiles * 32;
1140 output0_tm_1 += tiles * 32;
1141 output0_tm_2 += tiles * 32;
1142 output0_tm_3 += tiles * 32;
1143 output0_tm_4 += tiles * 32;
1144 output0_tm_5 += tiles * 32;
1145 output0_tm_6 += tiles * 32;
1146 output0_tm_7 += tiles * 32;
1147 }
1148
1149 for (int m = 0; m < 6; m++)
1150 {
1151 float16x4_t _tmp00 = vld1_f16(tmp[m][0]);
1152 float16x4_t _tmp01 = vld1_f16(tmp[m][1]);
1153 float16x4_t _tmp02 = vld1_f16(tmp[m][2]);
1154 float16x4_t _tmp03 = vld1_f16(tmp[m][3]);
1155 float16x4_t _tmp04 = vld1_f16(tmp[m][4]);
1156 float16x4_t _tmp05 = vld1_f16(tmp[m][5]);
1157 float16x4_t _tmp06 = vld1_f16(tmp[m][6]);
1158 float16x4_t _tmp07 = vld1_f16(tmp[m][7]);
1159
1160 float16x4_t _tmp024a = vadd_f16(_tmp01, _tmp02);
1161 float16x4_t _tmp135a = vsub_f16(_tmp01, _tmp02);
1162
1163 // float tmp024a = tmp0[1] + tmp0[2];
1164 // float tmp135a = tmp0[1] - tmp0[2];
1165
1166 float16x4_t _tmp024b = vadd_f16(_tmp03, _tmp04);
1167 float16x4_t _tmp135b = vsub_f16(_tmp03, _tmp04);
1168
1169 // float tmp024b = tmp0[3] + tmp0[4];
1170 // float tmp135b = tmp0[3] - tmp0[4];
1171
1172 float16x4_t _tmp024c = vadd_f16(_tmp05, _tmp06);
1173 float16x4_t _tmp135c = vsub_f16(_tmp05, _tmp06);
1174
1175 // float tmp024c = tmp0[5] + tmp0[6];
1176 // float tmp135c = tmp0[5] - tmp0[6];
1177
1178 float16x4_t _out00 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp00, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f)));
1179 float16x4_t _out02 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f));
1180 float16x4_t _out04 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f));
1181 vst1_f16(output0, _out00);
1182 vst1_f16(output0 + 8, _out02);
1183 vst1_f16(output0 + 16, _out04);
1184
1185 // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
1186 // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8;
1187 // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1188
1189 float16x4_t _out01 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f));
1190 float16x4_t _out03 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f));
1191 float16x4_t _out05 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp07, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f)));
1192 vst1_f16(output0 + 4, _out01);
1193 vst1_f16(output0 + 12, _out03);
1194 vst1_f16(output0 + 20, _out05);
1195
1196 // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16;
1197 // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4;
1198 // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
1199
1200 output0 += outw * 4;
1201 }
1202 }
1203 }
1204 }
1205 }
1206 // END transform output
1207
1208 // cut result pad
1209 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1210 }
1211