1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #if __AVX__
conv_im2col_sgemm_transform_kernel_sse(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)16 static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
17 {
18     const float* kernel = _kernel;
19 
20     // kernel memory packed 8 x 8
21     kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4);
22 
23     int nn_outch = 0;
24     int remain_outch_start = 0;
25 
26     nn_outch = outch >> 3;
27     remain_outch_start = nn_outch << 3;
28 
29     for (int pp = 0; pp < nn_outch; pp++)
30     {
31         int p = pp * 8;
32 
33         const float* k0 = kernel + (p + 0) * inch * kernel_size;
34         const float* k1 = kernel + (p + 1) * inch * kernel_size;
35         const float* k2 = kernel + (p + 2) * inch * kernel_size;
36         const float* k3 = kernel + (p + 3) * inch * kernel_size;
37         const float* k4 = kernel + (p + 4) * inch * kernel_size;
38         const float* k5 = kernel + (p + 5) * inch * kernel_size;
39         const float* k6 = kernel + (p + 6) * inch * kernel_size;
40         const float* k7 = kernel + (p + 7) * inch * kernel_size;
41 
42         float* ktmp = kernel_tm.channel(p / 8);
43 
44         for (int q = 0; q < inch * kernel_size; q++)
45         {
46             ktmp[0] = k0[0];
47             ktmp[1] = k1[0];
48             ktmp[2] = k2[0];
49             ktmp[3] = k3[0];
50             ktmp[4] = k4[0];
51             ktmp[5] = k5[0];
52             ktmp[6] = k6[0];
53             ktmp[7] = k7[0];
54             ktmp += 8;
55 
56             k0 += 1;
57             k1 += 1;
58             k2 += 1;
59             k3 += 1;
60             k4 += 1;
61             k5 += 1;
62             k6 += 1;
63             k7 += 1;
64         }
65     }
66 
67     nn_outch = (outch - remain_outch_start) >> 2;
68 
69     for (int pp = 0; pp < nn_outch; pp++)
70     {
71         int p = remain_outch_start + pp * 4;
72 
73         const float* k0 = kernel + (p + 0) * inch * kernel_size;
74         const float* k1 = kernel + (p + 1) * inch * kernel_size;
75         const float* k2 = kernel + (p + 2) * inch * kernel_size;
76         const float* k3 = kernel + (p + 3) * inch * kernel_size;
77 
78         float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4);
79 
80         for (int q = 0; q < inch * kernel_size; q++)
81         {
82             ktmp[0] = k0[0];
83             ktmp[1] = k1[0];
84             ktmp[2] = k2[0];
85             ktmp[3] = k3[0];
86             ktmp += 4;
87 
88             k0 += 1;
89             k1 += 1;
90             k2 += 1;
91             k3 += 1;
92         }
93     }
94 
95     remain_outch_start += nn_outch << 2;
96 
97     for (int p = remain_outch_start; p < outch; p++)
98     {
99         const float* k0 = kernel + (p + 0) * inch * kernel_size;
100 
101         float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4);
102 
103         for (int q = 0; q < inch * kernel_size; q++)
104         {
105             ktmp[0] = k0[0];
106             ktmp++;
107             k0++;
108         }
109     }
110 }
111 
conv_im2col_sgemm_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)112 static void conv_im2col_sgemm_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias,
113                                   const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
114 {
115     int w = bottom_blob.w;
116     int inch = bottom_blob.c;
117     size_t elemsize = bottom_blob.elemsize;
118 
119     int outw = top_blob.w;
120     int outh = top_blob.h;
121     int outch = top_blob.c;
122 
123     const float* bias = _bias;
124 
125     // im2col
126     Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator);
127     {
128         const int stride = kernel_h * kernel_w * outw * outh;
129         float* ret = (float*)bottom_im2col;
130 
131         #pragma omp parallel for num_threads(opt.num_threads)
132         for (int p = 0; p < inch; p++)
133         {
134             const float* input = bottom_blob.channel(p);
135             int retID = stride * p;
136             for (int u = 0; u < kernel_h; u++)
137             {
138                 for (int v = 0; v < kernel_w; v++)
139                 {
140                     for (int i = 0; i < outh; i++)
141                     {
142                         for (int j = 0; j < outw; j++)
143                         {
144                             int row = u + i * stride_h;
145                             int col = v + j * stride_w;
146                             int index = row * w + col;
147                             ret[retID] = input[index];
148                             retID++;
149                         }
150                     }
151                 }
152             }
153         }
154     }
155 
156     int kernel_size = kernel_w * kernel_h;
157     int out_size = outw * outh;
158 
159     // bottom_im2col memory packed 8 x 8
160     Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, elemsize, opt.workspace_allocator);
161     {
162         int nn_size = out_size >> 3;
163         int remain_size_start = nn_size << 3;
164 
165         #pragma omp parallel for num_threads(opt.num_threads)
166         for (int ii = 0; ii < nn_size; ii++)
167         {
168             int i = ii * 8;
169 
170             const float* img0 = bottom_im2col.channel(0);
171             img0 += i;
172 
173             float* tmpptr = bottom_tm.channel(i / 8);
174 
175             for (int q = 0; q < inch * kernel_size; q++)
176             {
177 #if __AVX__
178                 _mm256_storeu_ps(tmpptr, _mm256_loadu_ps(img0));
179 #else
180                 tmpptr[0] = img0[0];
181                 tmpptr[1] = img0[1];
182                 tmpptr[2] = img0[2];
183                 tmpptr[3] = img0[3];
184                 tmpptr[4] = img0[4];
185                 tmpptr[5] = img0[5];
186                 tmpptr[6] = img0[6];
187                 tmpptr[7] = img0[7];
188 #endif // __SSE__
189                 tmpptr += 8;
190                 img0 += out_size;
191             }
192         }
193 
194         #pragma omp parallel for num_threads(opt.num_threads)
195         for (int i = remain_size_start; i < out_size; i++)
196         {
197             const float* img0 = bottom_im2col.channel(0);
198             img0 += i;
199 
200             float* tmpptr = bottom_tm.channel(i / 8 + i % 8);
201 
202             for (int q = 0; q < inch * kernel_size; q++)
203             {
204                 tmpptr[0] = img0[0];
205 
206                 tmpptr += 1;
207                 img0 += out_size;
208             }
209         }
210     }
211 
212     // sgemm(int M, int N, int L, float* A, float* B, float* C)
213     {
214         //int M = outch;                    // outch
215         int N = outw * outh;                // outsize or out stride
216         int L = kernel_w * kernel_h * inch; // ksize * inch
217 
218         int nn_outch = 0;
219         int remain_outch_start = 0;
220 
221         nn_outch = outch >> 3;
222         remain_outch_start = nn_outch << 3;
223 
224         #pragma omp parallel for num_threads(opt.num_threads)
225         for (int pp = 0; pp < nn_outch; pp++)
226         {
227             int i = pp * 8;
228 
229             float* output0 = top_blob.channel(i);
230             float* output1 = top_blob.channel(i + 1);
231             float* output2 = top_blob.channel(i + 2);
232             float* output3 = top_blob.channel(i + 3);
233             float* output4 = top_blob.channel(i + 4);
234             float* output5 = top_blob.channel(i + 5);
235             float* output6 = top_blob.channel(i + 6);
236             float* output7 = top_blob.channel(i + 7);
237 
238             const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
239             const float* biasptr = bias ? bias + i : zeros;
240 
241             int j = 0;
242             for (; j + 7 < N; j = j + 8)
243             {
244                 const float* vb = bottom_tm.channel(j / 8);
245                 const float* va = kernel_tm.channel(i / 8);
246 #if __AVX__
247                 __m256 _sum0 = _mm256_broadcast_ss(biasptr);
248                 __m256 _sum1 = _mm256_broadcast_ss(biasptr + 1);
249                 __m256 _sum2 = _mm256_broadcast_ss(biasptr + 2);
250                 __m256 _sum3 = _mm256_broadcast_ss(biasptr + 3);
251                 __m256 _sum4 = _mm256_broadcast_ss(biasptr + 4);
252                 __m256 _sum5 = _mm256_broadcast_ss(biasptr + 5);
253                 __m256 _sum6 = _mm256_broadcast_ss(biasptr + 6);
254                 __m256 _sum7 = _mm256_broadcast_ss(biasptr + 7);
255 
256                 int k = 0;
257                 for (; k + 3 < L; k = k + 4)
258                 {
259                     // k0
260                     __m256 _va0 = _mm256_broadcast_ss(va);
261                     __m256 _va1 = _mm256_broadcast_ss(va + 1);
262                     __m256 _va2 = _mm256_broadcast_ss(va + 2);
263                     __m256 _va3 = _mm256_broadcast_ss(va + 3);
264                     __m256 _vb0 = _mm256_loadu_ps(vb);
265                     __m256 _vb1 = _mm256_loadu_ps(vb + 8);
266                     __m256 _vb2 = _mm256_loadu_ps(vb + 16);
267                     __m256 _vb3 = _mm256_loadu_ps(vb + 24);
268                     _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
269                     _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
270                     _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
271                     _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
272                     _va0 = _mm256_broadcast_ss(va + 4);
273                     _va1 = _mm256_broadcast_ss(va + 5);
274                     _va2 = _mm256_broadcast_ss(va + 6);
275                     _va3 = _mm256_broadcast_ss(va + 7);
276                     _sum4 = _mm256_fmadd_ps(_vb0, _va0, _sum4); // sum4 = (a00-a07) * k40
277                     _sum5 = _mm256_fmadd_ps(_vb0, _va1, _sum5); // sum5 = (a00-a07) * k50
278                     _sum6 = _mm256_fmadd_ps(_vb0, _va2, _sum6); // sum6 = (a00-a07) * k60
279                     _sum7 = _mm256_fmadd_ps(_vb0, _va3, _sum7); // sum7 = (a00-a07) * k70
280 
281                     va += 8;
282 
283                     // k1
284                     _va0 = _mm256_broadcast_ss(va);
285                     _va1 = _mm256_broadcast_ss(va + 1);
286                     _va2 = _mm256_broadcast_ss(va + 2);
287                     _va3 = _mm256_broadcast_ss(va + 3);
288                     _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
289                     _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
290                     _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
291                     _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
292                     _va0 = _mm256_broadcast_ss(va + 4);
293                     _va1 = _mm256_broadcast_ss(va + 5);
294                     _va2 = _mm256_broadcast_ss(va + 6);
295                     _va3 = _mm256_broadcast_ss(va + 7);
296                     _sum4 = _mm256_fmadd_ps(_vb1, _va0, _sum4); // sum4 += (a10-a17) * k41
297                     _sum5 = _mm256_fmadd_ps(_vb1, _va1, _sum5); // sum5 += (a10-a17) * k51
298                     _sum6 = _mm256_fmadd_ps(_vb1, _va2, _sum6); // sum6 += (a10-a17) * k61
299                     _sum7 = _mm256_fmadd_ps(_vb1, _va3, _sum7); // sum7 += (a10-a17) * k71
300 
301                     va += 8;
302 
303                     // k2
304                     _va0 = _mm256_broadcast_ss(va);
305                     _va1 = _mm256_broadcast_ss(va + 1);
306                     _va2 = _mm256_broadcast_ss(va + 2);
307                     _va3 = _mm256_broadcast_ss(va + 3);
308                     _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
309                     _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
310                     _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
311                     _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
312                     _va0 = _mm256_broadcast_ss(va + 4);
313                     _va1 = _mm256_broadcast_ss(va + 5);
314                     _va2 = _mm256_broadcast_ss(va + 6);
315                     _va3 = _mm256_broadcast_ss(va + 7);
316                     _sum4 = _mm256_fmadd_ps(_vb2, _va0, _sum4); // sum4 += (a20-a27) * k42
317                     _sum5 = _mm256_fmadd_ps(_vb2, _va1, _sum5); // sum5 += (a20-a27) * k52
318                     _sum6 = _mm256_fmadd_ps(_vb2, _va2, _sum6); // sum6 += (a20-a27) * k62
319                     _sum7 = _mm256_fmadd_ps(_vb2, _va3, _sum7); // sum7 += (a20-a27) * k72
320 
321                     va += 8;
322 
323                     // k3
324                     _va0 = _mm256_broadcast_ss(va);
325                     _va1 = _mm256_broadcast_ss(va + 1);
326                     _va2 = _mm256_broadcast_ss(va + 2);
327                     _va3 = _mm256_broadcast_ss(va + 3);
328                     _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
329                     _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
330                     _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
331                     _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
332                     _va0 = _mm256_broadcast_ss(va + 4);
333                     _va1 = _mm256_broadcast_ss(va + 5);
334                     _va2 = _mm256_broadcast_ss(va + 6);
335                     _va3 = _mm256_broadcast_ss(va + 7);
336                     _sum4 = _mm256_fmadd_ps(_vb3, _va0, _sum4); // sum4 += (a30-a37) * k43
337                     _sum5 = _mm256_fmadd_ps(_vb3, _va1, _sum5); // sum5 += (a30-a37) * k53
338                     _sum6 = _mm256_fmadd_ps(_vb3, _va2, _sum6); // sum6 += (a30-a37) * k63
339                     _sum7 = _mm256_fmadd_ps(_vb3, _va3, _sum7); // sum7 += (a30-a37) * k73
340 
341                     va += 8;
342                     vb += 32;
343                 }
344 
345                 for (; k < L; k++)
346                 {
347                     // k0
348                     __m256 _va0 = _mm256_broadcast_ss(va);
349                     __m256 _va1 = _mm256_broadcast_ss(va + 1);
350                     __m256 _va2 = _mm256_broadcast_ss(va + 2);
351                     __m256 _va3 = _mm256_broadcast_ss(va + 3);
352                     __m256 _va4 = _mm256_broadcast_ss(va + 4);
353                     __m256 _va5 = _mm256_broadcast_ss(va + 5);
354                     __m256 _va6 = _mm256_broadcast_ss(va + 6);
355                     __m256 _va7 = _mm256_broadcast_ss(va + 7);
356                     __m256 _vb0 = _mm256_loadu_ps(vb);
357                     _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
358                     _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
359                     _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
360                     _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
361                     _sum4 = _mm256_fmadd_ps(_vb0, _va4, _sum4); // sum4 = (a00-a07) * k40
362                     _sum5 = _mm256_fmadd_ps(_vb0, _va5, _sum5); // sum5 = (a00-a07) * k50
363                     _sum6 = _mm256_fmadd_ps(_vb0, _va6, _sum6); // sum6 = (a00-a07) * k60
364                     _sum7 = _mm256_fmadd_ps(_vb0, _va7, _sum7); // sum7 = (a00-a07) * k70
365 
366                     va += 8;
367                     vb += 8;
368                 }
369 
370                 _mm256_storeu_ps(output0, _sum0);
371                 _mm256_storeu_ps(output1, _sum1);
372                 _mm256_storeu_ps(output2, _sum2);
373                 _mm256_storeu_ps(output3, _sum3);
374                 _mm256_storeu_ps(output4, _sum4);
375                 _mm256_storeu_ps(output5, _sum5);
376                 _mm256_storeu_ps(output6, _sum6);
377                 _mm256_storeu_ps(output7, _sum7);
378 #else
379                 float sum0[8] = {0};
380                 float sum1[8] = {0};
381                 float sum2[8] = {0};
382                 float sum3[8] = {0};
383                 float sum4[8] = {0};
384                 float sum5[8] = {0};
385                 float sum6[8] = {0};
386                 float sum7[8] = {0};
387 
388                 int k = 0;
389                 for (; k + 7 < L; k = k + 8)
390                 {
391                     for (int n = 0; n < 8; n++)
392                     {
393                         sum0[n] += va[0] * vb[n];
394                         sum1[n] += va[1] * vb[n];
395                         sum2[n] += va[2] * vb[n];
396                         sum3[n] += va[3] * vb[n];
397                         sum4[n] += va[4] * vb[n];
398                         sum5[n] += va[5] * vb[n];
399                         sum6[n] += va[6] * vb[n];
400                         sum7[n] += va[7] * vb[n];
401                         va += 8;
402 
403                         sum0[n] += va[0] * vb[n + 8];
404                         sum1[n] += va[1] * vb[n + 8];
405                         sum2[n] += va[2] * vb[n + 8];
406                         sum3[n] += va[3] * vb[n + 8];
407                         sum4[n] += va[4] * vb[n + 8];
408                         sum5[n] += va[5] * vb[n + 8];
409                         sum6[n] += va[6] * vb[n + 8];
410                         sum7[n] += va[7] * vb[n + 8];
411                         va += 8;
412 
413                         sum0[n] += va[0] * vb[n + 16];
414                         sum1[n] += va[1] * vb[n + 16];
415                         sum2[n] += va[2] * vb[n + 16];
416                         sum3[n] += va[3] * vb[n + 16];
417                         sum4[n] += va[4] * vb[n + 16];
418                         sum5[n] += va[5] * vb[n + 16];
419                         sum6[n] += va[6] * vb[n + 16];
420                         sum7[n] += va[7] * vb[n + 16];
421                         va += 8;
422 
423                         sum0[n] += va[0] * vb[n + 24];
424                         sum1[n] += va[1] * vb[n + 24];
425                         sum2[n] += va[2] * vb[n + 24];
426                         sum3[n] += va[3] * vb[n + 24];
427                         sum4[n] += va[4] * vb[n + 24];
428                         sum5[n] += va[5] * vb[n + 24];
429                         sum6[n] += va[6] * vb[n + 24];
430                         sum7[n] += va[7] * vb[n + 24];
431                         va += 8;
432 
433                         sum0[n] += va[0] * vb[n + 32];
434                         sum1[n] += va[1] * vb[n + 32];
435                         sum2[n] += va[2] * vb[n + 32];
436                         sum3[n] += va[3] * vb[n + 32];
437                         sum4[n] += va[4] * vb[n + 32];
438                         sum5[n] += va[5] * vb[n + 32];
439                         sum6[n] += va[6] * vb[n + 32];
440                         sum7[n] += va[7] * vb[n + 32];
441                         va += 8;
442 
443                         sum0[n] += va[0] * vb[n + 40];
444                         sum1[n] += va[1] * vb[n + 40];
445                         sum2[n] += va[2] * vb[n + 40];
446                         sum3[n] += va[3] * vb[n + 40];
447                         sum4[n] += va[4] * vb[n + 40];
448                         sum5[n] += va[5] * vb[n + 40];
449                         sum6[n] += va[6] * vb[n + 40];
450                         sum7[n] += va[7] * vb[n + 40];
451                         va += 8;
452 
453                         sum0[n] += va[0] * vb[n + 48];
454                         sum1[n] += va[1] * vb[n + 48];
455                         sum2[n] += va[2] * vb[n + 48];
456                         sum3[n] += va[3] * vb[n + 48];
457                         sum4[n] += va[4] * vb[n + 48];
458                         sum5[n] += va[5] * vb[n + 48];
459                         sum6[n] += va[6] * vb[n + 48];
460                         sum7[n] += va[7] * vb[n + 48];
461                         va += 8;
462 
463                         sum0[n] += va[0] * vb[n + 56];
464                         sum1[n] += va[1] * vb[n + 56];
465                         sum2[n] += va[2] * vb[n + 56];
466                         sum3[n] += va[3] * vb[n + 56];
467                         sum4[n] += va[4] * vb[n + 56];
468                         sum5[n] += va[5] * vb[n + 56];
469                         sum6[n] += va[6] * vb[n + 56];
470                         sum7[n] += va[7] * vb[n + 56];
471                         va -= 56;
472                     }
473 
474                     va += 64;
475                     vb += 64;
476                 }
477 
478                 for (; k < L; k++)
479                 {
480                     for (int n = 0; n < 8; n++)
481                     {
482                         sum0[n] += va[0] * vb[n];
483                         sum1[n] += va[1] * vb[n];
484                         sum2[n] += va[2] * vb[n];
485                         sum3[n] += va[3] * vb[n];
486                         sum4[n] += va[4] * vb[n];
487                         sum5[n] += va[5] * vb[n];
488                         sum6[n] += va[6] * vb[n];
489                         sum7[n] += va[7] * vb[n];
490                     }
491 
492                     va += 8;
493                     vb += 8;
494                 }
495 
496                 for (int n = 0; n < 8; n++)
497                 {
498                     output0[n] = sum0[n] + biasptr[0];
499                     output1[n] = sum1[n] + biasptr[1];
500                     output2[n] = sum2[n] + biasptr[2];
501                     output3[n] = sum3[n] + biasptr[3];
502                     output4[n] = sum4[n] + biasptr[4];
503                     output5[n] = sum5[n] + biasptr[5];
504                     output6[n] = sum6[n] + biasptr[6];
505                     output7[n] = sum7[n] + biasptr[7];
506                 }
507 #endif // __AVX__
508                 output0 += 8;
509                 output1 += 8;
510                 output2 += 8;
511                 output3 += 8;
512                 output4 += 8;
513                 output5 += 8;
514                 output6 += 8;
515                 output7 += 8;
516             }
517 
518             for (; j < N; j++)
519             {
520                 const float* vb = bottom_tm.channel(j / 8 + j % 8);
521                 const float* va = kernel_tm.channel(i / 8);
522 
523 #if __AVX__
524                 __m256 _sum0_7 = _mm256_loadu_ps(biasptr);
525                 __m256 _sum0 = _mm256_set1_ps(0.0);
526                 __m256 _sum1 = _mm256_set1_ps(0.0);
527                 __m256 _sum2 = _mm256_set1_ps(0.0);
528                 __m256 _sum3 = _mm256_set1_ps(0.0);
529 
530                 int k = 0;
531                 for (; k + 3 < L; k = k + 4)
532                 {
533                     __m256 _vb0 = _mm256_broadcast_ss(vb);
534                     __m256 _vb1 = _mm256_broadcast_ss(vb + 1);
535                     __m256 _vb2 = _mm256_broadcast_ss(vb + 2);
536                     __m256 _vb3 = _mm256_broadcast_ss(vb + 3);
537                     __m256 _va0 = _mm256_loadu_ps(va);
538                     __m256 _va1 = _mm256_loadu_ps(va + 8);
539                     __m256 _va2 = _mm256_loadu_ps(va + 16);
540                     __m256 _va3 = _mm256_loadu_ps(va + 24);
541 
542                     _sum0 = _mm256_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k70) * a00
543                     _sum1 = _mm256_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k71) * a10
544                     _sum2 = _mm256_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k72) * a20
545                     _sum3 = _mm256_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k73) * a30
546 
547                     va += 32;
548                     vb += 4;
549                 }
550 
551                 _sum0 = _mm256_add_ps(_sum0, _sum1);
552                 _sum2 = _mm256_add_ps(_sum2, _sum3);
553                 _sum0_7 = _mm256_add_ps(_sum0_7, _sum0);
554                 _sum0_7 = _mm256_add_ps(_sum0_7, _sum2);
555 
556                 for (; k < L; k++)
557                 {
558                     __m256 _vb0 = _mm256_broadcast_ss(vb);
559                     __m256 _va = _mm256_loadu_ps(va);
560 
561                     _sum0_7 = _mm256_fmadd_ps(_va, _vb0, _sum0_7); // sum0 += (k00-k70) * a00
562 
563                     va += 8;
564                     vb += 1;
565                 }
566 
567                 float output_sum0_7[8] = {0.f};
568                 _mm256_storeu_ps(output_sum0_7, _sum0_7);
569 
570                 output0[0] = output_sum0_7[0];
571                 output1[0] = output_sum0_7[1];
572                 output2[0] = output_sum0_7[2];
573                 output3[0] = output_sum0_7[3];
574                 output4[0] = output_sum0_7[4];
575                 output5[0] = output_sum0_7[5];
576                 output6[0] = output_sum0_7[6];
577                 output7[0] = output_sum0_7[7];
578 #else
579                 float sum0 = biasptr[0];
580                 float sum1 = biasptr[1];
581                 float sum2 = biasptr[2];
582                 float sum3 = biasptr[3];
583                 float sum4 = biasptr[4];
584                 float sum5 = biasptr[5];
585                 float sum6 = biasptr[6];
586                 float sum7 = biasptr[7];
587 
588                 for (int k = 0; k < L; k++)
589                 {
590                     sum0 += va[0] * vb[0];
591                     sum1 += va[1] * vb[0];
592                     sum2 += va[2] * vb[0];
593                     sum3 += va[3] * vb[0];
594                     sum4 += va[4] * vb[0];
595                     sum5 += va[5] * vb[0];
596                     sum6 += va[6] * vb[0];
597                     sum7 += va[7] * vb[0];
598 
599                     va += 8;
600                     vb += 1;
601                 }
602 
603                 output0[0] = sum0;
604                 output1[0] = sum1;
605                 output2[0] = sum2;
606                 output3[0] = sum3;
607                 output4[0] = sum4;
608                 output5[0] = sum5;
609                 output6[0] = sum6;
610                 output7[0] = sum7;
611 #endif // __AVX__
612                 output0++;
613                 output1++;
614                 output2++;
615                 output3++;
616                 output4++;
617                 output5++;
618                 output6++;
619                 output7++;
620             }
621         }
622 
623         nn_outch = (outch - remain_outch_start) >> 2;
624 
625         #pragma omp parallel for num_threads(opt.num_threads)
626         for (int pp = 0; pp < nn_outch; pp++)
627         {
628             int i = remain_outch_start + pp * 4;
629 
630             float* output0 = top_blob.channel(i);
631             float* output1 = top_blob.channel(i + 1);
632             float* output2 = top_blob.channel(i + 2);
633             float* output3 = top_blob.channel(i + 3);
634 
635             const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
636             const float* biasptr = bias ? bias + i : zeros;
637 
638             int j = 0;
639             for (; j + 7 < N; j = j + 8)
640             {
641                 const float* vb = bottom_tm.channel(j / 8);
642                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
643 #if __AVX__
644                 __m256 _sum0 = _mm256_broadcast_ss(biasptr);
645                 __m256 _sum1 = _mm256_broadcast_ss(biasptr + 1);
646                 __m256 _sum2 = _mm256_broadcast_ss(biasptr + 2);
647                 __m256 _sum3 = _mm256_broadcast_ss(biasptr + 3);
648 
649                 int k = 0;
650                 for (; k + 3 < L; k = k + 4)
651                 {
652                     // k0
653                     __m256 _va0 = _mm256_broadcast_ss(va);
654                     __m256 _va1 = _mm256_broadcast_ss(va + 1);
655                     __m256 _va2 = _mm256_broadcast_ss(va + 2);
656                     __m256 _va3 = _mm256_broadcast_ss(va + 3);
657                     __m256 _vb0 = _mm256_loadu_ps(vb);
658                     __m256 _vb1 = _mm256_loadu_ps(vb + 8);
659                     __m256 _vb2 = _mm256_loadu_ps(vb + 16);
660                     __m256 _vb3 = _mm256_loadu_ps(vb + 24);
661                     _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
662                     _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
663                     _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
664                     _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
665 
666                     va += 4;
667 
668                     // k1
669                     _va0 = _mm256_broadcast_ss(va);
670                     _va1 = _mm256_broadcast_ss(va + 1);
671                     _va2 = _mm256_broadcast_ss(va + 2);
672                     _va3 = _mm256_broadcast_ss(va + 3);
673                     _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
674                     _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
675                     _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
676                     _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
677 
678                     va += 4;
679 
680                     // k2
681                     _va0 = _mm256_broadcast_ss(va);
682                     _va1 = _mm256_broadcast_ss(va + 1);
683                     _va2 = _mm256_broadcast_ss(va + 2);
684                     _va3 = _mm256_broadcast_ss(va + 3);
685                     _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
686                     _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
687                     _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
688                     _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
689 
690                     va += 4;
691 
692                     // k3
693                     _va0 = _mm256_broadcast_ss(va);
694                     _va1 = _mm256_broadcast_ss(va + 1);
695                     _va2 = _mm256_broadcast_ss(va + 2);
696                     _va3 = _mm256_broadcast_ss(va + 3);
697                     _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
698                     _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
699                     _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
700                     _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
701 
702                     va += 4;
703                     vb += 32;
704                 }
705 
706                 for (; k < L; k++)
707                 {
708                     // k0
709                     __m256 _va0 = _mm256_broadcast_ss(va);
710                     __m256 _va1 = _mm256_broadcast_ss(va + 1);
711                     __m256 _va2 = _mm256_broadcast_ss(va + 2);
712                     __m256 _va3 = _mm256_broadcast_ss(va + 3);
713                     __m256 _vb0 = _mm256_loadu_ps(vb);
714                     _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
715                     _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
716                     _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
717                     _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
718 
719                     va += 4;
720                     vb += 8;
721                 }
722 
723                 _mm256_storeu_ps(output0, _sum0);
724                 _mm256_storeu_ps(output1, _sum1);
725                 _mm256_storeu_ps(output2, _sum2);
726                 _mm256_storeu_ps(output3, _sum3);
727 #else
728                 float sum0[8] = {0};
729                 float sum1[8] = {0};
730                 float sum2[8] = {0};
731                 float sum3[8] = {0};
732 
733                 int k = 0;
734                 for (; k + 7 < L; k = k + 8)
735                 {
736                     for (int n = 0; n < 8; n++)
737                     {
738                         sum0[n] += va[0] * vb[n];
739                         sum1[n] += va[1] * vb[n];
740                         sum2[n] += va[2] * vb[n];
741                         sum3[n] += va[3] * vb[n];
742                         va += 4;
743 
744                         sum0[n] += va[0] * vb[n + 8];
745                         sum1[n] += va[1] * vb[n + 8];
746                         sum2[n] += va[2] * vb[n + 8];
747                         sum3[n] += va[3] * vb[n + 8];
748                         va += 4;
749 
750                         sum0[n] += va[0] * vb[n + 16];
751                         sum1[n] += va[1] * vb[n + 16];
752                         sum2[n] += va[2] * vb[n + 16];
753                         sum3[n] += va[3] * vb[n + 16];
754                         va += 4;
755 
756                         sum0[n] += va[0] * vb[n + 24];
757                         sum1[n] += va[1] * vb[n + 24];
758                         sum2[n] += va[2] * vb[n + 24];
759                         sum3[n] += va[3] * vb[n + 24];
760                         va += 4;
761 
762                         sum0[n] += va[0] * vb[n + 32];
763                         sum1[n] += va[1] * vb[n + 32];
764                         sum2[n] += va[2] * vb[n + 32];
765                         sum3[n] += va[3] * vb[n + 32];
766                         va += 4;
767 
768                         sum0[n] += va[0] * vb[n + 40];
769                         sum1[n] += va[1] * vb[n + 40];
770                         sum2[n] += va[2] * vb[n + 40];
771                         sum3[n] += va[3] * vb[n + 40];
772                         va += 4;
773 
774                         sum0[n] += va[0] * vb[n + 48];
775                         sum1[n] += va[1] * vb[n + 48];
776                         sum2[n] += va[2] * vb[n + 48];
777                         sum3[n] += va[3] * vb[n + 48];
778                         va += 4;
779 
780                         sum0[n] += va[0] * vb[n + 56];
781                         sum1[n] += va[1] * vb[n + 56];
782                         sum2[n] += va[2] * vb[n + 56];
783                         sum3[n] += va[3] * vb[n + 56];
784                         va -= 28;
785                     }
786 
787                     va += 32;
788                     vb += 64;
789                 }
790 
791                 for (; k < L; k++)
792                 {
793                     for (int n = 0; n < 8; n++)
794                     {
795                         sum0[n] += va[0] * vb[n];
796                         sum1[n] += va[1] * vb[n];
797                         sum2[n] += va[2] * vb[n];
798                         sum3[n] += va[3] * vb[n];
799                     }
800 
801                     va += 4;
802                     vb += 8;
803                 }
804 
805                 for (int n = 0; n < 8; n++)
806                 {
807                     output0[n] = sum0[n] + biasptr[0];
808                     output1[n] = sum1[n] + biasptr[1];
809                     output2[n] = sum2[n] + biasptr[2];
810                     output3[n] = sum3[n] + biasptr[3];
811                 }
812 #endif // __AVX__
813                 output0 += 8;
814                 output1 += 8;
815                 output2 += 8;
816                 output3 += 8;
817             }
818 
819             for (; j < N; j++)
820             {
821                 const float* vb = bottom_tm.channel(j / 8 + j % 8);
822                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
823 #if __AVX__
824                 __m128 _sum0_3 = _mm_loadu_ps(biasptr);
825                 __m128 _sum0 = _mm_set1_ps(0.0);
826                 __m128 _sum1 = _mm_set1_ps(0.0);
827                 __m128 _sum2 = _mm_set1_ps(0.0);
828                 __m128 _sum3 = _mm_set1_ps(0.0);
829 
830                 int k = 0;
831                 for (; k + 3 < L; k = k + 4)
832                 {
833                     __m128 _vb0 = _mm_set1_ps(vb[0]);
834                     __m128 _vb1 = _mm_set1_ps(vb[1]);
835                     __m128 _vb2 = _mm_set1_ps(vb[2]);
836                     __m128 _vb3 = _mm_set1_ps(vb[3]);
837                     __m128 _va0 = _mm_loadu_ps(va);
838                     __m128 _va1 = _mm_loadu_ps(va + 4);
839                     __m128 _va2 = _mm_loadu_ps(va + 8);
840                     __m128 _va3 = _mm_loadu_ps(va + 12);
841 
842                     _sum0 = _mm_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k30) * a00
843                     _sum1 = _mm_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k31) * a10
844                     _sum2 = _mm_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k32) * a20
845                     _sum3 = _mm_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k33) * a30
846 
847                     va += 16;
848                     vb += 4;
849                 }
850 
851                 _sum0 = _mm_add_ps(_sum0, _sum1);
852                 _sum2 = _mm_add_ps(_sum2, _sum3);
853                 _sum0_3 = _mm_add_ps(_sum0_3, _sum0);
854                 _sum0_3 = _mm_add_ps(_sum0_3, _sum2);
855 
856                 for (; k < L; k++)
857                 {
858                     __m128 _vb0 = _mm_set1_ps(vb[0]);
859                     __m128 _va = _mm_loadu_ps(va);
860 
861                     _sum0_3 = _mm_fmadd_ps(_va, _vb0, _sum0_3); // sum0 += (k00-k30) * a00
862 
863                     va += 4;
864                     vb += 1;
865                 }
866 
867                 float output_sum0_3[4] = {0.f};
868                 _mm_storeu_ps(output_sum0_3, _sum0_3);
869                 output0[0] = output_sum0_3[0];
870                 output1[0] = output_sum0_3[1];
871                 output2[0] = output_sum0_3[2];
872                 output3[0] = output_sum0_3[3];
873 #else
874                 float sum0 = biasptr[0];
875                 float sum1 = biasptr[1];
876                 float sum2 = biasptr[2];
877                 float sum3 = biasptr[3];
878 
879                 for (int k = 0; k < L; k++)
880                 {
881                     sum0 += va[0] * vb[0];
882                     sum1 += va[1] * vb[0];
883                     sum2 += va[2] * vb[0];
884                     sum3 += va[3] * vb[0];
885 
886                     va += 4;
887                     vb += 1;
888                 }
889 
890                 output0[0] = sum0;
891                 output1[0] = sum1;
892                 output2[0] = sum2;
893                 output3[0] = sum3;
894 #endif // __AVX__
895                 output0++;
896                 output1++;
897                 output2++;
898                 output3++;
899             }
900         }
901 
902         remain_outch_start += nn_outch << 2;
903 
904         #pragma omp parallel for num_threads(opt.num_threads)
905         for (int i = remain_outch_start; i < outch; i++)
906         {
907             float* output = top_blob.channel(i);
908 
909             const float bias0 = bias ? bias[i] : 0.f;
910 
911             int j = 0;
912             for (; j + 7 < N; j = j + 8)
913             {
914                 const float* vb = bottom_tm.channel(j / 8);
915                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
916 #if __AVX__
917                 __m256 _sum0 = _mm256_broadcast_ss(&bias0);
918 
919                 int k = 0;
920                 for (; k + 3 < L; k = k + 4)
921                 {
922                     // k0
923                     __m256 _va0 = _mm256_broadcast_ss(va);
924                     __m256 _va1 = _mm256_broadcast_ss(va + 1);
925                     __m256 _va2 = _mm256_broadcast_ss(va + 2);
926                     __m256 _va3 = _mm256_broadcast_ss(va + 3);
927                     __m256 _vb0 = _mm256_loadu_ps(vb);
928                     __m256 _vb1 = _mm256_loadu_ps(vb + 8);
929                     __m256 _vb2 = _mm256_loadu_ps(vb + 16);
930                     __m256 _vb3 = _mm256_loadu_ps(vb + 24);
931 
932                     _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
933                     _sum0 = _mm256_fmadd_ps(_vb1, _va1, _sum0); // sum0 += (a10-a17) * k01
934                     _sum0 = _mm256_fmadd_ps(_vb2, _va2, _sum0); // sum0 += (a20-a27) * k02
935                     _sum0 = _mm256_fmadd_ps(_vb3, _va3, _sum0); // sum0 += (a30-a37) * k03
936 
937                     va += 4;
938                     vb += 32;
939                 }
940 
941                 for (; k < L; k++)
942                 {
943                     // k0
944                     __m256 _va0 = _mm256_broadcast_ss(va);
945                     __m256 _vb0 = _mm256_loadu_ps(vb);
946 
947                     _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
948 
949                     va += 1;
950                     vb += 8;
951                 }
952 
953                 _mm256_storeu_ps(output, _sum0);
954 #else
955                 float sum[8] = {0};
956 
957                 int k = 0;
958                 for (; k + 7 < L; k = k + 8)
959                 {
960                     for (int n = 0; n < 8; n++)
961                     {
962                         sum[n] += va[0] * vb[n];
963                         sum[n] += va[1] * vb[n + 8];
964                         sum[n] += va[2] * vb[n + 16];
965                         sum[n] += va[3] * vb[n + 24];
966                         sum[n] += va[4] * vb[n + 32];
967                         sum[n] += va[5] * vb[n + 40];
968                         sum[n] += va[6] * vb[n + 48];
969                         sum[n] += va[7] * vb[n + 56];
970                     }
971 
972                     va += 8;
973                     vb += 64;
974                 }
975 
976                 for (; k < L; k++)
977                 {
978                     for (int n = 0; n < 8; n++)
979                     {
980                         sum[n] += va[0] * vb[n];
981                     }
982 
983                     va += 1;
984                     vb += 8;
985                 }
986 
987                 for (int n = 0; n < 8; n++)
988                 {
989                     output[n] = sum[n] + bias0;
990                 }
991 #endif // __AVX__
992                 output += 8;
993             }
994 
995             for (; j < N; j++)
996             {
997                 const float* vb = bottom_tm.channel(j / 8 + j % 8);
998                 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
999 
1000                 int k = 0;
1001 #if __AVX__
1002                 __m128 _sum0 = _mm_set1_ps(0.f);
1003 
1004                 for (; k + 3 < L; k += 4)
1005                 {
1006                     __m128 _p0 = _mm_loadu_ps(vb);
1007                     vb += 4;
1008 
1009                     __m128 _k0 = _mm_loadu_ps(va);
1010                     va += 4;
1011 
1012                     _sum0 = _mm_fmadd_ps(_p0, _k0, _sum0);
1013                 }
1014 
1015                 float output_sum0[4] = {0.f};
1016                 _mm_storeu_ps(output_sum0, _sum0);
1017 
1018                 float sum0 = bias0 + output_sum0[0] + output_sum0[1] + output_sum0[2] + output_sum0[3];
1019 
1020 #else
1021                 float sum0 = bias0;
1022 #endif // __AVX__
1023                 for (; k < L; k++)
1024                 {
1025                     sum0 += va[0] * vb[0];
1026 
1027                     va += 1;
1028                     vb += 1;
1029                 }
1030                 output[0] = sum0;
1031 
1032                 output++;
1033             }
1034         }
1035     }
1036 }
1037 #else
conv_im2col_sgemm_transform_kernel_sse(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)1038 static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
1039 {
1040     const float* kernel = _kernel;
1041 
1042     // kernel memory packed 4 x 4
1043     kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4);
1044 
1045     int nn_outch = 0;
1046     int remain_outch_start = 0;
1047 
1048     nn_outch = outch >> 2;
1049     remain_outch_start = nn_outch << 2;
1050 
1051     for (int pp = 0; pp < nn_outch; pp++)
1052     {
1053         int p = pp * 4;
1054 
1055         const float* k0 = kernel + (p + 0) * inch * kernel_size;
1056         const float* k1 = kernel + (p + 1) * inch * kernel_size;
1057         const float* k2 = kernel + (p + 2) * inch * kernel_size;
1058         const float* k3 = kernel + (p + 3) * inch * kernel_size;
1059 
1060         float* ktmp = kernel_tm.channel(p / 4);
1061 
1062         for (int q = 0; q < inch * kernel_size; q++)
1063         {
1064             ktmp[0] = k0[0];
1065             ktmp[1] = k1[0];
1066             ktmp[2] = k2[0];
1067             ktmp[3] = k3[0];
1068             ktmp += 4;
1069 
1070             k0 += 1;
1071             k1 += 1;
1072             k2 += 1;
1073             k3 += 1;
1074         }
1075     }
1076 
1077     for (int p = remain_outch_start; p < outch; p++)
1078     {
1079         const float* k0 = kernel + (p + 0) * inch * kernel_size;
1080 
1081         float* ktmp = kernel_tm.channel(p / 4 + p % 4);
1082 
1083         for (int q = 0; q < inch * kernel_size; q++)
1084         {
1085             ktmp[0] = k0[0];
1086             ktmp++;
1087             k0++;
1088         }
1089     }
1090 }
1091 
conv_im2col_sgemm_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)1092 static void conv_im2col_sgemm_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias,
1093                                   const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
1094 {
1095     int w = bottom_blob.w;
1096     int inch = bottom_blob.c;
1097     size_t elemsize = bottom_blob.elemsize;
1098 
1099     int outw = top_blob.w;
1100     int outh = top_blob.h;
1101     int outch = top_blob.c;
1102 
1103     const float* bias = _bias;
1104 
1105     // im2col
1106     Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator);
1107     {
1108         const int stride = kernel_h * kernel_w * outw * outh;
1109         float* ret = (float*)bottom_im2col;
1110 
1111         #pragma omp parallel for num_threads(opt.num_threads)
1112         for (int p = 0; p < inch; p++)
1113         {
1114             const float* input = bottom_blob.channel(p);
1115             int retID = stride * p;
1116             for (int u = 0; u < kernel_h; u++)
1117             {
1118                 for (int v = 0; v < kernel_w; v++)
1119                 {
1120                     for (int i = 0; i < outh; i++)
1121                     {
1122                         for (int j = 0; j < outw; j++)
1123                         {
1124                             int row = u + i * stride_h;
1125                             int col = v + j * stride_w;
1126                             int index = row * w + col;
1127                             ret[retID] = input[index];
1128                             retID++;
1129                         }
1130                     }
1131                 }
1132             }
1133         }
1134     }
1135 
1136     int kernel_size = kernel_w * kernel_h;
1137     int out_size = outw * outh;
1138 
1139     // bottom_im2col memory packed 4 x 4
1140     Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, elemsize, opt.workspace_allocator);
1141     {
1142         int nn_size = out_size >> 2;
1143         int remain_size_start = nn_size << 2;
1144 
1145         #pragma omp parallel for num_threads(opt.num_threads)
1146         for (int ii = 0; ii < nn_size; ii++)
1147         {
1148             int i = ii * 4;
1149 
1150             const float* img0 = bottom_im2col.channel(0);
1151             img0 += i;
1152 
1153             float* tmpptr = bottom_tm.channel(i / 4);
1154 
1155             for (int q = 0; q < inch * kernel_size; q++)
1156             {
1157 #if __SSE__
1158                 _mm_storeu_ps(tmpptr, _mm_loadu_ps(img0));
1159 #else
1160                 tmpptr[0] = img0[0];
1161                 tmpptr[1] = img0[1];
1162                 tmpptr[2] = img0[2];
1163                 tmpptr[3] = img0[3];
1164 #endif // __SSE__
1165                 tmpptr += 4;
1166                 img0 += out_size;
1167             }
1168         }
1169 
1170         #pragma omp parallel for num_threads(opt.num_threads)
1171         for (int i = remain_size_start; i < out_size; i++)
1172         {
1173             const float* img0 = bottom_im2col.channel(0);
1174             img0 += i;
1175 
1176             float* tmpptr = bottom_tm.channel(i / 4 + i % 4);
1177 
1178             for (int q = 0; q < inch * kernel_size; q++)
1179             {
1180                 tmpptr[0] = img0[0];
1181 
1182                 tmpptr += 1;
1183                 img0 += out_size;
1184             }
1185         }
1186     }
1187 
1188     // sgemm(int M, int N, int L, float* A, float* B, float* C)
1189     {
1190         //int M = outch;                    // outch
1191         int N = outw * outh;                // outsize or out stride
1192         int L = kernel_w * kernel_h * inch; // ksize * inch
1193 
1194         int nn_outch = 0;
1195         int remain_outch_start = 0;
1196 
1197         nn_outch = outch >> 2;
1198         remain_outch_start = nn_outch << 2;
1199 
1200         #pragma omp parallel for num_threads(opt.num_threads)
1201         for (int pp = 0; pp < nn_outch; pp++)
1202         {
1203             int i = pp * 4;
1204 
1205             float* output0 = top_blob.channel(i);
1206             float* output1 = top_blob.channel(i + 1);
1207             float* output2 = top_blob.channel(i + 2);
1208             float* output3 = top_blob.channel(i + 3);
1209 
1210             const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
1211             const float* biasptr = bias ? bias + i : zeros;
1212 
1213             int j = 0;
1214             for (; j + 3 < N; j = j + 4)
1215             {
1216                 const float* vb = bottom_tm.channel(j / 4);
1217                 const float* va = kernel_tm.channel(i / 4);
1218 #if __SSE__
1219                 __m128 _sum0 = _mm_set1_ps(biasptr[0]);
1220                 __m128 _sum1 = _mm_set1_ps(biasptr[1]);
1221                 __m128 _sum2 = _mm_set1_ps(biasptr[2]);
1222                 __m128 _sum3 = _mm_set1_ps(biasptr[3]);
1223 
1224                 int k = 0;
1225                 for (; k + 3 < L; k = k + 4)
1226                 {
1227                     // k0
1228                     __m128 _vb = _mm_loadu_ps(vb);
1229                     __m128 _va0 = _mm_set1_ps(va[0]);
1230                     __m128 _va1 = _mm_set1_ps(va[1]);
1231                     __m128 _va2 = _mm_set1_ps(va[2]);
1232                     __m128 _va3 = _mm_set1_ps(va[3]);
1233                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a00-a03) * k00
1234                     _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a00-a03) * k10
1235                     _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a00-a03) * k20
1236                     _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a00-a03) * k30
1237 
1238                     // k1
1239                     _vb = _mm_loadu_ps(vb + 4);
1240                     _va0 = _mm_set1_ps(va[4]);
1241                     _va1 = _mm_set1_ps(va[5]);
1242                     _va2 = _mm_set1_ps(va[6]);
1243                     _va3 = _mm_set1_ps(va[7]);
1244                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a10-a13) * k01
1245                     _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a10-a13) * k11
1246                     _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a10-a13) * k21
1247                     _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a10-a13) * k31
1248 
1249                     // k2
1250                     _vb = _mm_loadu_ps(vb + 8);
1251                     _va0 = _mm_set1_ps(va[8]);
1252                     _va1 = _mm_set1_ps(va[9]);
1253                     _va2 = _mm_set1_ps(va[10]);
1254                     _va3 = _mm_set1_ps(va[11]);
1255                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a20-a23) * k02
1256                     _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a20-a23) * k12
1257                     _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a20-a23) * k22
1258                     _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a20-a23) * k32
1259 
1260                     // k3
1261                     _vb = _mm_loadu_ps(vb + 12);
1262                     _va0 = _mm_set1_ps(va[12]);
1263                     _va1 = _mm_set1_ps(va[13]);
1264                     _va2 = _mm_set1_ps(va[14]);
1265                     _va3 = _mm_set1_ps(va[15]);
1266                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a30-a33) * k03
1267                     _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a30-a33) * k13
1268                     _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a30-a33) * k23
1269                     _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a30-a33) * k33
1270 
1271                     va += 16;
1272                     vb += 16;
1273                 }
1274 
1275                 for (; k < L; k++)
1276                 {
1277                     // k0
1278                     __m128 _vb = _mm_loadu_ps(vb);
1279                     __m128 _va0 = _mm_set1_ps(va[0]);
1280                     __m128 _va1 = _mm_set1_ps(va[1]);
1281                     __m128 _va2 = _mm_set1_ps(va[2]);
1282                     __m128 _va3 = _mm_set1_ps(va[3]);
1283                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a00-a03) * k00
1284                     _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a00-a03) * k10
1285                     _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a00-a03) * k20
1286                     _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a00-a03) * k30
1287 
1288                     va += 4;
1289                     vb += 4;
1290                 }
1291                 _mm_storeu_ps(output0, _sum0);
1292                 _mm_storeu_ps(output1, _sum1);
1293                 _mm_storeu_ps(output2, _sum2);
1294                 _mm_storeu_ps(output3, _sum3);
1295 #else
1296                 float sum0[4] = {0};
1297                 float sum1[4] = {0};
1298                 float sum2[4] = {0};
1299                 float sum3[4] = {0};
1300 
1301                 int k = 0;
1302                 for (; k + 7 < L; k = k + 8)
1303                 {
1304                     for (int n = 0; n < 4; n++)
1305                     {
1306                         sum0[n] += va[0] * vb[n];
1307                         sum1[n] += va[1] * vb[n];
1308                         sum2[n] += va[2] * vb[n];
1309                         sum3[n] += va[3] * vb[n];
1310                         va += 4;
1311 
1312                         sum0[n] += va[0] * vb[n + 4];
1313                         sum1[n] += va[1] * vb[n + 4];
1314                         sum2[n] += va[2] * vb[n + 4];
1315                         sum3[n] += va[3] * vb[n + 4];
1316                         va += 4;
1317 
1318                         sum0[n] += va[0] * vb[n + 8];
1319                         sum1[n] += va[1] * vb[n + 8];
1320                         sum2[n] += va[2] * vb[n + 8];
1321                         sum3[n] += va[3] * vb[n + 8];
1322                         va += 4;
1323 
1324                         sum0[n] += va[0] * vb[n + 12];
1325                         sum1[n] += va[1] * vb[n + 12];
1326                         sum2[n] += va[2] * vb[n + 12];
1327                         sum3[n] += va[3] * vb[n + 12];
1328                         va += 4;
1329 
1330                         sum0[n] += va[0] * vb[n + 16];
1331                         sum1[n] += va[1] * vb[n + 16];
1332                         sum2[n] += va[2] * vb[n + 16];
1333                         sum3[n] += va[3] * vb[n + 16];
1334                         va += 4;
1335 
1336                         sum0[n] += va[0] * vb[n + 20];
1337                         sum1[n] += va[1] * vb[n + 20];
1338                         sum2[n] += va[2] * vb[n + 20];
1339                         sum3[n] += va[3] * vb[n + 20];
1340                         va += 4;
1341 
1342                         sum0[n] += va[0] * vb[n + 24];
1343                         sum1[n] += va[1] * vb[n + 24];
1344                         sum2[n] += va[2] * vb[n + 24];
1345                         sum3[n] += va[3] * vb[n + 24];
1346                         va += 4;
1347 
1348                         sum0[n] += va[0] * vb[n + 28];
1349                         sum1[n] += va[1] * vb[n + 28];
1350                         sum2[n] += va[2] * vb[n + 28];
1351                         sum3[n] += va[3] * vb[n + 28];
1352                         va -= 28;
1353                     }
1354 
1355                     va += 32;
1356                     vb += 32;
1357                 }
1358 
1359                 for (; k < L; k++)
1360                 {
1361                     for (int n = 0; n < 4; n++)
1362                     {
1363                         sum0[n] += va[0] * vb[n];
1364                         sum1[n] += va[1] * vb[n];
1365                         sum2[n] += va[2] * vb[n];
1366                         sum3[n] += va[3] * vb[n];
1367                     }
1368 
1369                     va += 4;
1370                     vb += 4;
1371                 }
1372 
1373                 for (int n = 0; n < 4; n++)
1374                 {
1375                     output0[n] = sum0[n] + biasptr[0];
1376                     output1[n] = sum1[n] + biasptr[1];
1377                     output2[n] = sum2[n] + biasptr[2];
1378                     output3[n] = sum3[n] + biasptr[3];
1379                 }
1380 #endif // __SSE__
1381                 output0 += 4;
1382                 output1 += 4;
1383                 output2 += 4;
1384                 output3 += 4;
1385             }
1386 
1387             for (; j < N; j++)
1388             {
1389                 const float* vb = bottom_tm.channel(j / 4 + j % 4);
1390                 const float* va = kernel_tm.channel(i / 4);
1391 #if __SSE__
1392                 __m128 _sum0_3 = _mm_loadu_ps(biasptr);
1393                 __m128 _sum0 = _mm_set1_ps(0.0);
1394                 __m128 _sum1 = _mm_set1_ps(0.0);
1395                 __m128 _sum2 = _mm_set1_ps(0.0);
1396                 __m128 _sum3 = _mm_set1_ps(0.0);
1397 
1398                 int k = 0;
1399                 for (; k + 3 < L; k = k + 4)
1400                 {
1401                     __m128 _vb0 = _mm_set1_ps(vb[0]);
1402                     __m128 _vb1 = _mm_set1_ps(vb[1]);
1403                     __m128 _vb2 = _mm_set1_ps(vb[2]);
1404                     __m128 _vb3 = _mm_set1_ps(vb[3]);
1405                     __m128 _va0 = _mm_loadu_ps(va);
1406                     __m128 _va1 = _mm_loadu_ps(va + 4);
1407                     __m128 _va2 = _mm_loadu_ps(va + 8);
1408                     __m128 _va3 = _mm_loadu_ps(va + 12);
1409 
1410                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_va0, _vb0)); // sum0 += (k00-k30) * a00
1411                     _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_va1, _vb1)); // sum1 += (k01-k31) * a10
1412                     _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_va2, _vb2)); // sum2 += (k02-k32) * a20
1413                     _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_va3, _vb3)); // sum3 += (k03-k33) * a30
1414 
1415                     va += 16;
1416                     vb += 4;
1417                 }
1418 
1419                 _sum0 = _mm_add_ps(_sum0, _sum1);
1420                 _sum2 = _mm_add_ps(_sum2, _sum3);
1421                 _sum0_3 = _mm_add_ps(_sum0_3, _sum0);
1422                 _sum0_3 = _mm_add_ps(_sum0_3, _sum2);
1423 
1424                 for (; k < L; k++)
1425                 {
1426                     __m128 _vb0 = _mm_set1_ps(vb[0]);
1427                     __m128 _va = _mm_loadu_ps(va);
1428 
1429                     _sum0_3 = _mm_add_ps(_sum0_3, _mm_mul_ps(_va, _vb0)); // sum0 += (k00-k30) * a00
1430 
1431                     va += 4;
1432                     vb += 1;
1433                 }
1434 
1435                 float sum0_3_tmp[4];
1436                 _mm_storeu_ps(sum0_3_tmp, _sum0_3);
1437                 output0[0] = sum0_3_tmp[0];
1438                 output1[0] = sum0_3_tmp[1];
1439                 output2[0] = sum0_3_tmp[2];
1440                 output3[0] = sum0_3_tmp[3];
1441 #else
1442                 float sum0 = biasptr[0];
1443                 float sum1 = biasptr[1];
1444                 float sum2 = biasptr[2];
1445                 float sum3 = biasptr[3];
1446 
1447                 for (int k = 0; k < L; k++)
1448                 {
1449                     sum0 += va[0] * vb[0];
1450                     sum1 += va[1] * vb[0];
1451                     sum2 += va[2] * vb[0];
1452                     sum3 += va[3] * vb[0];
1453 
1454                     va += 4;
1455                     vb += 1;
1456                 }
1457 
1458                 output0[0] = sum0;
1459                 output1[0] = sum1;
1460                 output2[0] = sum2;
1461                 output3[0] = sum3;
1462 #endif // __SSE__
1463                 output0++;
1464                 output1++;
1465                 output2++;
1466                 output3++;
1467             }
1468         }
1469 
1470         #pragma omp parallel for num_threads(opt.num_threads)
1471         for (int i = remain_outch_start; i < outch; i++)
1472         {
1473             float* output = top_blob.channel(i);
1474 
1475             const float bias0 = bias ? bias[i] : 0.f;
1476 
1477             int j = 0;
1478             for (; j + 3 < N; j = j + 4)
1479             {
1480                 const float* vb = bottom_tm.channel(j / 4);
1481                 const float* va = kernel_tm.channel(i / 4 + i % 4);
1482 #if __SSE__
1483                 __m128 _sum0 = _mm_set1_ps(bias0);
1484 
1485                 int k = 0;
1486                 for (; k + 3 < L; k = k + 4)
1487                 {
1488                     // k0
1489                     __m128 _va0 = _mm_set1_ps(va[0]);
1490                     __m128 _va1 = _mm_set1_ps(va[1]);
1491                     __m128 _va2 = _mm_set1_ps(va[2]);
1492                     __m128 _va3 = _mm_set1_ps(va[3]);
1493                     __m128 _vb0 = _mm_loadu_ps(vb);
1494                     __m128 _vb1 = _mm_loadu_ps(vb + 4);
1495                     __m128 _vb2 = _mm_loadu_ps(vb + 8);
1496                     __m128 _vb3 = _mm_loadu_ps(vb + 12);
1497 
1498                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00
1499                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb1, _va1)); // sum0 += (a10-a13) * k01
1500                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb2, _va2)); // sum0 += (a20-a23) * k02
1501                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb3, _va3)); // sum0 += (a30-a33) * k03
1502 
1503                     va += 4;
1504                     vb += 16;
1505                 }
1506 
1507                 for (; k < L; k++)
1508                 {
1509                     // k0
1510                     __m128 _va0 = _mm_set1_ps(va[0]);
1511                     __m128 _vb0 = _mm_loadu_ps(vb);
1512 
1513                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00
1514 
1515                     va += 1;
1516                     vb += 4;
1517                 }
1518                 _mm_storeu_ps(output, _sum0);
1519 #else
1520                 float sum[4] = {0};
1521 
1522                 int k = 0;
1523                 for (; k + 3 < L; k = k + 4)
1524                 {
1525                     for (int n = 0; n < 4; n++)
1526                     {
1527                         sum[n] += va[0] * vb[n];
1528                         sum[n] += va[1] * vb[n + 4];
1529                         sum[n] += va[2] * vb[n + 8];
1530                         sum[n] += va[3] * vb[n + 12];
1531                         //sum[n] += va[4] * vb[n+16];
1532                         //sum[n] += va[5] * vb[n+20];
1533                         //sum[n] += va[6] * vb[n+24];
1534                         //sum[n] += va[7] * vb[n+28];
1535                     }
1536 
1537                     va += 4;
1538                     vb += 16;
1539                 }
1540 
1541                 for (; k < L; k++)
1542                 {
1543                     for (int n = 0; n < 4; n++)
1544                     {
1545                         sum[n] += va[0] * vb[n];
1546                     }
1547 
1548                     va += 1;
1549                     vb += 4;
1550                 }
1551 
1552                 for (int n = 0; n < 4; n++)
1553                 {
1554                     output[n] = sum[n] + bias0;
1555                 }
1556 #endif // __SSE__
1557                 output += 4;
1558             }
1559 
1560             for (; j < N; j++)
1561             {
1562                 const float* vb = bottom_tm.channel(j / 4 + j % 4);
1563                 const float* va = kernel_tm.channel(i / 4 + i % 4);
1564 
1565                 int k = 0;
1566 #if __SSE__
1567                 __m128 _sum0 = _mm_set1_ps(0.f);
1568 
1569                 for (; k + 3 < L; k += 4)
1570                 {
1571                     __m128 _p0 = _mm_loadu_ps(vb);
1572                     __m128 _k0 = _mm_loadu_ps(va);
1573                     _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_p0, _k0));
1574 
1575                     va += 4;
1576                     vb += 4;
1577                 }
1578                 float sum0_tmp[4];
1579                 _mm_storeu_ps(sum0_tmp, _sum0);
1580                 float sum0 = bias0 + sum0_tmp[0] + sum0_tmp[1] + sum0_tmp[2] + sum0_tmp[3];
1581 #else
1582                 float sum0 = bias0;
1583 #endif // __SSE__
1584                 for (; k < L; k++)
1585                 {
1586                     sum0 += va[0] * vb[0];
1587 
1588                     va += 1;
1589                     vb += 1;
1590                 }
1591                 output[0] = sum0;
1592 
1593                 output++;
1594             }
1595         }
1596     }
1597 }
1598 #endif
1599