1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 // Copyright (C) 2019 BUG1989. All rights reserved.
5 //
6 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
7 // in compliance with the License. You may obtain a copy of the License at
8 //
9 // https://opensource.org/licenses/BSD-3-Clause
10 //
11 // Unless required by applicable law or agreed to in writing, software distributed
12 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
13 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
14 // specific language governing permissions and limitations under the License.
15
conv3x3s1_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)16 static void conv3x3s1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
17 {
18 int w = bottom_blob.w;
19 int inch = bottom_blob.c;
20
21 int outw = top_blob.w;
22 int outh = top_blob.h;
23 int outch = top_blob.c;
24
25 const float* kernel = _kernel;
26 const float* bias = _bias;
27
28 #pragma omp parallel for num_threads(opt.num_threads)
29 for (int p = 0; p < outch; p++)
30 {
31 Mat out = top_blob.channel(p);
32
33 const float bias0 = bias ? bias[p] : 0.f;
34
35 out.fill(bias0);
36
37 for (int q = 0; q < inch; q++)
38 {
39 float* outptr = out;
40 float* outptr2 = outptr + outw;
41
42 const float* img0 = bottom_blob.channel(q);
43
44 const float* kernel0 = kernel + p * inch * 9 + q * 9;
45
46 const float* r0 = img0;
47 const float* r1 = img0 + w;
48 const float* r2 = img0 + w * 2;
49 const float* r3 = img0 + w * 3;
50
51 const float* k0 = kernel0;
52 const float* k1 = kernel0 + 3;
53 const float* k2 = kernel0 + 6;
54
55 int i = 0;
56
57 for (; i + 1 < outh; i += 2)
58 {
59 int remain = outw;
60
61 for (; remain > 0; remain--)
62 {
63 float sum = 0;
64 float sum2 = 0;
65
66 sum += r0[0] * k0[0];
67 sum += r0[1] * k0[1];
68 sum += r0[2] * k0[2];
69 sum += r1[0] * k1[0];
70 sum += r1[1] * k1[1];
71 sum += r1[2] * k1[2];
72 sum += r2[0] * k2[0];
73 sum += r2[1] * k2[1];
74 sum += r2[2] * k2[2];
75
76 sum2 += r1[0] * k0[0];
77 sum2 += r1[1] * k0[1];
78 sum2 += r1[2] * k0[2];
79 sum2 += r2[0] * k1[0];
80 sum2 += r2[1] * k1[1];
81 sum2 += r2[2] * k1[2];
82 sum2 += r3[0] * k2[0];
83 sum2 += r3[1] * k2[1];
84 sum2 += r3[2] * k2[2];
85
86 *outptr += sum;
87 *outptr2 += sum2;
88
89 r0++;
90 r1++;
91 r2++;
92 r3++;
93 outptr++;
94 outptr2++;
95 }
96
97 r0 += 2 + w;
98 r1 += 2 + w;
99 r2 += 2 + w;
100 r3 += 2 + w;
101
102 outptr += outw;
103 outptr2 += outw;
104 }
105
106 for (; i < outh; i++)
107 {
108 int remain = outw;
109
110 for (; remain > 0; remain--)
111 {
112 float sum = 0;
113
114 sum += r0[0] * k0[0];
115 sum += r0[1] * k0[1];
116 sum += r0[2] * k0[2];
117 sum += r1[0] * k1[0];
118 sum += r1[1] * k1[1];
119 sum += r1[2] * k1[2];
120 sum += r2[0] * k2[0];
121 sum += r2[1] * k2[1];
122 sum += r2[2] * k2[2];
123
124 *outptr += sum;
125
126 r0++;
127 r1++;
128 r2++;
129 outptr++;
130 }
131
132 r0 += 2;
133 r1 += 2;
134 r2 += 2;
135 }
136 }
137 }
138 }
139
conv3x3s1_winograd23_transform_kernel_sse(const Mat & kernel,Mat & kernel_tm,int inch,int outch)140 static void conv3x3s1_winograd23_transform_kernel_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch)
141 {
142 kernel_tm.create(4 * 4, inch, outch);
143
144 // G
145 const float ktm[4][3] = {
146 {1.0f, 0.0f, 0.0f},
147 {1.0f / 2, 1.0f / 2, 1.0f / 2},
148 {1.0f / 2, -1.0f / 2, 1.0f / 2},
149 {0.0f, 0.0f, 1.0f}
150 };
151
152 #pragma omp parallel for
153 for (int p = 0; p < outch; p++)
154 {
155 for (int q = 0; q < inch; q++)
156 {
157 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
158 float* kernel_tm0 = kernel_tm.channel(p).row(q);
159
160 // transform kernel
161 const float* k0 = kernel0;
162 const float* k1 = kernel0 + 3;
163 const float* k2 = kernel0 + 6;
164
165 // h
166 float tmp[4][3];
167 for (int i = 0; i < 4; i++)
168 {
169 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
170 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
171 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
172 }
173
174 // U
175 for (int j = 0; j < 4; j++)
176 {
177 float* tmpp = &tmp[j][0];
178
179 for (int i = 0; i < 4; i++)
180 {
181 kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
182 }
183 }
184 }
185 }
186 }
187
conv3x3s1_winograd23_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)188 static void conv3x3s1_winograd23_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
189 {
190 int w = bottom_blob.w;
191 int h = bottom_blob.h;
192 int inch = bottom_blob.c;
193
194 int outw = top_blob.w;
195 int outh = top_blob.h;
196 int outch = top_blob.c;
197
198 // pad to 2n+2, winograd F(2,3)
199 Mat bottom_blob_bordered = bottom_blob;
200
201 outw = (outw + 1) / 2 * 2;
202 outh = (outh + 1) / 2 * 2;
203
204 w = outw + 2;
205 h = outh + 2;
206 Option opt_b = opt;
207 opt_b.blob_allocator = opt.workspace_allocator;
208 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
209
210 const float* bias = _bias;
211
212 // BEGIN transform input
213 Mat bottom_blob_tm;
214 {
215 int w_tm = outw / 2 * 4;
216 int h_tm = outh / 2 * 4;
217
218 int nColBlocks = h_tm / 4; // may be the block num in Feathercnn
219 int nRowBlocks = w_tm / 4;
220
221 const int tiles = nColBlocks * nRowBlocks;
222
223 bottom_blob_tm.create(4 * 4, tiles, inch, 4u, opt.workspace_allocator);
224
225 // BT
226 // const float itm[4][4] = {
227 // {1.0f, 0.0f, -1.0f, 0.0f},
228 // {0.0f, 1.0f, 1.00f, 0.0f},
229 // {0.0f, -1.0f, 1.00f, 0.0f},
230 // {0.0f, -1.0f, 0.00f, 1.0f}
231 // };
232 #pragma omp parallel for num_threads(opt.num_threads)
233 for (int q = 0; q < inch; q++)
234 {
235 const float* img = bottom_blob_bordered.channel(q);
236 float* out_tm0 = bottom_blob_tm.channel(q);
237
238 for (int j = 0; j < nColBlocks; j++)
239 {
240 const float* r0 = img + w * j * 2;
241 const float* r1 = r0 + w;
242 const float* r2 = r1 + w;
243 const float* r3 = r2 + w;
244
245 for (int i = 0; i < nRowBlocks; i++)
246 {
247 #if __AVX__
248 __m128 _d0, _d1, _d2, _d3;
249 __m128 _w0, _w1, _w2, _w3;
250
251 // load
252 _d0 = _mm_loadu_ps(r0);
253 _d1 = _mm_loadu_ps(r1);
254 _d2 = _mm_loadu_ps(r2);
255 _d3 = _mm_loadu_ps(r3);
256
257 // w = B_t * d
258 _w0 = _mm_sub_ps(_d0, _d2);
259 _w1 = _mm_add_ps(_d1, _d2);
260 _w2 = _mm_sub_ps(_d2, _d1);
261 _w3 = _mm_sub_ps(_d3, _d1);
262
263 // transpose d to d_t
264 _MM_TRANSPOSE4_PS(_w0, _w1, _w2, _w3);
265
266 // d = B_t * d_t
267 _d0 = _mm_sub_ps(_w0, _w2);
268 _d1 = _mm_add_ps(_w1, _w2);
269 _d2 = _mm_sub_ps(_w2, _w1);
270 _d3 = _mm_sub_ps(_w3, _w1);
271
272 // save to out_tm
273 _mm_storeu_ps(out_tm0, _d0);
274 _mm_storeu_ps(out_tm0 + 4, _d1);
275 _mm_storeu_ps(out_tm0 + 8, _d2);
276 _mm_storeu_ps(out_tm0 + 12, _d3);
277 #else
278 float d0[4], d1[4], d2[4], d3[4];
279 float w0[4], w1[4], w2[4], w3[4];
280 float t0[4], t1[4], t2[4], t3[4];
281 // load
282 for (int n = 0; n < 4; n++)
283 {
284 d0[n] = r0[n];
285 d1[n] = r1[n];
286 d2[n] = r2[n];
287 d3[n] = r3[n];
288 }
289 // w = B_t * d
290 for (int n = 0; n < 4; n++)
291 {
292 w0[n] = d0[n] - d2[n];
293 w1[n] = d1[n] + d2[n];
294 w2[n] = d2[n] - d1[n];
295 w3[n] = d3[n] - d1[n];
296 }
297 // transpose d to d_t
298 {
299 t0[0] = w0[0];
300 t1[0] = w0[1];
301 t2[0] = w0[2];
302 t3[0] = w0[3];
303 t0[1] = w1[0];
304 t1[1] = w1[1];
305 t2[1] = w1[2];
306 t3[1] = w1[3];
307 t0[2] = w2[0];
308 t1[2] = w2[1];
309 t2[2] = w2[2];
310 t3[2] = w2[3];
311 t0[3] = w3[0];
312 t1[3] = w3[1];
313 t2[3] = w3[2];
314 t3[3] = w3[3];
315 }
316 // d = B_t * d_t
317 for (int n = 0; n < 4; n++)
318 {
319 d0[n] = t0[n] - t2[n];
320 d1[n] = t1[n] + t2[n];
321 d2[n] = t2[n] - t1[n];
322 d3[n] = t3[n] - t1[n];
323 }
324 // save to out_tm
325 for (int n = 0; n < 4; n++)
326 {
327 out_tm0[n] = d0[n];
328 out_tm0[n + 4] = d1[n];
329 out_tm0[n + 8] = d2[n];
330 out_tm0[n + 12] = d3[n];
331 }
332 #endif
333 r0 += 2;
334 r1 += 2;
335 r2 += 2;
336 r3 += 2;
337
338 out_tm0 += 16;
339 }
340 }
341 }
342 }
343 bottom_blob_bordered = Mat();
344
345 // BEGIN dot
346 Mat top_blob_tm;
347 {
348 int w_tm = outw / 2 * 4;
349 int h_tm = outh / 2 * 4;
350
351 int nColBlocks = h_tm / 4; // may be the block num in Feathercnn
352 int nRowBlocks = w_tm / 4;
353
354 const int tiles = nColBlocks * nRowBlocks;
355
356 top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator);
357
358 int nn_outch = outch >> 2;
359 int remain_outch_start = nn_outch << 2;
360
361 #pragma omp parallel for num_threads(opt.num_threads)
362 for (int pp = 0; pp < nn_outch; pp++)
363 {
364 int p = pp * 4;
365
366 Mat out0_tm = top_blob_tm.channel(p);
367 Mat out1_tm = top_blob_tm.channel(p + 1);
368 Mat out2_tm = top_blob_tm.channel(p + 2);
369 Mat out3_tm = top_blob_tm.channel(p + 3);
370
371 const Mat kernel0_tm = kernel_tm.channel(p);
372 const Mat kernel1_tm = kernel_tm.channel(p + 1);
373 const Mat kernel2_tm = kernel_tm.channel(p + 2);
374 const Mat kernel3_tm = kernel_tm.channel(p + 3);
375
376 for (int i = 0; i < tiles; i++)
377 {
378 float* output0_tm = out0_tm.row(i);
379 float* output1_tm = out1_tm.row(i);
380 float* output2_tm = out2_tm.row(i);
381 float* output3_tm = out3_tm.row(i);
382
383 #if __AVX__
384 float zero_val = 0.f;
385
386 __m256 _sum0 = _mm256_broadcast_ss(&zero_val);
387 __m256 _sum0n = _mm256_broadcast_ss(&zero_val);
388 __m256 _sum1 = _mm256_broadcast_ss(&zero_val);
389 __m256 _sum1n = _mm256_broadcast_ss(&zero_val);
390 __m256 _sum2 = _mm256_broadcast_ss(&zero_val);
391 __m256 _sum2n = _mm256_broadcast_ss(&zero_val);
392 __m256 _sum3 = _mm256_broadcast_ss(&zero_val);
393 __m256 _sum3n = _mm256_broadcast_ss(&zero_val);
394
395 int q = 0;
396
397 for (; q + 3 < inch; q += 4)
398 {
399 const float* r0 = bottom_blob_tm.channel(q).row(i);
400 const float* r1 = bottom_blob_tm.channel(q + 1).row(i);
401 const float* r2 = bottom_blob_tm.channel(q + 2).row(i);
402 const float* r3 = bottom_blob_tm.channel(q + 3).row(i);
403
404 const float* k0 = kernel0_tm.row(q);
405 const float* k1 = kernel1_tm.row(q);
406 const float* k2 = kernel2_tm.row(q);
407 const float* k3 = kernel3_tm.row(q);
408
409 __m256 _r0 = _mm256_loadu_ps(r0);
410 __m256 _r0n = _mm256_loadu_ps(r0 + 8);
411 // k0
412 __m256 _k0 = _mm256_loadu_ps(k0);
413 __m256 _k0n = _mm256_loadu_ps(k0 + 8);
414 __m256 _k1 = _mm256_loadu_ps(k1);
415 __m256 _k1n = _mm256_loadu_ps(k1 + 8);
416 __m256 _k2 = _mm256_loadu_ps(k2);
417 __m256 _k2n = _mm256_loadu_ps(k2 + 8);
418 __m256 _k3 = _mm256_loadu_ps(k3);
419 __m256 _k3n = _mm256_loadu_ps(k3 + 8);
420 _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
421 _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
422 _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
423 _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
424 _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
425 _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
426 _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
427 _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
428
429 // k1
430 _r0 = _mm256_loadu_ps(r1);
431 _r0n = _mm256_loadu_ps(r1 + 8);
432 _k0 = _mm256_loadu_ps(k0 + 16);
433 _k0n = _mm256_loadu_ps(k0 + 24);
434 _k1 = _mm256_loadu_ps(k1 + 16);
435 _k1n = _mm256_loadu_ps(k1 + 24);
436 _k2 = _mm256_loadu_ps(k2 + 16);
437 _k2n = _mm256_loadu_ps(k2 + 24);
438 _k3 = _mm256_loadu_ps(k3 + 16);
439 _k3n = _mm256_loadu_ps(k3 + 24);
440 _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
441 _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
442 _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
443 _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
444 _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
445 _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
446 _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
447 _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
448 // k2
449 _r0 = _mm256_loadu_ps(r2);
450 _r0n = _mm256_loadu_ps(r2 + 8);
451 _k0 = _mm256_loadu_ps(k0 + 32);
452 _k0n = _mm256_loadu_ps(k0 + 40);
453 _k1 = _mm256_loadu_ps(k1 + 32);
454 _k1n = _mm256_loadu_ps(k1 + 40);
455 _k2 = _mm256_loadu_ps(k2 + 32);
456 _k2n = _mm256_loadu_ps(k2 + 40);
457 _k3 = _mm256_loadu_ps(k3 + 32);
458 _k3n = _mm256_loadu_ps(k3 + 40);
459 _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
460 _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
461 _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
462 _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
463 _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
464 _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
465 _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
466 _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
467 // k3
468 _r0 = _mm256_loadu_ps(r3);
469 _r0n = _mm256_loadu_ps(r3 + 8);
470 _k0 = _mm256_loadu_ps(k0 + 48);
471 _k0n = _mm256_loadu_ps(k0 + 56);
472 _k1 = _mm256_loadu_ps(k1 + 48);
473 _k1n = _mm256_loadu_ps(k1 + 56);
474 _k2 = _mm256_loadu_ps(k2 + 48);
475 _k2n = _mm256_loadu_ps(k2 + 56);
476 _k3 = _mm256_loadu_ps(k3 + 48);
477 _k3n = _mm256_loadu_ps(k3 + 56);
478 _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
479 _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
480 _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
481 _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
482 _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
483 _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
484 _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
485 _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
486 }
487
488 for (; q < inch; q++)
489 {
490 const float* r0 = bottom_blob_tm.channel(q).row(i);
491
492 const float* k0 = kernel0_tm.row(q);
493 const float* k1 = kernel1_tm.row(q);
494 const float* k2 = kernel2_tm.row(q);
495 const float* k3 = kernel3_tm.row(q);
496
497 __m256 _r0 = _mm256_loadu_ps(r0);
498 __m256 _r0n = _mm256_loadu_ps(r0 + 8);
499 __m256 _k0 = _mm256_loadu_ps(k0);
500 __m256 _k0n = _mm256_loadu_ps(k0 + 8);
501 __m256 _k1 = _mm256_loadu_ps(k1);
502 __m256 _k1n = _mm256_loadu_ps(k1 + 8);
503 __m256 _k2 = _mm256_loadu_ps(k2);
504 __m256 _k2n = _mm256_loadu_ps(k2 + 8);
505 __m256 _k3 = _mm256_loadu_ps(k3);
506 __m256 _k3n = _mm256_loadu_ps(k3 + 8);
507
508 _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
509 _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
510 _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
511 _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
512 _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
513 _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
514 _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
515 _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
516 }
517
518 _mm256_storeu_ps(output0_tm, _sum0);
519 _mm256_storeu_ps(output0_tm + 8, _sum0n);
520 _mm256_storeu_ps(output1_tm, _sum1);
521 _mm256_storeu_ps(output1_tm + 8, _sum1n);
522 _mm256_storeu_ps(output2_tm, _sum2);
523 _mm256_storeu_ps(output2_tm + 8, _sum2n);
524 _mm256_storeu_ps(output3_tm, _sum3);
525 _mm256_storeu_ps(output3_tm + 8, _sum3n);
526 #else
527 float sum0[16] = {0.0f};
528 float sum1[16] = {0.0f};
529 float sum2[16] = {0.0f};
530 float sum3[16] = {0.0f};
531
532 int q = 0;
533 for (; q + 3 < inch; q += 4)
534 {
535 const float* r0 = bottom_blob_tm.channel(q).row(i);
536 const float* r1 = bottom_blob_tm.channel(q + 1).row(i);
537 const float* r2 = bottom_blob_tm.channel(q + 2).row(i);
538 const float* r3 = bottom_blob_tm.channel(q + 3).row(i);
539
540 const float* k0 = kernel0_tm.row(q);
541 const float* k1 = kernel1_tm.row(q);
542 const float* k2 = kernel2_tm.row(q);
543 const float* k3 = kernel3_tm.row(q);
544
545 for (int n = 0; n < 16; n++)
546 {
547 sum0[n] += r0[n] * k0[n];
548 k0 += 16;
549 sum0[n] += r1[n] * k0[n];
550 k0 += 16;
551 sum0[n] += r2[n] * k0[n];
552 k0 += 16;
553 sum0[n] += r3[n] * k0[n];
554 k0 -= 16 * 3;
555
556 sum1[n] += r0[n] * k1[n];
557 k1 += 16;
558 sum1[n] += r1[n] * k1[n];
559 k1 += 16;
560 sum1[n] += r2[n] * k1[n];
561 k1 += 16;
562 sum1[n] += r3[n] * k1[n];
563 k1 -= 16 * 3;
564
565 sum2[n] += r0[n] * k2[n];
566 k2 += 16;
567 sum2[n] += r1[n] * k2[n];
568 k2 += 16;
569 sum2[n] += r2[n] * k2[n];
570 k2 += 16;
571 sum2[n] += r3[n] * k2[n];
572 k2 -= 16 * 3;
573
574 sum3[n] += r0[n] * k3[n];
575 k3 += 16;
576 sum3[n] += r1[n] * k3[n];
577 k3 += 16;
578 sum3[n] += r2[n] * k3[n];
579 k3 += 16;
580 sum3[n] += r3[n] * k3[n];
581 k3 -= 16 * 3;
582 }
583 }
584
585 for (; q < inch; q++)
586 {
587 const float* r0 = bottom_blob_tm.channel(q).row(i);
588
589 const float* k0 = kernel0_tm.row(q);
590 const float* k1 = kernel1_tm.row(q);
591 const float* k2 = kernel2_tm.row(q);
592 const float* k3 = kernel3_tm.row(q);
593
594 for (int n = 0; n < 16; n++)
595 {
596 sum0[n] += r0[n] * k0[n];
597 sum1[n] += r0[n] * k1[n];
598 sum2[n] += r0[n] * k2[n];
599 sum3[n] += r0[n] * k3[n];
600 }
601 }
602
603 for (int n = 0; n < 16; n++)
604 {
605 output0_tm[n] = sum0[n];
606 output1_tm[n] = sum1[n];
607 output2_tm[n] = sum2[n];
608 output3_tm[n] = sum3[n];
609 }
610 #endif
611 }
612 }
613
614 #pragma omp parallel for num_threads(opt.num_threads)
615 for (int p = remain_outch_start; p < outch; p++)
616 {
617 Mat out0_tm = top_blob_tm.channel(p);
618 const Mat kernel0_tm = kernel_tm.channel(p);
619
620 for (int i = 0; i < tiles; i++)
621 {
622 float* output0_tm = out0_tm.row(i);
623
624 float sum0[16] = {0.0f};
625
626 int q = 0;
627 for (; q + 3 < inch; q += 4)
628 {
629 const float* r0 = bottom_blob_tm.channel(q).row(i);
630 const float* r1 = bottom_blob_tm.channel(q + 1).row(i);
631 const float* r2 = bottom_blob_tm.channel(q + 2).row(i);
632 const float* r3 = bottom_blob_tm.channel(q + 3).row(i);
633
634 const float* k0 = kernel0_tm.row(q);
635 const float* k1 = kernel0_tm.row(q + 1);
636 const float* k2 = kernel0_tm.row(q + 2);
637 const float* k3 = kernel0_tm.row(q + 3);
638
639 for (int n = 0; n < 16; n++)
640 {
641 sum0[n] += r0[n] * k0[n];
642 sum0[n] += r1[n] * k1[n];
643 sum0[n] += r2[n] * k2[n];
644 sum0[n] += r3[n] * k3[n];
645 }
646 }
647
648 for (; q < inch; q++)
649 {
650 const float* r0 = bottom_blob_tm.channel(q).row(i);
651 const float* k0 = kernel0_tm.row(q);
652
653 for (int n = 0; n < 16; n++)
654 {
655 sum0[n] += r0[n] * k0[n];
656 }
657 }
658
659 for (int n = 0; n < 16; n++)
660 {
661 output0_tm[n] = sum0[n];
662 }
663 }
664 }
665 }
666 bottom_blob_tm = Mat();
667 // END dot
668
669 // BEGIN transform output
670 Mat top_blob_bordered;
671 if (outw == top_blob.w && outh == top_blob.h)
672 {
673 top_blob_bordered = top_blob;
674 }
675 else
676 {
677 top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
678 }
679 {
680 // AT
681 // const float itm[2][4] = {
682 // {1.0f, 1.0f, 1.0f, 0.0f},
683 // {0.0f, 1.0f, -1.0f, 1.0f}
684 // };
685
686 int w_tm = outw / 2 * 4;
687 int h_tm = outh / 2 * 4;
688
689 int nColBlocks = h_tm / 4; // may be the block num in Feathercnn
690 int nRowBlocks = w_tm / 4;
691
692 #pragma omp parallel for num_threads(opt.num_threads)
693 for (int p = 0; p < outch; p++)
694 {
695 Mat out_tm = top_blob_tm.channel(p);
696 Mat out = top_blob_bordered.channel(p);
697
698 const float bias0 = bias ? bias[p] : 0.f;
699
700 for (int j = 0; j < nColBlocks; j++)
701 {
702 float* outRow0 = out.row(j * 2);
703 float* outRow1 = out.row(j * 2 + 1);
704
705 for (int i = 0; i < nRowBlocks; i++)
706 {
707 float* out_tile = out_tm.row(j * nRowBlocks + i);
708
709 float s0[4], s1[4], s2[4], s3[4];
710 float w0[4], w1[4];
711 float d0[2], d1[2], d2[2], d3[2];
712 float o0[2], o1[2];
713 // load
714 for (int n = 0; n < 4; n++)
715 {
716 s0[n] = out_tile[n];
717 s1[n] = out_tile[n + 4];
718 s2[n] = out_tile[n + 8];
719 s3[n] = out_tile[n + 12];
720 }
721 // w = A_T * W
722 for (int n = 0; n < 4; n++)
723 {
724 w0[n] = s0[n] + s1[n] + s2[n];
725 w1[n] = s1[n] - s2[n] + s3[n];
726 }
727 // transpose w to w_t
728 {
729 d0[0] = w0[0];
730 d0[1] = w1[0];
731 d1[0] = w0[1];
732 d1[1] = w1[1];
733 d2[0] = w0[2];
734 d2[1] = w1[2];
735 d3[0] = w0[3];
736 d3[1] = w1[3];
737 }
738 // Y = A_T * w_t
739 for (int n = 0; n < 2; n++)
740 {
741 o0[n] = d0[n] + d1[n] + d2[n] + bias0;
742 o1[n] = d1[n] - d2[n] + d3[n] + bias0;
743 }
744 // save to top blob tm
745 outRow0[0] = o0[0];
746 outRow0[1] = o0[1];
747 outRow1[0] = o1[0];
748 outRow1[1] = o1[1];
749
750 outRow0 += 2;
751 outRow1 += 2;
752 }
753 }
754 }
755 }
756 // END transform output
757
758 // cut result pad
759 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
760 }
761
conv3x3s1_winograd43_transform_kernel_sse(const Mat & kernel,std::vector<Mat> & kernel_tm2,int inch,int outch)762 static void conv3x3s1_winograd43_transform_kernel_sse(const Mat& kernel, std::vector<Mat>& kernel_tm2, int inch, int outch)
763 {
764 Mat kernel_tm(6 * 6, inch, outch);
765
766 // G
767 const float ktm[6][3] = {
768 {1.0f / 4, 0.0f, 0.0f},
769 {-1.0f / 6, -1.0f / 6, -1.0f / 6},
770 {-1.0f / 6, 1.0f / 6, -1.0f / 6},
771 {1.0f / 24, 1.0f / 12, 1.0f / 6},
772 {1.0f / 24, -1.0f / 12, 1.0f / 6},
773 {0.0f, 0.0f, 1.0f}
774 };
775
776 #pragma omp parallel for
777 for (int p = 0; p < outch; p++)
778 {
779 for (int q = 0; q < inch; q++)
780 {
781 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
782 float* kernel_tm0 = kernel_tm.channel(p).row(q);
783
784 // transform kernel
785 const float* k0 = kernel0;
786 const float* k1 = kernel0 + 3;
787 const float* k2 = kernel0 + 6;
788
789 // h
790 float tmp[6][3];
791 for (int i = 0; i < 6; i++)
792 {
793 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
794 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
795 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
796 }
797
798 // U
799 for (int j = 0; j < 6; j++)
800 {
801 float* tmpp = &tmp[j][0];
802
803 for (int i = 0; i < 6; i++)
804 {
805 kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
806 }
807 }
808 }
809 }
810
811 for (int r = 0; r < 9; r++)
812 {
813 Mat kernel_tm_test(4 * 8, inch, outch / 8 + (outch % 8) / 4 + outch % 4);
814
815 int p = 0;
816 for (; p + 7 < outch; p += 8)
817 {
818 const float* kernel0 = (const float*)kernel_tm.channel(p);
819 const float* kernel1 = (const float*)kernel_tm.channel(p + 1);
820 const float* kernel2 = (const float*)kernel_tm.channel(p + 2);
821 const float* kernel3 = (const float*)kernel_tm.channel(p + 3);
822 const float* kernel4 = (const float*)kernel_tm.channel(p + 4);
823 const float* kernel5 = (const float*)kernel_tm.channel(p + 5);
824 const float* kernel6 = (const float*)kernel_tm.channel(p + 6);
825 const float* kernel7 = (const float*)kernel_tm.channel(p + 7);
826
827 float* ktmp = kernel_tm_test.channel(p / 8);
828
829 for (int q = 0; q < inch; q++)
830 {
831 ktmp[0] = kernel0[r * 4 + 0];
832 ktmp[1] = kernel0[r * 4 + 1];
833 ktmp[2] = kernel0[r * 4 + 2];
834 ktmp[3] = kernel0[r * 4 + 3];
835
836 ktmp[4] = kernel1[r * 4 + 0];
837 ktmp[5] = kernel1[r * 4 + 1];
838 ktmp[6] = kernel1[r * 4 + 2];
839 ktmp[7] = kernel1[r * 4 + 3];
840
841 ktmp[8] = kernel2[r * 4 + 0];
842 ktmp[9] = kernel2[r * 4 + 1];
843 ktmp[10] = kernel2[r * 4 + 2];
844 ktmp[11] = kernel2[r * 4 + 3];
845
846 ktmp[12] = kernel3[r * 4 + 0];
847 ktmp[13] = kernel3[r * 4 + 1];
848 ktmp[14] = kernel3[r * 4 + 2];
849 ktmp[15] = kernel3[r * 4 + 3];
850
851 ktmp[16] = kernel4[r * 4 + 0];
852 ktmp[17] = kernel4[r * 4 + 1];
853 ktmp[18] = kernel4[r * 4 + 2];
854 ktmp[19] = kernel4[r * 4 + 3];
855
856 ktmp[20] = kernel5[r * 4 + 0];
857 ktmp[21] = kernel5[r * 4 + 1];
858 ktmp[22] = kernel5[r * 4 + 2];
859 ktmp[23] = kernel5[r * 4 + 3];
860
861 ktmp[24] = kernel6[r * 4 + 0];
862 ktmp[25] = kernel6[r * 4 + 1];
863 ktmp[26] = kernel6[r * 4 + 2];
864 ktmp[27] = kernel6[r * 4 + 3];
865
866 ktmp[28] = kernel7[r * 4 + 0];
867 ktmp[29] = kernel7[r * 4 + 1];
868 ktmp[30] = kernel7[r * 4 + 2];
869 ktmp[31] = kernel7[r * 4 + 3];
870
871 ktmp += 32;
872 kernel0 += 36;
873 kernel1 += 36;
874 kernel2 += 36;
875 kernel3 += 36;
876 kernel4 += 36;
877 kernel5 += 36;
878 kernel6 += 36;
879 kernel7 += 36;
880 }
881 }
882
883 for (; p + 3 < outch; p += 4)
884 {
885 const float* kernel0 = (const float*)kernel_tm.channel(p);
886 const float* kernel1 = (const float*)kernel_tm.channel(p + 1);
887 const float* kernel2 = (const float*)kernel_tm.channel(p + 2);
888 const float* kernel3 = (const float*)kernel_tm.channel(p + 3);
889
890 float* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4);
891
892 for (int q = 0; q < inch; q++)
893 {
894 ktmp[0] = kernel0[r * 4 + 0];
895 ktmp[1] = kernel0[r * 4 + 1];
896 ktmp[2] = kernel0[r * 4 + 2];
897 ktmp[3] = kernel0[r * 4 + 3];
898
899 ktmp[4] = kernel1[r * 4 + 0];
900 ktmp[5] = kernel1[r * 4 + 1];
901 ktmp[6] = kernel1[r * 4 + 2];
902 ktmp[7] = kernel1[r * 4 + 3];
903
904 ktmp[8] = kernel2[r * 4 + 0];
905 ktmp[9] = kernel2[r * 4 + 1];
906 ktmp[10] = kernel2[r * 4 + 2];
907 ktmp[11] = kernel2[r * 4 + 3];
908
909 ktmp[12] = kernel3[r * 4 + 0];
910 ktmp[13] = kernel3[r * 4 + 1];
911 ktmp[14] = kernel3[r * 4 + 2];
912 ktmp[15] = kernel3[r * 4 + 3];
913
914 ktmp += 16;
915 kernel0 += 36;
916 kernel1 += 36;
917 kernel2 += 36;
918 kernel3 += 36;
919 }
920 }
921
922 for (; p < outch; p++)
923 {
924 const float* kernel0 = (const float*)kernel_tm.channel(p);
925
926 float* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4 + p % 4);
927
928 for (int q = 0; q < inch; q++)
929 {
930 ktmp[0] = kernel0[r * 4 + 0];
931 ktmp[1] = kernel0[r * 4 + 1];
932 ktmp[2] = kernel0[r * 4 + 2];
933 ktmp[3] = kernel0[r * 4 + 3];
934
935 ktmp += 4;
936 kernel0 += 36;
937 }
938 }
939 kernel_tm2.push_back(kernel_tm_test);
940 }
941 }
942
conv3x3s1_winograd43_sse(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Mat & _bias,const Option & opt)943 static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Mat& _bias, const Option& opt)
944 {
945 int w = bottom_blob.w;
946 int h = bottom_blob.h;
947 int inch = bottom_blob.c;
948
949 int outw = top_blob.w;
950 int outh = top_blob.h;
951 int outch = top_blob.c;
952
953 size_t elemsize = bottom_blob.elemsize;
954 const float* bias = _bias;
955
956 // pad to 4n+2, winograd F(4,3)
957 Mat bottom_blob_bordered = bottom_blob;
958
959 outw = (outw + 3) / 4 * 4;
960 outh = (outh + 3) / 4 * 4;
961
962 w = outw + 2;
963 h = outh + 2;
964
965 Option opt_b = opt;
966 opt_b.blob_allocator = opt.workspace_allocator;
967 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
968
969 // BEGIN transform input
970 Mat bottom_blob_tm;
971 {
972 int w_tm = outw / 4 * 6;
973 int h_tm = outh / 4 * 6;
974
975 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
976 int nRowBlocks = w_tm / 6;
977
978 const int tiles = nColBlocks * nRowBlocks;
979
980 bottom_blob_tm.create(4, inch, tiles * 9, elemsize, opt.workspace_allocator);
981
982 // BT
983 // const float itm[4][4] = {
984 // {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
985 // {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
986 // {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
987 // {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
988 // {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
989 // {0.0f, 4.0f, 0.0f,-5.0f, 0.0f, 1.0f}
990 // };
991
992 // 0 = 4 * r00 - 5 * r02 + r04
993 // 1 = -4 * (r01 + r02) + r03 + r04
994 // 2 = 4 * (r01 - r02) - r03 + r04
995 // 3 = -2 * r01 - r02 + 2 * r03 + r04
996 // 4 = 2 * r01 - r02 - 2 * r03 + r04
997 // 5 = 4 * r01 - 5 * r03 + r05
998
999 // 0 = 4 * r00 - 5 * r02 + r04
1000 // 1 = -4 * (r01 + r02) + r03 + r04
1001 // 2 = 4 * (r01 - r02) - r03 + r04
1002 // 3 = -2 * r01 - r02 + 2 * r03 + r04
1003 // 4 = 2 * r01 - r02 - 2 * r03 + r04
1004 // 5 = 4 * r01 - 5 * r03 + r05
1005
1006 #if __AVX__
1007 __m256 _1_n = _mm256_set1_ps(-1);
1008 __m256 _2_p = _mm256_set1_ps(2);
1009 __m256 _2_n = _mm256_set1_ps(-2);
1010 __m256 _4_p = _mm256_set1_ps(4);
1011 __m256 _4_n = _mm256_set1_ps(-4);
1012 __m256 _5_n = _mm256_set1_ps(-5);
1013 #endif
1014
1015 #pragma omp parallel for num_threads(opt.num_threads)
1016 for (int q = 0; q < inch; q++)
1017 {
1018 const float* img = bottom_blob_bordered.channel(q);
1019
1020 for (int j = 0; j < nColBlocks; j++)
1021 {
1022 const float* r0 = img + w * j * 4;
1023 const float* r1 = r0 + w;
1024 const float* r2 = r1 + w;
1025 const float* r3 = r2 + w;
1026 const float* r4 = r3 + w;
1027 const float* r5 = r4 + w;
1028
1029 for (int i = 0; i < nRowBlocks; i++)
1030 {
1031 float* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row(q);
1032 float* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row(q);
1033 float* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row(q);
1034 float* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row(q);
1035 float* out_tm4 = bottom_blob_tm.channel(tiles * 4 + j * nRowBlocks + i).row(q);
1036 float* out_tm5 = bottom_blob_tm.channel(tiles * 5 + j * nRowBlocks + i).row(q);
1037 float* out_tm6 = bottom_blob_tm.channel(tiles * 6 + j * nRowBlocks + i).row(q);
1038 float* out_tm7 = bottom_blob_tm.channel(tiles * 7 + j * nRowBlocks + i).row(q);
1039 float* out_tm8 = bottom_blob_tm.channel(tiles * 8 + j * nRowBlocks + i).row(q);
1040 #if __AVX__
1041 __m256 _d0, _d1, _d2, _d3, _d4, _d5;
1042 __m256 _w0, _w1, _w2, _w3, _w4, _w5;
1043 __m256 _t0, _t1, _t2, _t3, _t4, _t5;
1044 __m256 _n0, _n1, _n2, _n3, _n4, _n5;
1045 // load
1046 _d0 = _mm256_loadu_ps(r0);
1047 _d1 = _mm256_loadu_ps(r1);
1048 _d2 = _mm256_loadu_ps(r2);
1049 _d3 = _mm256_loadu_ps(r3);
1050 _d4 = _mm256_loadu_ps(r4);
1051 _d5 = _mm256_loadu_ps(r5);
1052
1053 // w = B_t * d
1054 _w0 = _mm256_mul_ps(_d0, _4_p);
1055 _w0 = _mm256_fmadd_ps(_d2, _5_n, _w0);
1056 _w0 = _mm256_add_ps(_w0, _d4);
1057
1058 _w1 = _mm256_mul_ps(_d1, _4_n);
1059 _w1 = _mm256_fmadd_ps(_d2, _4_n, _w1);
1060 _w1 = _mm256_add_ps(_w1, _d3);
1061 _w1 = _mm256_add_ps(_w1, _d4);
1062
1063 _w2 = _mm256_mul_ps(_d1, _4_p);
1064 _w2 = _mm256_fmadd_ps(_d2, _4_n, _w2);
1065 _w2 = _mm256_fmadd_ps(_d3, _1_n, _w2);
1066 _w2 = _mm256_add_ps(_w2, _d4);
1067
1068 _w3 = _mm256_mul_ps(_d1, _2_n);
1069 _w3 = _mm256_fmadd_ps(_d2, _1_n, _w3);
1070 _w3 = _mm256_fmadd_ps(_d3, _2_p, _w3);
1071 _w3 = _mm256_add_ps(_w3, _d4);
1072
1073 _w4 = _mm256_mul_ps(_d1, _2_p);
1074 _w4 = _mm256_fmadd_ps(_d2, _1_n, _w4);
1075 _w4 = _mm256_fmadd_ps(_d3, _2_n, _w4);
1076 _w4 = _mm256_add_ps(_w4, _d4);
1077
1078 _w5 = _mm256_mul_ps(_d1, _4_p);
1079 _w5 = _mm256_fmadd_ps(_d3, _5_n, _w5);
1080 _w5 = _mm256_add_ps(_w5, _d5);
1081 // transpose d to d_t
1082 #if (defined _WIN32 && !(defined __MINGW32__))
1083 {
1084 _t0.m256_f32[0] = _w0.m256_f32[0];
1085 _t1.m256_f32[0] = _w0.m256_f32[1];
1086 _t2.m256_f32[0] = _w0.m256_f32[2];
1087 _t3.m256_f32[0] = _w0.m256_f32[3];
1088 _t4.m256_f32[0] = _w0.m256_f32[4];
1089 _t5.m256_f32[0] = _w0.m256_f32[5];
1090 _t0.m256_f32[1] = _w1.m256_f32[0];
1091 _t1.m256_f32[1] = _w1.m256_f32[1];
1092 _t2.m256_f32[1] = _w1.m256_f32[2];
1093 _t3.m256_f32[1] = _w1.m256_f32[3];
1094 _t4.m256_f32[1] = _w1.m256_f32[4];
1095 _t5.m256_f32[1] = _w1.m256_f32[5];
1096 _t0.m256_f32[2] = _w2.m256_f32[0];
1097 _t1.m256_f32[2] = _w2.m256_f32[1];
1098 _t2.m256_f32[2] = _w2.m256_f32[2];
1099 _t3.m256_f32[2] = _w2.m256_f32[3];
1100 _t4.m256_f32[2] = _w2.m256_f32[4];
1101 _t5.m256_f32[2] = _w2.m256_f32[5];
1102 _t0.m256_f32[3] = _w3.m256_f32[0];
1103 _t1.m256_f32[3] = _w3.m256_f32[1];
1104 _t2.m256_f32[3] = _w3.m256_f32[2];
1105 _t3.m256_f32[3] = _w3.m256_f32[3];
1106 _t4.m256_f32[3] = _w3.m256_f32[4];
1107 _t5.m256_f32[3] = _w3.m256_f32[5];
1108 _t0.m256_f32[4] = _w4.m256_f32[0];
1109 _t1.m256_f32[4] = _w4.m256_f32[1];
1110 _t2.m256_f32[4] = _w4.m256_f32[2];
1111 _t3.m256_f32[4] = _w4.m256_f32[3];
1112 _t4.m256_f32[4] = _w4.m256_f32[4];
1113 _t5.m256_f32[4] = _w4.m256_f32[5];
1114 _t0.m256_f32[5] = _w5.m256_f32[0];
1115 _t1.m256_f32[5] = _w5.m256_f32[1];
1116 _t2.m256_f32[5] = _w5.m256_f32[2];
1117 _t3.m256_f32[5] = _w5.m256_f32[3];
1118 _t4.m256_f32[5] = _w5.m256_f32[4];
1119 _t5.m256_f32[5] = _w5.m256_f32[5];
1120 }
1121 #else
1122 {
1123 _t0[0] = _w0[0];
1124 _t1[0] = _w0[1];
1125 _t2[0] = _w0[2];
1126 _t3[0] = _w0[3];
1127 _t4[0] = _w0[4];
1128 _t5[0] = _w0[5];
1129 _t0[1] = _w1[0];
1130 _t1[1] = _w1[1];
1131 _t2[1] = _w1[2];
1132 _t3[1] = _w1[3];
1133 _t4[1] = _w1[4];
1134 _t5[1] = _w1[5];
1135 _t0[2] = _w2[0];
1136 _t1[2] = _w2[1];
1137 _t2[2] = _w2[2];
1138 _t3[2] = _w2[3];
1139 _t4[2] = _w2[4];
1140 _t5[2] = _w2[5];
1141 _t0[3] = _w3[0];
1142 _t1[3] = _w3[1];
1143 _t2[3] = _w3[2];
1144 _t3[3] = _w3[3];
1145 _t4[3] = _w3[4];
1146 _t5[3] = _w3[5];
1147 _t0[4] = _w4[0];
1148 _t1[4] = _w4[1];
1149 _t2[4] = _w4[2];
1150 _t3[4] = _w4[3];
1151 _t4[4] = _w4[4];
1152 _t5[4] = _w4[5];
1153 _t0[5] = _w5[0];
1154 _t1[5] = _w5[1];
1155 _t2[5] = _w5[2];
1156 _t3[5] = _w5[3];
1157 _t4[5] = _w5[4];
1158 _t5[5] = _w5[5];
1159 }
1160 #endif
1161 // d = B_t * d_t
1162 _n0 = _mm256_mul_ps(_t0, _4_p);
1163 _n0 = _mm256_fmadd_ps(_t2, _5_n, _n0);
1164 _n0 = _mm256_add_ps(_n0, _t4);
1165
1166 _n1 = _mm256_mul_ps(_t1, _4_n);
1167 _n1 = _mm256_fmadd_ps(_t2, _4_n, _n1);
1168 _n1 = _mm256_add_ps(_n1, _t3);
1169 _n1 = _mm256_add_ps(_n1, _t4);
1170
1171 _n2 = _mm256_mul_ps(_t1, _4_p);
1172 _n2 = _mm256_fmadd_ps(_t2, _4_n, _n2);
1173 _n2 = _mm256_fmadd_ps(_t3, _1_n, _n2);
1174 _n2 = _mm256_add_ps(_n2, _t4);
1175
1176 _n3 = _mm256_mul_ps(_t1, _2_n);
1177 _n3 = _mm256_fmadd_ps(_t2, _1_n, _n3);
1178 _n3 = _mm256_fmadd_ps(_t3, _2_p, _n3);
1179 _n3 = _mm256_add_ps(_n3, _t4);
1180
1181 _n4 = _mm256_mul_ps(_t1, _2_p);
1182 _n4 = _mm256_fmadd_ps(_t2, _1_n, _n4);
1183 _n4 = _mm256_fmadd_ps(_t3, _2_n, _n4);
1184 _n4 = _mm256_add_ps(_n4, _t4);
1185
1186 _n5 = _mm256_mul_ps(_t1, _4_p);
1187 _n5 = _mm256_fmadd_ps(_t3, _5_n, _n5);
1188 _n5 = _mm256_add_ps(_n5, _t5);
1189 // save to out_tm
1190 float output_n0[8] = {0.f};
1191 _mm256_storeu_ps(output_n0, _n0);
1192 float output_n1[8] = {0.f};
1193 _mm256_storeu_ps(output_n1, _n1);
1194 float output_n2[8] = {0.f};
1195 _mm256_storeu_ps(output_n2, _n2);
1196 float output_n3[8] = {0.f};
1197 _mm256_storeu_ps(output_n3, _n3);
1198 float output_n4[8] = {0.f};
1199 _mm256_storeu_ps(output_n4, _n4);
1200 float output_n5[8] = {0.f};
1201 _mm256_storeu_ps(output_n5, _n5);
1202
1203 out_tm0[0] = output_n0[0];
1204 out_tm0[1] = output_n0[1];
1205 out_tm0[2] = output_n0[2];
1206 out_tm0[3] = output_n0[3];
1207 out_tm1[0] = output_n0[4];
1208 out_tm1[1] = output_n0[5];
1209 out_tm1[2] = output_n1[0];
1210 out_tm1[3] = output_n1[1];
1211 out_tm2[0] = output_n1[2];
1212 out_tm2[1] = output_n1[3];
1213 out_tm2[2] = output_n1[4];
1214 out_tm2[3] = output_n1[5];
1215
1216 out_tm3[0] = output_n2[0];
1217 out_tm3[1] = output_n2[1];
1218 out_tm3[2] = output_n2[2];
1219 out_tm3[3] = output_n2[3];
1220 out_tm4[0] = output_n2[4];
1221 out_tm4[1] = output_n2[5];
1222 out_tm4[2] = output_n3[0];
1223 out_tm4[3] = output_n3[1];
1224 out_tm5[0] = output_n3[2];
1225 out_tm5[1] = output_n3[3];
1226 out_tm5[2] = output_n3[4];
1227 out_tm5[3] = output_n3[5];
1228
1229 out_tm6[0] = output_n4[0];
1230 out_tm6[1] = output_n4[1];
1231 out_tm6[2] = output_n4[2];
1232 out_tm6[3] = output_n4[3];
1233 out_tm7[0] = output_n4[4];
1234 out_tm7[1] = output_n4[5];
1235 out_tm7[2] = output_n5[0];
1236 out_tm7[3] = output_n5[1];
1237 out_tm8[0] = output_n5[2];
1238 out_tm8[1] = output_n5[3];
1239 out_tm8[2] = output_n5[4];
1240 out_tm8[3] = output_n5[5];
1241 #else
1242 float d0[6], d1[6], d2[6], d3[6], d4[6], d5[6];
1243 float w0[6], w1[6], w2[6], w3[6], w4[6], w5[6];
1244 float t0[6], t1[6], t2[6], t3[6], t4[6], t5[6];
1245
1246 // load
1247 for (int n = 0; n < 6; n++)
1248 {
1249 d0[n] = r0[n];
1250 d1[n] = r1[n];
1251 d2[n] = r2[n];
1252 d3[n] = r3[n];
1253 d4[n] = r4[n];
1254 d5[n] = r5[n];
1255 }
1256 // w = B_t * d
1257 for (int n = 0; n < 6; n++)
1258 {
1259 w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n];
1260 w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n];
1261 w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n];
1262 w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n];
1263 w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n];
1264 w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n];
1265 }
1266 // transpose d to d_t
1267 {
1268 t0[0] = w0[0];
1269 t1[0] = w0[1];
1270 t2[0] = w0[2];
1271 t3[0] = w0[3];
1272 t4[0] = w0[4];
1273 t5[0] = w0[5];
1274 t0[1] = w1[0];
1275 t1[1] = w1[1];
1276 t2[1] = w1[2];
1277 t3[1] = w1[3];
1278 t4[1] = w1[4];
1279 t5[1] = w1[5];
1280 t0[2] = w2[0];
1281 t1[2] = w2[1];
1282 t2[2] = w2[2];
1283 t3[2] = w2[3];
1284 t4[2] = w2[4];
1285 t5[2] = w2[5];
1286 t0[3] = w3[0];
1287 t1[3] = w3[1];
1288 t2[3] = w3[2];
1289 t3[3] = w3[3];
1290 t4[3] = w3[4];
1291 t5[3] = w3[5];
1292 t0[4] = w4[0];
1293 t1[4] = w4[1];
1294 t2[4] = w4[2];
1295 t3[4] = w4[3];
1296 t4[4] = w4[4];
1297 t5[4] = w4[5];
1298 t0[5] = w5[0];
1299 t1[5] = w5[1];
1300 t2[5] = w5[2];
1301 t3[5] = w5[3];
1302 t4[5] = w5[4];
1303 t5[5] = w5[5];
1304 }
1305 // d = B_t * d_t
1306 for (int n = 0; n < 6; n++)
1307 {
1308 d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n];
1309 d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n];
1310 d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n];
1311 d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n];
1312 d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n];
1313 d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n];
1314 }
1315 // save to out_tm
1316 {
1317 out_tm0[0] = d0[0];
1318 out_tm0[1] = d0[1];
1319 out_tm0[2] = d0[2];
1320 out_tm0[3] = d0[3];
1321 out_tm1[0] = d0[4];
1322 out_tm1[1] = d0[5];
1323 out_tm1[2] = d1[0];
1324 out_tm1[3] = d1[1];
1325 out_tm2[0] = d1[2];
1326 out_tm2[1] = d1[3];
1327 out_tm2[2] = d1[4];
1328 out_tm2[3] = d1[5];
1329
1330 out_tm3[0] = d2[0];
1331 out_tm3[1] = d2[1];
1332 out_tm3[2] = d2[2];
1333 out_tm3[3] = d2[3];
1334 out_tm4[0] = d2[4];
1335 out_tm4[1] = d2[5];
1336 out_tm4[2] = d3[0];
1337 out_tm4[3] = d3[1];
1338 out_tm5[0] = d3[2];
1339 out_tm5[1] = d3[3];
1340 out_tm5[2] = d3[4];
1341 out_tm5[3] = d3[5];
1342
1343 out_tm6[0] = d4[0];
1344 out_tm6[1] = d4[1];
1345 out_tm6[2] = d4[2];
1346 out_tm6[3] = d4[3];
1347 out_tm7[0] = d4[4];
1348 out_tm7[1] = d4[5];
1349 out_tm7[2] = d5[0];
1350 out_tm7[3] = d5[1];
1351 out_tm8[0] = d5[2];
1352 out_tm8[1] = d5[3];
1353 out_tm8[2] = d5[4];
1354 out_tm8[3] = d5[5];
1355 }
1356 #endif // __AVX__
1357 r0 += 4;
1358 r1 += 4;
1359 r2 += 4;
1360 r3 += 4;
1361 r4 += 4;
1362 r5 += 4;
1363 }
1364 }
1365 }
1366 }
1367 bottom_blob_bordered = Mat();
1368
1369 // BEGIN dot
1370 Mat top_blob_tm;
1371 {
1372 int w_tm = outw / 4 * 6;
1373 int h_tm = outh / 4 * 6;
1374
1375 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1376 int nRowBlocks = w_tm / 6;
1377
1378 const int tiles = nColBlocks * nRowBlocks;
1379
1380 top_blob_tm.create(36, tiles, outch, elemsize, opt.workspace_allocator);
1381
1382 #pragma omp parallel for num_threads(opt.num_threads)
1383 for (int r = 0; r < 9; r++)
1384 {
1385 int nn_outch = 0;
1386 int remain_outch_start = 0;
1387
1388 nn_outch = outch >> 3;
1389 remain_outch_start = nn_outch << 3;
1390
1391 for (int pp = 0; pp < nn_outch; pp++)
1392 {
1393 int p = pp * 8;
1394
1395 float* output0_tm = top_blob_tm.channel(p);
1396 float* output1_tm = top_blob_tm.channel(p + 1);
1397 float* output2_tm = top_blob_tm.channel(p + 2);
1398 float* output3_tm = top_blob_tm.channel(p + 3);
1399 float* output4_tm = top_blob_tm.channel(p + 4);
1400 float* output5_tm = top_blob_tm.channel(p + 5);
1401 float* output6_tm = top_blob_tm.channel(p + 6);
1402 float* output7_tm = top_blob_tm.channel(p + 7);
1403
1404 output0_tm = output0_tm + r * 4;
1405 output1_tm = output1_tm + r * 4;
1406 output2_tm = output2_tm + r * 4;
1407 output3_tm = output3_tm + r * 4;
1408 output4_tm = output4_tm + r * 4;
1409 output5_tm = output5_tm + r * 4;
1410 output6_tm = output6_tm + r * 4;
1411 output7_tm = output7_tm + r * 4;
1412
1413 for (int i = 0; i < tiles; i++)
1414 {
1415 const float* kptr = kernel_tm_test[r].channel(p / 8);
1416 const float* r0 = bottom_blob_tm.channel(tiles * r + i);
1417 #if __AVX__ || __SSE__
1418 #if __AVX__
1419 float zero_val = 0.f;
1420 __m128 _sum0 = _mm_broadcast_ss(&zero_val);
1421 __m128 _sum1 = _mm_broadcast_ss(&zero_val);
1422 __m128 _sum2 = _mm_broadcast_ss(&zero_val);
1423 __m128 _sum3 = _mm_broadcast_ss(&zero_val);
1424 __m128 _sum4 = _mm_broadcast_ss(&zero_val);
1425 __m128 _sum5 = _mm_broadcast_ss(&zero_val);
1426 __m128 _sum6 = _mm_broadcast_ss(&zero_val);
1427 __m128 _sum7 = _mm_broadcast_ss(&zero_val);
1428 #else
1429 __m128 _sum0 = _mm_set1_ps(0.f);
1430 __m128 _sum1 = _mm_set1_ps(0.f);
1431 __m128 _sum2 = _mm_set1_ps(0.f);
1432 __m128 _sum3 = _mm_set1_ps(0.f);
1433 __m128 _sum4 = _mm_set1_ps(0.f);
1434 __m128 _sum5 = _mm_set1_ps(0.f);
1435 __m128 _sum6 = _mm_set1_ps(0.f);
1436 __m128 _sum7 = _mm_set1_ps(0.f);
1437 #endif
1438 int q = 0;
1439 for (; q + 3 < inch; q = q + 4)
1440 {
1441 __m128 _r0 = _mm_loadu_ps(r0);
1442 __m128 _r1 = _mm_loadu_ps(r0 + 4);
1443 __m128 _r2 = _mm_loadu_ps(r0 + 8);
1444 __m128 _r3 = _mm_loadu_ps(r0 + 12);
1445
1446 __m128 _k0 = _mm_loadu_ps(kptr);
1447 __m128 _k1 = _mm_loadu_ps(kptr + 4);
1448 __m128 _k2 = _mm_loadu_ps(kptr + 8);
1449 __m128 _k3 = _mm_loadu_ps(kptr + 12);
1450 __m128 _k4 = _mm_loadu_ps(kptr + 16);
1451 __m128 _k5 = _mm_loadu_ps(kptr + 20);
1452 __m128 _k6 = _mm_loadu_ps(kptr + 24);
1453 __m128 _k7 = _mm_loadu_ps(kptr + 28);
1454 #if __AVX__
1455 _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1456 _sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
1457 _sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
1458 _sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);
1459 _sum4 = _mm_fmadd_ps(_r0, _k4, _sum4);
1460 _sum5 = _mm_fmadd_ps(_r0, _k5, _sum5);
1461 _sum6 = _mm_fmadd_ps(_r0, _k6, _sum6);
1462 _sum7 = _mm_fmadd_ps(_r0, _k7, _sum7);
1463 #else
1464 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1465 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
1466 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
1467 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
1468 _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r0, _k4));
1469 _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r0, _k5));
1470 _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r0, _k6));
1471 _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r0, _k7));
1472 #endif
1473 kptr += 32;
1474 _k0 = _mm_loadu_ps(kptr);
1475 _k1 = _mm_loadu_ps(kptr + 4);
1476 _k2 = _mm_loadu_ps(kptr + 8);
1477 _k3 = _mm_loadu_ps(kptr + 12);
1478 _k4 = _mm_loadu_ps(kptr + 16);
1479 _k5 = _mm_loadu_ps(kptr + 20);
1480 _k6 = _mm_loadu_ps(kptr + 24);
1481 _k7 = _mm_loadu_ps(kptr + 28);
1482 #if __AVX__
1483 _sum0 = _mm_fmadd_ps(_r1, _k0, _sum0);
1484 _sum1 = _mm_fmadd_ps(_r1, _k1, _sum1);
1485 _sum2 = _mm_fmadd_ps(_r1, _k2, _sum2);
1486 _sum3 = _mm_fmadd_ps(_r1, _k3, _sum3);
1487 _sum4 = _mm_fmadd_ps(_r1, _k4, _sum4);
1488 _sum5 = _mm_fmadd_ps(_r1, _k5, _sum5);
1489 _sum6 = _mm_fmadd_ps(_r1, _k6, _sum6);
1490 _sum7 = _mm_fmadd_ps(_r1, _k7, _sum7);
1491 #else
1492 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r1, _k0));
1493 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r1, _k1));
1494 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r1, _k2));
1495 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r1, _k3));
1496 _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r1, _k4));
1497 _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r1, _k5));
1498 _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r1, _k6));
1499 _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r1, _k7));
1500 #endif
1501
1502 kptr += 32;
1503 _k0 = _mm_loadu_ps(kptr);
1504 _k1 = _mm_loadu_ps(kptr + 4);
1505 _k2 = _mm_loadu_ps(kptr + 8);
1506 _k3 = _mm_loadu_ps(kptr + 12);
1507 _k4 = _mm_loadu_ps(kptr + 16);
1508 _k5 = _mm_loadu_ps(kptr + 20);
1509 _k6 = _mm_loadu_ps(kptr + 24);
1510 _k7 = _mm_loadu_ps(kptr + 28);
1511 #if __AVX__
1512 _sum0 = _mm_fmadd_ps(_r2, _k0, _sum0);
1513 _sum1 = _mm_fmadd_ps(_r2, _k1, _sum1);
1514 _sum2 = _mm_fmadd_ps(_r2, _k2, _sum2);
1515 _sum3 = _mm_fmadd_ps(_r2, _k3, _sum3);
1516 _sum4 = _mm_fmadd_ps(_r2, _k4, _sum4);
1517 _sum5 = _mm_fmadd_ps(_r2, _k5, _sum5);
1518 _sum6 = _mm_fmadd_ps(_r2, _k6, _sum6);
1519 _sum7 = _mm_fmadd_ps(_r2, _k7, _sum7);
1520 #else
1521 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r2, _k0));
1522 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r2, _k1));
1523 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r2, _k2));
1524 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r2, _k3));
1525 _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r2, _k4));
1526 _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r2, _k5));
1527 _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r2, _k6));
1528 _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r2, _k7));
1529 #endif
1530 kptr += 32;
1531 _k0 = _mm_loadu_ps(kptr);
1532 _k1 = _mm_loadu_ps(kptr + 4);
1533 _k2 = _mm_loadu_ps(kptr + 8);
1534 _k3 = _mm_loadu_ps(kptr + 12);
1535 _k4 = _mm_loadu_ps(kptr + 16);
1536 _k5 = _mm_loadu_ps(kptr + 20);
1537 _k6 = _mm_loadu_ps(kptr + 24);
1538 _k7 = _mm_loadu_ps(kptr + 28);
1539 #if __AVX__
1540 _sum0 = _mm_fmadd_ps(_r3, _k0, _sum0);
1541 _sum1 = _mm_fmadd_ps(_r3, _k1, _sum1);
1542 _sum2 = _mm_fmadd_ps(_r3, _k2, _sum2);
1543 _sum3 = _mm_fmadd_ps(_r3, _k3, _sum3);
1544 _sum4 = _mm_fmadd_ps(_r3, _k4, _sum4);
1545 _sum5 = _mm_fmadd_ps(_r3, _k5, _sum5);
1546 _sum6 = _mm_fmadd_ps(_r3, _k6, _sum6);
1547 _sum7 = _mm_fmadd_ps(_r3, _k7, _sum7);
1548 #else
1549 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r3, _k0));
1550 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r3, _k1));
1551 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r3, _k2));
1552 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r3, _k3));
1553 _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r3, _k4));
1554 _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r3, _k5));
1555 _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r3, _k6));
1556 _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r3, _k7));
1557 #endif
1558 kptr += 32;
1559 r0 += 16;
1560 }
1561
1562 for (; q < inch; q++)
1563 {
1564 __m128 _r0 = _mm_loadu_ps(r0);
1565 __m128 _k0 = _mm_loadu_ps(kptr);
1566 __m128 _k1 = _mm_loadu_ps(kptr + 4);
1567 __m128 _k2 = _mm_loadu_ps(kptr + 8);
1568 __m128 _k3 = _mm_loadu_ps(kptr + 12);
1569 __m128 _k4 = _mm_loadu_ps(kptr + 16);
1570 __m128 _k5 = _mm_loadu_ps(kptr + 20);
1571 __m128 _k6 = _mm_loadu_ps(kptr + 24);
1572 __m128 _k7 = _mm_loadu_ps(kptr + 28);
1573
1574 #if __AVX__
1575 _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1576 _sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
1577 _sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
1578 _sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);
1579 _sum4 = _mm_fmadd_ps(_r0, _k4, _sum4);
1580 _sum5 = _mm_fmadd_ps(_r0, _k5, _sum5);
1581 _sum6 = _mm_fmadd_ps(_r0, _k6, _sum6);
1582 _sum7 = _mm_fmadd_ps(_r0, _k7, _sum7);
1583 #else
1584 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1585 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
1586 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
1587 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
1588 _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r0, _k4));
1589 _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r0, _k5));
1590 _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r0, _k6));
1591 _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r0, _k7));
1592 #endif
1593
1594 kptr += 32;
1595 r0 += 4;
1596 }
1597
1598 _mm_storeu_ps(output0_tm, _sum0);
1599 _mm_storeu_ps(output1_tm, _sum1);
1600 _mm_storeu_ps(output2_tm, _sum2);
1601 _mm_storeu_ps(output3_tm, _sum3);
1602 _mm_storeu_ps(output4_tm, _sum4);
1603 _mm_storeu_ps(output5_tm, _sum5);
1604 _mm_storeu_ps(output6_tm, _sum6);
1605 _mm_storeu_ps(output7_tm, _sum7);
1606 #else
1607 float sum0[4] = {0};
1608 float sum1[4] = {0};
1609 float sum2[4] = {0};
1610 float sum3[4] = {0};
1611 float sum4[4] = {0};
1612 float sum5[4] = {0};
1613 float sum6[4] = {0};
1614 float sum7[4] = {0};
1615
1616 for (int q = 0; q < inch; q++)
1617 {
1618 for (int n = 0; n < 4; n++)
1619 {
1620 sum0[n] += r0[n] * kptr[n];
1621 sum1[n] += r0[n] * kptr[n + 4];
1622 sum2[n] += r0[n] * kptr[n + 8];
1623 sum3[n] += r0[n] * kptr[n + 12];
1624 sum4[n] += r0[n] * kptr[n + 16];
1625 sum5[n] += r0[n] * kptr[n + 20];
1626 sum6[n] += r0[n] * kptr[n + 24];
1627 sum7[n] += r0[n] * kptr[n + 28];
1628 }
1629 kptr += 32;
1630 r0 += 4;
1631 }
1632
1633 for (int n = 0; n < 4; n++)
1634 {
1635 output0_tm[n] = sum0[n];
1636 output1_tm[n] = sum1[n];
1637 output2_tm[n] = sum2[n];
1638 output3_tm[n] = sum3[n];
1639 output4_tm[n] = sum4[n];
1640 output5_tm[n] = sum5[n];
1641 output6_tm[n] = sum6[n];
1642 output7_tm[n] = sum7[n];
1643 }
1644 #endif // __AVX__
1645 output0_tm += 36;
1646 output1_tm += 36;
1647 output2_tm += 36;
1648 output3_tm += 36;
1649 output4_tm += 36;
1650 output5_tm += 36;
1651 output6_tm += 36;
1652 output7_tm += 36;
1653 }
1654 }
1655
1656 nn_outch = (outch - remain_outch_start) >> 2;
1657
1658 for (int pp = 0; pp < nn_outch; pp++)
1659 {
1660 int p = remain_outch_start + pp * 4;
1661
1662 float* output0_tm = top_blob_tm.channel(p);
1663 float* output1_tm = top_blob_tm.channel(p + 1);
1664 float* output2_tm = top_blob_tm.channel(p + 2);
1665 float* output3_tm = top_blob_tm.channel(p + 3);
1666
1667 output0_tm = output0_tm + r * 4;
1668 output1_tm = output1_tm + r * 4;
1669 output2_tm = output2_tm + r * 4;
1670 output3_tm = output3_tm + r * 4;
1671
1672 for (int i = 0; i < tiles; i++)
1673 {
1674 const float* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
1675 const float* r0 = bottom_blob_tm.channel(tiles * r + i);
1676 #if __AVX__ || __SSE__
1677 #if __AVX__
1678 float zero_val = 0.f;
1679 __m128 _sum0 = _mm_broadcast_ss(&zero_val);
1680 __m128 _sum1 = _mm_broadcast_ss(&zero_val);
1681 __m128 _sum2 = _mm_broadcast_ss(&zero_val);
1682 __m128 _sum3 = _mm_broadcast_ss(&zero_val);
1683 #else
1684 __m128 _sum0 = _mm_set1_ps(0.f);
1685 __m128 _sum1 = _mm_set1_ps(0.f);
1686 __m128 _sum2 = _mm_set1_ps(0.f);
1687 __m128 _sum3 = _mm_set1_ps(0.f);
1688 #endif
1689 for (int q = 0; q < inch; q++)
1690 {
1691 __m128 _r0 = _mm_loadu_ps(r0);
1692 __m128 _k0 = _mm_loadu_ps(kptr);
1693 __m128 _k1 = _mm_loadu_ps(kptr + 4);
1694 __m128 _k2 = _mm_loadu_ps(kptr + 8);
1695 __m128 _k3 = _mm_loadu_ps(kptr + 12);
1696 #if __AVX__
1697 _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1698 _sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
1699 _sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
1700 _sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);
1701 #else
1702 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1703 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
1704 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
1705 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
1706 #endif
1707 kptr += 16;
1708 r0 += 4;
1709 }
1710
1711 _mm_storeu_ps(output0_tm, _sum0);
1712 _mm_storeu_ps(output1_tm, _sum1);
1713 _mm_storeu_ps(output2_tm, _sum2);
1714 _mm_storeu_ps(output3_tm, _sum3);
1715 #else
1716 float sum0[4] = {0};
1717 float sum1[4] = {0};
1718 float sum2[4] = {0};
1719 float sum3[4] = {0};
1720
1721 for (int q = 0; q < inch; q++)
1722 {
1723 for (int n = 0; n < 4; n++)
1724 {
1725 sum0[n] += r0[n] * kptr[n];
1726 sum1[n] += r0[n] * kptr[n + 4];
1727 sum2[n] += r0[n] * kptr[n + 8];
1728 sum3[n] += r0[n] * kptr[n + 12];
1729 }
1730 kptr += 16;
1731 r0 += 4;
1732 }
1733
1734 for (int n = 0; n < 4; n++)
1735 {
1736 output0_tm[n] = sum0[n];
1737 output1_tm[n] = sum1[n];
1738 output2_tm[n] = sum2[n];
1739 output3_tm[n] = sum3[n];
1740 }
1741 #endif // __AVX__
1742 output0_tm += 36;
1743 output1_tm += 36;
1744 output2_tm += 36;
1745 output3_tm += 36;
1746 }
1747 }
1748
1749 remain_outch_start += nn_outch << 2;
1750
1751 for (int p = remain_outch_start; p < outch; p++)
1752 {
1753 float* output0_tm = top_blob_tm.channel(p);
1754
1755 output0_tm = output0_tm + r * 4;
1756
1757 for (int i = 0; i < tiles; i++)
1758 {
1759 const float* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
1760 const float* r0 = bottom_blob_tm.channel(tiles * r + i);
1761 #if __AVX__ || __SSE__
1762 #if __AVX__
1763 float zero_val = 0.f;
1764 __m128 _sum0 = _mm_broadcast_ss(&zero_val);
1765 #else
1766 __m128 _sum0 = _mm_set1_ps(0.f);
1767 #endif
1768
1769 for (int q = 0; q < inch; q++)
1770 {
1771 __m128 _r0 = _mm_loadu_ps(r0);
1772 __m128 _k0 = _mm_loadu_ps(kptr);
1773 #if __AVX__
1774 _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1775 #else
1776 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1777 #endif
1778 kptr += 16;
1779 r0 += 4;
1780 }
1781 _mm_storeu_ps(output0_tm, _sum0);
1782 #else
1783 float sum0[4] = {0};
1784
1785 for (int q = 0; q < inch; q++)
1786 {
1787 for (int n = 0; n < 4; n++)
1788 {
1789 sum0[n] += (int)r0[n] * kptr[n];
1790 }
1791 kptr += 4;
1792 r0 += 4;
1793 }
1794
1795 for (int n = 0; n < 4; n++)
1796 {
1797 output0_tm[n] = sum0[n];
1798 }
1799 #endif // __AVX__ || __SSE__
1800 output0_tm += 36;
1801 }
1802 }
1803
1804 // for (int p=0; p<outch; p++)
1805 // {
1806 // Mat out0_tm = top_blob_tm.channel(p);
1807 // const Mat kernel0_tm = kernel_tm.channel(p);
1808
1809 // for (int i=0; i<tiles; i++)
1810 // {
1811 // float* output0_tm = out0_tm.row<int>(i);
1812
1813 // int sum0[36] = {0};
1814
1815 // for (int q=0; q<inch; q++)
1816 // {
1817 // const float* r0 = bottom_blob_tm.channel(q).row<float>(i);
1818 // const float* k0 = kernel0_tm.row<float>(q);
1819
1820 // for (int n=0; n<36; n++)
1821 // {
1822 // sum0[n] += (int)r0[n] * k0[n];
1823 // }
1824 // }
1825
1826 // for (int n=0; n<36; n++)
1827 // {
1828 // output0_tm[n] = sum0[n];
1829 // }
1830 // }
1831 // }
1832 }
1833 }
1834 bottom_blob_tm = Mat();
1835 // END dot
1836
1837 // BEGIN transform output
1838 Mat top_blob_bordered;
1839 if (outw == top_blob.w && outh == top_blob.h)
1840 {
1841 top_blob_bordered = top_blob;
1842 }
1843 else
1844 {
1845 top_blob_bordered.create(outw, outh, outch, elemsize, opt.workspace_allocator);
1846 }
1847 {
1848 // AT
1849 // const float itm[4][6] = {
1850 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
1851 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
1852 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 0.0f},
1853 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
1854 // };
1855
1856 // 0 = r00 + r01 + r02 + r03 + r04
1857 // 1 = r01 - r02 + 2 * (r03 - r04)
1858 // 2 = r01 + r02 + 4 * (r03 + r04)
1859 // 3 = r01 - r02 + 8 * (r03 - r04) + r05
1860
1861 int w_tm = outw / 4 * 6;
1862 int h_tm = outh / 4 * 6;
1863
1864 int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1865 int nRowBlocks = w_tm / 6;
1866
1867 #pragma omp parallel for num_threads(opt.num_threads)
1868 for (int p = 0; p < outch; p++)
1869 {
1870 float* out_tile = top_blob_tm.channel(p);
1871 float* outRow0 = top_blob_bordered.channel(p);
1872 float* outRow1 = outRow0 + outw;
1873 float* outRow2 = outRow0 + outw * 2;
1874 float* outRow3 = outRow0 + outw * 3;
1875
1876 const float bias0 = bias ? bias[p] : 0.f;
1877
1878 for (int j = 0; j < nColBlocks; j++)
1879 {
1880 for (int i = 0; i < nRowBlocks; i++)
1881 {
1882 // TODO AVX2
1883 float s0[6], s1[6], s2[6], s3[6], s4[6], s5[6];
1884 float w0[6], w1[6], w2[6], w3[6];
1885 float d0[4], d1[4], d2[4], d3[4], d4[4], d5[4];
1886 float o0[4], o1[4], o2[4], o3[4];
1887
1888 // load
1889 for (int n = 0; n < 6; n++)
1890 {
1891 s0[n] = out_tile[n];
1892 s1[n] = out_tile[n + 6];
1893 s2[n] = out_tile[n + 12];
1894 s3[n] = out_tile[n + 18];
1895 s4[n] = out_tile[n + 24];
1896 s5[n] = out_tile[n + 30];
1897 }
1898 // w = A_T * W
1899 for (int n = 0; n < 6; n++)
1900 {
1901 w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n];
1902 w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n];
1903 w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n];
1904 w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + s5[n];
1905 }
1906 // transpose w to w_t
1907 {
1908 d0[0] = w0[0];
1909 d0[1] = w1[0];
1910 d0[2] = w2[0];
1911 d0[3] = w3[0];
1912 d1[0] = w0[1];
1913 d1[1] = w1[1];
1914 d1[2] = w2[1];
1915 d1[3] = w3[1];
1916 d2[0] = w0[2];
1917 d2[1] = w1[2];
1918 d2[2] = w2[2];
1919 d2[3] = w3[2];
1920 d3[0] = w0[3];
1921 d3[1] = w1[3];
1922 d3[2] = w2[3];
1923 d3[3] = w3[3];
1924 d4[0] = w0[4];
1925 d4[1] = w1[4];
1926 d4[2] = w2[4];
1927 d4[3] = w3[4];
1928 d5[0] = w0[5];
1929 d5[1] = w1[5];
1930 d5[2] = w2[5];
1931 d5[3] = w3[5];
1932 }
1933 // Y = A_T * w_t
1934 for (int n = 0; n < 4; n++)
1935 {
1936 o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n];
1937 o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n];
1938 o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n];
1939 o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n];
1940 }
1941 // save to top blob tm
1942 for (int n = 0; n < 4; n++)
1943 {
1944 outRow0[n] = o0[n] + bias0;
1945 outRow1[n] = o1[n] + bias0;
1946 outRow2[n] = o2[n] + bias0;
1947 outRow3[n] = o3[n] + bias0;
1948 }
1949
1950 out_tile += 36;
1951
1952 outRow0 += 4;
1953 outRow1 += 4;
1954 outRow2 += 4;
1955 outRow3 += 4;
1956 }
1957
1958 outRow0 += outw * 3;
1959 outRow1 += outw * 3;
1960 outRow2 += outw * 3;
1961 outRow3 += outw * 3;
1962 }
1963 }
1964 }
1965 // END transform output
1966
1967 // cut result pad
1968 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1969 }
1970
conv3x3s2_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)1971 static void conv3x3s2_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
1972 {
1973 int w = bottom_blob.w;
1974 int inch = bottom_blob.c;
1975
1976 int outw = top_blob.w;
1977 int outh = top_blob.h;
1978 int outch = top_blob.c;
1979
1980 const int tailstep = w - 2 * outw + w;
1981
1982 const float* kernel = _kernel;
1983 const float* bias = _bias;
1984
1985 #pragma omp parallel for num_threads(opt.num_threads)
1986 for (int p = 0; p < outch; p++)
1987 {
1988 Mat out = top_blob.channel(p);
1989
1990 const float bias0 = bias ? bias[p] : 0.f;
1991
1992 out.fill(bias0);
1993
1994 for (int q = 0; q < inch; q++)
1995 {
1996 float* outptr = out;
1997
1998 const float* img = bottom_blob.channel(q);
1999 const float* kernel0 = kernel + p * inch * 9 + q * 9;
2000
2001 const float* r0 = img;
2002 const float* r1 = img + w;
2003 const float* r2 = img + w * 2;
2004
2005 const float* k0 = kernel0;
2006 const float* k1 = kernel0 + 3;
2007 const float* k2 = kernel0 + 6;
2008
2009 for (int i = 0; i < outh; i++)
2010 {
2011 int remain = outw;
2012
2013 for (; remain > 0; remain--)
2014 {
2015 float sum = 0;
2016
2017 sum += r0[0] * k0[0];
2018 sum += r0[1] * k0[1];
2019 sum += r0[2] * k0[2];
2020 sum += r1[0] * k1[0];
2021 sum += r1[1] * k1[1];
2022 sum += r1[2] * k1[2];
2023 sum += r2[0] * k2[0];
2024 sum += r2[1] * k2[1];
2025 sum += r2[2] * k2[2];
2026
2027 *outptr += sum;
2028
2029 r0 += 2;
2030 r1 += 2;
2031 r2 += 2;
2032 outptr++;
2033 }
2034
2035 r0 += tailstep;
2036 r1 += tailstep;
2037 r2 += tailstep;
2038 }
2039 }
2040 }
2041 }
2042