1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv3x3s1_winograd64_transform_kernel_pack8to1_fp16sa_neon(const Mat & kernel,Mat & kernel_tm_pack8to1,int inch,int outch)15 static void conv3x3s1_winograd64_transform_kernel_pack8to1_fp16sa_neon(const Mat& kernel, Mat& kernel_tm_pack8to1, int inch, int outch)
16 {
17 // winograd63 transform kernel
18 Mat kernel_tm;
19 kernel_tm.create(8 * 8, inch, outch);
20
21 const float ktm[8][3] = {
22 {1.0f, 0.0f, 0.0f},
23 {-2.0f / 9, -2.0f / 9, -2.0f / 9},
24 {-2.0f / 9, 2.0f / 9, -2.0f / 9},
25 {1.0f / 90, 1.0f / 45, 2.0f / 45},
26 {1.0f / 90, -1.0f / 45, 2.0f / 45},
27 {1.0f / 45, 1.0f / 90, 1.0f / 180},
28 {1.0f / 45, -1.0f / 90, 1.0f / 180},
29 {0.0f, 0.0f, 1.0f}
30 };
31
32 #pragma omp parallel for
33 for (int p = 0; p < outch; p++)
34 {
35 for (int q = 0; q < inch; q++)
36 {
37 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
38 float* kernel_tm0 = kernel_tm.channel(p).row(q);
39
40 // transform kernel, transposed
41 const float* k0 = kernel0;
42 const float* k1 = kernel0 + 3;
43 const float* k2 = kernel0 + 6;
44
45 // h
46 float tmp[8][3];
47 for (int i = 0; i < 8; i++)
48 {
49 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
50 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
51 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
52 }
53
54 // v
55 for (int j = 0; j < 8; j++)
56 {
57 float* tmpp = &tmp[j][0];
58
59 for (int i = 0; i < 8; i++)
60 {
61 kernel_tm0[j * 8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
62 }
63 }
64 }
65 }
66
67 // interleave
68 // src = 64-inch-outch
69 // dst = 8a-inch/8a-64-outch;
70 kernel_tm_pack8to1.create(8 * inch / 8, 64, outch / 8 + outch % 8, (size_t)2u * 8, 8);
71
72 int p = 0;
73 for (; p + 7 < outch; p += 8)
74 {
75 const Mat k0 = kernel_tm.channel(p);
76 const Mat k1 = kernel_tm.channel(p + 1);
77 const Mat k2 = kernel_tm.channel(p + 2);
78 const Mat k3 = kernel_tm.channel(p + 3);
79 const Mat k4 = kernel_tm.channel(p + 4);
80 const Mat k5 = kernel_tm.channel(p + 5);
81 const Mat k6 = kernel_tm.channel(p + 6);
82 const Mat k7 = kernel_tm.channel(p + 7);
83
84 Mat g0 = kernel_tm_pack8to1.channel(p / 8);
85
86 for (int k = 0; k < 64; k++)
87 {
88 __fp16* g00 = g0.row<__fp16>(k);
89
90 for (int q = 0; q + 7 < inch; q += 8)
91 {
92 for (int i = 0; i < 8; i++)
93 {
94 g00[0] = (__fp16)k0.row(q + i)[k];
95 g00[1] = (__fp16)k1.row(q + i)[k];
96 g00[2] = (__fp16)k2.row(q + i)[k];
97 g00[3] = (__fp16)k3.row(q + i)[k];
98 g00[4] = (__fp16)k4.row(q + i)[k];
99 g00[5] = (__fp16)k5.row(q + i)[k];
100 g00[6] = (__fp16)k6.row(q + i)[k];
101 g00[7] = (__fp16)k7.row(q + i)[k];
102
103 g00 += 8;
104 }
105 }
106 }
107 }
108 for (; p < outch; p++)
109 {
110 const Mat k0 = kernel_tm.channel(p);
111
112 Mat g0 = kernel_tm_pack8to1.channel(p / 8 + p % 8);
113
114 for (int k = 0; k < 64; k++)
115 {
116 __fp16* g00 = g0.row<__fp16>(k);
117
118 for (int q = 0; q + 7 < inch; q += 8)
119 {
120 for (int i = 0; i < 8; i++)
121 {
122 g00[0] = (__fp16)k0.row(q + i)[k];
123
124 g00 += 1;
125 }
126 }
127 }
128 }
129 }
130
conv3x3s1_winograd64_pack8to1_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)131 static void conv3x3s1_winograd64_pack8to1_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
132 {
133 int w = bottom_blob.w;
134 int h = bottom_blob.h;
135 int inch = bottom_blob.c;
136 //size_t elemsize = bottom_blob.elemsize;
137 int elempack = bottom_blob.elempack;
138
139 int outw = top_blob.w;
140 int outh = top_blob.h;
141 int outch = top_blob.c;
142
143 // pad to 6n+2
144 Mat bottom_blob_bordered = bottom_blob;
145
146 outw = (outw + 5) / 6 * 6;
147 outh = (outh + 5) / 6 * 6;
148
149 w = outw + 2;
150 h = outh + 2;
151 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
152
153 const __fp16* bias = _bias;
154
155 // BEGIN transform input
156 Mat bottom_blob_tm;
157 {
158 int w_tm = outw / 6 * 8;
159 int h_tm = outh / 6 * 8;
160
161 const int tiles = w_tm / 8 * h_tm / 8;
162
163 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
164 bottom_blob_tm.create(tiles, 64, inch, 2u * elempack, elempack, opt.workspace_allocator);
165
166 // const float itm[8][8] = {
167 // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f},
168 //
169 // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f},
170 // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f},
171 //
172 // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f},
173 // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f},
174 //
175 // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f},
176 // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f},
177 //
178 // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f}
179 // };
180
181 // 0 = r00 - r06 + (r04 - r02) * 5.25
182 // 7 = r07 - r01 + (r03 - r05) * 5.25
183
184 // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
185 // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
186
187 // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
188 // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
189
190 // reuse r04 * 1.25
191 // reuse r03 * 2.5
192 // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
193 // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
194
195 #pragma omp parallel for num_threads(opt.num_threads)
196 for (int q = 0; q < inch; q++)
197 {
198 const Mat img0 = bottom_blob_bordered.channel(q);
199 Mat img0_tm = bottom_blob_tm.channel(q);
200
201 __fp16 tmp[8][8][8];
202
203 // tile
204 for (int i = 0; i < h_tm / 8; i++)
205 {
206 for (int j = 0; j < w_tm / 8; j++)
207 {
208 const __fp16* r0 = img0.row<const __fp16>(i * 6) + (j * 6) * 8;
209
210 for (int m = 0; m < 8; m++)
211 {
212 float16x8_t _r00 = vld1q_f16(r0);
213 float16x8_t _r01 = vld1q_f16(r0 + 8);
214 float16x8_t _r02 = vld1q_f16(r0 + 16);
215 float16x8_t _r03 = vld1q_f16(r0 + 24);
216 float16x8_t _r04 = vld1q_f16(r0 + 32);
217 float16x8_t _r05 = vld1q_f16(r0 + 40);
218 float16x8_t _r06 = vld1q_f16(r0 + 48);
219 float16x8_t _r07 = vld1q_f16(r0 + 56);
220
221 float16x8_t _tmp0m = vfmaq_n_f16(vsubq_f16(_r00, _r06), vsubq_f16(_r04, _r02), 5.25f);
222 float16x8_t _tmp7m = vfmaq_n_f16(vsubq_f16(_r07, _r01), vsubq_f16(_r03, _r05), 5.25f);
223 vst1q_f16(tmp[0][m], _tmp0m);
224 vst1q_f16(tmp[7][m], _tmp7m);
225
226 // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25;
227 // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25;
228
229 float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_r02, _r06), _r04, 4.25f);
230 float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_r01, _r05), _r03, 4.25f);
231
232 // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25);
233 // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25);
234
235 float16x8_t _tmp1m = vaddq_f16(_tmp12a, _tmp12b);
236 float16x8_t _tmp2m = vsubq_f16(_tmp12a, _tmp12b);
237 vst1q_f16(tmp[1][m], _tmp1m);
238 vst1q_f16(tmp[2][m], _tmp2m);
239
240 // tmp[1][m] = tmp12a + tmp12b;
241 // tmp[2][m] = tmp12a - tmp12b;
242
243 float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_r06, _r02, 0.25f), _r04, 1.25f);
244 float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 0.5f), _r03, 2.5f), _r05, 2.f);
245
246 // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25);
247 // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2);
248
249 float16x8_t _tmp3m = vaddq_f16(_tmp34a, _tmp34b);
250 float16x8_t _tmp4m = vsubq_f16(_tmp34a, _tmp34b);
251 vst1q_f16(tmp[3][m], _tmp3m);
252 vst1q_f16(tmp[4][m], _tmp4m);
253
254 // tmp[3][m] = tmp34a + tmp34b;
255 // tmp[4][m] = tmp34a - tmp34b;
256
257 float16x8_t _tmp56a = vfmaq_n_f16(_r06, vfmsq_n_f16(_r02, _r04, 1.25f), 4.f);
258 float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_r01, 2.f), _r03, 2.5f), _r05, 0.5f);
259
260 // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4);
261 // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5);
262
263 float16x8_t _tmp5m = vaddq_f16(_tmp56a, _tmp56b);
264 float16x8_t _tmp6m = vsubq_f16(_tmp56a, _tmp56b);
265 vst1q_f16(tmp[5][m], _tmp5m);
266 vst1q_f16(tmp[6][m], _tmp6m);
267
268 // tmp[5][m] = tmp56a + tmp56b;
269 // tmp[6][m] = tmp56a - tmp56b;
270
271 r0 += w * 8;
272 }
273
274 __fp16* r0_tm_0 = (__fp16*)img0_tm + (i * w_tm / 8 + j) * 8;
275 __fp16* r0_tm_1 = r0_tm_0 + tiles * 8;
276 __fp16* r0_tm_2 = r0_tm_0 + tiles * 16;
277 __fp16* r0_tm_3 = r0_tm_0 + tiles * 24;
278 __fp16* r0_tm_4 = r0_tm_0 + tiles * 32;
279 __fp16* r0_tm_5 = r0_tm_0 + tiles * 40;
280 __fp16* r0_tm_6 = r0_tm_0 + tiles * 48;
281 __fp16* r0_tm_7 = r0_tm_0 + tiles * 56;
282
283 for (int m = 0; m < 8; m++)
284 {
285 float16x8_t _tmp00 = vld1q_f16(tmp[m][0]);
286 float16x8_t _tmp01 = vld1q_f16(tmp[m][1]);
287 float16x8_t _tmp02 = vld1q_f16(tmp[m][2]);
288 float16x8_t _tmp03 = vld1q_f16(tmp[m][3]);
289 float16x8_t _tmp04 = vld1q_f16(tmp[m][4]);
290 float16x8_t _tmp05 = vld1q_f16(tmp[m][5]);
291 float16x8_t _tmp06 = vld1q_f16(tmp[m][6]);
292 float16x8_t _tmp07 = vld1q_f16(tmp[m][7]);
293
294 float16x8_t _r0tm0 = vfmaq_n_f16(vsubq_f16(_tmp00, _tmp06), vsubq_f16(_tmp04, _tmp02), 5.25f);
295 float16x8_t _r0tm7 = vfmaq_n_f16(vsubq_f16(_tmp07, _tmp01), vsubq_f16(_tmp03, _tmp05), 5.25f);
296
297 // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25;
298 // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25;
299
300 float16x8_t _tmp12a = vfmsq_n_f16(vaddq_f16(_tmp02, _tmp06), _tmp04, 4.25f);
301 float16x8_t _tmp12b = vfmsq_n_f16(vaddq_f16(_tmp01, _tmp05), _tmp03, 4.25f);
302
303 // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25);
304 // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25);
305
306 float16x8_t _r0tm1 = vaddq_f16(_tmp12a, _tmp12b);
307 float16x8_t _r0tm2 = vsubq_f16(_tmp12a, _tmp12b);
308
309 // r0_tm[1] = tmp12a + tmp12b;
310 // r0_tm[2] = tmp12a - tmp12b;
311
312 float16x8_t _tmp34a = vfmsq_n_f16(vfmaq_n_f16(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f);
313 float16x8_t _tmp34b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 0.5f), _tmp03, 2.5f), _tmp05, 2.f);
314
315 // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25);
316 // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2);
317
318 float16x8_t _r0tm3 = vaddq_f16(_tmp34a, _tmp34b);
319 float16x8_t _r0tm4 = vsubq_f16(_tmp34a, _tmp34b);
320
321 // r0_tm[3] = tmp34a + tmp34b;
322 // r0_tm[4] = tmp34a - tmp34b;
323
324 float16x8_t _tmp56a = vfmaq_n_f16(_tmp06, vfmsq_n_f16(_tmp02, _tmp04, 1.25f), 4.f);
325 float16x8_t _tmp56b = vfmaq_n_f16(vfmsq_n_f16(vmulq_n_f16(_tmp01, 2.f), _tmp03, 2.5f), _tmp05, 0.5f);
326
327 // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4);
328 // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5);
329
330 float16x8_t _r0tm5 = vaddq_f16(_tmp56a, _tmp56b);
331 float16x8_t _r0tm6 = vsubq_f16(_tmp56a, _tmp56b);
332
333 // r0_tm[5] = tmp56a + tmp56b;
334 // r0_tm[6] = tmp56a - tmp56b;
335
336 vst1q_f16(r0_tm_0, _r0tm0);
337 vst1q_f16(r0_tm_1, _r0tm1);
338 vst1q_f16(r0_tm_2, _r0tm2);
339 vst1q_f16(r0_tm_3, _r0tm3);
340 vst1q_f16(r0_tm_4, _r0tm4);
341 vst1q_f16(r0_tm_5, _r0tm5);
342 vst1q_f16(r0_tm_6, _r0tm6);
343 vst1q_f16(r0_tm_7, _r0tm7);
344
345 r0_tm_0 += tiles * 64;
346 r0_tm_1 += tiles * 64;
347 r0_tm_2 += tiles * 64;
348 r0_tm_3 += tiles * 64;
349 r0_tm_4 += tiles * 64;
350 r0_tm_5 += tiles * 64;
351 r0_tm_6 += tiles * 64;
352 r0_tm_7 += tiles * 64;
353 }
354 }
355 }
356 }
357 }
358 bottom_blob_bordered = Mat();
359 // END transform input
360
361 // BEGIN dot
362 Mat top_blob_tm;
363 {
364 int w_tm = outw / 6 * 8;
365 int h_tm = outh / 6 * 8;
366
367 const int tiles = h_tm / 8 * w_tm / 8;
368
369 // permute
370 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
371 Mat bottom_blob_tm2;
372 if (tiles >= 8)
373 bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
374 else if (tiles >= 4)
375 bottom_blob_tm2.create(4 * inch, tiles / 4 + tiles % 4, 64, 2u * elempack, elempack, opt.workspace_allocator);
376 else // if (tiles >= 1)
377 bottom_blob_tm2.create(1 * inch, tiles, 64, 2u * elempack, elempack, opt.workspace_allocator);
378
379 #pragma omp parallel for num_threads(opt.num_threads)
380 for (int r = 0; r < 64; r++)
381 {
382 Mat tm2 = bottom_blob_tm2.channel(r);
383
384 // tile
385 int i = 0;
386 for (; i + 7 < tiles; i += 8)
387 {
388 __fp16* tm2p = tm2.row<__fp16>(i / 8);
389
390 const __fp16* r0 = bottom_blob_tm;
391
392 r0 += (r * tiles + i) * 8;
393
394 for (int q = 0; q < inch; q++)
395 {
396 // transpose 8x8
397 asm volatile(
398 "prfm pldl1keep, [%0, #512] \n"
399 "ld4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0], #64 \n"
400 "ld4 {v4.8h, v5.8h, v6.8h, v7.8h}, [%0] \n"
401 "sub %0, %0, #64 \n"
402
403 "uzp1 v16.8h, v0.8h, v4.8h \n"
404 "uzp2 v20.8h, v0.8h, v4.8h \n"
405 "uzp1 v17.8h, v1.8h, v5.8h \n"
406 "uzp2 v21.8h, v1.8h, v5.8h \n"
407 "uzp1 v18.8h, v2.8h, v6.8h \n"
408 "uzp2 v22.8h, v2.8h, v6.8h \n"
409 "uzp1 v19.8h, v3.8h, v7.8h \n"
410 "uzp2 v23.8h, v3.8h, v7.8h \n"
411
412 "st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%1], #64 \n"
413 "st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%1], #64 \n"
414 : "=r"(r0), // %0
415 "=r"(tm2p) // %1
416 : "0"(r0),
417 "1"(tm2p)
418 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
419
420 r0 += bottom_blob_tm.cstep * 8;
421 }
422 }
423 for (; i + 3 < tiles; i += 4)
424 {
425 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4);
426
427 const __fp16* r0 = bottom_blob_tm;
428
429 r0 += (r * tiles + i) * 8;
430
431 for (int q = 0; q < inch; q++)
432 {
433 // transpose 8x4
434 asm volatile(
435 "prfm pldl1keep, [%0, #256] \n"
436 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%0] \n"
437 "st4 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
438 : "=r"(r0), // %0
439 "=r"(tm2p) // %1
440 : "0"(r0),
441 "1"(tm2p)
442 : "memory", "v0", "v1", "v2", "v3");
443
444 r0 += bottom_blob_tm.cstep * 8;
445 }
446 }
447 for (; i < tiles; i++)
448 {
449 __fp16* tm2p = tm2.row<__fp16>(i / 8 + (i % 8) / 4 + i % 4);
450
451 const __fp16* r0 = bottom_blob_tm;
452
453 r0 += (r * tiles + i) * 8;
454
455 for (int q = 0; q < inch; q++)
456 {
457 asm volatile(
458 "prfm pldl1keep, [%0, #128] \n"
459 "ld1 {v0.8h}, [%0] \n"
460 "st1 {v0.8h}, [%1], #16 \n"
461 : "=r"(r0), // %0
462 "=r"(tm2p) // %1
463 : "0"(r0),
464 "1"(tm2p)
465 : "memory", "v0");
466
467 r0 += bottom_blob_tm.cstep * 8;
468 }
469 }
470 }
471
472 bottom_blob_tm = Mat();
473 // permute end
474
475 top_blob_tm.create(tiles, 64, outch, 2u, 1, opt.workspace_allocator);
476
477 int nn_outch = 0;
478 int remain_outch_start = 0;
479
480 nn_outch = outch >> 3;
481
482 #pragma omp parallel for num_threads(opt.num_threads)
483 for (int pp = 0; pp < nn_outch; pp++)
484 {
485 int p = pp * 8;
486
487 __fp16* output0_tm = top_blob_tm.channel(p);
488 __fp16* output1_tm = top_blob_tm.channel(p + 1);
489 __fp16* output2_tm = top_blob_tm.channel(p + 2);
490 __fp16* output3_tm = top_blob_tm.channel(p + 3);
491 __fp16* output4_tm = top_blob_tm.channel(p + 4);
492 __fp16* output5_tm = top_blob_tm.channel(p + 5);
493 __fp16* output6_tm = top_blob_tm.channel(p + 6);
494 __fp16* output7_tm = top_blob_tm.channel(p + 7);
495
496 const Mat kernel01_tm = kernel_tm.channel(p / 8);
497
498 for (int r = 0; r < 64; r++)
499 {
500 const Mat bb2 = bottom_blob_tm2.channel(r);
501
502 int i = 0;
503 for (; i + 7 < tiles; i += 8)
504 {
505 const __fp16* r0 = bb2.row<const __fp16>(i / 8);
506
507 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
508
509 int nn = inch; // inch always > 0
510
511 asm volatile(
512 "eor v24.16b, v24.16b, v24.16b \n"
513 "eor v25.16b, v25.16b, v25.16b \n"
514 "eor v26.16b, v26.16b, v26.16b \n"
515 "eor v27.16b, v27.16b, v27.16b \n"
516 "eor v28.16b, v28.16b, v28.16b \n"
517 "eor v29.16b, v29.16b, v29.16b \n"
518 "eor v30.16b, v30.16b, v30.16b \n"
519 "eor v31.16b, v31.16b, v31.16b \n"
520
521 "0: \n"
522
523 "prfm pldl1keep, [%9, #512] \n"
524 "ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%9], #64 \n"
525
526 "prfm pldl1keep, [%10, #512] \n"
527 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%10], #64 \n"
528
529 "fmla v24.8h, v16.8h, v0.h[0] \n"
530 "fmla v25.8h, v16.8h, v0.h[1] \n"
531 "fmla v26.8h, v16.8h, v0.h[2] \n"
532 "fmla v27.8h, v16.8h, v0.h[3] \n"
533 "fmla v28.8h, v16.8h, v0.h[4] \n"
534 "fmla v29.8h, v16.8h, v0.h[5] \n"
535 "fmla v30.8h, v16.8h, v0.h[6] \n"
536 "fmla v31.8h, v16.8h, v0.h[7] \n"
537
538 "fmla v24.8h, v17.8h, v1.h[0] \n"
539 "fmla v25.8h, v17.8h, v1.h[1] \n"
540 "fmla v26.8h, v17.8h, v1.h[2] \n"
541 "fmla v27.8h, v17.8h, v1.h[3] \n"
542 "fmla v28.8h, v17.8h, v1.h[4] \n"
543 "fmla v29.8h, v17.8h, v1.h[5] \n"
544 "fmla v30.8h, v17.8h, v1.h[6] \n"
545 "fmla v31.8h, v17.8h, v1.h[7] \n"
546
547 "prfm pldl1keep, [%9, #512] \n"
548 "ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%9], #64 \n"
549
550 "fmla v24.8h, v18.8h, v2.h[0] \n"
551 "fmla v25.8h, v18.8h, v2.h[1] \n"
552 "fmla v26.8h, v18.8h, v2.h[2] \n"
553 "fmla v27.8h, v18.8h, v2.h[3] \n"
554 "fmla v28.8h, v18.8h, v2.h[4] \n"
555 "fmla v29.8h, v18.8h, v2.h[5] \n"
556 "fmla v30.8h, v18.8h, v2.h[6] \n"
557 "fmla v31.8h, v18.8h, v2.h[7] \n"
558
559 "prfm pldl1keep, [%10, #512] \n"
560 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%10], #64 \n"
561
562 "fmla v24.8h, v19.8h, v3.h[0] \n"
563 "fmla v25.8h, v19.8h, v3.h[1] \n"
564 "fmla v26.8h, v19.8h, v3.h[2] \n"
565 "fmla v27.8h, v19.8h, v3.h[3] \n"
566 "fmla v28.8h, v19.8h, v3.h[4] \n"
567 "fmla v29.8h, v19.8h, v3.h[5] \n"
568 "fmla v30.8h, v19.8h, v3.h[6] \n"
569 "fmla v31.8h, v19.8h, v3.h[7] \n"
570
571 "fmla v24.8h, v20.8h, v4.h[0] \n"
572 "fmla v25.8h, v20.8h, v4.h[1] \n"
573 "fmla v26.8h, v20.8h, v4.h[2] \n"
574 "fmla v27.8h, v20.8h, v4.h[3] \n"
575 "fmla v28.8h, v20.8h, v4.h[4] \n"
576 "fmla v29.8h, v20.8h, v4.h[5] \n"
577 "fmla v30.8h, v20.8h, v4.h[6] \n"
578 "fmla v31.8h, v20.8h, v4.h[7] \n"
579
580 "fmla v24.8h, v21.8h, v5.h[0] \n"
581 "fmla v25.8h, v21.8h, v5.h[1] \n"
582 "fmla v26.8h, v21.8h, v5.h[2] \n"
583 "fmla v27.8h, v21.8h, v5.h[3] \n"
584 "fmla v28.8h, v21.8h, v5.h[4] \n"
585 "fmla v29.8h, v21.8h, v5.h[5] \n"
586 "fmla v30.8h, v21.8h, v5.h[6] \n"
587 "fmla v31.8h, v21.8h, v5.h[7] \n"
588
589 "fmla v24.8h, v22.8h, v6.h[0] \n"
590 "fmla v25.8h, v22.8h, v6.h[1] \n"
591 "fmla v26.8h, v22.8h, v6.h[2] \n"
592 "fmla v27.8h, v22.8h, v6.h[3] \n"
593 "fmla v28.8h, v22.8h, v6.h[4] \n"
594 "fmla v29.8h, v22.8h, v6.h[5] \n"
595 "fmla v30.8h, v22.8h, v6.h[6] \n"
596 "fmla v31.8h, v22.8h, v6.h[7] \n"
597
598 "subs %w0, %w0, #1 \n"
599
600 "fmla v24.8h, v23.8h, v7.h[0] \n"
601 "fmla v25.8h, v23.8h, v7.h[1] \n"
602 "fmla v26.8h, v23.8h, v7.h[2] \n"
603 "fmla v27.8h, v23.8h, v7.h[3] \n"
604 "fmla v28.8h, v23.8h, v7.h[4] \n"
605 "fmla v29.8h, v23.8h, v7.h[5] \n"
606 "fmla v30.8h, v23.8h, v7.h[6] \n"
607 "fmla v31.8h, v23.8h, v7.h[7] \n"
608
609 "bne 0b \n"
610
611 "st1 {v24.8h}, [%1], #16 \n"
612 "st1 {v25.8h}, [%2], #16 \n"
613 "st1 {v26.8h}, [%3], #16 \n"
614 "st1 {v27.8h}, [%4], #16 \n"
615 "st1 {v28.8h}, [%5], #16 \n"
616 "st1 {v29.8h}, [%6], #16 \n"
617 "st1 {v30.8h}, [%7], #16 \n"
618 "st1 {v31.8h}, [%8], #16 \n"
619
620 : "=r"(nn), // %0
621 "=r"(output0_tm), // %1
622 "=r"(output1_tm), // %2
623 "=r"(output2_tm), // %3
624 "=r"(output3_tm), // %4
625 "=r"(output4_tm), // %5
626 "=r"(output5_tm), // %6
627 "=r"(output6_tm), // %7
628 "=r"(output7_tm), // %8
629 "=r"(r0), // %9
630 "=r"(kptr) // %10
631 : "0"(nn),
632 "1"(output0_tm),
633 "2"(output1_tm),
634 "3"(output2_tm),
635 "4"(output3_tm),
636 "5"(output4_tm),
637 "6"(output5_tm),
638 "7"(output6_tm),
639 "8"(output7_tm),
640 "9"(r0),
641 "10"(kptr)
642 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
643 }
644 for (; i + 3 < tiles; i += 4)
645 {
646 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
647
648 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
649
650 int nn = inch; // inch always > 0
651
652 asm volatile(
653 "eor v24.16b, v24.16b, v24.16b \n"
654 "eor v25.16b, v25.16b, v25.16b \n"
655 "eor v26.16b, v26.16b, v26.16b \n"
656 "eor v27.16b, v27.16b, v27.16b \n"
657 "eor v28.16b, v28.16b, v28.16b \n"
658 "eor v29.16b, v29.16b, v29.16b \n"
659 "eor v30.16b, v30.16b, v30.16b \n"
660 "eor v31.16b, v31.16b, v31.16b \n"
661
662 "0: \n"
663
664 "prfm pldl1keep, [%9, #256] \n"
665 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%9], #32 \n"
666
667 "prfm pldl1keep, [%10, #512] \n"
668 "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%10], #64 \n"
669
670 "fmla v24.4h, v16.4h, v0.h[0] \n"
671 "fmla v25.4h, v16.4h, v0.h[1] \n"
672 "fmla v26.4h, v16.4h, v0.h[2] \n"
673 "fmla v27.4h, v16.4h, v0.h[3] \n"
674 "fmla v28.4h, v16.4h, v0.h[4] \n"
675 "fmla v29.4h, v16.4h, v0.h[5] \n"
676 "fmla v30.4h, v16.4h, v0.h[6] \n"
677 "fmla v31.4h, v16.4h, v0.h[7] \n"
678
679 "fmla v24.4h, v17.4h, v1.h[0] \n"
680 "fmla v25.4h, v17.4h, v1.h[1] \n"
681 "fmla v26.4h, v17.4h, v1.h[2] \n"
682 "fmla v27.4h, v17.4h, v1.h[3] \n"
683 "fmla v28.4h, v17.4h, v1.h[4] \n"
684 "fmla v29.4h, v17.4h, v1.h[5] \n"
685 "fmla v30.4h, v17.4h, v1.h[6] \n"
686 "fmla v31.4h, v17.4h, v1.h[7] \n"
687
688 "prfm pldl1keep, [%9, #256] \n"
689 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%9], #32 \n"
690
691 "fmla v24.4h, v18.4h, v2.h[0] \n"
692 "fmla v25.4h, v18.4h, v2.h[1] \n"
693 "fmla v26.4h, v18.4h, v2.h[2] \n"
694 "fmla v27.4h, v18.4h, v2.h[3] \n"
695 "fmla v28.4h, v18.4h, v2.h[4] \n"
696 "fmla v29.4h, v18.4h, v2.h[5] \n"
697 "fmla v30.4h, v18.4h, v2.h[6] \n"
698 "fmla v31.4h, v18.4h, v2.h[7] \n"
699
700 "prfm pldl1keep, [%10, #512] \n"
701 "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [%10], #64 \n"
702
703 "fmla v24.4h, v19.4h, v3.h[0] \n"
704 "fmla v25.4h, v19.4h, v3.h[1] \n"
705 "fmla v26.4h, v19.4h, v3.h[2] \n"
706 "fmla v27.4h, v19.4h, v3.h[3] \n"
707 "fmla v28.4h, v19.4h, v3.h[4] \n"
708 "fmla v29.4h, v19.4h, v3.h[5] \n"
709 "fmla v30.4h, v19.4h, v3.h[6] \n"
710 "fmla v31.4h, v19.4h, v3.h[7] \n"
711
712 "fmla v24.4h, v20.4h, v4.h[0] \n"
713 "fmla v25.4h, v20.4h, v4.h[1] \n"
714 "fmla v26.4h, v20.4h, v4.h[2] \n"
715 "fmla v27.4h, v20.4h, v4.h[3] \n"
716 "fmla v28.4h, v20.4h, v4.h[4] \n"
717 "fmla v29.4h, v20.4h, v4.h[5] \n"
718 "fmla v30.4h, v20.4h, v4.h[6] \n"
719 "fmla v31.4h, v20.4h, v4.h[7] \n"
720
721 "fmla v24.4h, v21.4h, v5.h[0] \n"
722 "fmla v25.4h, v21.4h, v5.h[1] \n"
723 "fmla v26.4h, v21.4h, v5.h[2] \n"
724 "fmla v27.4h, v21.4h, v5.h[3] \n"
725 "fmla v28.4h, v21.4h, v5.h[4] \n"
726 "fmla v29.4h, v21.4h, v5.h[5] \n"
727 "fmla v30.4h, v21.4h, v5.h[6] \n"
728 "fmla v31.4h, v21.4h, v5.h[7] \n"
729
730 "fmla v24.4h, v22.4h, v6.h[0] \n"
731 "fmla v25.4h, v22.4h, v6.h[1] \n"
732 "fmla v26.4h, v22.4h, v6.h[2] \n"
733 "fmla v27.4h, v22.4h, v6.h[3] \n"
734 "fmla v28.4h, v22.4h, v6.h[4] \n"
735 "fmla v29.4h, v22.4h, v6.h[5] \n"
736 "fmla v30.4h, v22.4h, v6.h[6] \n"
737 "fmla v31.4h, v22.4h, v6.h[7] \n"
738
739 "subs %w0, %w0, #1 \n"
740
741 "fmla v24.4h, v23.4h, v7.h[0] \n"
742 "fmla v25.4h, v23.4h, v7.h[1] \n"
743 "fmla v26.4h, v23.4h, v7.h[2] \n"
744 "fmla v27.4h, v23.4h, v7.h[3] \n"
745 "fmla v28.4h, v23.4h, v7.h[4] \n"
746 "fmla v29.4h, v23.4h, v7.h[5] \n"
747 "fmla v30.4h, v23.4h, v7.h[6] \n"
748 "fmla v31.4h, v23.4h, v7.h[7] \n"
749
750 "bne 0b \n"
751
752 "st1 {v24.4h}, [%1], #8 \n"
753 "st1 {v25.4h}, [%2], #8 \n"
754 "st1 {v26.4h}, [%3], #8 \n"
755 "st1 {v27.4h}, [%4], #8 \n"
756 "st1 {v28.4h}, [%5], #8 \n"
757 "st1 {v29.4h}, [%6], #8 \n"
758 "st1 {v30.4h}, [%7], #8 \n"
759 "st1 {v31.4h}, [%8], #8 \n"
760
761 : "=r"(nn), // %0
762 "=r"(output0_tm), // %1
763 "=r"(output1_tm), // %2
764 "=r"(output2_tm), // %3
765 "=r"(output3_tm), // %4
766 "=r"(output4_tm), // %5
767 "=r"(output5_tm), // %6
768 "=r"(output6_tm), // %7
769 "=r"(output7_tm), // %8
770 "=r"(r0), // %9
771 "=r"(kptr) // %10
772 : "0"(nn),
773 "1"(output0_tm),
774 "2"(output1_tm),
775 "3"(output2_tm),
776 "4"(output3_tm),
777 "5"(output4_tm),
778 "6"(output5_tm),
779 "7"(output6_tm),
780 "8"(output7_tm),
781 "9"(r0),
782 "10"(kptr)
783 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
784 }
785 for (; i < tiles; i++)
786 {
787 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
788
789 const __fp16* kptr = kernel01_tm.row<const __fp16>(r);
790
791 int nn = inch; // inch always > 0
792
793 asm volatile(
794 "eor v30.16b, v30.16b, v30.16b \n"
795
796 "0: \n"
797
798 "prfm pldl1keep, [%9, #128] \n"
799 "ld1 {v0.8h}, [%9], #16 \n"
800
801 "prfm pldl1keep, [%10, #512] \n"
802 "ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%10], #64 \n"
803
804 "fmla v30.8h, v16.8h, v0.h[0] \n"
805 "fmla v30.8h, v17.8h, v0.h[1] \n"
806
807 "prfm pldl1keep, [%10, #512] \n"
808 "ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%10], #64 \n"
809
810 "fmla v30.8h, v18.8h, v0.h[2] \n"
811 "fmla v30.8h, v19.8h, v0.h[3] \n"
812
813 "subs %w0, %w0, #1 \n"
814
815 "fmla v30.8h, v20.8h, v0.h[4] \n"
816 "fmla v30.8h, v21.8h, v0.h[5] \n"
817 "fmla v30.8h, v22.8h, v0.h[6] \n"
818 "fmla v30.8h, v23.8h, v0.h[7] \n"
819
820 "bne 0b \n"
821
822 "st1 {v30.h}[0], [%1], #2 \n"
823 "st1 {v30.h}[1], [%2], #2 \n"
824 "st1 {v30.h}[2], [%3], #2 \n"
825 "st1 {v30.h}[3], [%4], #2 \n"
826 "st1 {v30.h}[4], [%5], #2 \n"
827 "st1 {v30.h}[5], [%6], #2 \n"
828 "st1 {v30.h}[6], [%7], #2 \n"
829 "st1 {v30.h}[7], [%8], #2 \n"
830
831 : "=r"(nn), // %0
832 "=r"(output0_tm), // %1
833 "=r"(output1_tm), // %2
834 "=r"(output2_tm), // %3
835 "=r"(output3_tm), // %4
836 "=r"(output4_tm), // %5
837 "=r"(output5_tm), // %6
838 "=r"(output6_tm), // %7
839 "=r"(output7_tm), // %8
840 "=r"(r0), // %9
841 "=r"(kptr) // %10
842 : "0"(nn),
843 "1"(output0_tm),
844 "2"(output1_tm),
845 "3"(output2_tm),
846 "4"(output3_tm),
847 "5"(output4_tm),
848 "6"(output5_tm),
849 "7"(output6_tm),
850 "8"(output7_tm),
851 "9"(r0),
852 "10"(kptr)
853 : "cc", "memory", "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v30");
854 }
855 }
856 }
857
858 remain_outch_start += nn_outch << 3;
859
860 #pragma omp parallel for num_threads(opt.num_threads)
861 for (int p = remain_outch_start; p < outch; p++)
862 {
863 __fp16* output0_tm = top_blob_tm.channel(p);
864
865 const Mat kernel0_tm = kernel_tm.channel(p / 8 + p % 8);
866
867 for (int r = 0; r < 64; r++)
868 {
869 const Mat bb2 = bottom_blob_tm2.channel(r);
870
871 int i = 0;
872 for (; i + 7 < tiles; i += 8)
873 {
874 const __fp16* r0 = bb2.row<const __fp16>(i / 8);
875
876 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
877
878 int nn = inch; // inch always > 0
879
880 asm volatile(
881 "eor v30.16b, v30.16b, v30.16b \n"
882
883 "0: \n"
884
885 "prfm pldl1keep, [%2, #512] \n"
886 "ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [%2], #64 \n"
887
888 "prfm pldl1keep, [%3, #128] \n"
889 "ld1 {v0.8h}, [%3], #16 \n"
890
891 "fmla v30.8h, v16.8h, v0.h[0] \n"
892 "fmla v30.8h, v17.8h, v0.h[1] \n"
893
894 "prfm pldl1keep, [%2, #512] \n"
895 "ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [%2], #64 \n"
896
897 "fmla v30.8h, v18.8h, v0.h[2] \n"
898 "fmla v30.8h, v19.8h, v0.h[3] \n"
899
900 "subs %w0, %w0, #1 \n"
901
902 "fmla v30.8h, v20.8h, v0.h[4] \n"
903 "fmla v30.8h, v21.8h, v0.h[5] \n"
904 "fmla v30.8h, v22.8h, v0.h[6] \n"
905 "fmla v30.8h, v23.8h, v0.h[7] \n"
906
907 "bne 0b \n"
908
909 "st1 {v30.8h}, [%1], #16 \n"
910
911 : "=r"(nn), // %0
912 "=r"(output0_tm), // %1
913 "=r"(r0), // %2
914 "=r"(kptr) // %3
915 : "0"(nn),
916 "1"(output0_tm),
917 "2"(r0),
918 "3"(kptr)
919 : "cc", "memory", "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v30");
920 }
921 for (; i + 3 < tiles; i += 4)
922 {
923 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4);
924
925 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
926
927 int nn = inch; // inch always > 0
928
929 asm volatile(
930 "eor v30.16b, v30.16b, v30.16b \n"
931
932 "0: \n"
933
934 "prfm pldl1keep, [%2, #256] \n"
935 "ld1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%2], #32 \n"
936
937 "prfm pldl1keep, [%3, #128] \n"
938 "ld1 {v0.8h}, [%3], #16 \n"
939
940 "fmla v30.4h, v16.4h, v0.h[0] \n"
941 "fmla v30.4h, v17.4h, v0.h[1] \n"
942
943 "prfm pldl1keep, [%2, #256] \n"
944 "ld1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%2], #32 \n"
945
946 "fmla v30.4h, v18.4h, v0.h[2] \n"
947 "fmla v30.4h, v19.4h, v0.h[3] \n"
948
949 "subs %w0, %w0, #1 \n"
950
951 "fmla v30.4h, v20.4h, v0.h[4] \n"
952 "fmla v30.4h, v21.4h, v0.h[5] \n"
953 "fmla v30.4h, v22.4h, v0.h[6] \n"
954 "fmla v30.4h, v23.4h, v0.h[7] \n"
955
956 "bne 0b \n"
957
958 "st1 {v30.4h}, [%1], #8 \n"
959
960 : "=r"(nn), // %0
961 "=r"(output0_tm), // %1
962 "=r"(r0), // %2
963 "=r"(kptr) // %3
964 : "0"(nn),
965 "1"(output0_tm),
966 "2"(r0),
967 "3"(kptr)
968 : "cc", "memory", "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v30");
969 }
970 for (; i < tiles; i++)
971 {
972 const __fp16* r0 = bb2.row<const __fp16>(i / 8 + (i % 8) / 4 + i % 4);
973
974 const __fp16* kptr = kernel0_tm.row<const __fp16>(r);
975
976 float16x8_t _sum0 = vdupq_n_f16((__fp16)0.f);
977
978 for (int q = 0; q < inch; q++)
979 {
980 float16x8_t _r0 = vld1q_f16(r0);
981
982 float16x8_t _k0 = vld1q_f16(kptr);
983
984 _sum0 = vfmaq_f16(_sum0, _r0, _k0);
985
986 kptr += 4;
987 r0 += 4;
988 }
989
990 __fp16 sum0 = vaddvq_f32(vcvt_f32_f16(vadd_f16(vget_low_f16(_sum0), vget_high_f16(_sum0))));
991
992 output0_tm[0] = sum0;
993
994 output0_tm++;
995 }
996 }
997 }
998 }
999 bottom_blob_tm = Mat();
1000 // END dot
1001
1002 // BEGIN transform output
1003 Mat top_blob_bordered;
1004 if (outw == top_blob.w && outh == top_blob.h)
1005 {
1006 top_blob_bordered = top_blob;
1007 }
1008 else
1009 {
1010 top_blob_bordered.create(outw, outh, outch, 2u, 1, opt.workspace_allocator);
1011 }
1012 {
1013 // const float otm[6][8] = {
1014 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f},
1015 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f},
1016 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f},
1017 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f},
1018 // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f},
1019 // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f}
1020 // };
1021
1022 // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32
1023 // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
1024 // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
1025 // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
1026 // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
1027 // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
1028
1029 int w_tm = outw / 6 * 8;
1030 int h_tm = outh / 6 * 8;
1031 const int tiles = w_tm / 8 * h_tm / 8;
1032
1033 #pragma omp parallel for num_threads(opt.num_threads)
1034 for (int p = 0; p < outch; p++)
1035 {
1036 const Mat out0_tm = top_blob_tm.channel(p);
1037 Mat out0 = top_blob_bordered.channel(p);
1038
1039 const __fp16 bias0 = bias ? bias[p] : 0.f;
1040 // float32x2_t _bias0 = vdup_n_f32(bias0);
1041
1042 __fp16 tmp[6][8];
1043
1044 // tile
1045 for (int i = 0; i < outh / 6; i++)
1046 {
1047 for (int j = 0; j < outw / 6; j++)
1048 {
1049 // top_blob_tm.create(tiles, 64, outch, 4u, 1, opt.workspace_allocator);
1050
1051 const __fp16* output0_tm_0 = (const __fp16*)out0_tm + (i * w_tm / 8 + j) * 1;
1052 const __fp16* output0_tm_1 = output0_tm_0 + tiles * 1;
1053 const __fp16* output0_tm_2 = output0_tm_0 + tiles * 2;
1054 const __fp16* output0_tm_3 = output0_tm_0 + tiles * 3;
1055 const __fp16* output0_tm_4 = output0_tm_0 + tiles * 4;
1056 const __fp16* output0_tm_5 = output0_tm_0 + tiles * 5;
1057 const __fp16* output0_tm_6 = output0_tm_0 + tiles * 6;
1058 const __fp16* output0_tm_7 = output0_tm_0 + tiles * 7;
1059
1060 // TODO neon optimize
1061 for (int m = 0; m < 8; m++)
1062 {
1063 __fp16 tmp024a = output0_tm_1[0] + output0_tm_2[0];
1064 __fp16 tmp135a = output0_tm_1[0] - output0_tm_2[0];
1065
1066 __fp16 tmp024b = output0_tm_3[0] + output0_tm_4[0];
1067 __fp16 tmp135b = output0_tm_3[0] - output0_tm_4[0];
1068
1069 __fp16 tmp024c = output0_tm_5[0] + output0_tm_6[0];
1070 __fp16 tmp135c = output0_tm_5[0] - output0_tm_6[0];
1071
1072 tmp[0][m] = output0_tm_0[0] + tmp024a + tmp024b + tmp024c * 32;
1073 tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
1074 tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1075
1076 tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
1077 tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
1078 tmp[5][m] = output0_tm_7[0] + tmp135a + tmp135b * 32 + tmp135c;
1079
1080 output0_tm_0 += tiles * 8;
1081 output0_tm_1 += tiles * 8;
1082 output0_tm_2 += tiles * 8;
1083 output0_tm_3 += tiles * 8;
1084 output0_tm_4 += tiles * 8;
1085 output0_tm_5 += tiles * 8;
1086 output0_tm_6 += tiles * 8;
1087 output0_tm_7 += tiles * 8;
1088 }
1089
1090 __fp16* output0 = out0.row<__fp16>(i * 6) + j * 6;
1091
1092 for (int m = 0; m < 6; m++)
1093 {
1094 const __fp16* tmp0 = tmp[m];
1095
1096 __fp16 tmp024a = tmp0[1] + tmp0[2];
1097 __fp16 tmp135a = tmp0[1] - tmp0[2];
1098
1099 __fp16 tmp024b = tmp0[3] + tmp0[4];
1100 __fp16 tmp135b = tmp0[3] - tmp0[4];
1101
1102 __fp16 tmp024c = tmp0[5] + tmp0[6];
1103 __fp16 tmp135c = tmp0[5] - tmp0[6];
1104
1105 output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
1106 output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8;
1107 output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c;
1108
1109 output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16;
1110 output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4;
1111 output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
1112
1113 output0 += outw;
1114 }
1115 }
1116 }
1117 }
1118 }
1119 // END transform output
1120
1121 // cut result pad
1122 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1123 }
1124