1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv3x3s1_winograd64_transform_kernel_pack4_fp16sa_neon(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch)15 static void conv3x3s1_winograd64_transform_kernel_pack4_fp16sa_neon(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch)
16 {
17     // winograd63 transform kernel
18     Mat kernel_tm;
19     kernel_tm.create(8 * 8, inch, outch);
20 
21     const float ktm[8][3] = {
22         {1.0f, 0.0f, 0.0f},
23         {-2.0f / 9, -2.0f / 9, -2.0f / 9},
24         {-2.0f / 9, 2.0f / 9, -2.0f / 9},
25         {1.0f / 90, 1.0f / 45, 2.0f / 45},
26         {1.0f / 90, -1.0f / 45, 2.0f / 45},
27         {1.0f / 45, 1.0f / 90, 1.0f / 180},
28         {1.0f / 45, -1.0f / 90, 1.0f / 180},
29         {0.0f, 0.0f, 1.0f}
30     };
31 
32     #pragma omp parallel for
33     for (int p = 0; p < outch; p++)
34     {
35         for (int q = 0; q < inch; q++)
36         {
37             const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
38             float* kernel_tm0 = kernel_tm.channel(p).row(q);
39 
40             // transform kernel, transposed
41             const float* k0 = kernel0;
42             const float* k1 = kernel0 + 3;
43             const float* k2 = kernel0 + 6;
44 
45             // h
46             float tmp[8][3];
47             for (int i = 0; i < 8; i++)
48             {
49                 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
50                 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
51                 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
52             }
53 
54             // v
55             for (int j = 0; j < 8; j++)
56             {
57                 float* tmpp = &tmp[j][0];
58 
59                 for (int i = 0; i < 8; i++)
60                 {
61                     kernel_tm0[j * 8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
62                 }
63             }
64         }
65     }
66 
67     // interleave
68     // src = 64-inch-outch
69     // dst = 4b-4a-inch/4a-64-outch/4b;
70     kernel_tm_pack4.create(2 * inch / 4, 64, (outch / 4) / 2 + (outch / 4) % 2, (size_t)2u * 16, 16);
71 
72     int q = 0;
73     for (; q + 7 < outch; q += 8)
74     {
75         const Mat k0 = kernel_tm.channel(q);
76         const Mat k1 = kernel_tm.channel(q + 1);
77         const Mat k2 = kernel_tm.channel(q + 2);
78         const Mat k3 = kernel_tm.channel(q + 3);
79         const Mat k4 = kernel_tm.channel(q + 4);
80         const Mat k5 = kernel_tm.channel(q + 5);
81         const Mat k6 = kernel_tm.channel(q + 6);
82         const Mat k7 = kernel_tm.channel(q + 7);
83 
84         Mat g0 = kernel_tm_pack4.channel(q / 8);
85 
86         for (int k = 0; k < 64; k++)
87         {
88             __fp16* g00 = g0.row<__fp16>(k);
89 
90             for (int p = 0; p + 3 < inch; p += 4)
91             {
92                 const float* k00 = k0.row(p);
93                 const float* k01 = k0.row(p + 1);
94                 const float* k02 = k0.row(p + 2);
95                 const float* k03 = k0.row(p + 3);
96 
97                 const float* k10 = k1.row(p);
98                 const float* k11 = k1.row(p + 1);
99                 const float* k12 = k1.row(p + 2);
100                 const float* k13 = k1.row(p + 3);
101 
102                 const float* k20 = k2.row(p);
103                 const float* k21 = k2.row(p + 1);
104                 const float* k22 = k2.row(p + 2);
105                 const float* k23 = k2.row(p + 3);
106 
107                 const float* k30 = k3.row(p);
108                 const float* k31 = k3.row(p + 1);
109                 const float* k32 = k3.row(p + 2);
110                 const float* k33 = k3.row(p + 3);
111 
112                 const float* k40 = k4.row(p);
113                 const float* k41 = k4.row(p + 1);
114                 const float* k42 = k4.row(p + 2);
115                 const float* k43 = k4.row(p + 3);
116 
117                 const float* k50 = k5.row(p);
118                 const float* k51 = k5.row(p + 1);
119                 const float* k52 = k5.row(p + 2);
120                 const float* k53 = k5.row(p + 3);
121 
122                 const float* k60 = k6.row(p);
123                 const float* k61 = k6.row(p + 1);
124                 const float* k62 = k6.row(p + 2);
125                 const float* k63 = k6.row(p + 3);
126 
127                 const float* k70 = k7.row(p);
128                 const float* k71 = k7.row(p + 1);
129                 const float* k72 = k7.row(p + 2);
130                 const float* k73 = k7.row(p + 3);
131 
132                 g00[0] = (__fp16)k00[k];
133                 g00[1] = (__fp16)k10[k];
134                 g00[2] = (__fp16)k20[k];
135                 g00[3] = (__fp16)k30[k];
136 
137                 g00[4] = (__fp16)k40[k];
138                 g00[5] = (__fp16)k50[k];
139                 g00[6] = (__fp16)k60[k];
140                 g00[7] = (__fp16)k70[k];
141 
142                 g00[8] = (__fp16)k01[k];
143                 g00[9] = (__fp16)k11[k];
144                 g00[10] = (__fp16)k21[k];
145                 g00[11] = (__fp16)k31[k];
146 
147                 g00[12] = (__fp16)k41[k];
148                 g00[13] = (__fp16)k51[k];
149                 g00[14] = (__fp16)k61[k];
150                 g00[15] = (__fp16)k71[k];
151 
152                 g00[16] = (__fp16)k02[k];
153                 g00[17] = (__fp16)k12[k];
154                 g00[18] = (__fp16)k22[k];
155                 g00[19] = (__fp16)k32[k];
156 
157                 g00[20] = (__fp16)k42[k];
158                 g00[21] = (__fp16)k52[k];
159                 g00[22] = (__fp16)k62[k];
160                 g00[23] = (__fp16)k72[k];
161 
162                 g00[24] = (__fp16)k03[k];
163                 g00[25] = (__fp16)k13[k];
164                 g00[26] = (__fp16)k23[k];
165                 g00[27] = (__fp16)k33[k];
166 
167                 g00[28] = (__fp16)k43[k];
168                 g00[29] = (__fp16)k53[k];
169                 g00[30] = (__fp16)k63[k];
170                 g00[31] = (__fp16)k73[k];
171 
172                 g00 += 32;
173             }
174         }
175     }
176     for (; q + 3 < outch; q += 4)
177     {
178         const Mat k0 = kernel_tm.channel(q);
179         const Mat k1 = kernel_tm.channel(q + 1);
180         const Mat k2 = kernel_tm.channel(q + 2);
181         const Mat k3 = kernel_tm.channel(q + 3);
182 
183         Mat g0 = kernel_tm_pack4.channel(q / 8 + (q % 8) / 4);
184 
185         for (int k = 0; k < 64; k++)
186         {
187             __fp16* g00 = g0.row<__fp16>(k);
188 
189             for (int p = 0; p + 3 < inch; p += 4)
190             {
191                 const float* k00 = k0.row(p);
192                 const float* k01 = k0.row(p + 1);
193                 const float* k02 = k0.row(p + 2);
194                 const float* k03 = k0.row(p + 3);
195 
196                 const float* k10 = k1.row(p);
197                 const float* k11 = k1.row(p + 1);
198                 const float* k12 = k1.row(p + 2);
199                 const float* k13 = k1.row(p + 3);
200 
201                 const float* k20 = k2.row(p);
202                 const float* k21 = k2.row(p + 1);
203                 const float* k22 = k2.row(p + 2);
204                 const float* k23 = k2.row(p + 3);
205 
206                 const float* k30 = k3.row(p);
207                 const float* k31 = k3.row(p + 1);
208                 const float* k32 = k3.row(p + 2);
209                 const float* k33 = k3.row(p + 3);
210 
211                 g00[0] = (__fp16)k00[k];
212                 g00[1] = (__fp16)k10[k];
213                 g00[2] = (__fp16)k20[k];
214                 g00[3] = (__fp16)k30[k];
215 
216                 g00[4] = (__fp16)k01[k];
217                 g00[5] = (__fp16)k11[k];
218                 g00[6] = (__fp16)k21[k];
219                 g00[7] = (__fp16)k31[k];
220 
221                 g00[8] = (__fp16)k02[k];
222                 g00[9] = (__fp16)k12[k];
223                 g00[10] = (__fp16)k22[k];
224                 g00[11] = (__fp16)k32[k];
225 
226                 g00[12] = (__fp16)k03[k];
227                 g00[13] = (__fp16)k13[k];
228                 g00[14] = (__fp16)k23[k];
229                 g00[15] = (__fp16)k33[k];
230 
231                 g00 += 16;
232             }
233         }
234     }
235 }
236 
conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)237 static void conv3x3s1_winograd64_pack4_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
238 {
239     int w = bottom_blob.w;
240     int h = bottom_blob.h;
241     int inch = bottom_blob.c;
242     //size_t elemsize = bottom_blob.elemsize;
243     int elempack = bottom_blob.elempack;
244 
245     int outw = top_blob.w;
246     int outh = top_blob.h;
247     int outch = top_blob.c;
248 
249     // pad to 6n+2
250     Mat bottom_blob_bordered = bottom_blob;
251 
252     outw = (outw + 5) / 6 * 6;
253     outh = (outh + 5) / 6 * 6;
254 
255     w = outw + 2;
256     h = outh + 2;
257     copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
258 
259     const float* bias = _bias;
260 
261     // BEGIN transform input
262     Mat bottom_blob_tm;
263     {
264         int w_tm = outw / 6 * 8;
265         int h_tm = outh / 6 * 8;
266 
267         const int tiles = w_tm / 8 * h_tm / 8;
268 
269         //         bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
270         bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator);
271 
272         //         const float itm[8][8] = {
273         //             {1.0f,  0.0f, -5.25f,  0.00f,  5.25f,  0.00f, -1.0f, 0.0f},
274         //
275         //             {0.0f,  1.0f,  1.00f, -4.25f, -4.25f,  1.00f,  1.0f, 0.0f},
276         //             {0.0f, -1.0f,  1.00f,  4.25f, -4.25f, -1.00f,  1.0f, 0.0f},
277         //
278         //             {0.0f,  0.5f,  0.25f, -2.50f, -1.25f,  2.00f,  1.0f, 0.0f},
279         //             {0.0f, -0.5f,  0.25f,  2.50f, -1.25f, -2.00f,  1.0f, 0.0f},
280         //
281         //             {0.0f,  2.0f,  4.00f, -2.50f, -5.00f,  0.50f,  1.0f, 0.0f},
282         //             {0.0f, -2.0f,  4.00f,  2.50f, -5.00f, -0.50f,  1.0f, 0.0f},
283         //
284         //             {0.0f, -1.0f,  0.00f,  5.25f,  0.00f, -5.25f,  0.0f, 1.0f}
285         //         };
286 
287         // 0 = r00 - r06 + (r04 - r02) * 5.25
288         // 7 = r07 - r01 + (r03 - r05) * 5.25
289 
290         // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
291         // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
292 
293         // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
294         // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
295 
296         // reuse r04 * 1.25
297         // reuse r03 * 2.5
298         // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
299         // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
300 
301         #pragma omp parallel for num_threads(opt.num_threads)
302         for (int q = 0; q < inch; q++)
303         {
304             const Mat img0 = bottom_blob_bordered.channel(q);
305             Mat img0_tm = bottom_blob_tm.channel(q);
306 
307             __fp16 tmp[8][8][4];
308 
309             // tile
310             for (int i = 0; i < h_tm / 8; i++)
311             {
312                 for (int j = 0; j < w_tm / 8; j++)
313                 {
314                     const __fp16* r0 = img0.row<const __fp16>(i * 6) + (j * 6) * 4;
315 
316                     for (int m = 0; m < 8; m++)
317                     {
318                         float16x4_t _r00 = vld1_f16(r0);
319                         float16x4_t _r01 = vld1_f16(r0 + 4);
320                         float16x4_t _r02 = vld1_f16(r0 + 8);
321                         float16x4_t _r03 = vld1_f16(r0 + 12);
322                         float16x4_t _r04 = vld1_f16(r0 + 16);
323                         float16x4_t _r05 = vld1_f16(r0 + 20);
324                         float16x4_t _r06 = vld1_f16(r0 + 24);
325                         float16x4_t _r07 = vld1_f16(r0 + 28);
326 
327                         float16x4_t _tmp0m = vfma_n_f16(vsub_f16(_r00, _r06), vsub_f16(_r04, _r02), 5.25f);
328                         float16x4_t _tmp7m = vfma_n_f16(vsub_f16(_r07, _r01), vsub_f16(_r03, _r05), 5.25f);
329                         vst1_f16(tmp[0][m], _tmp0m);
330                         vst1_f16(tmp[7][m], _tmp7m);
331 
332                         //                         tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25;
333                         //                         tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25;
334 
335                         float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_r02, _r06), _r04, 4.25f);
336                         float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_r01, _r05), _r03, 4.25f);
337 
338                         //                         float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25);
339                         //                         float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25);
340 
341                         float16x4_t _tmp1m = vadd_f16(_tmp12a, _tmp12b);
342                         float16x4_t _tmp2m = vsub_f16(_tmp12a, _tmp12b);
343                         vst1_f16(tmp[1][m], _tmp1m);
344                         vst1_f16(tmp[2][m], _tmp2m);
345 
346                         //                         tmp[1][m] = tmp12a + tmp12b;
347                         //                         tmp[2][m] = tmp12a - tmp12b;
348 
349                         float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_r06, _r02, 0.25f), _r04, 1.25f);
350                         float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f);
351 
352                         //                         float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25);
353                         //                         float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2);
354 
355                         float16x4_t _tmp3m = vadd_f16(_tmp34a, _tmp34b);
356                         float16x4_t _tmp4m = vsub_f16(_tmp34a, _tmp34b);
357                         vst1_f16(tmp[3][m], _tmp3m);
358                         vst1_f16(tmp[4][m], _tmp4m);
359 
360                         //                         tmp[3][m] = tmp34a + tmp34b;
361                         //                         tmp[4][m] = tmp34a - tmp34b;
362 
363                         float16x4_t _tmp56a = vfma_n_f16(_r06, vfms_n_f16(_r02, _r04, 1.25f), 4.f);
364                         float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f);
365 
366                         //                         float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4);
367                         //                         float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5);
368 
369                         float16x4_t _tmp5m = vadd_f16(_tmp56a, _tmp56b);
370                         float16x4_t _tmp6m = vsub_f16(_tmp56a, _tmp56b);
371                         vst1_f16(tmp[5][m], _tmp5m);
372                         vst1_f16(tmp[6][m], _tmp6m);
373 
374                         //                         tmp[5][m] = tmp56a + tmp56b;
375                         //                         tmp[6][m] = tmp56a - tmp56b;
376 
377                         r0 += w * 4;
378                     }
379 
380                     __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 4;
381                     __fp16* r0_tm_1 = r0_tm_0 + tiles * 4;
382                     __fp16* r0_tm_2 = r0_tm_0 + tiles * 8;
383                     __fp16* r0_tm_3 = r0_tm_0 + tiles * 12;
384                     __fp16* r0_tm_4 = r0_tm_0 + tiles * 16;
385                     __fp16* r0_tm_5 = r0_tm_0 + tiles * 20;
386                     __fp16* r0_tm_6 = r0_tm_0 + tiles * 24;
387                     __fp16* r0_tm_7 = r0_tm_0 + tiles * 28;
388 
389                     for (int m = 0; m < 8; m++)
390                     {
391                         float16x4_t _tmp00 = vld1_f16(tmp[m][0]);
392                         float16x4_t _tmp01 = vld1_f16(tmp[m][1]);
393                         float16x4_t _tmp02 = vld1_f16(tmp[m][2]);
394                         float16x4_t _tmp03 = vld1_f16(tmp[m][3]);
395                         float16x4_t _tmp04 = vld1_f16(tmp[m][4]);
396                         float16x4_t _tmp05 = vld1_f16(tmp[m][5]);
397                         float16x4_t _tmp06 = vld1_f16(tmp[m][6]);
398                         float16x4_t _tmp07 = vld1_f16(tmp[m][7]);
399 
400                         float16x4_t _r0tm0 = vfma_n_f16(vsub_f16(_tmp00, _tmp06), vsub_f16(_tmp04, _tmp02), 5.25f);
401                         float16x4_t _r0tm7 = vfma_n_f16(vsub_f16(_tmp07, _tmp01), vsub_f16(_tmp03, _tmp05), 5.25f);
402 
403                         //                         r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25;
404                         //                         r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25;
405 
406                         float16x4_t _tmp12a = vfms_n_f16(vadd_f16(_tmp02, _tmp06), _tmp04, 4.25f);
407                         float16x4_t _tmp12b = vfms_n_f16(vadd_f16(_tmp01, _tmp05), _tmp03, 4.25f);
408 
409                         //                         float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25);
410                         //                         float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25);
411 
412                         float16x4_t _r0tm1 = vadd_f16(_tmp12a, _tmp12b);
413                         float16x4_t _r0tm2 = vsub_f16(_tmp12a, _tmp12b);
414 
415                         //                         r0_tm[1] = tmp12a + tmp12b;
416                         //                         r0_tm[2] = tmp12a - tmp12b;
417 
418                         float16x4_t _tmp34a = vfms_n_f16(vfma_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f);
419                         float16x4_t _tmp34b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f);
420 
421                         //                         float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25);
422                         //                         float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2);
423 
424                         float16x4_t _r0tm3 = vadd_f16(_tmp34a, _tmp34b);
425                         float16x4_t _r0tm4 = vsub_f16(_tmp34a, _tmp34b);
426 
427                         //                         r0_tm[3] = tmp34a + tmp34b;
428                         //                         r0_tm[4] = tmp34a - tmp34b;
429 
430                         float16x4_t _tmp56a = vfma_n_f16(_tmp06, vfms_n_f16(_tmp02, _tmp04, 1.25f), 4.f);
431                         float16x4_t _tmp56b = vfma_n_f16(vfms_n_f16(vmul_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f);
432 
433                         //                         float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4);
434                         //                         float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5);
435 
436                         float16x4_t _r0tm5 = vadd_f16(_tmp56a, _tmp56b);
437                         float16x4_t _r0tm6 = vsub_f16(_tmp56a, _tmp56b);
438 
439                         //                         r0_tm[5] = tmp56a + tmp56b;
440                         //                         r0_tm[6] = tmp56a - tmp56b;
441 
442                         vst1_f16(r0_tm_0, _r0tm0);
443                         vst1_f16(r0_tm_1, _r0tm1);
444                         vst1_f16(r0_tm_2, _r0tm2);
445                         vst1_f16(r0_tm_3, _r0tm3);
446                         vst1_f16(r0_tm_4, _r0tm4);
447                         vst1_f16(r0_tm_5, _r0tm5);
448                         vst1_f16(r0_tm_6, _r0tm6);
449                         vst1_f16(r0_tm_7, _r0tm7);
450 
451                         r0_tm_0 += tiles * 32;
452                         r0_tm_1 += tiles * 32;
453                         r0_tm_2 += tiles * 32;
454                         r0_tm_3 += tiles * 32;
455                         r0_tm_4 += tiles * 32;
456                         r0_tm_5 += tiles * 32;
457                         r0_tm_6 += tiles * 32;
458                         r0_tm_7 += tiles * 32;
459                     }
460                 }
461             }
462         }
463     }
464     bottom_blob_bordered = Mat();
465     // END transform input
466 
467     // BEGIN dot
468     Mat top_blob_tm;
469     {
470         int w_tm = outw / 6 * 8;
471         int h_tm = outh / 6 * 8;
472 
473         const int tiles = h_tm / 8 * w_tm / 8;
474 
475         // permute
476         //         bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
477         Mat bottom_blob_tm2;
478         if (tiles >= 8)
479             bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
480         else if (tiles >= 4)
481             bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
482         else // if (tiles >= 1)
483             bottom_blob_tm2.create(1 * inch, tiles, 64, 2u * elempack, elempack, opt.workspace_allocator);
484 
485         #pragma omp parallel for num_threads(opt.num_threads)
486         for (int r = 0; r < 64; r++)
487         {
488             Mat tm2 = bottom_blob_tm2.channel(r);
489 
490             // tile
491             int i = 0;
492             for (; i + 7 < tiles; i += 8)
493             {
494                 __fp16* tm2p = tm2.row<__fp16>(i / 8);
495 
496                 const __fp16* r0 = bottom_blob_tm;
497 
498                 r0 += (r * tiles + i) * 4;
499 
500                 for (int q = 0; q < inch; q++)
501                 {
502                     // transpose 4x8
503                     asm volatile(
504                         "prfm   pldl1keep, [%0, #512]   \n"
505                         "ld4    {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
506                         "st1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
507                         : "=r"(r0),  // %0
508                         "=r"(tm2p) // %1
509                         : "0"(r0),
510                         "1"(tm2p)
511                         : "memory", "v0", "v1", "v2", "v3");
512 
513                     r0 += bottom_blob_tm.cstep * 4;
514                 }
515             }
516             for (; i + 3 < tiles; i += 4)
517             {
518                 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4);
519 
520                 const __fp16* r0 = bottom_blob_tm;
521 
522                 r0 += (r * tiles + i) * 4;
523 
524                 for (int q = 0; q < inch; q++)
525                 {
526                     // transpose 4x4
527                     asm volatile(
528                         "prfm   pldl1keep, [%0, #256]   \n"
529                         "ld4    {v0.4h, v1.4h, v2.4h, v3.4h}, [%0] \n"
530                         "st1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
531                         : "=r"(r0),  // %0
532                         "=r"(tm2p) // %1
533                         : "0"(r0),
534                         "1"(tm2p)
535                         : "memory", "v0", "v1", "v2", "v3");
536 
537                     r0 += bottom_blob_tm.cstep * 4;
538                 }
539             }
540             for (; i < tiles; i++)
541             {
542                 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4 + i % 4);
543 
544                 const __fp16* r0 = bottom_blob_tm;
545 
546                 r0 += (r * tiles + i) * 4;
547 
548                 for (int q = 0; q < inch; q++)
549                 {
550                     asm volatile(
551                         "prfm   pldl1keep, [%0, #64]    \n"
552                         "ld1    {v0.4h}, [%0]           \n"
553                         "st1    {v0.4h}, [%1], #8       \n"
554                         : "=r"(r0),  // %0
555                         "=r"(tm2p) // %1
556                         : "0"(r0),
557                         "1"(tm2p)
558                         : "memory", "v0");
559 
560                     r0 += bottom_blob_tm.cstep * 4;
561                 }
562             }
563         }
564 
565         bottom_blob_tm = Mat();
566         // permute end
567 
568         top_blob_tm.create(tiles, 64, outch, 2u * elempack, elempack, opt.workspace_allocator);
569 
570         int nn_outch = 0;
571         int remain_outch_start = 0;
572 
573         nn_outch = outch >> 1;
574         remain_outch_start = nn_outch << 1;
575 
576         #pragma omp parallel for num_threads(opt.num_threads)
577         for (int pp = 0; pp < nn_outch; pp++)
578         {
579             int p = pp * 2;
580 
581             __fp16* output0_tm = top_blob_tm.channel(p);
582             __fp16* output1_tm = top_blob_tm.channel(p + 1);
583 
584             const Mat kernel01_tm = kernel_tm.channel(pp);
585 
586             for (int r = 0; r < 64; r++)
587             {
588                 const Mat bb2 = bottom_blob_tm2.channel(r);
589 
590                 int i = 0;
591                 for (; i + 7 < tiles; i += 8)
592                 {
593                     const __fp16* r0 = bb2.row<const __fp16>(i / 8);
594 
595                     const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
596 
597                     int nn = inch; // inch always > 0
598 
599                     asm volatile(
600                         "eor    v24.16b, v24.16b, v24.16b   \n"
601                         "eor    v25.16b, v25.16b, v25.16b   \n"
602                         "eor    v26.16b, v26.16b, v26.16b   \n"
603                         "eor    v27.16b, v27.16b, v27.16b   \n"
604                         "eor    v28.16b, v28.16b, v28.16b   \n"
605                         "eor    v29.16b, v29.16b, v29.16b   \n"
606                         "eor    v30.16b, v30.16b, v30.16b   \n"
607                         "eor    v31.16b, v31.16b, v31.16b   \n"
608 
609                         "0:                                 \n"
610 
611                         "prfm   pldl1keep, [%3, #512]       \n"
612                         "ld1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%3], #64 \n" // r01 r23 r45 r67
613 
614                         "prfm   pldl1keep, [%4, #512]       \n"
615                         "ld1    {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
616 
617                         "fmla   v24.8h, v4.8h, v0.h[0]      \n"
618                         "fmla   v25.8h, v4.8h, v0.h[1]      \n"
619                         "fmla   v26.8h, v4.8h, v0.h[2]      \n"
620                         "fmla   v27.8h, v4.8h, v0.h[3]      \n"
621                         "fmla   v28.8h, v4.8h, v0.h[4]      \n"
622                         "fmla   v29.8h, v4.8h, v0.h[5]      \n"
623                         "fmla   v30.8h, v4.8h, v0.h[6]      \n"
624                         "fmla   v31.8h, v4.8h, v0.h[7]      \n"
625 
626                         "fmla   v24.8h, v5.8h, v1.h[0]      \n"
627                         "fmla   v25.8h, v5.8h, v1.h[1]      \n"
628                         "fmla   v26.8h, v5.8h, v1.h[2]      \n"
629                         "fmla   v27.8h, v5.8h, v1.h[3]      \n"
630                         "fmla   v28.8h, v5.8h, v1.h[4]      \n"
631                         "fmla   v29.8h, v5.8h, v1.h[5]      \n"
632                         "fmla   v30.8h, v5.8h, v1.h[6]      \n"
633                         "fmla   v31.8h, v5.8h, v1.h[7]      \n"
634 
635                         "fmla   v24.8h, v6.8h, v2.h[0]      \n"
636                         "fmla   v25.8h, v6.8h, v2.h[1]      \n"
637                         "fmla   v26.8h, v6.8h, v2.h[2]      \n"
638                         "fmla   v27.8h, v6.8h, v2.h[3]      \n"
639                         "fmla   v28.8h, v6.8h, v2.h[4]      \n"
640                         "fmla   v29.8h, v6.8h, v2.h[5]      \n"
641                         "fmla   v30.8h, v6.8h, v2.h[6]      \n"
642                         "fmla   v31.8h, v6.8h, v2.h[7]      \n"
643 
644                         "subs   %w0, %w0, #1                \n"
645 
646                         "fmla   v24.8h, v7.8h, v3.h[0]      \n"
647                         "fmla   v25.8h, v7.8h, v3.h[1]      \n"
648                         "fmla   v26.8h, v7.8h, v3.h[2]      \n"
649                         "fmla   v27.8h, v7.8h, v3.h[3]      \n"
650                         "fmla   v28.8h, v7.8h, v3.h[4]      \n"
651                         "fmla   v29.8h, v7.8h, v3.h[5]      \n"
652                         "fmla   v30.8h, v7.8h, v3.h[6]      \n"
653                         "fmla   v31.8h, v7.8h, v3.h[7]      \n"
654 
655                         "bne    0b                          \n"
656 
657                         "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
658                         "st1    {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
659 
660                         "ext    v24.16b, v24.16b, v24.16b, #8 \n"
661                         "ext    v25.16b, v25.16b, v25.16b, #8 \n"
662                         "ext    v26.16b, v26.16b, v26.16b, #8 \n"
663                         "ext    v27.16b, v27.16b, v27.16b, #8 \n"
664                         "ext    v28.16b, v28.16b, v28.16b, #8 \n"
665                         "ext    v29.16b, v29.16b, v29.16b, #8 \n"
666                         "ext    v30.16b, v30.16b, v30.16b, #8 \n"
667                         "ext    v31.16b, v31.16b, v31.16b, #8 \n"
668 
669                         "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
670                         "st1    {v28.4h, v29.4h, v30.4h, v31.4h}, [%2], #32 \n"
671 
672                         : "=r"(nn),         // %0
673                         "=r"(output0_tm), // %1
674                         "=r"(output1_tm), // %2
675                         "=r"(r0),         // %3
676                         "=r"(kptr)        // %4
677                         : "0"(nn),
678                         "1"(output0_tm),
679                         "2"(output1_tm),
680                         "3"(r0),
681                         "4"(kptr)
682                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
683                 }
684                 for (; i + 3 < tiles; i += 4)
685                 {
686                     const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
687 
688                     const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
689 
690                     int nn = inch; // inch always > 0
691 
692                     asm volatile(
693                         "eor    v24.16b, v24.16b, v24.16b   \n"
694                         "eor    v25.16b, v25.16b, v25.16b   \n"
695                         "eor    v26.16b, v26.16b, v26.16b   \n"
696                         "eor    v27.16b, v27.16b, v27.16b   \n"
697 
698                         "0:                                 \n"
699 
700                         "prfm   pldl1keep, [%3, #256]       \n"
701                         "ld1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%3], #32 \n" // r01 r23 r45 r67
702 
703                         "prfm   pldl1keep, [%4, #512]       \n"
704                         "ld1    {v4.8h, v5.8h, v6.8h, v7.8h}, [%4], #64 \n" // k0123
705 
706                         "fmla   v24.8h, v4.8h, v0.h[0]      \n"
707                         "fmla   v25.8h, v4.8h, v0.h[1]      \n"
708                         "fmla   v26.8h, v4.8h, v0.h[2]      \n"
709                         "fmla   v27.8h, v4.8h, v0.h[3]      \n"
710 
711                         "fmla   v24.8h, v5.8h, v1.h[0]      \n"
712                         "fmla   v25.8h, v5.8h, v1.h[1]      \n"
713                         "fmla   v26.8h, v5.8h, v1.h[2]      \n"
714                         "fmla   v27.8h, v5.8h, v1.h[3]      \n"
715 
716                         "fmla   v24.8h, v6.8h, v2.h[0]      \n"
717                         "fmla   v25.8h, v6.8h, v2.h[1]      \n"
718                         "fmla   v26.8h, v6.8h, v2.h[2]      \n"
719                         "fmla   v27.8h, v6.8h, v2.h[3]      \n"
720 
721                         "subs   %w0, %w0, #1                \n"
722 
723                         "fmla   v24.8h, v7.8h, v3.h[0]      \n"
724                         "fmla   v25.8h, v7.8h, v3.h[1]      \n"
725                         "fmla   v26.8h, v7.8h, v3.h[2]      \n"
726                         "fmla   v27.8h, v7.8h, v3.h[3]      \n"
727 
728                         "bne    0b                          \n"
729 
730                         "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
731 
732                         "ext    v24.16b, v24.16b, v24.16b, #8 \n"
733                         "ext    v25.16b, v25.16b, v25.16b, #8 \n"
734                         "ext    v26.16b, v26.16b, v26.16b, #8 \n"
735                         "ext    v27.16b, v27.16b, v27.16b, #8 \n"
736 
737                         "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%2], #32 \n"
738 
739                         : "=r"(nn),         // %0
740                         "=r"(output0_tm), // %1
741                         "=r"(output1_tm), // %2
742                         "=r"(r0),         // %3
743                         "=r"(kptr)        // %4
744                         : "0"(nn),
745                         "1"(output0_tm),
746                         "2"(output1_tm),
747                         "3"(r0),
748                         "4"(kptr)
749                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
750                 }
751                 for (; i < tiles; i++)
752                 {
753                     const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
754 
755                     const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
756 
757                     float16x8_t _sum0 = vdupq_n_f16(0.f);
758 
759                     for (int q = 0; q < inch; q++)
760                     {
761                         float16x4_t _r0 = vld1_f16(r0);
762 
763                         float16x8_t _k0 = vld1q_f16(kptr);
764                         float16x8_t _k1 = vld1q_f16(kptr + 8);
765                         float16x8_t _k2 = vld1q_f16(kptr + 16);
766                         float16x8_t _k3 = vld1q_f16(kptr + 24);
767 
768                         _sum0 = vfmaq_lane_f16(_sum0, _k0, _r0, 0);
769                         _sum0 = vfmaq_lane_f16(_sum0, _k1, _r0, 1);
770                         _sum0 = vfmaq_lane_f16(_sum0, _k2, _r0, 2);
771                         _sum0 = vfmaq_lane_f16(_sum0, _k3, _r0, 3);
772 
773                         kptr += 32;
774                         r0 += 4;
775                     }
776 
777                     vst1_f16(output0_tm, vget_low_f16(_sum0));
778                     vst1_f16(output1_tm, vget_high_f16(_sum0));
779 
780                     output0_tm += 4;
781                     output1_tm += 4;
782                 }
783             }
784         }
785 
786         #pragma omp parallel for num_threads(opt.num_threads)
787         for (int p = remain_outch_start; p < outch; p++)
788         {
789             __fp16* output0_tm = top_blob_tm.channel(p);
790 
791             const Mat kernel0_tm = kernel_tm.channel(p / 2 + p % 2);
792 
793             for (int r = 0; r < 64; r++)
794             {
795                 const Mat bb2 = bottom_blob_tm2.channel(r);
796 
797                 int i = 0;
798                 for (; i + 7 < tiles; i += 8)
799                 {
800                     const __fp16* r0 = bb2.row<const __fp16>(i / 8);
801 
802                     const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
803 
804                     int nn = inch; // inch always > 0
805 
806                     asm volatile(
807                         "eor    v24.16b, v24.16b, v24.16b   \n"
808                         "eor    v25.16b, v25.16b, v25.16b   \n"
809                         "eor    v26.16b, v26.16b, v26.16b   \n"
810                         "eor    v27.16b, v27.16b, v27.16b   \n"
811                         "eor    v28.16b, v28.16b, v28.16b   \n"
812                         "eor    v29.16b, v29.16b, v29.16b   \n"
813                         "eor    v30.16b, v30.16b, v30.16b   \n"
814                         "eor    v31.16b, v31.16b, v31.16b   \n"
815 
816                         "0:                                 \n"
817 
818                         "prfm   pldl1keep, [%2, #512]       \n"
819                         "ld1    {v0.8h, v1.8h, v2.8h, v3.8h}, [%2], #64 \n" // r01 r23 r45 r67
820 
821                         "prfm   pldl1keep, [%3, #256]       \n"
822                         "ld1    {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
823 
824                         "fmla   v24.4h, v4.4h, v0.h[0]      \n"
825                         "fmla   v25.4h, v4.4h, v0.h[1]      \n"
826                         "fmla   v26.4h, v4.4h, v0.h[2]      \n"
827                         "fmla   v27.4h, v4.4h, v0.h[3]      \n"
828                         "fmla   v28.4h, v4.4h, v0.h[4]      \n"
829                         "fmla   v29.4h, v4.4h, v0.h[5]      \n"
830                         "fmla   v30.4h, v4.4h, v0.h[6]      \n"
831                         "fmla   v31.4h, v4.4h, v0.h[7]      \n"
832 
833                         "fmla   v24.4h, v5.4h, v1.h[0]      \n"
834                         "fmla   v25.4h, v5.4h, v1.h[1]      \n"
835                         "fmla   v26.4h, v5.4h, v1.h[2]      \n"
836                         "fmla   v27.4h, v5.4h, v1.h[3]      \n"
837                         "fmla   v28.4h, v5.4h, v1.h[4]      \n"
838                         "fmla   v29.4h, v5.4h, v1.h[5]      \n"
839                         "fmla   v30.4h, v5.4h, v1.h[6]      \n"
840                         "fmla   v31.4h, v5.4h, v1.h[7]      \n"
841 
842                         "fmla   v24.4h, v6.4h, v2.h[0]      \n"
843                         "fmla   v25.4h, v6.4h, v2.h[1]      \n"
844                         "fmla   v26.4h, v6.4h, v2.h[2]      \n"
845                         "fmla   v27.4h, v6.4h, v2.h[3]      \n"
846                         "fmla   v28.4h, v6.4h, v2.h[4]      \n"
847                         "fmla   v29.4h, v6.4h, v2.h[5]      \n"
848                         "fmla   v30.4h, v6.4h, v2.h[6]      \n"
849                         "fmla   v31.4h, v6.4h, v2.h[7]      \n"
850 
851                         "subs   %w0, %w0, #1                \n"
852 
853                         "fmla   v24.4h, v7.4h, v3.h[0]      \n"
854                         "fmla   v25.4h, v7.4h, v3.h[1]      \n"
855                         "fmla   v26.4h, v7.4h, v3.h[2]      \n"
856                         "fmla   v27.4h, v7.4h, v3.h[3]      \n"
857                         "fmla   v28.4h, v7.4h, v3.h[4]      \n"
858                         "fmla   v29.4h, v7.4h, v3.h[5]      \n"
859                         "fmla   v30.4h, v7.4h, v3.h[6]      \n"
860                         "fmla   v31.4h, v7.4h, v3.h[7]      \n"
861 
862                         "bne    0b                          \n"
863 
864                         "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
865                         "st1    {v28.4h, v29.4h, v30.4h, v31.4h}, [%1], #32 \n"
866 
867                         : "=r"(nn),         // %0
868                         "=r"(output0_tm), // %1
869                         "=r"(r0),         // %2
870                         "=r"(kptr)        // %3
871                         : "0"(nn),
872                         "1"(output0_tm),
873                         "2"(r0),
874                         "3"(kptr)
875                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
876                 }
877                 for (; i + 3 < tiles; i += 4)
878                 {
879                     const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
880 
881                     const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
882 
883                     int nn = inch; // inch always > 0
884 
885                     asm volatile(
886                         "eor    v24.16b, v24.16b, v24.16b   \n"
887                         "eor    v25.16b, v25.16b, v25.16b   \n"
888                         "eor    v26.16b, v26.16b, v26.16b   \n"
889                         "eor    v27.16b, v27.16b, v27.16b   \n"
890 
891                         "0:                                 \n"
892 
893                         "prfm   pldl1keep, [%2, #256]       \n"
894                         "ld1    {v0.4h, v1.4h, v2.4h, v3.4h}, [%2], #32 \n" // r01 r23 r45 r67
895 
896                         "prfm   pldl1keep, [%3, #256]       \n"
897                         "ld1    {v4.4h, v5.4h, v6.4h, v7.4h}, [%3], #32 \n" // k0123
898 
899                         "fmla   v24.4h, v4.4h, v0.h[0]      \n"
900                         "fmla   v25.4h, v4.4h, v0.h[1]      \n"
901                         "fmla   v26.4h, v4.4h, v0.h[2]      \n"
902                         "fmla   v27.4h, v4.4h, v0.h[3]      \n"
903 
904                         "fmla   v24.4h, v5.4h, v1.h[0]      \n"
905                         "fmla   v25.4h, v5.4h, v1.h[1]      \n"
906                         "fmla   v26.4h, v5.4h, v1.h[2]      \n"
907                         "fmla   v27.4h, v5.4h, v1.h[3]      \n"
908 
909                         "fmla   v24.4h, v6.4h, v2.h[0]      \n"
910                         "fmla   v25.4h, v6.4h, v2.h[1]      \n"
911                         "fmla   v26.4h, v6.4h, v2.h[2]      \n"
912                         "fmla   v27.4h, v6.4h, v2.h[3]      \n"
913 
914                         "subs   %w0, %w0, #1                \n"
915 
916                         "fmla   v24.4h, v7.4h, v3.h[0]      \n"
917                         "fmla   v25.4h, v7.4h, v3.h[1]      \n"
918                         "fmla   v26.4h, v7.4h, v3.h[2]      \n"
919                         "fmla   v27.4h, v7.4h, v3.h[3]      \n"
920 
921                         "bne    0b                          \n"
922 
923                         "st1    {v24.4h, v25.4h, v26.4h, v27.4h}, [%1], #32 \n"
924 
925                         : "=r"(nn),         // %0
926                         "=r"(output0_tm), // %1
927                         "=r"(r0),         // %2
928                         "=r"(kptr)        // %3
929                         : "0"(nn),
930                         "1"(output0_tm),
931                         "2"(r0),
932                         "3"(kptr)
933                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v24", "v25", "v26", "v27");
934                 }
935                 for (; i < tiles; i++)
936                 {
937                     const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
938 
939                     const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
940 
941                     float16x4_t _sum0 = vdup_n_f16(0.f);
942 
943                     for (int q = 0; q < inch; q++)
944                     {
945                         float16x4_t _r0 = vld1_f16(r0);
946 
947                         float16x4_t _k0 = vld1_f16(kptr);
948                         float16x4_t _k1 = vld1_f16(kptr + 4);
949                         float16x4_t _k2 = vld1_f16(kptr + 8);
950                         float16x4_t _k3 = vld1_f16(kptr + 12);
951 
952                         _sum0 = vfma_lane_f16(_sum0, _k0, _r0, 0);
953                         _sum0 = vfma_lane_f16(_sum0, _k1, _r0, 1);
954                         _sum0 = vfma_lane_f16(_sum0, _k2, _r0, 2);
955                         _sum0 = vfma_lane_f16(_sum0, _k3, _r0, 3);
956 
957                         kptr += 16;
958                         r0 += 4;
959                     }
960 
961                     vst1_f16(output0_tm, _sum0);
962 
963                     output0_tm += 4;
964                 }
965             }
966         }
967     }
968     bottom_blob_tm = Mat();
969     // END dot
970 
971     // BEGIN transform output
972     Mat top_blob_bordered;
973     if (outw == top_blob.w && outh == top_blob.h)
974     {
975         top_blob_bordered = top_blob;
976     }
977     else
978     {
979         top_blob_bordered.create(outw, outh, outch, 2u * 4, 4, opt.workspace_allocator);
980     }
981     {
982         //         const float otm[6][8] = {
983         //             {1.0f,  1.0f,   1.0f,   1.0f,   1.0f,  32.0f, 32.0f, 0.0f},
984         //             {0.0f,  1.0f,  -1.0f,   2.0f,  -2.0f,  16.0f,-16.0f, 0.0f},
985         //             {0.0f,  1.0f,   1.0f,   4.0f,   4.0f,   8.0f,  8.0f, 0.0f},
986         //             {0.0f,  1.0f,  -1.0f,   8.0f,  -8.0f,   4.0f, -4.0f, 0.0f},
987         //             {0.0f,  1.0f,   1.0f,  16.0f,  16.0f,   2.0f,  2.0f, 0.0f},
988         //             {0.0f,  1.0f,  -1.0f,  32.0f, -32.0f,   1.0f, -1.0f, 1.0f}
989         //         };
990 
991         // 0 = r0 + (r1 + r2) + (r3 + r4)     + (r5 + r6) * 32
992         // 1 =      (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
993         // 2 =      (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
994         // 3 =      (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
995         // 4 =      (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
996         // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
997 
998         int w_tm = outw / 6 * 8;
999         int h_tm = outh / 6 * 8;
1000         const int tiles = w_tm / 8 * h_tm / 8;
1001 
1002         #pragma omp parallel for num_threads(opt.num_threads)
1003         for (int p = 0; p < outch; p++)
1004         {
1005             const Mat out0_tm = top_blob_tm.channel(p);
1006             Mat out0 = top_blob_bordered.channel(p);
1007 
1008             //             const float bias0 = bias ? bias[p] : 0.f;
1009             float16x4_t _bias0 = bias ? vld1_f16((const __fp16*)bias + p * 4) : vdup_n_f16(0.f);
1010 
1011             __fp16 tmp[6][8][4];
1012 
1013             // tile
1014             for (int i = 0; i < outh / 6; i++)
1015             {
1016                 for (int j = 0; j < outw / 6; j++)
1017                 {
1018                     //                     top_blob_tm.create(tiles, 64, outch, elemsize, elempack);
1019 
1020                     const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 4;
1021                     const __fp16* output0_tm_1 = output0_tm_0 + tiles * 4;
1022                     const __fp16* output0_tm_2 = output0_tm_0 + tiles * 8;
1023                     const __fp16* output0_tm_3 = output0_tm_0 + tiles * 12;
1024                     const __fp16* output0_tm_4 = output0_tm_0 + tiles * 16;
1025                     const __fp16* output0_tm_5 = output0_tm_0 + tiles * 20;
1026                     const __fp16* output0_tm_6 = output0_tm_0 + tiles * 24;
1027                     const __fp16* output0_tm_7 = output0_tm_0 + tiles * 28;
1028 
1029                     __fp16* output0 = out0.row<__fp16>(i * 6) + (j * 6) * 4;
1030 
1031                     // TODO neon optimize
1032                     for (int m = 0; m < 8; m++)
1033                     {
1034                         float16x4_t _out0tm0 = vld1_f16(output0_tm_0);
1035                         float16x4_t _out0tm1 = vld1_f16(output0_tm_1);
1036                         float16x4_t _out0tm2 = vld1_f16(output0_tm_2);
1037                         float16x4_t _out0tm3 = vld1_f16(output0_tm_3);
1038                         float16x4_t _out0tm4 = vld1_f16(output0_tm_4);
1039                         float16x4_t _out0tm5 = vld1_f16(output0_tm_5);
1040                         float16x4_t _out0tm6 = vld1_f16(output0_tm_6);
1041                         float16x4_t _out0tm7 = vld1_f16(output0_tm_7);
1042 
1043                         float16x4_t _tmp024a = vadd_f16(_out0tm1, _out0tm2);
1044                         float16x4_t _tmp135a = vsub_f16(_out0tm1, _out0tm2);
1045 
1046                         //                         float tmp024a = output0_tm[1] + output0_tm[2];
1047                         //                         float tmp135a = output0_tm[1] - output0_tm[2];
1048 
1049                         float16x4_t _tmp024b = vadd_f16(_out0tm3, _out0tm4);
1050                         float16x4_t _tmp135b = vsub_f16(_out0tm3, _out0tm4);
1051 
1052                         //                         float tmp024b = output0_tm[3] + output0_tm[4];
1053                         //                         float tmp135b = output0_tm[3] - output0_tm[4];
1054 
1055                         float16x4_t _tmp024c = vadd_f16(_out0tm5, _out0tm6);
1056                         float16x4_t _tmp135c = vsub_f16(_out0tm5, _out0tm6);
1057 
1058                         //                         float tmp024c = output0_tm[5] + output0_tm[6];
1059                         //                         float tmp135c = output0_tm[5] - output0_tm[6];
1060 
1061                         float16x4_t _tmp0m = vadd_f16(vadd_f16(_out0tm0, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f));
1062                         float16x4_t _tmp2m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f);
1063                         float16x4_t _tmp4m = vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f);
1064                         vst1_f16(tmp[0][m], _tmp0m);
1065                         vst1_f16(tmp[2][m], _tmp2m);
1066                         vst1_f16(tmp[4][m], _tmp4m);
1067 
1068                         //                         tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32;
1069                         //                         tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
1070                         //                         tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1071 
1072                         float16x4_t _tmp1m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f);
1073                         float16x4_t _tmp3m = vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f);
1074                         float16x4_t _tmp5m = vadd_f16(vadd_f16(_out0tm7, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f));
1075                         vst1_f16(tmp[1][m], _tmp1m);
1076                         vst1_f16(tmp[3][m], _tmp3m);
1077                         vst1_f16(tmp[5][m], _tmp5m);
1078 
1079                         //                         tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
1080                         //                         tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
1081                         //                         tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c;
1082 
1083                         output0_tm_0 += tiles * 32;
1084                         output0_tm_1 += tiles * 32;
1085                         output0_tm_2 += tiles * 32;
1086                         output0_tm_3 += tiles * 32;
1087                         output0_tm_4 += tiles * 32;
1088                         output0_tm_5 += tiles * 32;
1089                         output0_tm_6 += tiles * 32;
1090                         output0_tm_7 += tiles * 32;
1091                     }
1092 
1093                     for (int m = 0; m < 6; m++)
1094                     {
1095                         float16x4_t _tmp00 = vld1_f16(tmp[m][0]);
1096                         float16x4_t _tmp01 = vld1_f16(tmp[m][1]);
1097                         float16x4_t _tmp02 = vld1_f16(tmp[m][2]);
1098                         float16x4_t _tmp03 = vld1_f16(tmp[m][3]);
1099                         float16x4_t _tmp04 = vld1_f16(tmp[m][4]);
1100                         float16x4_t _tmp05 = vld1_f16(tmp[m][5]);
1101                         float16x4_t _tmp06 = vld1_f16(tmp[m][6]);
1102                         float16x4_t _tmp07 = vld1_f16(tmp[m][7]);
1103 
1104                         float16x4_t _tmp024a = vadd_f16(_tmp01, _tmp02);
1105                         float16x4_t _tmp135a = vsub_f16(_tmp01, _tmp02);
1106 
1107                         //                         float tmp024a = tmp0[1] + tmp0[2];
1108                         //                         float tmp135a = tmp0[1] - tmp0[2];
1109 
1110                         float16x4_t _tmp024b = vadd_f16(_tmp03, _tmp04);
1111                         float16x4_t _tmp135b = vsub_f16(_tmp03, _tmp04);
1112 
1113                         //                         float tmp024b = tmp0[3] + tmp0[4];
1114                         //                         float tmp135b = tmp0[3] - tmp0[4];
1115 
1116                         float16x4_t _tmp024c = vadd_f16(_tmp05, _tmp06);
1117                         float16x4_t _tmp135c = vsub_f16(_tmp05, _tmp06);
1118 
1119                         //                         float tmp024c = tmp0[5] + tmp0[6];
1120                         //                         float tmp135c = tmp0[5] - tmp0[6];
1121 
1122                         float16x4_t _out00 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp00, _tmp024a), vfma_n_f16(_tmp024b, _tmp024c, 32.f)));
1123                         float16x4_t _out02 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f));
1124                         float16x4_t _out04 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f));
1125                         vst1_f16(output0, _out00);
1126                         vst1_f16(output0 + 8, _out02);
1127                         vst1_f16(output0 + 16, _out04);
1128 
1129                         //                         output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
1130                         //                         output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8;
1131                         //                         output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1132 
1133                         float16x4_t _out01 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f));
1134                         float16x4_t _out03 = vadd_f16(_bias0, vfma_n_f16(vfma_n_f16(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f));
1135                         float16x4_t _out05 = vadd_f16(_bias0, vadd_f16(vadd_f16(_tmp07, _tmp135a), vfma_n_f16(_tmp135c, _tmp135b, 32.f)));
1136                         vst1_f16(output0 + 4, _out01);
1137                         vst1_f16(output0 + 12, _out03);
1138                         vst1_f16(output0 + 20, _out05);
1139 
1140                         //                         output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16;
1141                         //                         output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4;
1142                         //                         output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
1143 
1144                         output0 += outw * 4;
1145                     }
1146                 }
1147             }
1148         }
1149     }
1150     // END transform output
1151 
1152     // cut result pad
1153     copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1154 }
1155