1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 // Copyright (C) 2019 BUG1989. All rights reserved.
5 //
6 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
7 // in compliance with the License. You may obtain a copy of the License at
8 //
9 // https://opensource.org/licenses/BSD-3-Clause
10 //
11 // Unless required by applicable law or agreed to in writing, software distributed
12 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
13 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
14 // specific language governing permissions and limitations under the License.
15 
conv3x3s1_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)16 static void conv3x3s1_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
17 {
18     int w = bottom_blob.w;
19     int inch = bottom_blob.c;
20 
21     int outw = top_blob.w;
22     int outh = top_blob.h;
23     int outch = top_blob.c;
24 
25     const float* kernel = _kernel;
26     const float* bias = _bias;
27 
28     #pragma omp parallel for num_threads(opt.num_threads)
29     for (int p = 0; p < outch; p++)
30     {
31         Mat out = top_blob.channel(p);
32 
33         const float bias0 = bias ? bias[p] : 0.f;
34 
35         out.fill(bias0);
36 
37         for (int q = 0; q < inch; q++)
38         {
39             float* outptr = out;
40             float* outptr2 = outptr + outw;
41 
42             const float* img0 = bottom_blob.channel(q);
43 
44             const float* kernel0 = kernel + p * inch * 9 + q * 9;
45 
46             const float* r0 = img0;
47             const float* r1 = img0 + w;
48             const float* r2 = img0 + w * 2;
49             const float* r3 = img0 + w * 3;
50 
51             const float* k0 = kernel0;
52             const float* k1 = kernel0 + 3;
53             const float* k2 = kernel0 + 6;
54 
55             int i = 0;
56 
57             for (; i + 1 < outh; i += 2)
58             {
59                 int remain = outw;
60 
61                 for (; remain > 0; remain--)
62                 {
63                     float sum = 0;
64                     float sum2 = 0;
65 
66                     sum += r0[0] * k0[0];
67                     sum += r0[1] * k0[1];
68                     sum += r0[2] * k0[2];
69                     sum += r1[0] * k1[0];
70                     sum += r1[1] * k1[1];
71                     sum += r1[2] * k1[2];
72                     sum += r2[0] * k2[0];
73                     sum += r2[1] * k2[1];
74                     sum += r2[2] * k2[2];
75 
76                     sum2 += r1[0] * k0[0];
77                     sum2 += r1[1] * k0[1];
78                     sum2 += r1[2] * k0[2];
79                     sum2 += r2[0] * k1[0];
80                     sum2 += r2[1] * k1[1];
81                     sum2 += r2[2] * k1[2];
82                     sum2 += r3[0] * k2[0];
83                     sum2 += r3[1] * k2[1];
84                     sum2 += r3[2] * k2[2];
85 
86                     *outptr += sum;
87                     *outptr2 += sum2;
88 
89                     r0++;
90                     r1++;
91                     r2++;
92                     r3++;
93                     outptr++;
94                     outptr2++;
95                 }
96 
97                 r0 += 2 + w;
98                 r1 += 2 + w;
99                 r2 += 2 + w;
100                 r3 += 2 + w;
101 
102                 outptr += outw;
103                 outptr2 += outw;
104             }
105 
106             for (; i < outh; i++)
107             {
108                 int remain = outw;
109 
110                 for (; remain > 0; remain--)
111                 {
112                     float sum = 0;
113 
114                     sum += r0[0] * k0[0];
115                     sum += r0[1] * k0[1];
116                     sum += r0[2] * k0[2];
117                     sum += r1[0] * k1[0];
118                     sum += r1[1] * k1[1];
119                     sum += r1[2] * k1[2];
120                     sum += r2[0] * k2[0];
121                     sum += r2[1] * k2[1];
122                     sum += r2[2] * k2[2];
123 
124                     *outptr += sum;
125 
126                     r0++;
127                     r1++;
128                     r2++;
129                     outptr++;
130                 }
131 
132                 r0 += 2;
133                 r1 += 2;
134                 r2 += 2;
135             }
136         }
137     }
138 }
139 
conv3x3s1_winograd23_transform_kernel_sse(const Mat & kernel,Mat & kernel_tm,int inch,int outch)140 static void conv3x3s1_winograd23_transform_kernel_sse(const Mat& kernel, Mat& kernel_tm, int inch, int outch)
141 {
142     kernel_tm.create(4 * 4, inch, outch);
143 
144     // G
145     const float ktm[4][3] = {
146         {1.0f, 0.0f, 0.0f},
147         {1.0f / 2, 1.0f / 2, 1.0f / 2},
148         {1.0f / 2, -1.0f / 2, 1.0f / 2},
149         {0.0f, 0.0f, 1.0f}
150     };
151 
152     #pragma omp parallel for
153     for (int p = 0; p < outch; p++)
154     {
155         for (int q = 0; q < inch; q++)
156         {
157             const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
158             float* kernel_tm0 = kernel_tm.channel(p).row(q);
159 
160             // transform kernel
161             const float* k0 = kernel0;
162             const float* k1 = kernel0 + 3;
163             const float* k2 = kernel0 + 6;
164 
165             // h
166             float tmp[4][3];
167             for (int i = 0; i < 4; i++)
168             {
169                 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
170                 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
171                 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
172             }
173 
174             // U
175             for (int j = 0; j < 4; j++)
176             {
177                 float* tmpp = &tmp[j][0];
178 
179                 for (int i = 0; i < 4; i++)
180                 {
181                     kernel_tm0[j * 4 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
182                 }
183             }
184         }
185     }
186 }
187 
conv3x3s1_winograd23_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)188 static void conv3x3s1_winograd23_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
189 {
190     int w = bottom_blob.w;
191     int h = bottom_blob.h;
192     int inch = bottom_blob.c;
193 
194     int outw = top_blob.w;
195     int outh = top_blob.h;
196     int outch = top_blob.c;
197 
198     // pad to 2n+2, winograd F(2,3)
199     Mat bottom_blob_bordered = bottom_blob;
200 
201     outw = (outw + 1) / 2 * 2;
202     outh = (outh + 1) / 2 * 2;
203 
204     w = outw + 2;
205     h = outh + 2;
206     Option opt_b = opt;
207     opt_b.blob_allocator = opt.workspace_allocator;
208     copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
209 
210     const float* bias = _bias;
211 
212     // BEGIN transform input
213     Mat bottom_blob_tm;
214     {
215         int w_tm = outw / 2 * 4;
216         int h_tm = outh / 2 * 4;
217 
218         int nColBlocks = h_tm / 4; // may be the block num in Feathercnn
219         int nRowBlocks = w_tm / 4;
220 
221         const int tiles = nColBlocks * nRowBlocks;
222 
223         bottom_blob_tm.create(4 * 4, tiles, inch, 4u, opt.workspace_allocator);
224 
225         // BT
226         // const float itm[4][4] = {
227         //     {1.0f,  0.0f, -1.0f,  0.0f},
228         //     {0.0f,  1.0f,  1.00f, 0.0f},
229         //     {0.0f, -1.0f,  1.00f, 0.0f},
230         //     {0.0f, -1.0f,  0.00f, 1.0f}
231         // };
232         #pragma omp parallel for num_threads(opt.num_threads)
233         for (int q = 0; q < inch; q++)
234         {
235             const float* img = bottom_blob_bordered.channel(q);
236             float* out_tm0 = bottom_blob_tm.channel(q);
237 
238             for (int j = 0; j < nColBlocks; j++)
239             {
240                 const float* r0 = img + w * j * 2;
241                 const float* r1 = r0 + w;
242                 const float* r2 = r1 + w;
243                 const float* r3 = r2 + w;
244 
245                 for (int i = 0; i < nRowBlocks; i++)
246                 {
247 #if __AVX__
248                     __m128 _d0, _d1, _d2, _d3;
249                     __m128 _w0, _w1, _w2, _w3;
250 
251                     // load
252                     _d0 = _mm_loadu_ps(r0);
253                     _d1 = _mm_loadu_ps(r1);
254                     _d2 = _mm_loadu_ps(r2);
255                     _d3 = _mm_loadu_ps(r3);
256 
257                     // w = B_t * d
258                     _w0 = _mm_sub_ps(_d0, _d2);
259                     _w1 = _mm_add_ps(_d1, _d2);
260                     _w2 = _mm_sub_ps(_d2, _d1);
261                     _w3 = _mm_sub_ps(_d3, _d1);
262 
263                     // transpose d to d_t
264                     _MM_TRANSPOSE4_PS(_w0, _w1, _w2, _w3);
265 
266                     // d = B_t * d_t
267                     _d0 = _mm_sub_ps(_w0, _w2);
268                     _d1 = _mm_add_ps(_w1, _w2);
269                     _d2 = _mm_sub_ps(_w2, _w1);
270                     _d3 = _mm_sub_ps(_w3, _w1);
271 
272                     // save to out_tm
273                     _mm_storeu_ps(out_tm0, _d0);
274                     _mm_storeu_ps(out_tm0 + 4, _d1);
275                     _mm_storeu_ps(out_tm0 + 8, _d2);
276                     _mm_storeu_ps(out_tm0 + 12, _d3);
277 #else
278                     float d0[4], d1[4], d2[4], d3[4];
279                     float w0[4], w1[4], w2[4], w3[4];
280                     float t0[4], t1[4], t2[4], t3[4];
281                     // load
282                     for (int n = 0; n < 4; n++)
283                     {
284                         d0[n] = r0[n];
285                         d1[n] = r1[n];
286                         d2[n] = r2[n];
287                         d3[n] = r3[n];
288                     }
289                     // w = B_t * d
290                     for (int n = 0; n < 4; n++)
291                     {
292                         w0[n] = d0[n] - d2[n];
293                         w1[n] = d1[n] + d2[n];
294                         w2[n] = d2[n] - d1[n];
295                         w3[n] = d3[n] - d1[n];
296                     }
297                     // transpose d to d_t
298                     {
299                         t0[0] = w0[0];
300                         t1[0] = w0[1];
301                         t2[0] = w0[2];
302                         t3[0] = w0[3];
303                         t0[1] = w1[0];
304                         t1[1] = w1[1];
305                         t2[1] = w1[2];
306                         t3[1] = w1[3];
307                         t0[2] = w2[0];
308                         t1[2] = w2[1];
309                         t2[2] = w2[2];
310                         t3[2] = w2[3];
311                         t0[3] = w3[0];
312                         t1[3] = w3[1];
313                         t2[3] = w3[2];
314                         t3[3] = w3[3];
315                     }
316                     // d = B_t * d_t
317                     for (int n = 0; n < 4; n++)
318                     {
319                         d0[n] = t0[n] - t2[n];
320                         d1[n] = t1[n] + t2[n];
321                         d2[n] = t2[n] - t1[n];
322                         d3[n] = t3[n] - t1[n];
323                     }
324                     // save to out_tm
325                     for (int n = 0; n < 4; n++)
326                     {
327                         out_tm0[n] = d0[n];
328                         out_tm0[n + 4] = d1[n];
329                         out_tm0[n + 8] = d2[n];
330                         out_tm0[n + 12] = d3[n];
331                     }
332 #endif
333                     r0 += 2;
334                     r1 += 2;
335                     r2 += 2;
336                     r3 += 2;
337 
338                     out_tm0 += 16;
339                 }
340             }
341         }
342     }
343     bottom_blob_bordered = Mat();
344 
345     // BEGIN dot
346     Mat top_blob_tm;
347     {
348         int w_tm = outw / 2 * 4;
349         int h_tm = outh / 2 * 4;
350 
351         int nColBlocks = h_tm / 4; // may be the block num in Feathercnn
352         int nRowBlocks = w_tm / 4;
353 
354         const int tiles = nColBlocks * nRowBlocks;
355 
356         top_blob_tm.create(16, tiles, outch, 4u, opt.workspace_allocator);
357 
358         int nn_outch = outch >> 2;
359         int remain_outch_start = nn_outch << 2;
360 
361         #pragma omp parallel for num_threads(opt.num_threads)
362         for (int pp = 0; pp < nn_outch; pp++)
363         {
364             int p = pp * 4;
365 
366             Mat out0_tm = top_blob_tm.channel(p);
367             Mat out1_tm = top_blob_tm.channel(p + 1);
368             Mat out2_tm = top_blob_tm.channel(p + 2);
369             Mat out3_tm = top_blob_tm.channel(p + 3);
370 
371             const Mat kernel0_tm = kernel_tm.channel(p);
372             const Mat kernel1_tm = kernel_tm.channel(p + 1);
373             const Mat kernel2_tm = kernel_tm.channel(p + 2);
374             const Mat kernel3_tm = kernel_tm.channel(p + 3);
375 
376             for (int i = 0; i < tiles; i++)
377             {
378                 float* output0_tm = out0_tm.row(i);
379                 float* output1_tm = out1_tm.row(i);
380                 float* output2_tm = out2_tm.row(i);
381                 float* output3_tm = out3_tm.row(i);
382 
383 #if __AVX__
384                 float zero_val = 0.f;
385 
386                 __m256 _sum0 = _mm256_broadcast_ss(&zero_val);
387                 __m256 _sum0n = _mm256_broadcast_ss(&zero_val);
388                 __m256 _sum1 = _mm256_broadcast_ss(&zero_val);
389                 __m256 _sum1n = _mm256_broadcast_ss(&zero_val);
390                 __m256 _sum2 = _mm256_broadcast_ss(&zero_val);
391                 __m256 _sum2n = _mm256_broadcast_ss(&zero_val);
392                 __m256 _sum3 = _mm256_broadcast_ss(&zero_val);
393                 __m256 _sum3n = _mm256_broadcast_ss(&zero_val);
394 
395                 int q = 0;
396 
397                 for (; q + 3 < inch; q += 4)
398                 {
399                     const float* r0 = bottom_blob_tm.channel(q).row(i);
400                     const float* r1 = bottom_blob_tm.channel(q + 1).row(i);
401                     const float* r2 = bottom_blob_tm.channel(q + 2).row(i);
402                     const float* r3 = bottom_blob_tm.channel(q + 3).row(i);
403 
404                     const float* k0 = kernel0_tm.row(q);
405                     const float* k1 = kernel1_tm.row(q);
406                     const float* k2 = kernel2_tm.row(q);
407                     const float* k3 = kernel3_tm.row(q);
408 
409                     __m256 _r0 = _mm256_loadu_ps(r0);
410                     __m256 _r0n = _mm256_loadu_ps(r0 + 8);
411                     // k0
412                     __m256 _k0 = _mm256_loadu_ps(k0);
413                     __m256 _k0n = _mm256_loadu_ps(k0 + 8);
414                     __m256 _k1 = _mm256_loadu_ps(k1);
415                     __m256 _k1n = _mm256_loadu_ps(k1 + 8);
416                     __m256 _k2 = _mm256_loadu_ps(k2);
417                     __m256 _k2n = _mm256_loadu_ps(k2 + 8);
418                     __m256 _k3 = _mm256_loadu_ps(k3);
419                     __m256 _k3n = _mm256_loadu_ps(k3 + 8);
420                     _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
421                     _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
422                     _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
423                     _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
424                     _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
425                     _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
426                     _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
427                     _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
428 
429                     // k1
430                     _r0 = _mm256_loadu_ps(r1);
431                     _r0n = _mm256_loadu_ps(r1 + 8);
432                     _k0 = _mm256_loadu_ps(k0 + 16);
433                     _k0n = _mm256_loadu_ps(k0 + 24);
434                     _k1 = _mm256_loadu_ps(k1 + 16);
435                     _k1n = _mm256_loadu_ps(k1 + 24);
436                     _k2 = _mm256_loadu_ps(k2 + 16);
437                     _k2n = _mm256_loadu_ps(k2 + 24);
438                     _k3 = _mm256_loadu_ps(k3 + 16);
439                     _k3n = _mm256_loadu_ps(k3 + 24);
440                     _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
441                     _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
442                     _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
443                     _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
444                     _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
445                     _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
446                     _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
447                     _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
448                     // k2
449                     _r0 = _mm256_loadu_ps(r2);
450                     _r0n = _mm256_loadu_ps(r2 + 8);
451                     _k0 = _mm256_loadu_ps(k0 + 32);
452                     _k0n = _mm256_loadu_ps(k0 + 40);
453                     _k1 = _mm256_loadu_ps(k1 + 32);
454                     _k1n = _mm256_loadu_ps(k1 + 40);
455                     _k2 = _mm256_loadu_ps(k2 + 32);
456                     _k2n = _mm256_loadu_ps(k2 + 40);
457                     _k3 = _mm256_loadu_ps(k3 + 32);
458                     _k3n = _mm256_loadu_ps(k3 + 40);
459                     _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
460                     _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
461                     _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
462                     _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
463                     _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
464                     _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
465                     _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
466                     _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
467                     // k3
468                     _r0 = _mm256_loadu_ps(r3);
469                     _r0n = _mm256_loadu_ps(r3 + 8);
470                     _k0 = _mm256_loadu_ps(k0 + 48);
471                     _k0n = _mm256_loadu_ps(k0 + 56);
472                     _k1 = _mm256_loadu_ps(k1 + 48);
473                     _k1n = _mm256_loadu_ps(k1 + 56);
474                     _k2 = _mm256_loadu_ps(k2 + 48);
475                     _k2n = _mm256_loadu_ps(k2 + 56);
476                     _k3 = _mm256_loadu_ps(k3 + 48);
477                     _k3n = _mm256_loadu_ps(k3 + 56);
478                     _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
479                     _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
480                     _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
481                     _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
482                     _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
483                     _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
484                     _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
485                     _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
486                 }
487 
488                 for (; q < inch; q++)
489                 {
490                     const float* r0 = bottom_blob_tm.channel(q).row(i);
491 
492                     const float* k0 = kernel0_tm.row(q);
493                     const float* k1 = kernel1_tm.row(q);
494                     const float* k2 = kernel2_tm.row(q);
495                     const float* k3 = kernel3_tm.row(q);
496 
497                     __m256 _r0 = _mm256_loadu_ps(r0);
498                     __m256 _r0n = _mm256_loadu_ps(r0 + 8);
499                     __m256 _k0 = _mm256_loadu_ps(k0);
500                     __m256 _k0n = _mm256_loadu_ps(k0 + 8);
501                     __m256 _k1 = _mm256_loadu_ps(k1);
502                     __m256 _k1n = _mm256_loadu_ps(k1 + 8);
503                     __m256 _k2 = _mm256_loadu_ps(k2);
504                     __m256 _k2n = _mm256_loadu_ps(k2 + 8);
505                     __m256 _k3 = _mm256_loadu_ps(k3);
506                     __m256 _k3n = _mm256_loadu_ps(k3 + 8);
507 
508                     _sum0 = _mm256_fmadd_ps(_r0, _k0, _sum0);
509                     _sum0n = _mm256_fmadd_ps(_r0n, _k0n, _sum0n);
510                     _sum1 = _mm256_fmadd_ps(_r0, _k1, _sum1);
511                     _sum1n = _mm256_fmadd_ps(_r0n, _k1n, _sum1n);
512                     _sum2 = _mm256_fmadd_ps(_r0, _k2, _sum2);
513                     _sum2n = _mm256_fmadd_ps(_r0n, _k2n, _sum2n);
514                     _sum3 = _mm256_fmadd_ps(_r0, _k3, _sum3);
515                     _sum3n = _mm256_fmadd_ps(_r0n, _k3n, _sum3n);
516                 }
517 
518                 _mm256_storeu_ps(output0_tm, _sum0);
519                 _mm256_storeu_ps(output0_tm + 8, _sum0n);
520                 _mm256_storeu_ps(output1_tm, _sum1);
521                 _mm256_storeu_ps(output1_tm + 8, _sum1n);
522                 _mm256_storeu_ps(output2_tm, _sum2);
523                 _mm256_storeu_ps(output2_tm + 8, _sum2n);
524                 _mm256_storeu_ps(output3_tm, _sum3);
525                 _mm256_storeu_ps(output3_tm + 8, _sum3n);
526 #else
527                 float sum0[16] = {0.0f};
528                 float sum1[16] = {0.0f};
529                 float sum2[16] = {0.0f};
530                 float sum3[16] = {0.0f};
531 
532                 int q = 0;
533                 for (; q + 3 < inch; q += 4)
534                 {
535                     const float* r0 = bottom_blob_tm.channel(q).row(i);
536                     const float* r1 = bottom_blob_tm.channel(q + 1).row(i);
537                     const float* r2 = bottom_blob_tm.channel(q + 2).row(i);
538                     const float* r3 = bottom_blob_tm.channel(q + 3).row(i);
539 
540                     const float* k0 = kernel0_tm.row(q);
541                     const float* k1 = kernel1_tm.row(q);
542                     const float* k2 = kernel2_tm.row(q);
543                     const float* k3 = kernel3_tm.row(q);
544 
545                     for (int n = 0; n < 16; n++)
546                     {
547                         sum0[n] += r0[n] * k0[n];
548                         k0 += 16;
549                         sum0[n] += r1[n] * k0[n];
550                         k0 += 16;
551                         sum0[n] += r2[n] * k0[n];
552                         k0 += 16;
553                         sum0[n] += r3[n] * k0[n];
554                         k0 -= 16 * 3;
555 
556                         sum1[n] += r0[n] * k1[n];
557                         k1 += 16;
558                         sum1[n] += r1[n] * k1[n];
559                         k1 += 16;
560                         sum1[n] += r2[n] * k1[n];
561                         k1 += 16;
562                         sum1[n] += r3[n] * k1[n];
563                         k1 -= 16 * 3;
564 
565                         sum2[n] += r0[n] * k2[n];
566                         k2 += 16;
567                         sum2[n] += r1[n] * k2[n];
568                         k2 += 16;
569                         sum2[n] += r2[n] * k2[n];
570                         k2 += 16;
571                         sum2[n] += r3[n] * k2[n];
572                         k2 -= 16 * 3;
573 
574                         sum3[n] += r0[n] * k3[n];
575                         k3 += 16;
576                         sum3[n] += r1[n] * k3[n];
577                         k3 += 16;
578                         sum3[n] += r2[n] * k3[n];
579                         k3 += 16;
580                         sum3[n] += r3[n] * k3[n];
581                         k3 -= 16 * 3;
582                     }
583                 }
584 
585                 for (; q < inch; q++)
586                 {
587                     const float* r0 = bottom_blob_tm.channel(q).row(i);
588 
589                     const float* k0 = kernel0_tm.row(q);
590                     const float* k1 = kernel1_tm.row(q);
591                     const float* k2 = kernel2_tm.row(q);
592                     const float* k3 = kernel3_tm.row(q);
593 
594                     for (int n = 0; n < 16; n++)
595                     {
596                         sum0[n] += r0[n] * k0[n];
597                         sum1[n] += r0[n] * k1[n];
598                         sum2[n] += r0[n] * k2[n];
599                         sum3[n] += r0[n] * k3[n];
600                     }
601                 }
602 
603                 for (int n = 0; n < 16; n++)
604                 {
605                     output0_tm[n] = sum0[n];
606                     output1_tm[n] = sum1[n];
607                     output2_tm[n] = sum2[n];
608                     output3_tm[n] = sum3[n];
609                 }
610 #endif
611             }
612         }
613 
614         #pragma omp parallel for num_threads(opt.num_threads)
615         for (int p = remain_outch_start; p < outch; p++)
616         {
617             Mat out0_tm = top_blob_tm.channel(p);
618             const Mat kernel0_tm = kernel_tm.channel(p);
619 
620             for (int i = 0; i < tiles; i++)
621             {
622                 float* output0_tm = out0_tm.row(i);
623 
624                 float sum0[16] = {0.0f};
625 
626                 int q = 0;
627                 for (; q + 3 < inch; q += 4)
628                 {
629                     const float* r0 = bottom_blob_tm.channel(q).row(i);
630                     const float* r1 = bottom_blob_tm.channel(q + 1).row(i);
631                     const float* r2 = bottom_blob_tm.channel(q + 2).row(i);
632                     const float* r3 = bottom_blob_tm.channel(q + 3).row(i);
633 
634                     const float* k0 = kernel0_tm.row(q);
635                     const float* k1 = kernel0_tm.row(q + 1);
636                     const float* k2 = kernel0_tm.row(q + 2);
637                     const float* k3 = kernel0_tm.row(q + 3);
638 
639                     for (int n = 0; n < 16; n++)
640                     {
641                         sum0[n] += r0[n] * k0[n];
642                         sum0[n] += r1[n] * k1[n];
643                         sum0[n] += r2[n] * k2[n];
644                         sum0[n] += r3[n] * k3[n];
645                     }
646                 }
647 
648                 for (; q < inch; q++)
649                 {
650                     const float* r0 = bottom_blob_tm.channel(q).row(i);
651                     const float* k0 = kernel0_tm.row(q);
652 
653                     for (int n = 0; n < 16; n++)
654                     {
655                         sum0[n] += r0[n] * k0[n];
656                     }
657                 }
658 
659                 for (int n = 0; n < 16; n++)
660                 {
661                     output0_tm[n] = sum0[n];
662                 }
663             }
664         }
665     }
666     bottom_blob_tm = Mat();
667     // END dot
668 
669     // BEGIN transform output
670     Mat top_blob_bordered;
671     if (outw == top_blob.w && outh == top_blob.h)
672     {
673         top_blob_bordered = top_blob;
674     }
675     else
676     {
677         top_blob_bordered.create(outw, outh, outch, 4u, opt.workspace_allocator);
678     }
679     {
680         // AT
681         // const float itm[2][4] = {
682         //     {1.0f,  1.0f,  1.0f,  0.0f},
683         //     {0.0f,  1.0f, -1.0f,  1.0f}
684         // };
685 
686         int w_tm = outw / 2 * 4;
687         int h_tm = outh / 2 * 4;
688 
689         int nColBlocks = h_tm / 4; // may be the block num in Feathercnn
690         int nRowBlocks = w_tm / 4;
691 
692         #pragma omp parallel for num_threads(opt.num_threads)
693         for (int p = 0; p < outch; p++)
694         {
695             Mat out_tm = top_blob_tm.channel(p);
696             Mat out = top_blob_bordered.channel(p);
697 
698             const float bias0 = bias ? bias[p] : 0.f;
699 
700             for (int j = 0; j < nColBlocks; j++)
701             {
702                 float* outRow0 = out.row(j * 2);
703                 float* outRow1 = out.row(j * 2 + 1);
704 
705                 for (int i = 0; i < nRowBlocks; i++)
706                 {
707                     float* out_tile = out_tm.row(j * nRowBlocks + i);
708 
709                     float s0[4], s1[4], s2[4], s3[4];
710                     float w0[4], w1[4];
711                     float d0[2], d1[2], d2[2], d3[2];
712                     float o0[2], o1[2];
713                     // load
714                     for (int n = 0; n < 4; n++)
715                     {
716                         s0[n] = out_tile[n];
717                         s1[n] = out_tile[n + 4];
718                         s2[n] = out_tile[n + 8];
719                         s3[n] = out_tile[n + 12];
720                     }
721                     // w = A_T * W
722                     for (int n = 0; n < 4; n++)
723                     {
724                         w0[n] = s0[n] + s1[n] + s2[n];
725                         w1[n] = s1[n] - s2[n] + s3[n];
726                     }
727                     // transpose w to w_t
728                     {
729                         d0[0] = w0[0];
730                         d0[1] = w1[0];
731                         d1[0] = w0[1];
732                         d1[1] = w1[1];
733                         d2[0] = w0[2];
734                         d2[1] = w1[2];
735                         d3[0] = w0[3];
736                         d3[1] = w1[3];
737                     }
738                     // Y = A_T * w_t
739                     for (int n = 0; n < 2; n++)
740                     {
741                         o0[n] = d0[n] + d1[n] + d2[n] + bias0;
742                         o1[n] = d1[n] - d2[n] + d3[n] + bias0;
743                     }
744                     // save to top blob tm
745                     outRow0[0] = o0[0];
746                     outRow0[1] = o0[1];
747                     outRow1[0] = o1[0];
748                     outRow1[1] = o1[1];
749 
750                     outRow0 += 2;
751                     outRow1 += 2;
752                 }
753             }
754         }
755     }
756     // END transform output
757 
758     // cut result pad
759     copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
760 }
761 
conv3x3s1_winograd43_transform_kernel_sse(const Mat & kernel,std::vector<Mat> & kernel_tm2,int inch,int outch)762 static void conv3x3s1_winograd43_transform_kernel_sse(const Mat& kernel, std::vector<Mat>& kernel_tm2, int inch, int outch)
763 {
764     Mat kernel_tm(6 * 6, inch, outch);
765 
766     // G
767     const float ktm[6][3] = {
768         {1.0f / 4, 0.0f, 0.0f},
769         {-1.0f / 6, -1.0f / 6, -1.0f / 6},
770         {-1.0f / 6, 1.0f / 6, -1.0f / 6},
771         {1.0f / 24, 1.0f / 12, 1.0f / 6},
772         {1.0f / 24, -1.0f / 12, 1.0f / 6},
773         {0.0f, 0.0f, 1.0f}
774     };
775 
776     #pragma omp parallel for
777     for (int p = 0; p < outch; p++)
778     {
779         for (int q = 0; q < inch; q++)
780         {
781             const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
782             float* kernel_tm0 = kernel_tm.channel(p).row(q);
783 
784             // transform kernel
785             const float* k0 = kernel0;
786             const float* k1 = kernel0 + 3;
787             const float* k2 = kernel0 + 6;
788 
789             // h
790             float tmp[6][3];
791             for (int i = 0; i < 6; i++)
792             {
793                 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
794                 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
795                 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
796             }
797 
798             // U
799             for (int j = 0; j < 6; j++)
800             {
801                 float* tmpp = &tmp[j][0];
802 
803                 for (int i = 0; i < 6; i++)
804                 {
805                     kernel_tm0[j * 6 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
806                 }
807             }
808         }
809     }
810 
811     for (int r = 0; r < 9; r++)
812     {
813         Mat kernel_tm_test(4 * 8, inch, outch / 8 + (outch % 8) / 4 + outch % 4);
814 
815         int p = 0;
816         for (; p + 7 < outch; p += 8)
817         {
818             const float* kernel0 = (const float*)kernel_tm.channel(p);
819             const float* kernel1 = (const float*)kernel_tm.channel(p + 1);
820             const float* kernel2 = (const float*)kernel_tm.channel(p + 2);
821             const float* kernel3 = (const float*)kernel_tm.channel(p + 3);
822             const float* kernel4 = (const float*)kernel_tm.channel(p + 4);
823             const float* kernel5 = (const float*)kernel_tm.channel(p + 5);
824             const float* kernel6 = (const float*)kernel_tm.channel(p + 6);
825             const float* kernel7 = (const float*)kernel_tm.channel(p + 7);
826 
827             float* ktmp = kernel_tm_test.channel(p / 8);
828 
829             for (int q = 0; q < inch; q++)
830             {
831                 ktmp[0] = kernel0[r * 4 + 0];
832                 ktmp[1] = kernel0[r * 4 + 1];
833                 ktmp[2] = kernel0[r * 4 + 2];
834                 ktmp[3] = kernel0[r * 4 + 3];
835 
836                 ktmp[4] = kernel1[r * 4 + 0];
837                 ktmp[5] = kernel1[r * 4 + 1];
838                 ktmp[6] = kernel1[r * 4 + 2];
839                 ktmp[7] = kernel1[r * 4 + 3];
840 
841                 ktmp[8] = kernel2[r * 4 + 0];
842                 ktmp[9] = kernel2[r * 4 + 1];
843                 ktmp[10] = kernel2[r * 4 + 2];
844                 ktmp[11] = kernel2[r * 4 + 3];
845 
846                 ktmp[12] = kernel3[r * 4 + 0];
847                 ktmp[13] = kernel3[r * 4 + 1];
848                 ktmp[14] = kernel3[r * 4 + 2];
849                 ktmp[15] = kernel3[r * 4 + 3];
850 
851                 ktmp[16] = kernel4[r * 4 + 0];
852                 ktmp[17] = kernel4[r * 4 + 1];
853                 ktmp[18] = kernel4[r * 4 + 2];
854                 ktmp[19] = kernel4[r * 4 + 3];
855 
856                 ktmp[20] = kernel5[r * 4 + 0];
857                 ktmp[21] = kernel5[r * 4 + 1];
858                 ktmp[22] = kernel5[r * 4 + 2];
859                 ktmp[23] = kernel5[r * 4 + 3];
860 
861                 ktmp[24] = kernel6[r * 4 + 0];
862                 ktmp[25] = kernel6[r * 4 + 1];
863                 ktmp[26] = kernel6[r * 4 + 2];
864                 ktmp[27] = kernel6[r * 4 + 3];
865 
866                 ktmp[28] = kernel7[r * 4 + 0];
867                 ktmp[29] = kernel7[r * 4 + 1];
868                 ktmp[30] = kernel7[r * 4 + 2];
869                 ktmp[31] = kernel7[r * 4 + 3];
870 
871                 ktmp += 32;
872                 kernel0 += 36;
873                 kernel1 += 36;
874                 kernel2 += 36;
875                 kernel3 += 36;
876                 kernel4 += 36;
877                 kernel5 += 36;
878                 kernel6 += 36;
879                 kernel7 += 36;
880             }
881         }
882 
883         for (; p + 3 < outch; p += 4)
884         {
885             const float* kernel0 = (const float*)kernel_tm.channel(p);
886             const float* kernel1 = (const float*)kernel_tm.channel(p + 1);
887             const float* kernel2 = (const float*)kernel_tm.channel(p + 2);
888             const float* kernel3 = (const float*)kernel_tm.channel(p + 3);
889 
890             float* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4);
891 
892             for (int q = 0; q < inch; q++)
893             {
894                 ktmp[0] = kernel0[r * 4 + 0];
895                 ktmp[1] = kernel0[r * 4 + 1];
896                 ktmp[2] = kernel0[r * 4 + 2];
897                 ktmp[3] = kernel0[r * 4 + 3];
898 
899                 ktmp[4] = kernel1[r * 4 + 0];
900                 ktmp[5] = kernel1[r * 4 + 1];
901                 ktmp[6] = kernel1[r * 4 + 2];
902                 ktmp[7] = kernel1[r * 4 + 3];
903 
904                 ktmp[8] = kernel2[r * 4 + 0];
905                 ktmp[9] = kernel2[r * 4 + 1];
906                 ktmp[10] = kernel2[r * 4 + 2];
907                 ktmp[11] = kernel2[r * 4 + 3];
908 
909                 ktmp[12] = kernel3[r * 4 + 0];
910                 ktmp[13] = kernel3[r * 4 + 1];
911                 ktmp[14] = kernel3[r * 4 + 2];
912                 ktmp[15] = kernel3[r * 4 + 3];
913 
914                 ktmp += 16;
915                 kernel0 += 36;
916                 kernel1 += 36;
917                 kernel2 += 36;
918                 kernel3 += 36;
919             }
920         }
921 
922         for (; p < outch; p++)
923         {
924             const float* kernel0 = (const float*)kernel_tm.channel(p);
925 
926             float* ktmp = kernel_tm_test.channel(p / 8 + (p % 8) / 4 + p % 4);
927 
928             for (int q = 0; q < inch; q++)
929             {
930                 ktmp[0] = kernel0[r * 4 + 0];
931                 ktmp[1] = kernel0[r * 4 + 1];
932                 ktmp[2] = kernel0[r * 4 + 2];
933                 ktmp[3] = kernel0[r * 4 + 3];
934 
935                 ktmp += 4;
936                 kernel0 += 36;
937             }
938         }
939         kernel_tm2.push_back(kernel_tm_test);
940     }
941 }
942 
conv3x3s1_winograd43_sse(const Mat & bottom_blob,Mat & top_blob,const std::vector<Mat> & kernel_tm_test,const Mat & _bias,const Option & opt)943 static void conv3x3s1_winograd43_sse(const Mat& bottom_blob, Mat& top_blob, const std::vector<Mat>& kernel_tm_test, const Mat& _bias, const Option& opt)
944 {
945     int w = bottom_blob.w;
946     int h = bottom_blob.h;
947     int inch = bottom_blob.c;
948 
949     int outw = top_blob.w;
950     int outh = top_blob.h;
951     int outch = top_blob.c;
952 
953     size_t elemsize = bottom_blob.elemsize;
954     const float* bias = _bias;
955 
956     // pad to 4n+2, winograd F(4,3)
957     Mat bottom_blob_bordered = bottom_blob;
958 
959     outw = (outw + 3) / 4 * 4;
960     outh = (outh + 3) / 4 * 4;
961 
962     w = outw + 2;
963     h = outh + 2;
964 
965     Option opt_b = opt;
966     opt_b.blob_allocator = opt.workspace_allocator;
967     copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, 0, 0.f, opt_b);
968 
969     // BEGIN transform input
970     Mat bottom_blob_tm;
971     {
972         int w_tm = outw / 4 * 6;
973         int h_tm = outh / 4 * 6;
974 
975         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
976         int nRowBlocks = w_tm / 6;
977 
978         const int tiles = nColBlocks * nRowBlocks;
979 
980         bottom_blob_tm.create(4, inch, tiles * 9, elemsize, opt.workspace_allocator);
981 
982         // BT
983         // const float itm[4][4] = {
984         //     {4.0f, 0.0f, -5.0f, 0.0f, 1.0f, 0.0f},
985         //     {0.0f,-4.0f, -4.0f, 1.0f, 1.0f, 0.0f},
986         //     {0.0f, 4.0f, -4.0f,-1.0f, 1.0f, 0.0f},
987         //     {0.0f,-2.0f, -1.0f, 2.0f, 1.0f, 0.0f},
988         //     {0.0f, 2.0f, -1.0f,-2.0f, 1.0f, 0.0f},
989         //     {0.0f, 4.0f,  0.0f,-5.0f, 0.0f, 1.0f}
990         // };
991 
992         // 0 =	4 * r00  - 5 * r02	+ r04
993         // 1 = -4 * (r01 + r02)  + r03 + r04
994         // 2 =	4 * (r01 - r02)  - r03 + r04
995         // 3 = -2 * r01 - r02 + 2 * r03 + r04
996         // 4 =	2 * r01 - r02 - 2 * r03 + r04
997         // 5 =	4 * r01 - 5 * r03 + r05
998 
999         // 0 =	4 * r00  - 5 * r02	+ r04
1000         // 1 = -4 * (r01 + r02)  + r03 + r04
1001         // 2 =	4 * (r01 - r02)  - r03 + r04
1002         // 3 = -2 * r01 - r02 + 2 * r03 + r04
1003         // 4 =	2 * r01 - r02 - 2 * r03 + r04
1004         // 5 =	4 * r01 - 5 * r03 + r05
1005 
1006 #if __AVX__
1007         __m256 _1_n = _mm256_set1_ps(-1);
1008         __m256 _2_p = _mm256_set1_ps(2);
1009         __m256 _2_n = _mm256_set1_ps(-2);
1010         __m256 _4_p = _mm256_set1_ps(4);
1011         __m256 _4_n = _mm256_set1_ps(-4);
1012         __m256 _5_n = _mm256_set1_ps(-5);
1013 #endif
1014 
1015         #pragma omp parallel for num_threads(opt.num_threads)
1016         for (int q = 0; q < inch; q++)
1017         {
1018             const float* img = bottom_blob_bordered.channel(q);
1019 
1020             for (int j = 0; j < nColBlocks; j++)
1021             {
1022                 const float* r0 = img + w * j * 4;
1023                 const float* r1 = r0 + w;
1024                 const float* r2 = r1 + w;
1025                 const float* r3 = r2 + w;
1026                 const float* r4 = r3 + w;
1027                 const float* r5 = r4 + w;
1028 
1029                 for (int i = 0; i < nRowBlocks; i++)
1030                 {
1031                     float* out_tm0 = bottom_blob_tm.channel(tiles * 0 + j * nRowBlocks + i).row(q);
1032                     float* out_tm1 = bottom_blob_tm.channel(tiles * 1 + j * nRowBlocks + i).row(q);
1033                     float* out_tm2 = bottom_blob_tm.channel(tiles * 2 + j * nRowBlocks + i).row(q);
1034                     float* out_tm3 = bottom_blob_tm.channel(tiles * 3 + j * nRowBlocks + i).row(q);
1035                     float* out_tm4 = bottom_blob_tm.channel(tiles * 4 + j * nRowBlocks + i).row(q);
1036                     float* out_tm5 = bottom_blob_tm.channel(tiles * 5 + j * nRowBlocks + i).row(q);
1037                     float* out_tm6 = bottom_blob_tm.channel(tiles * 6 + j * nRowBlocks + i).row(q);
1038                     float* out_tm7 = bottom_blob_tm.channel(tiles * 7 + j * nRowBlocks + i).row(q);
1039                     float* out_tm8 = bottom_blob_tm.channel(tiles * 8 + j * nRowBlocks + i).row(q);
1040 #if __AVX__
1041                     __m256 _d0, _d1, _d2, _d3, _d4, _d5;
1042                     __m256 _w0, _w1, _w2, _w3, _w4, _w5;
1043                     __m256 _t0, _t1, _t2, _t3, _t4, _t5;
1044                     __m256 _n0, _n1, _n2, _n3, _n4, _n5;
1045                     // load
1046                     _d0 = _mm256_loadu_ps(r0);
1047                     _d1 = _mm256_loadu_ps(r1);
1048                     _d2 = _mm256_loadu_ps(r2);
1049                     _d3 = _mm256_loadu_ps(r3);
1050                     _d4 = _mm256_loadu_ps(r4);
1051                     _d5 = _mm256_loadu_ps(r5);
1052 
1053                     // w = B_t * d
1054                     _w0 = _mm256_mul_ps(_d0, _4_p);
1055                     _w0 = _mm256_fmadd_ps(_d2, _5_n, _w0);
1056                     _w0 = _mm256_add_ps(_w0, _d4);
1057 
1058                     _w1 = _mm256_mul_ps(_d1, _4_n);
1059                     _w1 = _mm256_fmadd_ps(_d2, _4_n, _w1);
1060                     _w1 = _mm256_add_ps(_w1, _d3);
1061                     _w1 = _mm256_add_ps(_w1, _d4);
1062 
1063                     _w2 = _mm256_mul_ps(_d1, _4_p);
1064                     _w2 = _mm256_fmadd_ps(_d2, _4_n, _w2);
1065                     _w2 = _mm256_fmadd_ps(_d3, _1_n, _w2);
1066                     _w2 = _mm256_add_ps(_w2, _d4);
1067 
1068                     _w3 = _mm256_mul_ps(_d1, _2_n);
1069                     _w3 = _mm256_fmadd_ps(_d2, _1_n, _w3);
1070                     _w3 = _mm256_fmadd_ps(_d3, _2_p, _w3);
1071                     _w3 = _mm256_add_ps(_w3, _d4);
1072 
1073                     _w4 = _mm256_mul_ps(_d1, _2_p);
1074                     _w4 = _mm256_fmadd_ps(_d2, _1_n, _w4);
1075                     _w4 = _mm256_fmadd_ps(_d3, _2_n, _w4);
1076                     _w4 = _mm256_add_ps(_w4, _d4);
1077 
1078                     _w5 = _mm256_mul_ps(_d1, _4_p);
1079                     _w5 = _mm256_fmadd_ps(_d3, _5_n, _w5);
1080                     _w5 = _mm256_add_ps(_w5, _d5);
1081                     // transpose d to d_t
1082 #if (defined _WIN32 && !(defined __MINGW32__))
1083                     {
1084                         _t0.m256_f32[0] = _w0.m256_f32[0];
1085                         _t1.m256_f32[0] = _w0.m256_f32[1];
1086                         _t2.m256_f32[0] = _w0.m256_f32[2];
1087                         _t3.m256_f32[0] = _w0.m256_f32[3];
1088                         _t4.m256_f32[0] = _w0.m256_f32[4];
1089                         _t5.m256_f32[0] = _w0.m256_f32[5];
1090                         _t0.m256_f32[1] = _w1.m256_f32[0];
1091                         _t1.m256_f32[1] = _w1.m256_f32[1];
1092                         _t2.m256_f32[1] = _w1.m256_f32[2];
1093                         _t3.m256_f32[1] = _w1.m256_f32[3];
1094                         _t4.m256_f32[1] = _w1.m256_f32[4];
1095                         _t5.m256_f32[1] = _w1.m256_f32[5];
1096                         _t0.m256_f32[2] = _w2.m256_f32[0];
1097                         _t1.m256_f32[2] = _w2.m256_f32[1];
1098                         _t2.m256_f32[2] = _w2.m256_f32[2];
1099                         _t3.m256_f32[2] = _w2.m256_f32[3];
1100                         _t4.m256_f32[2] = _w2.m256_f32[4];
1101                         _t5.m256_f32[2] = _w2.m256_f32[5];
1102                         _t0.m256_f32[3] = _w3.m256_f32[0];
1103                         _t1.m256_f32[3] = _w3.m256_f32[1];
1104                         _t2.m256_f32[3] = _w3.m256_f32[2];
1105                         _t3.m256_f32[3] = _w3.m256_f32[3];
1106                         _t4.m256_f32[3] = _w3.m256_f32[4];
1107                         _t5.m256_f32[3] = _w3.m256_f32[5];
1108                         _t0.m256_f32[4] = _w4.m256_f32[0];
1109                         _t1.m256_f32[4] = _w4.m256_f32[1];
1110                         _t2.m256_f32[4] = _w4.m256_f32[2];
1111                         _t3.m256_f32[4] = _w4.m256_f32[3];
1112                         _t4.m256_f32[4] = _w4.m256_f32[4];
1113                         _t5.m256_f32[4] = _w4.m256_f32[5];
1114                         _t0.m256_f32[5] = _w5.m256_f32[0];
1115                         _t1.m256_f32[5] = _w5.m256_f32[1];
1116                         _t2.m256_f32[5] = _w5.m256_f32[2];
1117                         _t3.m256_f32[5] = _w5.m256_f32[3];
1118                         _t4.m256_f32[5] = _w5.m256_f32[4];
1119                         _t5.m256_f32[5] = _w5.m256_f32[5];
1120                     }
1121 #else
1122                     {
1123                         _t0[0] = _w0[0];
1124                         _t1[0] = _w0[1];
1125                         _t2[0] = _w0[2];
1126                         _t3[0] = _w0[3];
1127                         _t4[0] = _w0[4];
1128                         _t5[0] = _w0[5];
1129                         _t0[1] = _w1[0];
1130                         _t1[1] = _w1[1];
1131                         _t2[1] = _w1[2];
1132                         _t3[1] = _w1[3];
1133                         _t4[1] = _w1[4];
1134                         _t5[1] = _w1[5];
1135                         _t0[2] = _w2[0];
1136                         _t1[2] = _w2[1];
1137                         _t2[2] = _w2[2];
1138                         _t3[2] = _w2[3];
1139                         _t4[2] = _w2[4];
1140                         _t5[2] = _w2[5];
1141                         _t0[3] = _w3[0];
1142                         _t1[3] = _w3[1];
1143                         _t2[3] = _w3[2];
1144                         _t3[3] = _w3[3];
1145                         _t4[3] = _w3[4];
1146                         _t5[3] = _w3[5];
1147                         _t0[4] = _w4[0];
1148                         _t1[4] = _w4[1];
1149                         _t2[4] = _w4[2];
1150                         _t3[4] = _w4[3];
1151                         _t4[4] = _w4[4];
1152                         _t5[4] = _w4[5];
1153                         _t0[5] = _w5[0];
1154                         _t1[5] = _w5[1];
1155                         _t2[5] = _w5[2];
1156                         _t3[5] = _w5[3];
1157                         _t4[5] = _w5[4];
1158                         _t5[5] = _w5[5];
1159                     }
1160 #endif
1161                     // d = B_t * d_t
1162                     _n0 = _mm256_mul_ps(_t0, _4_p);
1163                     _n0 = _mm256_fmadd_ps(_t2, _5_n, _n0);
1164                     _n0 = _mm256_add_ps(_n0, _t4);
1165 
1166                     _n1 = _mm256_mul_ps(_t1, _4_n);
1167                     _n1 = _mm256_fmadd_ps(_t2, _4_n, _n1);
1168                     _n1 = _mm256_add_ps(_n1, _t3);
1169                     _n1 = _mm256_add_ps(_n1, _t4);
1170 
1171                     _n2 = _mm256_mul_ps(_t1, _4_p);
1172                     _n2 = _mm256_fmadd_ps(_t2, _4_n, _n2);
1173                     _n2 = _mm256_fmadd_ps(_t3, _1_n, _n2);
1174                     _n2 = _mm256_add_ps(_n2, _t4);
1175 
1176                     _n3 = _mm256_mul_ps(_t1, _2_n);
1177                     _n3 = _mm256_fmadd_ps(_t2, _1_n, _n3);
1178                     _n3 = _mm256_fmadd_ps(_t3, _2_p, _n3);
1179                     _n3 = _mm256_add_ps(_n3, _t4);
1180 
1181                     _n4 = _mm256_mul_ps(_t1, _2_p);
1182                     _n4 = _mm256_fmadd_ps(_t2, _1_n, _n4);
1183                     _n4 = _mm256_fmadd_ps(_t3, _2_n, _n4);
1184                     _n4 = _mm256_add_ps(_n4, _t4);
1185 
1186                     _n5 = _mm256_mul_ps(_t1, _4_p);
1187                     _n5 = _mm256_fmadd_ps(_t3, _5_n, _n5);
1188                     _n5 = _mm256_add_ps(_n5, _t5);
1189                     // save to out_tm
1190                     float output_n0[8] = {0.f};
1191                     _mm256_storeu_ps(output_n0, _n0);
1192                     float output_n1[8] = {0.f};
1193                     _mm256_storeu_ps(output_n1, _n1);
1194                     float output_n2[8] = {0.f};
1195                     _mm256_storeu_ps(output_n2, _n2);
1196                     float output_n3[8] = {0.f};
1197                     _mm256_storeu_ps(output_n3, _n3);
1198                     float output_n4[8] = {0.f};
1199                     _mm256_storeu_ps(output_n4, _n4);
1200                     float output_n5[8] = {0.f};
1201                     _mm256_storeu_ps(output_n5, _n5);
1202 
1203                     out_tm0[0] = output_n0[0];
1204                     out_tm0[1] = output_n0[1];
1205                     out_tm0[2] = output_n0[2];
1206                     out_tm0[3] = output_n0[3];
1207                     out_tm1[0] = output_n0[4];
1208                     out_tm1[1] = output_n0[5];
1209                     out_tm1[2] = output_n1[0];
1210                     out_tm1[3] = output_n1[1];
1211                     out_tm2[0] = output_n1[2];
1212                     out_tm2[1] = output_n1[3];
1213                     out_tm2[2] = output_n1[4];
1214                     out_tm2[3] = output_n1[5];
1215 
1216                     out_tm3[0] = output_n2[0];
1217                     out_tm3[1] = output_n2[1];
1218                     out_tm3[2] = output_n2[2];
1219                     out_tm3[3] = output_n2[3];
1220                     out_tm4[0] = output_n2[4];
1221                     out_tm4[1] = output_n2[5];
1222                     out_tm4[2] = output_n3[0];
1223                     out_tm4[3] = output_n3[1];
1224                     out_tm5[0] = output_n3[2];
1225                     out_tm5[1] = output_n3[3];
1226                     out_tm5[2] = output_n3[4];
1227                     out_tm5[3] = output_n3[5];
1228 
1229                     out_tm6[0] = output_n4[0];
1230                     out_tm6[1] = output_n4[1];
1231                     out_tm6[2] = output_n4[2];
1232                     out_tm6[3] = output_n4[3];
1233                     out_tm7[0] = output_n4[4];
1234                     out_tm7[1] = output_n4[5];
1235                     out_tm7[2] = output_n5[0];
1236                     out_tm7[3] = output_n5[1];
1237                     out_tm8[0] = output_n5[2];
1238                     out_tm8[1] = output_n5[3];
1239                     out_tm8[2] = output_n5[4];
1240                     out_tm8[3] = output_n5[5];
1241 #else
1242                     float d0[6], d1[6], d2[6], d3[6], d4[6], d5[6];
1243                     float w0[6], w1[6], w2[6], w3[6], w4[6], w5[6];
1244                     float t0[6], t1[6], t2[6], t3[6], t4[6], t5[6];
1245 
1246                     // load
1247                     for (int n = 0; n < 6; n++)
1248                     {
1249                         d0[n] = r0[n];
1250                         d1[n] = r1[n];
1251                         d2[n] = r2[n];
1252                         d3[n] = r3[n];
1253                         d4[n] = r4[n];
1254                         d5[n] = r5[n];
1255                     }
1256                     // w = B_t * d
1257                     for (int n = 0; n < 6; n++)
1258                     {
1259                         w0[n] = 4 * d0[n] - 5 * d2[n] + d4[n];
1260                         w1[n] = -4 * d1[n] - 4 * d2[n] + d3[n] + d4[n];
1261                         w2[n] = 4 * d1[n] - 4 * d2[n] - d3[n] + d4[n];
1262                         w3[n] = -2 * d1[n] - d2[n] + 2 * d3[n] + d4[n];
1263                         w4[n] = 2 * d1[n] - d2[n] - 2 * d3[n] + d4[n];
1264                         w5[n] = 4 * d1[n] - 5 * d3[n] + d5[n];
1265                     }
1266                     // transpose d to d_t
1267                     {
1268                         t0[0] = w0[0];
1269                         t1[0] = w0[1];
1270                         t2[0] = w0[2];
1271                         t3[0] = w0[3];
1272                         t4[0] = w0[4];
1273                         t5[0] = w0[5];
1274                         t0[1] = w1[0];
1275                         t1[1] = w1[1];
1276                         t2[1] = w1[2];
1277                         t3[1] = w1[3];
1278                         t4[1] = w1[4];
1279                         t5[1] = w1[5];
1280                         t0[2] = w2[0];
1281                         t1[2] = w2[1];
1282                         t2[2] = w2[2];
1283                         t3[2] = w2[3];
1284                         t4[2] = w2[4];
1285                         t5[2] = w2[5];
1286                         t0[3] = w3[0];
1287                         t1[3] = w3[1];
1288                         t2[3] = w3[2];
1289                         t3[3] = w3[3];
1290                         t4[3] = w3[4];
1291                         t5[3] = w3[5];
1292                         t0[4] = w4[0];
1293                         t1[4] = w4[1];
1294                         t2[4] = w4[2];
1295                         t3[4] = w4[3];
1296                         t4[4] = w4[4];
1297                         t5[4] = w4[5];
1298                         t0[5] = w5[0];
1299                         t1[5] = w5[1];
1300                         t2[5] = w5[2];
1301                         t3[5] = w5[3];
1302                         t4[5] = w5[4];
1303                         t5[5] = w5[5];
1304                     }
1305                     // d = B_t * d_t
1306                     for (int n = 0; n < 6; n++)
1307                     {
1308                         d0[n] = 4 * t0[n] - 5 * t2[n] + t4[n];
1309                         d1[n] = -4 * t1[n] - 4 * t2[n] + t3[n] + t4[n];
1310                         d2[n] = 4 * t1[n] - 4 * t2[n] - t3[n] + t4[n];
1311                         d3[n] = -2 * t1[n] - t2[n] + 2 * t3[n] + t4[n];
1312                         d4[n] = 2 * t1[n] - t2[n] - 2 * t3[n] + t4[n];
1313                         d5[n] = 4 * t1[n] - 5 * t3[n] + t5[n];
1314                     }
1315                     // save to out_tm
1316                     {
1317                         out_tm0[0] = d0[0];
1318                         out_tm0[1] = d0[1];
1319                         out_tm0[2] = d0[2];
1320                         out_tm0[3] = d0[3];
1321                         out_tm1[0] = d0[4];
1322                         out_tm1[1] = d0[5];
1323                         out_tm1[2] = d1[0];
1324                         out_tm1[3] = d1[1];
1325                         out_tm2[0] = d1[2];
1326                         out_tm2[1] = d1[3];
1327                         out_tm2[2] = d1[4];
1328                         out_tm2[3] = d1[5];
1329 
1330                         out_tm3[0] = d2[0];
1331                         out_tm3[1] = d2[1];
1332                         out_tm3[2] = d2[2];
1333                         out_tm3[3] = d2[3];
1334                         out_tm4[0] = d2[4];
1335                         out_tm4[1] = d2[5];
1336                         out_tm4[2] = d3[0];
1337                         out_tm4[3] = d3[1];
1338                         out_tm5[0] = d3[2];
1339                         out_tm5[1] = d3[3];
1340                         out_tm5[2] = d3[4];
1341                         out_tm5[3] = d3[5];
1342 
1343                         out_tm6[0] = d4[0];
1344                         out_tm6[1] = d4[1];
1345                         out_tm6[2] = d4[2];
1346                         out_tm6[3] = d4[3];
1347                         out_tm7[0] = d4[4];
1348                         out_tm7[1] = d4[5];
1349                         out_tm7[2] = d5[0];
1350                         out_tm7[3] = d5[1];
1351                         out_tm8[0] = d5[2];
1352                         out_tm8[1] = d5[3];
1353                         out_tm8[2] = d5[4];
1354                         out_tm8[3] = d5[5];
1355                     }
1356 #endif // __AVX__
1357                     r0 += 4;
1358                     r1 += 4;
1359                     r2 += 4;
1360                     r3 += 4;
1361                     r4 += 4;
1362                     r5 += 4;
1363                 }
1364             }
1365         }
1366     }
1367     bottom_blob_bordered = Mat();
1368 
1369     // BEGIN dot
1370     Mat top_blob_tm;
1371     {
1372         int w_tm = outw / 4 * 6;
1373         int h_tm = outh / 4 * 6;
1374 
1375         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1376         int nRowBlocks = w_tm / 6;
1377 
1378         const int tiles = nColBlocks * nRowBlocks;
1379 
1380         top_blob_tm.create(36, tiles, outch, elemsize, opt.workspace_allocator);
1381 
1382         #pragma omp parallel for num_threads(opt.num_threads)
1383         for (int r = 0; r < 9; r++)
1384         {
1385             int nn_outch = 0;
1386             int remain_outch_start = 0;
1387 
1388             nn_outch = outch >> 3;
1389             remain_outch_start = nn_outch << 3;
1390 
1391             for (int pp = 0; pp < nn_outch; pp++)
1392             {
1393                 int p = pp * 8;
1394 
1395                 float* output0_tm = top_blob_tm.channel(p);
1396                 float* output1_tm = top_blob_tm.channel(p + 1);
1397                 float* output2_tm = top_blob_tm.channel(p + 2);
1398                 float* output3_tm = top_blob_tm.channel(p + 3);
1399                 float* output4_tm = top_blob_tm.channel(p + 4);
1400                 float* output5_tm = top_blob_tm.channel(p + 5);
1401                 float* output6_tm = top_blob_tm.channel(p + 6);
1402                 float* output7_tm = top_blob_tm.channel(p + 7);
1403 
1404                 output0_tm = output0_tm + r * 4;
1405                 output1_tm = output1_tm + r * 4;
1406                 output2_tm = output2_tm + r * 4;
1407                 output3_tm = output3_tm + r * 4;
1408                 output4_tm = output4_tm + r * 4;
1409                 output5_tm = output5_tm + r * 4;
1410                 output6_tm = output6_tm + r * 4;
1411                 output7_tm = output7_tm + r * 4;
1412 
1413                 for (int i = 0; i < tiles; i++)
1414                 {
1415                     const float* kptr = kernel_tm_test[r].channel(p / 8);
1416                     const float* r0 = bottom_blob_tm.channel(tiles * r + i);
1417 #if __AVX__ || __SSE__
1418 #if __AVX__
1419                     float zero_val = 0.f;
1420                     __m128 _sum0 = _mm_broadcast_ss(&zero_val);
1421                     __m128 _sum1 = _mm_broadcast_ss(&zero_val);
1422                     __m128 _sum2 = _mm_broadcast_ss(&zero_val);
1423                     __m128 _sum3 = _mm_broadcast_ss(&zero_val);
1424                     __m128 _sum4 = _mm_broadcast_ss(&zero_val);
1425                     __m128 _sum5 = _mm_broadcast_ss(&zero_val);
1426                     __m128 _sum6 = _mm_broadcast_ss(&zero_val);
1427                     __m128 _sum7 = _mm_broadcast_ss(&zero_val);
1428 #else
1429                     __m128 _sum0 = _mm_set1_ps(0.f);
1430                     __m128 _sum1 = _mm_set1_ps(0.f);
1431                     __m128 _sum2 = _mm_set1_ps(0.f);
1432                     __m128 _sum3 = _mm_set1_ps(0.f);
1433                     __m128 _sum4 = _mm_set1_ps(0.f);
1434                     __m128 _sum5 = _mm_set1_ps(0.f);
1435                     __m128 _sum6 = _mm_set1_ps(0.f);
1436                     __m128 _sum7 = _mm_set1_ps(0.f);
1437 #endif
1438                     int q = 0;
1439                     for (; q + 3 < inch; q = q + 4)
1440                     {
1441                         __m128 _r0 = _mm_loadu_ps(r0);
1442                         __m128 _r1 = _mm_loadu_ps(r0 + 4);
1443                         __m128 _r2 = _mm_loadu_ps(r0 + 8);
1444                         __m128 _r3 = _mm_loadu_ps(r0 + 12);
1445 
1446                         __m128 _k0 = _mm_loadu_ps(kptr);
1447                         __m128 _k1 = _mm_loadu_ps(kptr + 4);
1448                         __m128 _k2 = _mm_loadu_ps(kptr + 8);
1449                         __m128 _k3 = _mm_loadu_ps(kptr + 12);
1450                         __m128 _k4 = _mm_loadu_ps(kptr + 16);
1451                         __m128 _k5 = _mm_loadu_ps(kptr + 20);
1452                         __m128 _k6 = _mm_loadu_ps(kptr + 24);
1453                         __m128 _k7 = _mm_loadu_ps(kptr + 28);
1454 #if __AVX__
1455                         _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1456                         _sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
1457                         _sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
1458                         _sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);
1459                         _sum4 = _mm_fmadd_ps(_r0, _k4, _sum4);
1460                         _sum5 = _mm_fmadd_ps(_r0, _k5, _sum5);
1461                         _sum6 = _mm_fmadd_ps(_r0, _k6, _sum6);
1462                         _sum7 = _mm_fmadd_ps(_r0, _k7, _sum7);
1463 #else
1464                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1465                         _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
1466                         _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
1467                         _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
1468                         _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r0, _k4));
1469                         _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r0, _k5));
1470                         _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r0, _k6));
1471                         _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r0, _k7));
1472 #endif
1473                         kptr += 32;
1474                         _k0 = _mm_loadu_ps(kptr);
1475                         _k1 = _mm_loadu_ps(kptr + 4);
1476                         _k2 = _mm_loadu_ps(kptr + 8);
1477                         _k3 = _mm_loadu_ps(kptr + 12);
1478                         _k4 = _mm_loadu_ps(kptr + 16);
1479                         _k5 = _mm_loadu_ps(kptr + 20);
1480                         _k6 = _mm_loadu_ps(kptr + 24);
1481                         _k7 = _mm_loadu_ps(kptr + 28);
1482 #if __AVX__
1483                         _sum0 = _mm_fmadd_ps(_r1, _k0, _sum0);
1484                         _sum1 = _mm_fmadd_ps(_r1, _k1, _sum1);
1485                         _sum2 = _mm_fmadd_ps(_r1, _k2, _sum2);
1486                         _sum3 = _mm_fmadd_ps(_r1, _k3, _sum3);
1487                         _sum4 = _mm_fmadd_ps(_r1, _k4, _sum4);
1488                         _sum5 = _mm_fmadd_ps(_r1, _k5, _sum5);
1489                         _sum6 = _mm_fmadd_ps(_r1, _k6, _sum6);
1490                         _sum7 = _mm_fmadd_ps(_r1, _k7, _sum7);
1491 #else
1492                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r1, _k0));
1493                         _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r1, _k1));
1494                         _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r1, _k2));
1495                         _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r1, _k3));
1496                         _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r1, _k4));
1497                         _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r1, _k5));
1498                         _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r1, _k6));
1499                         _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r1, _k7));
1500 #endif
1501 
1502                         kptr += 32;
1503                         _k0 = _mm_loadu_ps(kptr);
1504                         _k1 = _mm_loadu_ps(kptr + 4);
1505                         _k2 = _mm_loadu_ps(kptr + 8);
1506                         _k3 = _mm_loadu_ps(kptr + 12);
1507                         _k4 = _mm_loadu_ps(kptr + 16);
1508                         _k5 = _mm_loadu_ps(kptr + 20);
1509                         _k6 = _mm_loadu_ps(kptr + 24);
1510                         _k7 = _mm_loadu_ps(kptr + 28);
1511 #if __AVX__
1512                         _sum0 = _mm_fmadd_ps(_r2, _k0, _sum0);
1513                         _sum1 = _mm_fmadd_ps(_r2, _k1, _sum1);
1514                         _sum2 = _mm_fmadd_ps(_r2, _k2, _sum2);
1515                         _sum3 = _mm_fmadd_ps(_r2, _k3, _sum3);
1516                         _sum4 = _mm_fmadd_ps(_r2, _k4, _sum4);
1517                         _sum5 = _mm_fmadd_ps(_r2, _k5, _sum5);
1518                         _sum6 = _mm_fmadd_ps(_r2, _k6, _sum6);
1519                         _sum7 = _mm_fmadd_ps(_r2, _k7, _sum7);
1520 #else
1521                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r2, _k0));
1522                         _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r2, _k1));
1523                         _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r2, _k2));
1524                         _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r2, _k3));
1525                         _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r2, _k4));
1526                         _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r2, _k5));
1527                         _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r2, _k6));
1528                         _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r2, _k7));
1529 #endif
1530                         kptr += 32;
1531                         _k0 = _mm_loadu_ps(kptr);
1532                         _k1 = _mm_loadu_ps(kptr + 4);
1533                         _k2 = _mm_loadu_ps(kptr + 8);
1534                         _k3 = _mm_loadu_ps(kptr + 12);
1535                         _k4 = _mm_loadu_ps(kptr + 16);
1536                         _k5 = _mm_loadu_ps(kptr + 20);
1537                         _k6 = _mm_loadu_ps(kptr + 24);
1538                         _k7 = _mm_loadu_ps(kptr + 28);
1539 #if __AVX__
1540                         _sum0 = _mm_fmadd_ps(_r3, _k0, _sum0);
1541                         _sum1 = _mm_fmadd_ps(_r3, _k1, _sum1);
1542                         _sum2 = _mm_fmadd_ps(_r3, _k2, _sum2);
1543                         _sum3 = _mm_fmadd_ps(_r3, _k3, _sum3);
1544                         _sum4 = _mm_fmadd_ps(_r3, _k4, _sum4);
1545                         _sum5 = _mm_fmadd_ps(_r3, _k5, _sum5);
1546                         _sum6 = _mm_fmadd_ps(_r3, _k6, _sum6);
1547                         _sum7 = _mm_fmadd_ps(_r3, _k7, _sum7);
1548 #else
1549                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r3, _k0));
1550                         _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r3, _k1));
1551                         _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r3, _k2));
1552                         _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r3, _k3));
1553                         _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r3, _k4));
1554                         _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r3, _k5));
1555                         _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r3, _k6));
1556                         _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r3, _k7));
1557 #endif
1558                         kptr += 32;
1559                         r0 += 16;
1560                     }
1561 
1562                     for (; q < inch; q++)
1563                     {
1564                         __m128 _r0 = _mm_loadu_ps(r0);
1565                         __m128 _k0 = _mm_loadu_ps(kptr);
1566                         __m128 _k1 = _mm_loadu_ps(kptr + 4);
1567                         __m128 _k2 = _mm_loadu_ps(kptr + 8);
1568                         __m128 _k3 = _mm_loadu_ps(kptr + 12);
1569                         __m128 _k4 = _mm_loadu_ps(kptr + 16);
1570                         __m128 _k5 = _mm_loadu_ps(kptr + 20);
1571                         __m128 _k6 = _mm_loadu_ps(kptr + 24);
1572                         __m128 _k7 = _mm_loadu_ps(kptr + 28);
1573 
1574 #if __AVX__
1575                         _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1576                         _sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
1577                         _sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
1578                         _sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);
1579                         _sum4 = _mm_fmadd_ps(_r0, _k4, _sum4);
1580                         _sum5 = _mm_fmadd_ps(_r0, _k5, _sum5);
1581                         _sum6 = _mm_fmadd_ps(_r0, _k6, _sum6);
1582                         _sum7 = _mm_fmadd_ps(_r0, _k7, _sum7);
1583 #else
1584                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1585                         _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
1586                         _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
1587                         _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
1588                         _sum4 = _mm_add_ps(_sum4, _mm_mul_ps(_r0, _k4));
1589                         _sum5 = _mm_add_ps(_sum5, _mm_mul_ps(_r0, _k5));
1590                         _sum6 = _mm_add_ps(_sum6, _mm_mul_ps(_r0, _k6));
1591                         _sum7 = _mm_add_ps(_sum7, _mm_mul_ps(_r0, _k7));
1592 #endif
1593 
1594                         kptr += 32;
1595                         r0 += 4;
1596                     }
1597 
1598                     _mm_storeu_ps(output0_tm, _sum0);
1599                     _mm_storeu_ps(output1_tm, _sum1);
1600                     _mm_storeu_ps(output2_tm, _sum2);
1601                     _mm_storeu_ps(output3_tm, _sum3);
1602                     _mm_storeu_ps(output4_tm, _sum4);
1603                     _mm_storeu_ps(output5_tm, _sum5);
1604                     _mm_storeu_ps(output6_tm, _sum6);
1605                     _mm_storeu_ps(output7_tm, _sum7);
1606 #else
1607                     float sum0[4] = {0};
1608                     float sum1[4] = {0};
1609                     float sum2[4] = {0};
1610                     float sum3[4] = {0};
1611                     float sum4[4] = {0};
1612                     float sum5[4] = {0};
1613                     float sum6[4] = {0};
1614                     float sum7[4] = {0};
1615 
1616                     for (int q = 0; q < inch; q++)
1617                     {
1618                         for (int n = 0; n < 4; n++)
1619                         {
1620                             sum0[n] += r0[n] * kptr[n];
1621                             sum1[n] += r0[n] * kptr[n + 4];
1622                             sum2[n] += r0[n] * kptr[n + 8];
1623                             sum3[n] += r0[n] * kptr[n + 12];
1624                             sum4[n] += r0[n] * kptr[n + 16];
1625                             sum5[n] += r0[n] * kptr[n + 20];
1626                             sum6[n] += r0[n] * kptr[n + 24];
1627                             sum7[n] += r0[n] * kptr[n + 28];
1628                         }
1629                         kptr += 32;
1630                         r0 += 4;
1631                     }
1632 
1633                     for (int n = 0; n < 4; n++)
1634                     {
1635                         output0_tm[n] = sum0[n];
1636                         output1_tm[n] = sum1[n];
1637                         output2_tm[n] = sum2[n];
1638                         output3_tm[n] = sum3[n];
1639                         output4_tm[n] = sum4[n];
1640                         output5_tm[n] = sum5[n];
1641                         output6_tm[n] = sum6[n];
1642                         output7_tm[n] = sum7[n];
1643                     }
1644 #endif // __AVX__
1645                     output0_tm += 36;
1646                     output1_tm += 36;
1647                     output2_tm += 36;
1648                     output3_tm += 36;
1649                     output4_tm += 36;
1650                     output5_tm += 36;
1651                     output6_tm += 36;
1652                     output7_tm += 36;
1653                 }
1654             }
1655 
1656             nn_outch = (outch - remain_outch_start) >> 2;
1657 
1658             for (int pp = 0; pp < nn_outch; pp++)
1659             {
1660                 int p = remain_outch_start + pp * 4;
1661 
1662                 float* output0_tm = top_blob_tm.channel(p);
1663                 float* output1_tm = top_blob_tm.channel(p + 1);
1664                 float* output2_tm = top_blob_tm.channel(p + 2);
1665                 float* output3_tm = top_blob_tm.channel(p + 3);
1666 
1667                 output0_tm = output0_tm + r * 4;
1668                 output1_tm = output1_tm + r * 4;
1669                 output2_tm = output2_tm + r * 4;
1670                 output3_tm = output3_tm + r * 4;
1671 
1672                 for (int i = 0; i < tiles; i++)
1673                 {
1674                     const float* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4);
1675                     const float* r0 = bottom_blob_tm.channel(tiles * r + i);
1676 #if __AVX__ || __SSE__
1677 #if __AVX__
1678                     float zero_val = 0.f;
1679                     __m128 _sum0 = _mm_broadcast_ss(&zero_val);
1680                     __m128 _sum1 = _mm_broadcast_ss(&zero_val);
1681                     __m128 _sum2 = _mm_broadcast_ss(&zero_val);
1682                     __m128 _sum3 = _mm_broadcast_ss(&zero_val);
1683 #else
1684                     __m128 _sum0 = _mm_set1_ps(0.f);
1685                     __m128 _sum1 = _mm_set1_ps(0.f);
1686                     __m128 _sum2 = _mm_set1_ps(0.f);
1687                     __m128 _sum3 = _mm_set1_ps(0.f);
1688 #endif
1689                     for (int q = 0; q < inch; q++)
1690                     {
1691                         __m128 _r0 = _mm_loadu_ps(r0);
1692                         __m128 _k0 = _mm_loadu_ps(kptr);
1693                         __m128 _k1 = _mm_loadu_ps(kptr + 4);
1694                         __m128 _k2 = _mm_loadu_ps(kptr + 8);
1695                         __m128 _k3 = _mm_loadu_ps(kptr + 12);
1696 #if __AVX__
1697                         _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1698                         _sum1 = _mm_fmadd_ps(_r0, _k1, _sum1);
1699                         _sum2 = _mm_fmadd_ps(_r0, _k2, _sum2);
1700                         _sum3 = _mm_fmadd_ps(_r0, _k3, _sum3);
1701 #else
1702                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1703                         _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_r0, _k1));
1704                         _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_r0, _k2));
1705                         _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_r0, _k3));
1706 #endif
1707                         kptr += 16;
1708                         r0 += 4;
1709                     }
1710 
1711                     _mm_storeu_ps(output0_tm, _sum0);
1712                     _mm_storeu_ps(output1_tm, _sum1);
1713                     _mm_storeu_ps(output2_tm, _sum2);
1714                     _mm_storeu_ps(output3_tm, _sum3);
1715 #else
1716                     float sum0[4] = {0};
1717                     float sum1[4] = {0};
1718                     float sum2[4] = {0};
1719                     float sum3[4] = {0};
1720 
1721                     for (int q = 0; q < inch; q++)
1722                     {
1723                         for (int n = 0; n < 4; n++)
1724                         {
1725                             sum0[n] += r0[n] * kptr[n];
1726                             sum1[n] += r0[n] * kptr[n + 4];
1727                             sum2[n] += r0[n] * kptr[n + 8];
1728                             sum3[n] += r0[n] * kptr[n + 12];
1729                         }
1730                         kptr += 16;
1731                         r0 += 4;
1732                     }
1733 
1734                     for (int n = 0; n < 4; n++)
1735                     {
1736                         output0_tm[n] = sum0[n];
1737                         output1_tm[n] = sum1[n];
1738                         output2_tm[n] = sum2[n];
1739                         output3_tm[n] = sum3[n];
1740                     }
1741 #endif // __AVX__
1742                     output0_tm += 36;
1743                     output1_tm += 36;
1744                     output2_tm += 36;
1745                     output3_tm += 36;
1746                 }
1747             }
1748 
1749             remain_outch_start += nn_outch << 2;
1750 
1751             for (int p = remain_outch_start; p < outch; p++)
1752             {
1753                 float* output0_tm = top_blob_tm.channel(p);
1754 
1755                 output0_tm = output0_tm + r * 4;
1756 
1757                 for (int i = 0; i < tiles; i++)
1758                 {
1759                     const float* kptr = kernel_tm_test[r].channel(p / 8 + (p % 8) / 4 + p % 4);
1760                     const float* r0 = bottom_blob_tm.channel(tiles * r + i);
1761 #if __AVX__ || __SSE__
1762 #if __AVX__
1763                     float zero_val = 0.f;
1764                     __m128 _sum0 = _mm_broadcast_ss(&zero_val);
1765 #else
1766                     __m128 _sum0 = _mm_set1_ps(0.f);
1767 #endif
1768 
1769                     for (int q = 0; q < inch; q++)
1770                     {
1771                         __m128 _r0 = _mm_loadu_ps(r0);
1772                         __m128 _k0 = _mm_loadu_ps(kptr);
1773 #if __AVX__
1774                         _sum0 = _mm_fmadd_ps(_r0, _k0, _sum0);
1775 #else
1776                         _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_r0, _k0));
1777 #endif
1778                         kptr += 16;
1779                         r0 += 4;
1780                     }
1781                     _mm_storeu_ps(output0_tm, _sum0);
1782 #else
1783                     float sum0[4] = {0};
1784 
1785                     for (int q = 0; q < inch; q++)
1786                     {
1787                         for (int n = 0; n < 4; n++)
1788                         {
1789                             sum0[n] += (int)r0[n] * kptr[n];
1790                         }
1791                         kptr += 4;
1792                         r0 += 4;
1793                     }
1794 
1795                     for (int n = 0; n < 4; n++)
1796                     {
1797                         output0_tm[n] = sum0[n];
1798                     }
1799 #endif // __AVX__ || __SSE__
1800                     output0_tm += 36;
1801                 }
1802             }
1803 
1804             // for (int p=0; p<outch; p++)
1805             // {
1806             //     Mat out0_tm = top_blob_tm.channel(p);
1807             //     const Mat kernel0_tm = kernel_tm.channel(p);
1808 
1809             //     for (int i=0; i<tiles; i++)
1810             //     {
1811             //         float* output0_tm = out0_tm.row<int>(i);
1812 
1813             //         int sum0[36] = {0};
1814 
1815             //         for (int q=0; q<inch; q++)
1816             //         {
1817             //             const float* r0 = bottom_blob_tm.channel(q).row<float>(i);
1818             //             const float* k0 = kernel0_tm.row<float>(q);
1819 
1820             //             for (int n=0; n<36; n++)
1821             //             {
1822             //                 sum0[n] += (int)r0[n] * k0[n];
1823             //             }
1824             //         }
1825 
1826             //         for (int n=0; n<36; n++)
1827             //         {
1828             //             output0_tm[n] = sum0[n];
1829             //         }
1830             //     }
1831             // }
1832         }
1833     }
1834     bottom_blob_tm = Mat();
1835     // END dot
1836 
1837     // BEGIN transform output
1838     Mat top_blob_bordered;
1839     if (outw == top_blob.w && outh == top_blob.h)
1840     {
1841         top_blob_bordered = top_blob;
1842     }
1843     else
1844     {
1845         top_blob_bordered.create(outw, outh, outch, elemsize, opt.workspace_allocator);
1846     }
1847     {
1848         // AT
1849         // const float itm[4][6] = {
1850         //     {1.0f, 1.0f,  1.0f, 1.0f,  1.0f, 0.0f},
1851         //     {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.0f},
1852         //     {0.0f, 1.0f,  1.0f, 4.0f,  4.0f, 0.0f},
1853         //     {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 1.0f}
1854         // };
1855 
1856         // 0 =	r00 + r01 + r02 + r03 +	r04
1857         // 1 =		  r01 - r02 + 2 * (r03 - r04)
1858         // 2 =		  r01 + r02 + 4 * (r03 + r04)
1859         // 3 =		  r01 - r02 + 8 * (r03 - r04)  + r05
1860 
1861         int w_tm = outw / 4 * 6;
1862         int h_tm = outh / 4 * 6;
1863 
1864         int nColBlocks = h_tm / 6; // may be the block num in Feathercnn
1865         int nRowBlocks = w_tm / 6;
1866 
1867         #pragma omp parallel for num_threads(opt.num_threads)
1868         for (int p = 0; p < outch; p++)
1869         {
1870             float* out_tile = top_blob_tm.channel(p);
1871             float* outRow0 = top_blob_bordered.channel(p);
1872             float* outRow1 = outRow0 + outw;
1873             float* outRow2 = outRow0 + outw * 2;
1874             float* outRow3 = outRow0 + outw * 3;
1875 
1876             const float bias0 = bias ? bias[p] : 0.f;
1877 
1878             for (int j = 0; j < nColBlocks; j++)
1879             {
1880                 for (int i = 0; i < nRowBlocks; i++)
1881                 {
1882                     // TODO AVX2
1883                     float s0[6], s1[6], s2[6], s3[6], s4[6], s5[6];
1884                     float w0[6], w1[6], w2[6], w3[6];
1885                     float d0[4], d1[4], d2[4], d3[4], d4[4], d5[4];
1886                     float o0[4], o1[4], o2[4], o3[4];
1887 
1888                     // load
1889                     for (int n = 0; n < 6; n++)
1890                     {
1891                         s0[n] = out_tile[n];
1892                         s1[n] = out_tile[n + 6];
1893                         s2[n] = out_tile[n + 12];
1894                         s3[n] = out_tile[n + 18];
1895                         s4[n] = out_tile[n + 24];
1896                         s5[n] = out_tile[n + 30];
1897                     }
1898                     // w = A_T * W
1899                     for (int n = 0; n < 6; n++)
1900                     {
1901                         w0[n] = s0[n] + s1[n] + s2[n] + s3[n] + s4[n];
1902                         w1[n] = s1[n] - s2[n] + 2 * s3[n] - 2 * s4[n];
1903                         w2[n] = s1[n] + s2[n] + 4 * s3[n] + 4 * s4[n];
1904                         w3[n] = s1[n] - s2[n] + 8 * s3[n] - 8 * s4[n] + s5[n];
1905                     }
1906                     // transpose w to w_t
1907                     {
1908                         d0[0] = w0[0];
1909                         d0[1] = w1[0];
1910                         d0[2] = w2[0];
1911                         d0[3] = w3[0];
1912                         d1[0] = w0[1];
1913                         d1[1] = w1[1];
1914                         d1[2] = w2[1];
1915                         d1[3] = w3[1];
1916                         d2[0] = w0[2];
1917                         d2[1] = w1[2];
1918                         d2[2] = w2[2];
1919                         d2[3] = w3[2];
1920                         d3[0] = w0[3];
1921                         d3[1] = w1[3];
1922                         d3[2] = w2[3];
1923                         d3[3] = w3[3];
1924                         d4[0] = w0[4];
1925                         d4[1] = w1[4];
1926                         d4[2] = w2[4];
1927                         d4[3] = w3[4];
1928                         d5[0] = w0[5];
1929                         d5[1] = w1[5];
1930                         d5[2] = w2[5];
1931                         d5[3] = w3[5];
1932                     }
1933                     // Y = A_T * w_t
1934                     for (int n = 0; n < 4; n++)
1935                     {
1936                         o0[n] = d0[n] + d1[n] + d2[n] + d3[n] + d4[n];
1937                         o1[n] = d1[n] - d2[n] + 2 * d3[n] - 2 * d4[n];
1938                         o2[n] = d1[n] + d2[n] + 4 * d3[n] + 4 * d4[n];
1939                         o3[n] = d1[n] - d2[n] + 8 * d3[n] - 8 * d4[n] + d5[n];
1940                     }
1941                     // save to top blob tm
1942                     for (int n = 0; n < 4; n++)
1943                     {
1944                         outRow0[n] = o0[n] + bias0;
1945                         outRow1[n] = o1[n] + bias0;
1946                         outRow2[n] = o2[n] + bias0;
1947                         outRow3[n] = o3[n] + bias0;
1948                     }
1949 
1950                     out_tile += 36;
1951 
1952                     outRow0 += 4;
1953                     outRow1 += 4;
1954                     outRow2 += 4;
1955                     outRow3 += 4;
1956                 }
1957 
1958                 outRow0 += outw * 3;
1959                 outRow1 += outw * 3;
1960                 outRow2 += outw * 3;
1961                 outRow3 += outw * 3;
1962             }
1963         }
1964     }
1965     // END transform output
1966 
1967     // cut result pad
1968     copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
1969 }
1970 
conv3x3s2_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)1971 static void conv3x3s2_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
1972 {
1973     int w = bottom_blob.w;
1974     int inch = bottom_blob.c;
1975 
1976     int outw = top_blob.w;
1977     int outh = top_blob.h;
1978     int outch = top_blob.c;
1979 
1980     const int tailstep = w - 2 * outw + w;
1981 
1982     const float* kernel = _kernel;
1983     const float* bias = _bias;
1984 
1985     #pragma omp parallel for num_threads(opt.num_threads)
1986     for (int p = 0; p < outch; p++)
1987     {
1988         Mat out = top_blob.channel(p);
1989 
1990         const float bias0 = bias ? bias[p] : 0.f;
1991 
1992         out.fill(bias0);
1993 
1994         for (int q = 0; q < inch; q++)
1995         {
1996             float* outptr = out;
1997 
1998             const float* img = bottom_blob.channel(q);
1999             const float* kernel0 = kernel + p * inch * 9 + q * 9;
2000 
2001             const float* r0 = img;
2002             const float* r1 = img + w;
2003             const float* r2 = img + w * 2;
2004 
2005             const float* k0 = kernel0;
2006             const float* k1 = kernel0 + 3;
2007             const float* k2 = kernel0 + 6;
2008 
2009             for (int i = 0; i < outh; i++)
2010             {
2011                 int remain = outw;
2012 
2013                 for (; remain > 0; remain--)
2014                 {
2015                     float sum = 0;
2016 
2017                     sum += r0[0] * k0[0];
2018                     sum += r0[1] * k0[1];
2019                     sum += r0[2] * k0[2];
2020                     sum += r1[0] * k1[0];
2021                     sum += r1[1] * k1[1];
2022                     sum += r1[2] * k1[2];
2023                     sum += r2[0] * k2[0];
2024                     sum += r2[1] * k2[1];
2025                     sum += r2[2] * k2[2];
2026 
2027                     *outptr += sum;
2028 
2029                     r0 += 2;
2030                     r1 += 2;
2031                     r2 += 2;
2032                     outptr++;
2033                 }
2034 
2035                 r0 += tailstep;
2036                 r1 += tailstep;
2037                 r2 += tailstep;
2038             }
2039         }
2040     }
2041 }
2042