1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv3x3s1_winograd64_transform_kernel_pack4_msa(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch,const Option & opt)15 static void conv3x3s1_winograd64_transform_kernel_pack4_msa(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch, const Option& opt)
16 {
17 // winograd63 transform kernel
18 Mat kernel_tm;
19 kernel_tm.create(8 * 8, inch, outch);
20
21 const float ktm[8][3] = {
22 {1.0f, 0.0f, 0.0f},
23 {-2.0f / 9, -2.0f / 9, -2.0f / 9},
24 {-2.0f / 9, 2.0f / 9, -2.0f / 9},
25 {1.0f / 90, 1.0f / 45, 2.0f / 45},
26 {1.0f / 90, -1.0f / 45, 2.0f / 45},
27 {1.0f / 45, 1.0f / 90, 1.0f / 180},
28 {1.0f / 45, -1.0f / 90, 1.0f / 180},
29 {0.0f, 0.0f, 1.0f}
30 };
31
32 #pragma omp parallel for num_threads(opt.num_threads)
33 for (int p = 0; p < outch; p++)
34 {
35 for (int q = 0; q < inch; q++)
36 {
37 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
38 float* kernel_tm0 = kernel_tm.channel(p).row(q);
39
40 // transform kernel, transposed
41 const float* k0 = kernel0;
42 const float* k1 = kernel0 + 3;
43 const float* k2 = kernel0 + 6;
44
45 // h
46 float tmp[8][3];
47 for (int i = 0; i < 8; i++)
48 {
49 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
50 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
51 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
52 }
53
54 // v
55 for (int j = 0; j < 8; j++)
56 {
57 float* tmpp = &tmp[j][0];
58
59 for (int i = 0; i < 8; i++)
60 {
61 kernel_tm0[j * 8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
62 }
63 }
64 }
65 }
66
67 // interleave
68 // src = 64-inch-outch
69 // dst = pb-pa-inch/pa-64-outch/pb
70 kernel_tm_pack4.create(inch / 4, 64, outch / 4, (size_t)4u * 4 * 4, 4 * 4);
71
72 for (int q = 0; q + (4 - 1) < outch; q += 4)
73 {
74 Mat g0 = kernel_tm_pack4.channel(q / 4);
75
76 for (int k = 0; k < 64; k++)
77 {
78 float* g00 = g0.row<float>(k);
79
80 for (int p = 0; p + (4 - 1) < inch; p += 4)
81 {
82 for (int i = 0; i < 4; i++)
83 {
84 for (int j = 0; j < 4; j++)
85 {
86 const float* k00 = kernel_tm.channel(q + j).row(p + i);
87 g00[0] = (float)k00[k];
88 g00++;
89 }
90 }
91 }
92 }
93 }
94 }
95
conv3x3s1_winograd64_pack4_msa(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)96 static void conv3x3s1_winograd64_pack4_msa(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
97 {
98 int w = bottom_blob.w;
99 int h = bottom_blob.h;
100 int inch = bottom_blob.c;
101 size_t elemsize = bottom_blob.elemsize;
102 int elempack = bottom_blob.elempack;
103
104 int outw = top_blob.w;
105 int outh = top_blob.h;
106 int outch = top_blob.c;
107
108 // pad to 6n+2
109 Mat bottom_blob_bordered = bottom_blob;
110
111 outw = (outw + 5) / 6 * 6;
112 outh = (outh + 5) / 6 * 6;
113
114 w = outw + 2;
115 h = outh + 2;
116 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
117
118 const float* bias = _bias;
119
120 // BEGIN transform input
121 Mat bottom_blob_tm;
122 {
123 int w_tm = outw / 6 * 8;
124 int h_tm = outh / 6 * 8;
125
126 const int tiles = w_tm / 8 * h_tm / 8;
127
128 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
129 bottom_blob_tm.create(tiles, 64, inch, 4u * elempack, elempack, opt.workspace_allocator);
130
131 // const float itm[8][8] = {
132 // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f},
133 //
134 // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f},
135 // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f},
136 //
137 // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f},
138 // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f},
139 //
140 // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f},
141 // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f},
142 //
143 // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f}
144 // };
145
146 // 0 = r00 - r06 + (r04 - r02) * 5.25
147 // 7 = r07 - r01 + (r03 - r05) * 5.25
148
149 // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
150 // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
151
152 // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
153 // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
154
155 // reuse r04 * 1.25
156 // reuse r03 * 2.5
157 // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
158 // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
159
160 #pragma omp parallel for num_threads(opt.num_threads)
161 for (int q = 0; q < inch; q++)
162 {
163 const Mat img0 = bottom_blob_bordered.channel(q);
164 Mat img0_tm = bottom_blob_tm.channel(q);
165
166 float tmp[8][8][4];
167
168 v4f32 _v5_25 = __msa_fill_w_f32(5.25f);
169 v4f32 _vm4_25 = __msa_fill_w_f32(-4.25f);
170 v4f32 _vm1_25 = __msa_fill_w_f32(-1.25f);
171 v4f32 _v0_25 = __msa_fill_w_f32(0.25f);
172 v4f32 _vm2_5 = __msa_fill_w_f32(-2.5f);
173 v4f32 _v0_5 = __msa_fill_w_f32(0.5f);
174 v4f32 _v2 = __msa_fill_w_f32(2.f);
175 v4f32 _v4 = __msa_fill_w_f32(4.f);
176
177 // tile
178 for (int i = 0; i < h_tm / 8; i++)
179 {
180 for (int j = 0; j < w_tm / 8; j++)
181 {
182 const float* r0 = img0.row(i * 6) + (j * 6) * 4;
183
184 for (int m = 0; m < 8; m++)
185 {
186 v4f32 _r00 = (v4f32)__msa_ld_w(r0, 0);
187 v4f32 _r01 = (v4f32)__msa_ld_w(r0 + 4, 0);
188 v4f32 _r02 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
189 v4f32 _r03 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
190 v4f32 _r04 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0);
191 v4f32 _r05 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0);
192 v4f32 _r06 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0);
193 v4f32 _r07 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0);
194
195 v4f32 _tmp0m = __msa_fmadd_w(__msa_fsub_w(_r00, _r06), _v5_25, __msa_fsub_w(_r04, _r02));
196 v4f32 _tmp7m = __msa_fmadd_w(__msa_fsub_w(_r07, _r01), _v5_25, __msa_fsub_w(_r03, _r05));
197 __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0);
198 __msa_st_w((v4i32)_tmp7m, tmp[7][m], 0);
199
200 v4f32 _tmp12a = __msa_fmadd_w(__msa_fadd_w(_r02, _r06), _vm4_25, _r04);
201 v4f32 _tmp12b = __msa_fmadd_w(__msa_fadd_w(_r01, _r05), _vm4_25, _r03);
202
203 v4f32 _tmp1m = __msa_fadd_w(_tmp12a, _tmp12b);
204 v4f32 _tmp2m = __msa_fsub_w(_tmp12a, _tmp12b);
205 __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0);
206 __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0);
207
208 v4f32 _tmp34a = __msa_fmadd_w(__msa_fmadd_w(_r06, _v0_25, _r02), _vm1_25, _r04);
209 v4f32 _tmp34b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_r01, _v0_5), _vm2_5, _r03), _v2, _r05);
210
211 v4f32 _tmp3m = __msa_fadd_w(_tmp34a, _tmp34b);
212 v4f32 _tmp4m = __msa_fsub_w(_tmp34a, _tmp34b);
213 __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0);
214 __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0);
215
216 v4f32 _tmp56a = __msa_fmadd_w(_r06, _v4, __msa_fmadd_w(_r02, _vm1_25, _r04));
217 v4f32 _tmp56b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_r01, _v2), _vm2_5, _r03), _v0_5, _r05);
218
219 v4f32 _tmp5m = __msa_fadd_w(_tmp56a, _tmp56b);
220 v4f32 _tmp6m = __msa_fsub_w(_tmp56a, _tmp56b);
221 __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0);
222 __msa_st_w((v4i32)_tmp6m, tmp[6][m], 0);
223
224 r0 += w * 4;
225 }
226
227 float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 4;
228 float* r0_tm_1 = r0_tm_0 + tiles * 4;
229 float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2;
230 float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3;
231 float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4;
232 float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5;
233 float* r0_tm_6 = r0_tm_0 + tiles * 4 * 6;
234 float* r0_tm_7 = r0_tm_0 + tiles * 4 * 7;
235
236 for (int m = 0; m < 8; m++)
237 {
238 v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0);
239 v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0);
240 v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0);
241 v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0);
242 v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0);
243 v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0);
244 v4f32 _tmp06 = (v4f32)__msa_ld_w(tmp[m][6], 0);
245 v4f32 _tmp07 = (v4f32)__msa_ld_w(tmp[m][7], 0);
246
247 v4f32 _r0tm0 = __msa_fmadd_w(__msa_fsub_w(_tmp00, _tmp06), _v5_25, __msa_fsub_w(_tmp04, _tmp02));
248 v4f32 _r0tm7 = __msa_fmadd_w(__msa_fsub_w(_tmp07, _tmp01), _v5_25, __msa_fsub_w(_tmp03, _tmp05));
249
250 v4f32 _tmp12a = __msa_fmadd_w(__msa_fadd_w(_tmp02, _tmp06), _vm4_25, _tmp04);
251 v4f32 _tmp12b = __msa_fmadd_w(__msa_fadd_w(_tmp01, _tmp05), _vm4_25, _tmp03);
252
253 v4f32 _r0tm1 = __msa_fadd_w(_tmp12a, _tmp12b);
254 v4f32 _r0tm2 = __msa_fsub_w(_tmp12a, _tmp12b);
255
256 v4f32 _tmp34a = __msa_fmadd_w(__msa_fmadd_w(_tmp06, _v0_25, _tmp02), _vm1_25, _tmp04);
257 v4f32 _tmp34b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_tmp01, _v0_5), _vm2_5, _tmp03), _v2, _tmp05);
258
259 v4f32 _r0tm3 = __msa_fadd_w(_tmp34a, _tmp34b);
260 v4f32 _r0tm4 = __msa_fsub_w(_tmp34a, _tmp34b);
261
262 v4f32 _tmp56a = __msa_fmadd_w(_tmp06, _v4, __msa_fmadd_w(_tmp02, _vm1_25, _tmp04));
263 v4f32 _tmp56b = __msa_fmadd_w(__msa_fmadd_w(__msa_fmul_w(_tmp01, _v2), _vm2_5, _tmp03), _v0_5, _tmp05);
264
265 v4f32 _r0tm5 = __msa_fadd_w(_tmp56a, _tmp56b);
266 v4f32 _r0tm6 = __msa_fsub_w(_tmp56a, _tmp56b);
267
268 __msa_st_w((v4i32)_r0tm0, r0_tm_0, 0);
269 __msa_st_w((v4i32)_r0tm1, r0_tm_1, 0);
270 __msa_st_w((v4i32)_r0tm2, r0_tm_2, 0);
271 __msa_st_w((v4i32)_r0tm3, r0_tm_3, 0);
272 __msa_st_w((v4i32)_r0tm4, r0_tm_4, 0);
273 __msa_st_w((v4i32)_r0tm5, r0_tm_5, 0);
274 __msa_st_w((v4i32)_r0tm6, r0_tm_6, 0);
275 __msa_st_w((v4i32)_r0tm7, r0_tm_7, 0);
276
277 r0_tm_0 += tiles * 4 * 8;
278 r0_tm_1 += tiles * 4 * 8;
279 r0_tm_2 += tiles * 4 * 8;
280 r0_tm_3 += tiles * 4 * 8;
281 r0_tm_4 += tiles * 4 * 8;
282 r0_tm_5 += tiles * 4 * 8;
283 r0_tm_6 += tiles * 4 * 8;
284 r0_tm_7 += tiles * 4 * 8;
285 }
286 }
287 }
288 }
289 }
290 bottom_blob_bordered = Mat();
291 // END transform input
292
293 // BEGIN dot
294 Mat top_blob_tm;
295 {
296 int w_tm = outw / 6 * 8;
297 int h_tm = outh / 6 * 8;
298
299 const int tiles = h_tm / 8 * w_tm / 8;
300
301 // permute
302 // bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
303 Mat bottom_blob_tm2;
304 if (tiles >= 12)
305 bottom_blob_tm2.create(12 * inch, tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2, 64, 4u * elempack, elempack, opt.workspace_allocator);
306 else if (tiles >= 8)
307 bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + (tiles % 4) / 2 + tiles % 2, 64, 4u * elempack, elempack, opt.workspace_allocator);
308 else if (tiles >= 4)
309 bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 64, 4u * elempack, elempack, opt.workspace_allocator);
310 else if (tiles >= 2)
311 bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 64, 4u * elempack, elempack, opt.workspace_allocator);
312 else // if (tiles >= 1)
313 bottom_blob_tm2.create(1 * inch, tiles, 64, 4u * elempack, elempack, opt.workspace_allocator);
314
315 #pragma omp parallel for num_threads(opt.num_threads)
316 for (int r = 0; r < 64; r++)
317 {
318 Mat tm2 = bottom_blob_tm2.channel(r);
319
320 // tile
321 int i = 0;
322 for (; i + 11 < tiles; i += 12)
323 {
324 float* tmpptr = tm2.row(i / 12);
325
326 const float* r0 = bottom_blob_tm;
327
328 r0 += (r * tiles + i) * 4;
329
330 for (int q = 0; q < inch; q++)
331 {
332 // transpose 4x8
333 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
334 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
335 v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
336 v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
337 v4f32 _r4 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0);
338 v4f32 _r5 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0);
339 v4f32 _r6 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0);
340 v4f32 _r7 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0);
341 v4f32 _r8 = (v4f32)__msa_ld_w(r0 + 4 * 8, 0);
342 v4f32 _r9 = (v4f32)__msa_ld_w(r0 + 4 * 9, 0);
343 v4f32 _ra = (v4f32)__msa_ld_w(r0 + 4 * 10, 0);
344 v4f32 _rb = (v4f32)__msa_ld_w(r0 + 4 * 11, 0);
345
346 v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
347 v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
348 v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
349 v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
350 v4i32 _r45r = __msa_ilvr_w((v4i32)_r5, (v4i32)_r4);
351 v4i32 _r45l = __msa_ilvl_w((v4i32)_r5, (v4i32)_r4);
352 v4i32 _r67r = __msa_ilvr_w((v4i32)_r7, (v4i32)_r6);
353 v4i32 _r67l = __msa_ilvl_w((v4i32)_r7, (v4i32)_r6);
354 v4i32 _r89r = __msa_ilvr_w((v4i32)_r9, (v4i32)_r8);
355 v4i32 _r89l = __msa_ilvl_w((v4i32)_r9, (v4i32)_r8);
356 v4i32 _rabr = __msa_ilvr_w((v4i32)_rb, (v4i32)_ra);
357 v4i32 _rabl = __msa_ilvl_w((v4i32)_rb, (v4i32)_ra);
358 v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
359 v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
360 v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
361 v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
362 v2i64 _r4567_0 = __msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r);
363 v2i64 _r4567_1 = __msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r);
364 v2i64 _r4567_2 = __msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l);
365 v2i64 _r4567_3 = __msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l);
366 v2i64 _r89ab_0 = __msa_ilvr_d((v2i64)_rabr, (v2i64)_r89r);
367 v2i64 _r89ab_1 = __msa_ilvl_d((v2i64)_rabr, (v2i64)_r89r);
368 v2i64 _r89ab_2 = __msa_ilvr_d((v2i64)_rabl, (v2i64)_r89l);
369 v2i64 _r89ab_3 = __msa_ilvl_d((v2i64)_rabl, (v2i64)_r89l);
370
371 __msa_st_w((v4i32)_r0123_0, tmpptr, 0);
372 __msa_st_w((v4i32)_r4567_0, tmpptr + 4, 0);
373 __msa_st_w((v4i32)_r89ab_0, tmpptr + 4 * 2, 0);
374 __msa_st_w((v4i32)_r0123_1, tmpptr + 4 * 3, 0);
375 __msa_st_w((v4i32)_r4567_1, tmpptr + 4 * 4, 0);
376 __msa_st_w((v4i32)_r89ab_1, tmpptr + 4 * 5, 0);
377 __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 6, 0);
378 __msa_st_w((v4i32)_r4567_2, tmpptr + 4 * 7, 0);
379 __msa_st_w((v4i32)_r89ab_2, tmpptr + 4 * 8, 0);
380 __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 9, 0);
381 __msa_st_w((v4i32)_r4567_3, tmpptr + 4 * 10, 0);
382 __msa_st_w((v4i32)_r89ab_3, tmpptr + 4 * 11, 0);
383
384 r0 += bottom_blob_tm.cstep * 4;
385 tmpptr += 48;
386 }
387 }
388 for (; i + 7 < tiles; i += 8)
389 {
390 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8);
391
392 const float* r0 = bottom_blob_tm;
393
394 r0 += (r * tiles + i) * 4;
395
396 for (int q = 0; q < inch; q++)
397 {
398 // transpose 4x8
399 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
400 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
401 v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
402 v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
403 v4f32 _r4 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0);
404 v4f32 _r5 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0);
405 v4f32 _r6 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0);
406 v4f32 _r7 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0);
407
408 v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
409 v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
410 v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
411 v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
412 v4i32 _r45r = __msa_ilvr_w((v4i32)_r5, (v4i32)_r4);
413 v4i32 _r45l = __msa_ilvl_w((v4i32)_r5, (v4i32)_r4);
414 v4i32 _r67r = __msa_ilvr_w((v4i32)_r7, (v4i32)_r6);
415 v4i32 _r67l = __msa_ilvl_w((v4i32)_r7, (v4i32)_r6);
416 v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
417 v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
418 v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
419 v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
420 v2i64 _r4567_0 = __msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r);
421 v2i64 _r4567_1 = __msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r);
422 v2i64 _r4567_2 = __msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l);
423 v2i64 _r4567_3 = __msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l);
424
425 __msa_st_w((v4i32)_r0123_0, tmpptr, 0);
426 __msa_st_w((v4i32)_r4567_0, tmpptr + 4, 0);
427 __msa_st_w((v4i32)_r0123_1, tmpptr + 4 * 2, 0);
428 __msa_st_w((v4i32)_r4567_1, tmpptr + 4 * 3, 0);
429 __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 4, 0);
430 __msa_st_w((v4i32)_r4567_2, tmpptr + 4 * 5, 0);
431 __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 6, 0);
432 __msa_st_w((v4i32)_r4567_3, tmpptr + 4 * 7, 0);
433
434 r0 += bottom_blob_tm.cstep * 4;
435 tmpptr += 32;
436 }
437 }
438 for (; i + 3 < tiles; i += 4)
439 {
440 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
441
442 const float* r0 = bottom_blob_tm;
443
444 r0 += (r * tiles + i) * 4;
445
446 for (int q = 0; q < inch; q++)
447 {
448 // transpose 4x4
449 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
450 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
451 v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
452 v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
453
454 v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
455 v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
456 v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
457 v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
458 v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
459 v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
460 v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
461 v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
462
463 __msa_st_w((v4i32)_r0123_0, tmpptr, 0);
464 __msa_st_w((v4i32)_r0123_1, tmpptr + 4, 0);
465 __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 2, 0);
466 __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 3, 0);
467
468 r0 += bottom_blob_tm.cstep * 4;
469 tmpptr += 16;
470 }
471 }
472 for (; i + 1 < tiles; i += 2)
473 {
474 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
475
476 const float* r0 = bottom_blob_tm;
477
478 r0 += (r * tiles + i) * 4;
479
480 for (int q = 0; q < inch; q++)
481 {
482 // transpose 4x2
483 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
484 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
485
486 v4i32 _r01_0 = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
487 v4i32 _r01_1 = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
488
489 __msa_st_w((v4i32)_r01_0, tmpptr, 0);
490 __msa_st_w((v4i32)_r01_1, tmpptr + 4, 0);
491
492 r0 += bottom_blob_tm.cstep * 4;
493 tmpptr += 8;
494 }
495 }
496 for (; i < tiles; i++)
497 {
498 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
499
500 const float* r0 = bottom_blob_tm;
501
502 r0 += (r * tiles + i) * 4;
503
504 for (int q = 0; q < inch; q++)
505 {
506 v4f32 _val = (v4f32)__msa_ld_w(r0, 0);
507 __msa_st_w((v4i32)_val, tmpptr, 0);
508
509 r0 += bottom_blob_tm.cstep * 4;
510 tmpptr += 4;
511 }
512 }
513 }
514
515 bottom_blob_tm = Mat();
516 // permute end
517
518 top_blob_tm.create(tiles, 64, outch, 4u * elempack, elempack, opt.workspace_allocator);
519
520 #pragma omp parallel for num_threads(opt.num_threads)
521 for (int p = 0; p < outch; p++)
522 {
523 float* output0_tm = top_blob_tm.channel(p);
524
525 const Mat kernel0_tm = kernel_tm.channel(p);
526
527 for (int r = 0; r < 64; r++)
528 {
529 const Mat bb2 = bottom_blob_tm2.channel(r);
530
531 int i = 0;
532 for (; i + 11 < tiles; i += 12)
533 {
534 const float* r0 = bb2.row(i / 12);
535 const float* k0 = kernel0_tm.row(r);
536
537 int nn = inch * 4; // inch always > 0
538
539 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
540 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
541 v4f32 _sum2 = (v4f32)__msa_fill_w(0);
542 v4f32 _sum3 = (v4f32)__msa_fill_w(0);
543 v4f32 _sum4 = (v4f32)__msa_fill_w(0);
544 v4f32 _sum5 = (v4f32)__msa_fill_w(0);
545 v4f32 _sum6 = (v4f32)__msa_fill_w(0);
546 v4f32 _sum7 = (v4f32)__msa_fill_w(0);
547 v4f32 _sum8 = (v4f32)__msa_fill_w(0);
548 v4f32 _sum9 = (v4f32)__msa_fill_w(0);
549 v4f32 _suma = (v4f32)__msa_fill_w(0);
550 v4f32 _sumb = (v4f32)__msa_fill_w(0);
551
552 for (int j = 0; j < nn; j++)
553 {
554 __builtin_prefetch(r0 + 96);
555 __builtin_prefetch(k0 + 32);
556 v4i32 _val0123 = __msa_ld_w(r0, 0);
557 v4i32 _val4567 = __msa_ld_w(r0 + 4, 0);
558 v4i32 _val89ab = __msa_ld_w(r0 + 8, 0);
559 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
560 _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0);
561 _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0);
562 _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0);
563 _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0);
564 _sum4 = __msa_fmadd_w(_sum4, (v4f32)__msa_splati_w(_val4567, 0), _w0);
565 _sum5 = __msa_fmadd_w(_sum5, (v4f32)__msa_splati_w(_val4567, 1), _w0);
566 _sum6 = __msa_fmadd_w(_sum6, (v4f32)__msa_splati_w(_val4567, 2), _w0);
567 _sum7 = __msa_fmadd_w(_sum7, (v4f32)__msa_splati_w(_val4567, 3), _w0);
568 _sum8 = __msa_fmadd_w(_sum8, (v4f32)__msa_splati_w(_val89ab, 0), _w0);
569 _sum9 = __msa_fmadd_w(_sum9, (v4f32)__msa_splati_w(_val89ab, 1), _w0);
570 _suma = __msa_fmadd_w(_suma, (v4f32)__msa_splati_w(_val89ab, 2), _w0);
571 _sumb = __msa_fmadd_w(_sumb, (v4f32)__msa_splati_w(_val89ab, 3), _w0);
572
573 r0 += 12;
574 k0 += 4;
575 }
576
577 __msa_st_w((v4i32)_sum0, output0_tm, 0);
578 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
579 __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0);
580 __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0);
581 __msa_st_w((v4i32)_sum4, output0_tm + 4 * 4, 0);
582 __msa_st_w((v4i32)_sum5, output0_tm + 4 * 5, 0);
583 __msa_st_w((v4i32)_sum6, output0_tm + 4 * 6, 0);
584 __msa_st_w((v4i32)_sum7, output0_tm + 4 * 7, 0);
585 __msa_st_w((v4i32)_sum8, output0_tm + 4 * 8, 0);
586 __msa_st_w((v4i32)_sum9, output0_tm + 4 * 9, 0);
587 __msa_st_w((v4i32)_suma, output0_tm + 4 * 10, 0);
588 __msa_st_w((v4i32)_sumb, output0_tm + 4 * 11, 0);
589
590 output0_tm += 4 * 12;
591 }
592 for (; i + 7 < tiles; i += 8)
593 {
594 const float* r0 = bb2.row(i / 12 + (i % 12) / 8);
595 const float* k0 = kernel0_tm.row(r);
596
597 int nn = inch * 4; // inch always > 0
598
599 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
600 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
601 v4f32 _sum2 = (v4f32)__msa_fill_w(0);
602 v4f32 _sum3 = (v4f32)__msa_fill_w(0);
603 v4f32 _sum4 = (v4f32)__msa_fill_w(0);
604 v4f32 _sum5 = (v4f32)__msa_fill_w(0);
605 v4f32 _sum6 = (v4f32)__msa_fill_w(0);
606 v4f32 _sum7 = (v4f32)__msa_fill_w(0);
607
608 for (int j = 0; j < nn; j++)
609 {
610 __builtin_prefetch(r0 + 64);
611 __builtin_prefetch(k0 + 32);
612 v4i32 _val0123 = __msa_ld_w(r0, 0);
613 v4i32 _val4567 = __msa_ld_w(r0 + 4, 0);
614 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
615 _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0);
616 _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0);
617 _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0);
618 _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0);
619 _sum4 = __msa_fmadd_w(_sum4, (v4f32)__msa_splati_w(_val4567, 0), _w0);
620 _sum5 = __msa_fmadd_w(_sum5, (v4f32)__msa_splati_w(_val4567, 1), _w0);
621 _sum6 = __msa_fmadd_w(_sum6, (v4f32)__msa_splati_w(_val4567, 2), _w0);
622 _sum7 = __msa_fmadd_w(_sum7, (v4f32)__msa_splati_w(_val4567, 3), _w0);
623
624 r0 += 8;
625 k0 += 4;
626 }
627
628 __msa_st_w((v4i32)_sum0, output0_tm, 0);
629 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
630 __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0);
631 __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0);
632 __msa_st_w((v4i32)_sum4, output0_tm + 4 * 4, 0);
633 __msa_st_w((v4i32)_sum5, output0_tm + 4 * 5, 0);
634 __msa_st_w((v4i32)_sum6, output0_tm + 4 * 6, 0);
635 __msa_st_w((v4i32)_sum7, output0_tm + 4 * 7, 0);
636
637 output0_tm += 4 * 8;
638 }
639 for (; i + 3 < tiles; i += 4)
640 {
641 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
642 const float* k0 = kernel0_tm.row(r);
643
644 int nn = inch * 4; // inch always > 0
645
646 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
647 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
648 v4f32 _sum2 = (v4f32)__msa_fill_w(0);
649 v4f32 _sum3 = (v4f32)__msa_fill_w(0);
650
651 for (int j = 0; j < nn; j++)
652 {
653 __builtin_prefetch(r0 + 32);
654 __builtin_prefetch(k0 + 32);
655 v4i32 _val0123 = __msa_ld_w(r0, 0);
656 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
657 _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0);
658 _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0);
659 _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0);
660 _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0);
661
662 r0 += 4;
663 k0 += 4;
664 }
665
666 __msa_st_w((v4i32)_sum0, output0_tm, 0);
667 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
668 __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0);
669 __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0);
670
671 output0_tm += 4 * 4;
672 }
673 for (; i + 1 < tiles; i += 2)
674 {
675 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
676 const float* k0 = kernel0_tm.row(r);
677
678 int nn = inch * 4; // inch always > 0
679
680 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
681 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
682
683 for (int j = 0; j < nn; j++)
684 {
685 __builtin_prefetch(r0 + 16);
686 __builtin_prefetch(k0 + 32);
687 v4f32 _val0 = __msa_fill_w_f32(*r0++);
688 v4f32 _val1 = __msa_fill_w_f32(*r0++);
689 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
690 _sum0 = __msa_fmadd_w(_sum0, _val0, _w0);
691 _sum1 = __msa_fmadd_w(_sum1, _val1, _w0);
692
693 k0 += 4;
694 }
695
696 __msa_st_w((v4i32)_sum0, output0_tm, 0);
697 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
698
699 output0_tm += 4 * 2;
700 }
701 for (; i < tiles; i++)
702 {
703 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
704 const float* k0 = kernel0_tm.row(r);
705
706 int nn = inch * 4; // inch always > 0
707
708 v4f32 _sum = (v4f32)__msa_fill_w(0);
709
710 for (int j = 0; j < nn; j++)
711 {
712 __builtin_prefetch(r0 + 8);
713 __builtin_prefetch(k0 + 32);
714 v4f32 _val0 = __msa_fill_w_f32(*r0++);
715 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
716 _sum = __msa_fmadd_w(_sum, _val0, _w0);
717
718 k0 += 4;
719 }
720
721 __msa_st_w((v4i32)_sum, output0_tm, 0);
722
723 output0_tm += 4;
724 }
725 }
726 }
727 }
728 bottom_blob_tm = Mat();
729 // END dot
730
731 // BEGIN transform output
732 Mat top_blob_bordered;
733 if (outw == top_blob.w && outh == top_blob.h)
734 {
735 top_blob_bordered = top_blob;
736 }
737 else
738 {
739 top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator);
740 }
741 {
742 // const float otm[6][8] = {
743 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f},
744 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f},
745 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f},
746 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f},
747 // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f},
748 // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f}
749 // };
750
751 // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32
752 // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
753 // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
754 // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
755 // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
756 // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
757
758 int w_tm = outw / 6 * 8;
759 int h_tm = outh / 6 * 8;
760 const int tiles = w_tm / 8 * h_tm / 8;
761
762 #pragma omp parallel for num_threads(opt.num_threads)
763 for (int p = 0; p < outch; p++)
764 {
765 const Mat out0_tm = top_blob_tm.channel(p);
766 Mat out0 = top_blob_bordered.channel(p);
767
768 // const float bias0 = bias ? bias[p] : 0.f;
769 v4f32 _bias0 = bias ? (v4f32)__msa_ld_w((const float*)bias + p * 4, 0) : (v4f32)__msa_fill_w(0);
770
771 float tmp[6][8][4];
772
773 v4f32 _v32 = __msa_fill_w_f32(32.f);
774 v4f32 _v16 = __msa_fill_w_f32(16.f);
775 v4f32 _v8 = __msa_fill_w_f32(8.f);
776 v4f32 _v4 = __msa_fill_w_f32(4.f);
777 v4f32 _v2 = __msa_fill_w_f32(2.f);
778
779 // tile
780 for (int i = 0; i < outh / 6; i++)
781 {
782 for (int j = 0; j < outw / 6; j++)
783 {
784 // top_blob_tm.create(tiles, 64, outch, elemsize, elempack);
785
786 const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 4;
787 const float* output0_tm_1 = output0_tm_0 + tiles * 4;
788 const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2;
789 const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3;
790 const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4;
791 const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5;
792 const float* output0_tm_6 = output0_tm_0 + tiles * 4 * 6;
793 const float* output0_tm_7 = output0_tm_0 + tiles * 4 * 7;
794
795 float* output0 = out0.row<float>(i * 6) + (j * 6) * 4;
796
797 // TODO msa optimize
798 for (int m = 0; m < 8; m++)
799 {
800 v4f32 _out0tm0 = (v4f32)__msa_ld_w(output0_tm_0, 0);
801 v4f32 _out0tm1 = (v4f32)__msa_ld_w(output0_tm_1, 0);
802 v4f32 _out0tm2 = (v4f32)__msa_ld_w(output0_tm_2, 0);
803 v4f32 _out0tm3 = (v4f32)__msa_ld_w(output0_tm_3, 0);
804 v4f32 _out0tm4 = (v4f32)__msa_ld_w(output0_tm_4, 0);
805 v4f32 _out0tm5 = (v4f32)__msa_ld_w(output0_tm_5, 0);
806 v4f32 _out0tm6 = (v4f32)__msa_ld_w(output0_tm_6, 0);
807 v4f32 _out0tm7 = (v4f32)__msa_ld_w(output0_tm_7, 0);
808
809 v4f32 _tmp024a = __msa_fadd_w(_out0tm1, _out0tm2);
810 v4f32 _tmp135a = __msa_fsub_w(_out0tm1, _out0tm2);
811
812 v4f32 _tmp024b = __msa_fadd_w(_out0tm3, _out0tm4);
813 v4f32 _tmp135b = __msa_fsub_w(_out0tm3, _out0tm4);
814
815 v4f32 _tmp024c = __msa_fadd_w(_out0tm5, _out0tm6);
816 v4f32 _tmp135c = __msa_fsub_w(_out0tm5, _out0tm6);
817
818 v4f32 _tmp0m = __msa_fadd_w(__msa_fadd_w(_out0tm0, _tmp024a), __msa_fmadd_w(_tmp024b, _v32, _tmp024c));
819 v4f32 _tmp2m = __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v4, _tmp024b), _v8, _tmp024c);
820 v4f32 _tmp4m = __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v16, _tmp024b), _v2, _tmp024c);
821 __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0);
822 __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0);
823 __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0);
824
825 v4f32 _tmp1m = __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v2, _tmp135b), _v16, _tmp135c);
826 v4f32 _tmp3m = __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v8, _tmp135b), _v4, _tmp135c);
827 v4f32 _tmp5m = __msa_fadd_w(__msa_fadd_w(_out0tm7, _tmp135a), __msa_fmadd_w(_tmp135c, _v32, _tmp135b));
828 __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0);
829 __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0);
830 __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0);
831
832 output0_tm_0 += tiles * 4 * 8;
833 output0_tm_1 += tiles * 4 * 8;
834 output0_tm_2 += tiles * 4 * 8;
835 output0_tm_3 += tiles * 4 * 8;
836 output0_tm_4 += tiles * 4 * 8;
837 output0_tm_5 += tiles * 4 * 8;
838 output0_tm_6 += tiles * 4 * 8;
839 output0_tm_7 += tiles * 4 * 8;
840 }
841
842 for (int m = 0; m < 6; m++)
843 {
844 v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0);
845 v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0);
846 v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0);
847 v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0);
848 v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0);
849 v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0);
850 v4f32 _tmp06 = (v4f32)__msa_ld_w(tmp[m][6], 0);
851 v4f32 _tmp07 = (v4f32)__msa_ld_w(tmp[m][7], 0);
852
853 v4f32 _tmp024a = __msa_fadd_w(_tmp01, _tmp02);
854 v4f32 _tmp135a = __msa_fsub_w(_tmp01, _tmp02);
855
856 v4f32 _tmp024b = __msa_fadd_w(_tmp03, _tmp04);
857 v4f32 _tmp135b = __msa_fsub_w(_tmp03, _tmp04);
858
859 v4f32 _tmp024c = __msa_fadd_w(_tmp05, _tmp06);
860 v4f32 _tmp135c = __msa_fsub_w(_tmp05, _tmp06);
861
862 v4f32 _out00 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp00, _tmp024a), __msa_fmadd_w(_tmp024b, _v32, _tmp024c)));
863 v4f32 _out02 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v4, _tmp024b), _v8, _tmp024c));
864 v4f32 _out04 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp024a, _v16, _tmp024b), _v2, _tmp024c));
865 __msa_st_w((v4i32)_out00, output0, 0);
866 __msa_st_w((v4i32)_out02, output0 + 4 * 2, 0);
867 __msa_st_w((v4i32)_out04, output0 + 4 * 4, 0);
868
869 v4f32 _out01 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v2, _tmp135b), _v16, _tmp135c));
870 v4f32 _out03 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fmadd_w(_tmp135a, _v8, _tmp135b), _v4, _tmp135c));
871 v4f32 _out05 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp07, _tmp135a), __msa_fmadd_w(_tmp135c, _v32, _tmp135b)));
872 __msa_st_w((v4i32)_out01, output0 + 4, 0);
873 __msa_st_w((v4i32)_out03, output0 + 4 * 3, 0);
874 __msa_st_w((v4i32)_out05, output0 + 4 * 5, 0);
875
876 output0 += outw * 4;
877 }
878 }
879 }
880 }
881 }
882 // END transform output
883
884 // cut result pad
885 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
886 }
887
conv3x3s1_winograd42_transform_kernel_pack4_msa(const Mat & kernel,Mat & kernel_tm_pack4,int inch,int outch,const Option & opt)888 static void conv3x3s1_winograd42_transform_kernel_pack4_msa(const Mat& kernel, Mat& kernel_tm_pack4, int inch, int outch, const Option& opt)
889 {
890 // winograd42 transform kernel
891 Mat kernel_tm(6 * 6, inch, outch);
892
893 const float ktm[6][3] = {
894 {1.0f / 4, 0.0f, 0.0f},
895 {-1.0f / 6, -1.0f / 6, -1.0f / 6},
896 {-1.0f / 6, 1.0f / 6, -1.0f / 6},
897 {1.0f / 24, 1.0f / 12, 1.0f / 6},
898 {1.0f / 24, -1.0f / 12, 1.0f / 6},
899 {0.0f, 0.0f, 1.0f}
900 };
901
902 #pragma omp parallel for num_threads(opt.num_threads)
903 for (int p = 0; p < outch; p++)
904 {
905 for (int q = 0; q < inch; q++)
906 {
907 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
908 float* kernel_tm0 = kernel_tm.channel(p).row(q);
909
910 // transform kernel
911 const float* k0 = kernel0;
912 const float* k1 = kernel0 + 3;
913 const float* k2 = kernel0 + 6;
914
915 // h
916 float tmp[6][3];
917 for (int i = 0; i < 6; i++)
918 {
919 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
920 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
921 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
922 }
923
924 // U
925 for (int j = 0; j < 6; j++)
926 {
927 float* tmpp = &tmp[j][0];
928
929 for (int i = 0; i < 6; i++)
930 {
931 kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
932 }
933 }
934 }
935 }
936
937 // interleave
938 // src = 36-inch-outch
939 // dst = pb-pa-inch/pa-36-outch/pb
940 kernel_tm_pack4.create(inch / 4, 36, outch / 4, (size_t)4u * 4 * 4, 4 * 4);
941
942 for (int q = 0; q + (4 - 1) < outch; q += 4)
943 {
944 Mat g0 = kernel_tm_pack4.channel(q / 4);
945
946 for (int k = 0; k < 36; k++)
947 {
948 float* g00 = g0.row<float>(k);
949
950 for (int p = 0; p + (4 - 1) < inch; p += 4)
951 {
952 for (int i = 0; i < 4; i++)
953 {
954 for (int j = 0; j < 4; j++)
955 {
956 const float* k00 = kernel_tm.channel(q + j).row(p + i);
957 g00[0] = (float)k00[k];
958 g00++;
959 }
960 }
961 }
962 }
963 }
964 }
965
conv3x3s1_winograd42_pack4_msa(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)966 static void conv3x3s1_winograd42_pack4_msa(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
967 {
968 int w = bottom_blob.w;
969 int h = bottom_blob.h;
970 int inch = bottom_blob.c;
971 size_t elemsize = bottom_blob.elemsize;
972 int elempack = bottom_blob.elempack;
973
974 int outw = top_blob.w;
975 int outh = top_blob.h;
976 int outch = top_blob.c;
977
978 // pad to 4n+2
979 Mat bottom_blob_bordered = bottom_blob;
980
981 outw = (outw + 3) / 4 * 4;
982 outh = (outh + 3) / 4 * 4;
983
984 w = outw + 2;
985 h = outh + 2;
986 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
987
988 const float* bias = _bias;
989
990 // BEGIN transform input
991 Mat bottom_blob_tm;
992 {
993 int w_tm = outw / 4 * 6;
994 int h_tm = outh / 4 * 6;
995
996 const int tiles = w_tm / 6 * h_tm / 6;
997
998 bottom_blob_tm.create(tiles, 36, inch, 4u * elempack, elempack, opt.workspace_allocator);
999
1000 // const float itm[4][4] = {
1001 // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
1002 // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
1003 // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
1004 // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
1005 // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
1006 // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f}
1007 // };
1008
1009 // 0 = 4 * r00 - 5 * r02 + r04
1010 // 1 = -4 * (r01 + r02) + r04 + r03
1011 // 2 = 4 * (r01 - r02) + r04 - r03
1012 // 3 = -2 * (r01 - r03) + r04 - r02
1013 // 4 = 2 * (r01 - r03) + r04 - r02
1014 // 5 = 4 * r01 - 5 * r03 + r05
1015
1016 #pragma omp parallel for num_threads(opt.num_threads)
1017 for (int q = 0; q < inch; q++)
1018 {
1019 const Mat img0 = bottom_blob_bordered.channel(q);
1020 Mat img0_tm = bottom_blob_tm.channel(q);
1021
1022 float tmp[6][6][4];
1023
1024 v4f32 _vm5 = __msa_fill_w_f32(-5.f);
1025 v4f32 _vm4 = __msa_fill_w_f32(-4.f);
1026 v4f32 _v4 = __msa_fill_w_f32(4.f);
1027 v4f32 _vm2 = __msa_fill_w_f32(-2.f);
1028 v4f32 _v2 = __msa_fill_w_f32(2.f);
1029
1030 // tile
1031 for (int i = 0; i < h_tm / 6; i++)
1032 {
1033 for (int j = 0; j < w_tm / 6; j++)
1034 {
1035 const float* r0 = img0.row(i * 4) + (j * 4) * 4;
1036
1037 for (int m = 0; m < 6; m++)
1038 {
1039 v4f32 _r00 = (v4f32)__msa_ld_w(r0, 0);
1040 v4f32 _r01 = (v4f32)__msa_ld_w(r0 + 4, 0);
1041 v4f32 _r02 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
1042 v4f32 _r03 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
1043 v4f32 _r04 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0);
1044 v4f32 _r05 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0);
1045
1046 v4f32 _tmp0m = __msa_fmadd_w(__msa_fmadd_w(_r04, _v4, _r00), _vm5, _r02);
1047 v4f32 _tmp1m = __msa_fmadd_w(__msa_fadd_w(_r04, _r03), _vm4, __msa_fadd_w(_r01, _r02));
1048 v4f32 _tmp2m = __msa_fmadd_w(__msa_fsub_w(_r04, _r03), _v4, __msa_fsub_w(_r01, _r02));
1049 v4f32 _tmp3m = __msa_fmadd_w(__msa_fsub_w(_r04, _r02), _vm2, __msa_fsub_w(_r01, _r03));
1050 v4f32 _tmp4m = __msa_fmadd_w(__msa_fsub_w(_r04, _r02), _v2, __msa_fsub_w(_r01, _r03));
1051 v4f32 _tmp5m = __msa_fmadd_w(__msa_fmadd_w(_r05, _v4, _r01), _vm5, _r03);
1052
1053 __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0);
1054 __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0);
1055 __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0);
1056 __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0);
1057 __msa_st_w((v4i32)_tmp4m, tmp[4][m], 0);
1058 __msa_st_w((v4i32)_tmp5m, tmp[5][m], 0);
1059
1060 r0 += w * 4;
1061 }
1062
1063 float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 6 + j) * 4;
1064 float* r0_tm_1 = r0_tm_0 + tiles * 4;
1065 float* r0_tm_2 = r0_tm_0 + tiles * 4 * 2;
1066 float* r0_tm_3 = r0_tm_0 + tiles * 4 * 3;
1067 float* r0_tm_4 = r0_tm_0 + tiles * 4 * 4;
1068 float* r0_tm_5 = r0_tm_0 + tiles * 4 * 5;
1069
1070 for (int m = 0; m < 6; m++)
1071 {
1072 v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0);
1073 v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0);
1074 v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0);
1075 v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0);
1076 v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0);
1077 v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0);
1078
1079 v4f32 _r0tm0 = __msa_fmadd_w(__msa_fmadd_w(_tmp04, _v4, _tmp00), _vm5, _tmp02);
1080 v4f32 _r0tm1 = __msa_fmadd_w(__msa_fadd_w(_tmp04, _tmp03), _vm4, __msa_fadd_w(_tmp01, _tmp02));
1081 v4f32 _r0tm2 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp03), _v4, __msa_fsub_w(_tmp01, _tmp02));
1082 v4f32 _r0tm3 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp02), _vm2, __msa_fsub_w(_tmp01, _tmp03));
1083 v4f32 _r0tm4 = __msa_fmadd_w(__msa_fsub_w(_tmp04, _tmp02), _v2, __msa_fsub_w(_tmp01, _tmp03));
1084 v4f32 _r0tm5 = __msa_fmadd_w(__msa_fmadd_w(_tmp05, _v4, _tmp01), _vm5, _tmp03);
1085
1086 __msa_st_w((v4i32)_r0tm0, r0_tm_0, 0);
1087 __msa_st_w((v4i32)_r0tm1, r0_tm_1, 0);
1088 __msa_st_w((v4i32)_r0tm2, r0_tm_2, 0);
1089 __msa_st_w((v4i32)_r0tm3, r0_tm_3, 0);
1090 __msa_st_w((v4i32)_r0tm4, r0_tm_4, 0);
1091 __msa_st_w((v4i32)_r0tm5, r0_tm_5, 0);
1092
1093 r0_tm_0 += tiles * 4 * 6;
1094 r0_tm_1 += tiles * 4 * 6;
1095 r0_tm_2 += tiles * 4 * 6;
1096 r0_tm_3 += tiles * 4 * 6;
1097 r0_tm_4 += tiles * 4 * 6;
1098 r0_tm_5 += tiles * 4 * 6;
1099 }
1100 }
1101 }
1102 }
1103 }
1104 bottom_blob_bordered = Mat();
1105 // END transform input
1106
1107 // BEGIN dot
1108 Mat top_blob_tm;
1109 {
1110 int w_tm = outw / 4 * 6;
1111 int h_tm = outh / 4 * 6;
1112
1113 const int tiles = h_tm / 6 * w_tm / 6;
1114
1115 // permute
1116 // bottom_blob_tm.create(tiles, 36, inch, elemsize, elempack, opt.workspace_allocator);
1117 Mat bottom_blob_tm2;
1118 if (tiles >= 12)
1119 bottom_blob_tm2.create(12 * inch, tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2, 36, 4u * elempack, elempack, opt.workspace_allocator);
1120 else if (tiles >= 8)
1121 bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + (tiles % 4) / 2 + tiles % 2, 36, 4u * elempack, elempack, opt.workspace_allocator);
1122 else if (tiles >= 4)
1123 bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 36, 4u * elempack, elempack, opt.workspace_allocator);
1124 else if (tiles >= 2)
1125 bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 36, 4u * elempack, elempack, opt.workspace_allocator);
1126 else // if (tiles >= 1)
1127 bottom_blob_tm2.create(1 * inch, tiles, 36, 4u * elempack, elempack, opt.workspace_allocator);
1128
1129 #pragma omp parallel for num_threads(opt.num_threads)
1130 for (int r = 0; r < 36; r++)
1131 {
1132 Mat tm2 = bottom_blob_tm2.channel(r);
1133
1134 // tile
1135 int i = 0;
1136 for (; i + 11 < tiles; i += 12)
1137 {
1138 float* tmpptr = tm2.row(i / 12);
1139
1140 const float* r0 = bottom_blob_tm;
1141
1142 r0 += (r * tiles + i) * 4;
1143
1144 for (int q = 0; q < inch; q++)
1145 {
1146 // transpose 4x8
1147 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
1148 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
1149 v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
1150 v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
1151 v4f32 _r4 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0);
1152 v4f32 _r5 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0);
1153 v4f32 _r6 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0);
1154 v4f32 _r7 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0);
1155 v4f32 _r8 = (v4f32)__msa_ld_w(r0 + 4 * 8, 0);
1156 v4f32 _r9 = (v4f32)__msa_ld_w(r0 + 4 * 9, 0);
1157 v4f32 _ra = (v4f32)__msa_ld_w(r0 + 4 * 10, 0);
1158 v4f32 _rb = (v4f32)__msa_ld_w(r0 + 4 * 11, 0);
1159
1160 v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
1161 v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
1162 v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
1163 v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
1164 v4i32 _r45r = __msa_ilvr_w((v4i32)_r5, (v4i32)_r4);
1165 v4i32 _r45l = __msa_ilvl_w((v4i32)_r5, (v4i32)_r4);
1166 v4i32 _r67r = __msa_ilvr_w((v4i32)_r7, (v4i32)_r6);
1167 v4i32 _r67l = __msa_ilvl_w((v4i32)_r7, (v4i32)_r6);
1168 v4i32 _r89r = __msa_ilvr_w((v4i32)_r9, (v4i32)_r8);
1169 v4i32 _r89l = __msa_ilvl_w((v4i32)_r9, (v4i32)_r8);
1170 v4i32 _rabr = __msa_ilvr_w((v4i32)_rb, (v4i32)_ra);
1171 v4i32 _rabl = __msa_ilvl_w((v4i32)_rb, (v4i32)_ra);
1172 v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
1173 v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
1174 v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
1175 v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
1176 v2i64 _r4567_0 = __msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r);
1177 v2i64 _r4567_1 = __msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r);
1178 v2i64 _r4567_2 = __msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l);
1179 v2i64 _r4567_3 = __msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l);
1180 v2i64 _r89ab_0 = __msa_ilvr_d((v2i64)_rabr, (v2i64)_r89r);
1181 v2i64 _r89ab_1 = __msa_ilvl_d((v2i64)_rabr, (v2i64)_r89r);
1182 v2i64 _r89ab_2 = __msa_ilvr_d((v2i64)_rabl, (v2i64)_r89l);
1183 v2i64 _r89ab_3 = __msa_ilvl_d((v2i64)_rabl, (v2i64)_r89l);
1184
1185 __msa_st_w((v4i32)_r0123_0, tmpptr, 0);
1186 __msa_st_w((v4i32)_r4567_0, tmpptr + 4, 0);
1187 __msa_st_w((v4i32)_r89ab_0, tmpptr + 4 * 2, 0);
1188 __msa_st_w((v4i32)_r0123_1, tmpptr + 4 * 3, 0);
1189 __msa_st_w((v4i32)_r4567_1, tmpptr + 4 * 4, 0);
1190 __msa_st_w((v4i32)_r89ab_1, tmpptr + 4 * 5, 0);
1191 __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 6, 0);
1192 __msa_st_w((v4i32)_r4567_2, tmpptr + 4 * 7, 0);
1193 __msa_st_w((v4i32)_r89ab_2, tmpptr + 4 * 8, 0);
1194 __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 9, 0);
1195 __msa_st_w((v4i32)_r4567_3, tmpptr + 4 * 10, 0);
1196 __msa_st_w((v4i32)_r89ab_3, tmpptr + 4 * 11, 0);
1197
1198 r0 += bottom_blob_tm.cstep * 4;
1199 tmpptr += 48;
1200 }
1201 }
1202 for (; i + 7 < tiles; i += 8)
1203 {
1204 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8);
1205
1206 const float* r0 = bottom_blob_tm;
1207
1208 r0 += (r * tiles + i) * 4;
1209
1210 for (int q = 0; q < inch; q++)
1211 {
1212 // transpose 4x8
1213 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
1214 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
1215 v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
1216 v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
1217 v4f32 _r4 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0);
1218 v4f32 _r5 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0);
1219 v4f32 _r6 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0);
1220 v4f32 _r7 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0);
1221
1222 v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
1223 v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
1224 v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
1225 v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
1226 v4i32 _r45r = __msa_ilvr_w((v4i32)_r5, (v4i32)_r4);
1227 v4i32 _r45l = __msa_ilvl_w((v4i32)_r5, (v4i32)_r4);
1228 v4i32 _r67r = __msa_ilvr_w((v4i32)_r7, (v4i32)_r6);
1229 v4i32 _r67l = __msa_ilvl_w((v4i32)_r7, (v4i32)_r6);
1230 v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
1231 v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
1232 v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
1233 v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
1234 v2i64 _r4567_0 = __msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r);
1235 v2i64 _r4567_1 = __msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r);
1236 v2i64 _r4567_2 = __msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l);
1237 v2i64 _r4567_3 = __msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l);
1238
1239 __msa_st_w((v4i32)_r0123_0, tmpptr, 0);
1240 __msa_st_w((v4i32)_r4567_0, tmpptr + 4, 0);
1241 __msa_st_w((v4i32)_r0123_1, tmpptr + 4 * 2, 0);
1242 __msa_st_w((v4i32)_r4567_1, tmpptr + 4 * 3, 0);
1243 __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 4, 0);
1244 __msa_st_w((v4i32)_r4567_2, tmpptr + 4 * 5, 0);
1245 __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 6, 0);
1246 __msa_st_w((v4i32)_r4567_3, tmpptr + 4 * 7, 0);
1247
1248 r0 += bottom_blob_tm.cstep * 4;
1249 tmpptr += 32;
1250 }
1251 }
1252 for (; i + 3 < tiles; i += 4)
1253 {
1254 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1255
1256 const float* r0 = bottom_blob_tm;
1257
1258 r0 += (r * tiles + i) * 4;
1259
1260 for (int q = 0; q < inch; q++)
1261 {
1262 // transpose 4x4
1263 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
1264 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
1265 v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0);
1266 v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0);
1267
1268 v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
1269 v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
1270 v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2);
1271 v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2);
1272 v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r);
1273 v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r);
1274 v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l);
1275 v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l);
1276
1277 __msa_st_w((v4i32)_r0123_0, tmpptr, 0);
1278 __msa_st_w((v4i32)_r0123_1, tmpptr + 4, 0);
1279 __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 2, 0);
1280 __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 3, 0);
1281
1282 r0 += bottom_blob_tm.cstep * 4;
1283 tmpptr += 16;
1284 }
1285 }
1286 for (; i + 1 < tiles; i += 2)
1287 {
1288 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
1289
1290 const float* r0 = bottom_blob_tm;
1291
1292 r0 += (r * tiles + i) * 4;
1293
1294 for (int q = 0; q < inch; q++)
1295 {
1296 // transpose 4x2
1297 v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0);
1298 v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0);
1299
1300 v4i32 _r01_0 = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0);
1301 v4i32 _r01_1 = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0);
1302
1303 __msa_st_w((v4i32)_r01_0, tmpptr, 0);
1304 __msa_st_w((v4i32)_r01_1, tmpptr + 4, 0);
1305
1306 r0 += bottom_blob_tm.cstep * 4;
1307 tmpptr += 8;
1308 }
1309 }
1310 for (; i < tiles; i++)
1311 {
1312 float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
1313
1314 const float* r0 = bottom_blob_tm;
1315
1316 r0 += (r * tiles + i) * 4;
1317
1318 for (int q = 0; q < inch; q++)
1319 {
1320 v4f32 _val = (v4f32)__msa_ld_w(r0, 0);
1321 __msa_st_w((v4i32)_val, tmpptr, 0);
1322
1323 r0 += bottom_blob_tm.cstep * 4;
1324 tmpptr += 4;
1325 }
1326 }
1327 }
1328
1329 bottom_blob_tm = Mat();
1330 // permute end
1331
1332 top_blob_tm.create(tiles, 36, outch, 4u * elempack, elempack, opt.workspace_allocator);
1333
1334 #pragma omp parallel for num_threads(opt.num_threads)
1335 for (int p = 0; p < outch; p++)
1336 {
1337 float* output0_tm = top_blob_tm.channel(p);
1338
1339 const Mat kernel0_tm = kernel_tm.channel(p);
1340
1341 for (int r = 0; r < 36; r++)
1342 {
1343 const Mat bb2 = bottom_blob_tm2.channel(r);
1344
1345 int i = 0;
1346 for (; i + 11 < tiles; i += 12)
1347 {
1348 const float* r0 = bb2.row(i / 12);
1349 const float* k0 = kernel0_tm.row(r);
1350
1351 int nn = inch * 4; // inch always > 0
1352
1353 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
1354 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
1355 v4f32 _sum2 = (v4f32)__msa_fill_w(0);
1356 v4f32 _sum3 = (v4f32)__msa_fill_w(0);
1357 v4f32 _sum4 = (v4f32)__msa_fill_w(0);
1358 v4f32 _sum5 = (v4f32)__msa_fill_w(0);
1359 v4f32 _sum6 = (v4f32)__msa_fill_w(0);
1360 v4f32 _sum7 = (v4f32)__msa_fill_w(0);
1361 v4f32 _sum8 = (v4f32)__msa_fill_w(0);
1362 v4f32 _sum9 = (v4f32)__msa_fill_w(0);
1363 v4f32 _suma = (v4f32)__msa_fill_w(0);
1364 v4f32 _sumb = (v4f32)__msa_fill_w(0);
1365
1366 for (int j = 0; j < nn; j++)
1367 {
1368 __builtin_prefetch(r0 + 96);
1369 __builtin_prefetch(k0 + 32);
1370 v4i32 _val0123 = __msa_ld_w(r0, 0);
1371 v4i32 _val4567 = __msa_ld_w(r0 + 4, 0);
1372 v4i32 _val89ab = __msa_ld_w(r0 + 8, 0);
1373 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
1374 _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0);
1375 _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0);
1376 _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0);
1377 _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0);
1378 _sum4 = __msa_fmadd_w(_sum4, (v4f32)__msa_splati_w(_val4567, 0), _w0);
1379 _sum5 = __msa_fmadd_w(_sum5, (v4f32)__msa_splati_w(_val4567, 1), _w0);
1380 _sum6 = __msa_fmadd_w(_sum6, (v4f32)__msa_splati_w(_val4567, 2), _w0);
1381 _sum7 = __msa_fmadd_w(_sum7, (v4f32)__msa_splati_w(_val4567, 3), _w0);
1382 _sum8 = __msa_fmadd_w(_sum8, (v4f32)__msa_splati_w(_val89ab, 0), _w0);
1383 _sum9 = __msa_fmadd_w(_sum9, (v4f32)__msa_splati_w(_val89ab, 1), _w0);
1384 _suma = __msa_fmadd_w(_suma, (v4f32)__msa_splati_w(_val89ab, 2), _w0);
1385 _sumb = __msa_fmadd_w(_sumb, (v4f32)__msa_splati_w(_val89ab, 3), _w0);
1386
1387 r0 += 12;
1388 k0 += 4;
1389 }
1390
1391 __msa_st_w((v4i32)_sum0, output0_tm, 0);
1392 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
1393 __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0);
1394 __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0);
1395 __msa_st_w((v4i32)_sum4, output0_tm + 4 * 4, 0);
1396 __msa_st_w((v4i32)_sum5, output0_tm + 4 * 5, 0);
1397 __msa_st_w((v4i32)_sum6, output0_tm + 4 * 6, 0);
1398 __msa_st_w((v4i32)_sum7, output0_tm + 4 * 7, 0);
1399 __msa_st_w((v4i32)_sum8, output0_tm + 4 * 8, 0);
1400 __msa_st_w((v4i32)_sum9, output0_tm + 4 * 9, 0);
1401 __msa_st_w((v4i32)_suma, output0_tm + 4 * 10, 0);
1402 __msa_st_w((v4i32)_sumb, output0_tm + 4 * 11, 0);
1403
1404 output0_tm += 4 * 12;
1405 }
1406 for (; i + 7 < tiles; i += 8)
1407 {
1408 const float* r0 = bb2.row(i / 12 + (i % 12) / 8);
1409 const float* k0 = kernel0_tm.row(r);
1410
1411 int nn = inch * 4; // inch always > 0
1412
1413 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
1414 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
1415 v4f32 _sum2 = (v4f32)__msa_fill_w(0);
1416 v4f32 _sum3 = (v4f32)__msa_fill_w(0);
1417 v4f32 _sum4 = (v4f32)__msa_fill_w(0);
1418 v4f32 _sum5 = (v4f32)__msa_fill_w(0);
1419 v4f32 _sum6 = (v4f32)__msa_fill_w(0);
1420 v4f32 _sum7 = (v4f32)__msa_fill_w(0);
1421
1422 for (int j = 0; j < nn; j++)
1423 {
1424 __builtin_prefetch(r0 + 64);
1425 __builtin_prefetch(k0 + 32);
1426 v4i32 _val0123 = __msa_ld_w(r0, 0);
1427 v4i32 _val4567 = __msa_ld_w(r0 + 4, 0);
1428 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
1429 _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0);
1430 _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0);
1431 _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0);
1432 _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0);
1433 _sum4 = __msa_fmadd_w(_sum4, (v4f32)__msa_splati_w(_val4567, 0), _w0);
1434 _sum5 = __msa_fmadd_w(_sum5, (v4f32)__msa_splati_w(_val4567, 1), _w0);
1435 _sum6 = __msa_fmadd_w(_sum6, (v4f32)__msa_splati_w(_val4567, 2), _w0);
1436 _sum7 = __msa_fmadd_w(_sum7, (v4f32)__msa_splati_w(_val4567, 3), _w0);
1437
1438 r0 += 8;
1439 k0 += 4;
1440 }
1441
1442 __msa_st_w((v4i32)_sum0, output0_tm, 0);
1443 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
1444 __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0);
1445 __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0);
1446 __msa_st_w((v4i32)_sum4, output0_tm + 4 * 4, 0);
1447 __msa_st_w((v4i32)_sum5, output0_tm + 4 * 5, 0);
1448 __msa_st_w((v4i32)_sum6, output0_tm + 4 * 6, 0);
1449 __msa_st_w((v4i32)_sum7, output0_tm + 4 * 7, 0);
1450
1451 output0_tm += 4 * 8;
1452 }
1453 for (; i + 3 < tiles; i += 4)
1454 {
1455 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1456 const float* k0 = kernel0_tm.row(r);
1457
1458 int nn = inch * 4; // inch always > 0
1459
1460 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
1461 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
1462 v4f32 _sum2 = (v4f32)__msa_fill_w(0);
1463 v4f32 _sum3 = (v4f32)__msa_fill_w(0);
1464
1465 for (int j = 0; j < nn; j++)
1466 {
1467 __builtin_prefetch(r0 + 32);
1468 __builtin_prefetch(k0 + 32);
1469 v4i32 _val0123 = __msa_ld_w(r0, 0);
1470 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
1471 _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0);
1472 _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0);
1473 _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0);
1474 _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0);
1475
1476 r0 += 4;
1477 k0 += 4;
1478 }
1479
1480 __msa_st_w((v4i32)_sum0, output0_tm, 0);
1481 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
1482 __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0);
1483 __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0);
1484
1485 output0_tm += 4 * 4;
1486 }
1487 for (; i + 1 < tiles; i += 2)
1488 {
1489 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
1490 const float* k0 = kernel0_tm.row(r);
1491
1492 int nn = inch * 4; // inch always > 0
1493
1494 v4f32 _sum0 = (v4f32)__msa_fill_w(0);
1495 v4f32 _sum1 = (v4f32)__msa_fill_w(0);
1496
1497 for (int j = 0; j < nn; j++)
1498 {
1499 __builtin_prefetch(r0 + 16);
1500 __builtin_prefetch(k0 + 32);
1501 v4f32 _val0 = __msa_fill_w_f32(*r0++);
1502 v4f32 _val1 = __msa_fill_w_f32(*r0++);
1503 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
1504 _sum0 = __msa_fmadd_w(_sum0, _val0, _w0);
1505 _sum1 = __msa_fmadd_w(_sum1, _val1, _w0);
1506
1507 k0 += 4;
1508 }
1509
1510 __msa_st_w((v4i32)_sum0, output0_tm, 0);
1511 __msa_st_w((v4i32)_sum1, output0_tm + 4, 0);
1512
1513 output0_tm += 4 * 2;
1514 }
1515 for (; i < tiles; i++)
1516 {
1517 const float* r0 = bb2.row<const float>(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
1518 const float* k0 = kernel0_tm.row<const float>(r);
1519
1520 int nn = inch * 4; // inch always > 0
1521
1522 v4f32 _sum = (v4f32)__msa_fill_w(0);
1523
1524 for (int j = 0; j < nn; j++)
1525 {
1526 __builtin_prefetch(r0 + 8);
1527 __builtin_prefetch(k0 + 32);
1528 v4f32 _val0 = __msa_fill_w_f32(*r0++);
1529 v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0);
1530 _sum = __msa_fmadd_w(_sum, _val0, _w0);
1531
1532 k0 += 4;
1533 }
1534
1535 __msa_st_w((v4i32)_sum, output0_tm, 0);
1536
1537 output0_tm += 4;
1538 }
1539 }
1540 }
1541 }
1542 bottom_blob_tm = Mat();
1543 // END dot
1544
1545 // BEGIN transform output
1546 Mat top_blob_bordered;
1547 if (outw == top_blob.w && outh == top_blob.h)
1548 {
1549 top_blob_bordered = top_blob;
1550 }
1551 else
1552 {
1553 top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator);
1554 }
1555 {
1556 // const float otm[4][6] = {
1557 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
1558 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
1559 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
1560 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
1561 // };
1562
1563 // 0 = r00 + (r01 + r02) + (r03 + r04)
1564 // 1 = (r01 - r02) + (r03 - r04) * 2
1565 // 2 = (r01 + r02) + (r03 + r04) * 4
1566 // 3 = r05 + (r01 - r02) + (r03 - r04) * 8
1567
1568 int w_tm = outw / 4 * 6;
1569 int h_tm = outh / 4 * 6;
1570 const int tiles = w_tm / 6 * h_tm / 6;
1571
1572 #pragma omp parallel for num_threads(opt.num_threads)
1573 for (int p = 0; p < outch; p++)
1574 {
1575 const Mat out0_tm = top_blob_tm.channel(p);
1576 Mat out0 = top_blob_bordered.channel(p);
1577
1578 // const float bias0 = bias ? bias[p] : 0.f;
1579 v4f32 _bias0 = bias ? (v4f32)__msa_ld_w((const float*)bias + p * 4, 0) : (v4f32)__msa_fill_w(0);
1580
1581 float tmp[4][6][4];
1582
1583 v4f32 _v2 = __msa_fill_w_f32(2.f);
1584 v4f32 _v4 = __msa_fill_w_f32(4.f);
1585 v4f32 _v8 = __msa_fill_w_f32(8.f);
1586
1587 // tile
1588 for (int i = 0; i < outh / 4; i++)
1589 {
1590 for (int j = 0; j < outw / 4; j++)
1591 {
1592 // top_blob_tm.create(tiles, 36, outch, elemsize, elempack);
1593
1594 const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 6 + j) * 4;
1595 const float* output0_tm_1 = output0_tm_0 + tiles * 4;
1596 const float* output0_tm_2 = output0_tm_0 + tiles * 4 * 2;
1597 const float* output0_tm_3 = output0_tm_0 + tiles * 4 * 3;
1598 const float* output0_tm_4 = output0_tm_0 + tiles * 4 * 4;
1599 const float* output0_tm_5 = output0_tm_0 + tiles * 4 * 5;
1600
1601 float* output0 = out0.row<float>(i * 4) + (j * 4) * 4;
1602
1603 // TODO msa optimize
1604 for (int m = 0; m < 6; m++)
1605 {
1606 v4f32 _out0tm0 = (v4f32)__msa_ld_w(output0_tm_0, 0);
1607 v4f32 _out0tm1 = (v4f32)__msa_ld_w(output0_tm_1, 0);
1608 v4f32 _out0tm2 = (v4f32)__msa_ld_w(output0_tm_2, 0);
1609 v4f32 _out0tm3 = (v4f32)__msa_ld_w(output0_tm_3, 0);
1610 v4f32 _out0tm4 = (v4f32)__msa_ld_w(output0_tm_4, 0);
1611 v4f32 _out0tm5 = (v4f32)__msa_ld_w(output0_tm_5, 0);
1612
1613 v4f32 _tmp02a = __msa_fadd_w(_out0tm1, _out0tm2);
1614 v4f32 _tmp13a = __msa_fsub_w(_out0tm1, _out0tm2);
1615
1616 v4f32 _tmp02b = __msa_fadd_w(_out0tm3, _out0tm4);
1617 v4f32 _tmp13b = __msa_fsub_w(_out0tm3, _out0tm4);
1618
1619 v4f32 _tmp0m = __msa_fadd_w(__msa_fadd_w(_out0tm0, _tmp02a), _tmp02b);
1620 v4f32 _tmp1m = __msa_fmadd_w(_tmp13a, _v2, _tmp13b);
1621 v4f32 _tmp2m = __msa_fmadd_w(_tmp02a, _v4, _tmp02b);
1622 v4f32 _tmp3m = __msa_fmadd_w(__msa_fadd_w(_out0tm5, _tmp13a), _v8, _tmp13b);
1623
1624 __msa_st_w((v4i32)_tmp0m, tmp[0][m], 0);
1625 __msa_st_w((v4i32)_tmp1m, tmp[1][m], 0);
1626 __msa_st_w((v4i32)_tmp2m, tmp[2][m], 0);
1627 __msa_st_w((v4i32)_tmp3m, tmp[3][m], 0);
1628
1629 output0_tm_0 += tiles * 4 * 6;
1630 output0_tm_1 += tiles * 4 * 6;
1631 output0_tm_2 += tiles * 4 * 6;
1632 output0_tm_3 += tiles * 4 * 6;
1633 output0_tm_4 += tiles * 4 * 6;
1634 output0_tm_5 += tiles * 4 * 6;
1635 }
1636
1637 for (int m = 0; m < 4; m++)
1638 {
1639 v4f32 _tmp00 = (v4f32)__msa_ld_w(tmp[m][0], 0);
1640 v4f32 _tmp01 = (v4f32)__msa_ld_w(tmp[m][1], 0);
1641 v4f32 _tmp02 = (v4f32)__msa_ld_w(tmp[m][2], 0);
1642 v4f32 _tmp03 = (v4f32)__msa_ld_w(tmp[m][3], 0);
1643 v4f32 _tmp04 = (v4f32)__msa_ld_w(tmp[m][4], 0);
1644 v4f32 _tmp05 = (v4f32)__msa_ld_w(tmp[m][5], 0);
1645
1646 v4f32 _tmp02a = __msa_fadd_w(_tmp01, _tmp02);
1647 v4f32 _tmp13a = __msa_fsub_w(_tmp01, _tmp02);
1648
1649 v4f32 _tmp02b = __msa_fadd_w(_tmp03, _tmp04);
1650 v4f32 _tmp13b = __msa_fsub_w(_tmp03, _tmp04);
1651
1652 v4f32 _out00 = __msa_fadd_w(_bias0, __msa_fadd_w(__msa_fadd_w(_tmp00, _tmp02a), _tmp02b));
1653 v4f32 _out01 = __msa_fadd_w(_bias0, __msa_fmadd_w(_tmp13a, _v2, _tmp13b));
1654 v4f32 _out02 = __msa_fadd_w(_bias0, __msa_fmadd_w(_tmp02a, _v4, _tmp02b));
1655 v4f32 _out03 = __msa_fadd_w(_bias0, __msa_fmadd_w(__msa_fadd_w(_tmp05, _tmp13a), _v8, _tmp13b));
1656
1657 __msa_st_w((v4i32)_out00, output0, 0);
1658 __msa_st_w((v4i32)_out01, output0 + 4, 0);
1659 __msa_st_w((v4i32)_out02, output0 + 4 * 2, 0);
1660 __msa_st_w((v4i32)_out03, output0 + 4 * 3, 0);
1661
1662 output0 += outw * 4;
1663 }
1664 }
1665 }
1666 }
1667 }
1668 // END transform output
1669
1670 // cut result pad
1671 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1672 }
1673