1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #if __AVX__
conv_im2col_sgemm_transform_kernel_sse(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)16 static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
17 {
18 const float* kernel = _kernel;
19
20 // kernel memory packed 8 x 8
21 kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4);
22
23 int nn_outch = 0;
24 int remain_outch_start = 0;
25
26 nn_outch = outch >> 3;
27 remain_outch_start = nn_outch << 3;
28
29 for (int pp = 0; pp < nn_outch; pp++)
30 {
31 int p = pp * 8;
32
33 const float* k0 = kernel + (p + 0) * inch * kernel_size;
34 const float* k1 = kernel + (p + 1) * inch * kernel_size;
35 const float* k2 = kernel + (p + 2) * inch * kernel_size;
36 const float* k3 = kernel + (p + 3) * inch * kernel_size;
37 const float* k4 = kernel + (p + 4) * inch * kernel_size;
38 const float* k5 = kernel + (p + 5) * inch * kernel_size;
39 const float* k6 = kernel + (p + 6) * inch * kernel_size;
40 const float* k7 = kernel + (p + 7) * inch * kernel_size;
41
42 float* ktmp = kernel_tm.channel(p / 8);
43
44 for (int q = 0; q < inch * kernel_size; q++)
45 {
46 ktmp[0] = k0[0];
47 ktmp[1] = k1[0];
48 ktmp[2] = k2[0];
49 ktmp[3] = k3[0];
50 ktmp[4] = k4[0];
51 ktmp[5] = k5[0];
52 ktmp[6] = k6[0];
53 ktmp[7] = k7[0];
54 ktmp += 8;
55
56 k0 += 1;
57 k1 += 1;
58 k2 += 1;
59 k3 += 1;
60 k4 += 1;
61 k5 += 1;
62 k6 += 1;
63 k7 += 1;
64 }
65 }
66
67 nn_outch = (outch - remain_outch_start) >> 2;
68
69 for (int pp = 0; pp < nn_outch; pp++)
70 {
71 int p = remain_outch_start + pp * 4;
72
73 const float* k0 = kernel + (p + 0) * inch * kernel_size;
74 const float* k1 = kernel + (p + 1) * inch * kernel_size;
75 const float* k2 = kernel + (p + 2) * inch * kernel_size;
76 const float* k3 = kernel + (p + 3) * inch * kernel_size;
77
78 float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4);
79
80 for (int q = 0; q < inch * kernel_size; q++)
81 {
82 ktmp[0] = k0[0];
83 ktmp[1] = k1[0];
84 ktmp[2] = k2[0];
85 ktmp[3] = k3[0];
86 ktmp += 4;
87
88 k0 += 1;
89 k1 += 1;
90 k2 += 1;
91 k3 += 1;
92 }
93 }
94
95 remain_outch_start += nn_outch << 2;
96
97 for (int p = remain_outch_start; p < outch; p++)
98 {
99 const float* k0 = kernel + (p + 0) * inch * kernel_size;
100
101 float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4);
102
103 for (int q = 0; q < inch * kernel_size; q++)
104 {
105 ktmp[0] = k0[0];
106 ktmp++;
107 k0++;
108 }
109 }
110 }
111
conv_im2col_sgemm_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)112 static void conv_im2col_sgemm_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias,
113 const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
114 {
115 int w = bottom_blob.w;
116 int inch = bottom_blob.c;
117 size_t elemsize = bottom_blob.elemsize;
118
119 int outw = top_blob.w;
120 int outh = top_blob.h;
121 int outch = top_blob.c;
122
123 const float* bias = _bias;
124
125 // im2col
126 Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator);
127 {
128 const int stride = kernel_h * kernel_w * outw * outh;
129 float* ret = (float*)bottom_im2col;
130
131 #pragma omp parallel for num_threads(opt.num_threads)
132 for (int p = 0; p < inch; p++)
133 {
134 const float* input = bottom_blob.channel(p);
135 int retID = stride * p;
136 for (int u = 0; u < kernel_h; u++)
137 {
138 for (int v = 0; v < kernel_w; v++)
139 {
140 for (int i = 0; i < outh; i++)
141 {
142 for (int j = 0; j < outw; j++)
143 {
144 int row = u + i * stride_h;
145 int col = v + j * stride_w;
146 int index = row * w + col;
147 ret[retID] = input[index];
148 retID++;
149 }
150 }
151 }
152 }
153 }
154 }
155
156 int kernel_size = kernel_w * kernel_h;
157 int out_size = outw * outh;
158
159 // bottom_im2col memory packed 8 x 8
160 Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, elemsize, opt.workspace_allocator);
161 {
162 int nn_size = out_size >> 3;
163 int remain_size_start = nn_size << 3;
164
165 #pragma omp parallel for num_threads(opt.num_threads)
166 for (int ii = 0; ii < nn_size; ii++)
167 {
168 int i = ii * 8;
169
170 const float* img0 = bottom_im2col.channel(0);
171 img0 += i;
172
173 float* tmpptr = bottom_tm.channel(i / 8);
174
175 for (int q = 0; q < inch * kernel_size; q++)
176 {
177 #if __AVX__
178 _mm256_storeu_ps(tmpptr, _mm256_loadu_ps(img0));
179 #else
180 tmpptr[0] = img0[0];
181 tmpptr[1] = img0[1];
182 tmpptr[2] = img0[2];
183 tmpptr[3] = img0[3];
184 tmpptr[4] = img0[4];
185 tmpptr[5] = img0[5];
186 tmpptr[6] = img0[6];
187 tmpptr[7] = img0[7];
188 #endif // __SSE__
189 tmpptr += 8;
190 img0 += out_size;
191 }
192 }
193
194 #pragma omp parallel for num_threads(opt.num_threads)
195 for (int i = remain_size_start; i < out_size; i++)
196 {
197 const float* img0 = bottom_im2col.channel(0);
198 img0 += i;
199
200 float* tmpptr = bottom_tm.channel(i / 8 + i % 8);
201
202 for (int q = 0; q < inch * kernel_size; q++)
203 {
204 tmpptr[0] = img0[0];
205
206 tmpptr += 1;
207 img0 += out_size;
208 }
209 }
210 }
211
212 // sgemm(int M, int N, int L, float* A, float* B, float* C)
213 {
214 //int M = outch; // outch
215 int N = outw * outh; // outsize or out stride
216 int L = kernel_w * kernel_h * inch; // ksize * inch
217
218 int nn_outch = 0;
219 int remain_outch_start = 0;
220
221 nn_outch = outch >> 3;
222 remain_outch_start = nn_outch << 3;
223
224 #pragma omp parallel for num_threads(opt.num_threads)
225 for (int pp = 0; pp < nn_outch; pp++)
226 {
227 int i = pp * 8;
228
229 float* output0 = top_blob.channel(i);
230 float* output1 = top_blob.channel(i + 1);
231 float* output2 = top_blob.channel(i + 2);
232 float* output3 = top_blob.channel(i + 3);
233 float* output4 = top_blob.channel(i + 4);
234 float* output5 = top_blob.channel(i + 5);
235 float* output6 = top_blob.channel(i + 6);
236 float* output7 = top_blob.channel(i + 7);
237
238 const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
239 const float* biasptr = bias ? bias + i : zeros;
240
241 int j = 0;
242 for (; j + 7 < N; j = j + 8)
243 {
244 const float* vb = bottom_tm.channel(j / 8);
245 const float* va = kernel_tm.channel(i / 8);
246 #if __AVX__
247 __m256 _sum0 = _mm256_broadcast_ss(biasptr);
248 __m256 _sum1 = _mm256_broadcast_ss(biasptr + 1);
249 __m256 _sum2 = _mm256_broadcast_ss(biasptr + 2);
250 __m256 _sum3 = _mm256_broadcast_ss(biasptr + 3);
251 __m256 _sum4 = _mm256_broadcast_ss(biasptr + 4);
252 __m256 _sum5 = _mm256_broadcast_ss(biasptr + 5);
253 __m256 _sum6 = _mm256_broadcast_ss(biasptr + 6);
254 __m256 _sum7 = _mm256_broadcast_ss(biasptr + 7);
255
256 int k = 0;
257 for (; k + 3 < L; k = k + 4)
258 {
259 // k0
260 __m256 _va0 = _mm256_broadcast_ss(va);
261 __m256 _va1 = _mm256_broadcast_ss(va + 1);
262 __m256 _va2 = _mm256_broadcast_ss(va + 2);
263 __m256 _va3 = _mm256_broadcast_ss(va + 3);
264 __m256 _vb0 = _mm256_loadu_ps(vb);
265 __m256 _vb1 = _mm256_loadu_ps(vb + 8);
266 __m256 _vb2 = _mm256_loadu_ps(vb + 16);
267 __m256 _vb3 = _mm256_loadu_ps(vb + 24);
268 _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
269 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
270 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
271 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
272 _va0 = _mm256_broadcast_ss(va + 4);
273 _va1 = _mm256_broadcast_ss(va + 5);
274 _va2 = _mm256_broadcast_ss(va + 6);
275 _va3 = _mm256_broadcast_ss(va + 7);
276 _sum4 = _mm256_fmadd_ps(_vb0, _va0, _sum4); // sum4 = (a00-a07) * k40
277 _sum5 = _mm256_fmadd_ps(_vb0, _va1, _sum5); // sum5 = (a00-a07) * k50
278 _sum6 = _mm256_fmadd_ps(_vb0, _va2, _sum6); // sum6 = (a00-a07) * k60
279 _sum7 = _mm256_fmadd_ps(_vb0, _va3, _sum7); // sum7 = (a00-a07) * k70
280
281 va += 8;
282
283 // k1
284 _va0 = _mm256_broadcast_ss(va);
285 _va1 = _mm256_broadcast_ss(va + 1);
286 _va2 = _mm256_broadcast_ss(va + 2);
287 _va3 = _mm256_broadcast_ss(va + 3);
288 _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
289 _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
290 _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
291 _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
292 _va0 = _mm256_broadcast_ss(va + 4);
293 _va1 = _mm256_broadcast_ss(va + 5);
294 _va2 = _mm256_broadcast_ss(va + 6);
295 _va3 = _mm256_broadcast_ss(va + 7);
296 _sum4 = _mm256_fmadd_ps(_vb1, _va0, _sum4); // sum4 += (a10-a17) * k41
297 _sum5 = _mm256_fmadd_ps(_vb1, _va1, _sum5); // sum5 += (a10-a17) * k51
298 _sum6 = _mm256_fmadd_ps(_vb1, _va2, _sum6); // sum6 += (a10-a17) * k61
299 _sum7 = _mm256_fmadd_ps(_vb1, _va3, _sum7); // sum7 += (a10-a17) * k71
300
301 va += 8;
302
303 // k2
304 _va0 = _mm256_broadcast_ss(va);
305 _va1 = _mm256_broadcast_ss(va + 1);
306 _va2 = _mm256_broadcast_ss(va + 2);
307 _va3 = _mm256_broadcast_ss(va + 3);
308 _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
309 _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
310 _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
311 _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
312 _va0 = _mm256_broadcast_ss(va + 4);
313 _va1 = _mm256_broadcast_ss(va + 5);
314 _va2 = _mm256_broadcast_ss(va + 6);
315 _va3 = _mm256_broadcast_ss(va + 7);
316 _sum4 = _mm256_fmadd_ps(_vb2, _va0, _sum4); // sum4 += (a20-a27) * k42
317 _sum5 = _mm256_fmadd_ps(_vb2, _va1, _sum5); // sum5 += (a20-a27) * k52
318 _sum6 = _mm256_fmadd_ps(_vb2, _va2, _sum6); // sum6 += (a20-a27) * k62
319 _sum7 = _mm256_fmadd_ps(_vb2, _va3, _sum7); // sum7 += (a20-a27) * k72
320
321 va += 8;
322
323 // k3
324 _va0 = _mm256_broadcast_ss(va);
325 _va1 = _mm256_broadcast_ss(va + 1);
326 _va2 = _mm256_broadcast_ss(va + 2);
327 _va3 = _mm256_broadcast_ss(va + 3);
328 _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
329 _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
330 _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
331 _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
332 _va0 = _mm256_broadcast_ss(va + 4);
333 _va1 = _mm256_broadcast_ss(va + 5);
334 _va2 = _mm256_broadcast_ss(va + 6);
335 _va3 = _mm256_broadcast_ss(va + 7);
336 _sum4 = _mm256_fmadd_ps(_vb3, _va0, _sum4); // sum4 += (a30-a37) * k43
337 _sum5 = _mm256_fmadd_ps(_vb3, _va1, _sum5); // sum5 += (a30-a37) * k53
338 _sum6 = _mm256_fmadd_ps(_vb3, _va2, _sum6); // sum6 += (a30-a37) * k63
339 _sum7 = _mm256_fmadd_ps(_vb3, _va3, _sum7); // sum7 += (a30-a37) * k73
340
341 va += 8;
342 vb += 32;
343 }
344
345 for (; k < L; k++)
346 {
347 // k0
348 __m256 _va0 = _mm256_broadcast_ss(va);
349 __m256 _va1 = _mm256_broadcast_ss(va + 1);
350 __m256 _va2 = _mm256_broadcast_ss(va + 2);
351 __m256 _va3 = _mm256_broadcast_ss(va + 3);
352 __m256 _va4 = _mm256_broadcast_ss(va + 4);
353 __m256 _va5 = _mm256_broadcast_ss(va + 5);
354 __m256 _va6 = _mm256_broadcast_ss(va + 6);
355 __m256 _va7 = _mm256_broadcast_ss(va + 7);
356 __m256 _vb0 = _mm256_loadu_ps(vb);
357 _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
358 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
359 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
360 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
361 _sum4 = _mm256_fmadd_ps(_vb0, _va4, _sum4); // sum4 = (a00-a07) * k40
362 _sum5 = _mm256_fmadd_ps(_vb0, _va5, _sum5); // sum5 = (a00-a07) * k50
363 _sum6 = _mm256_fmadd_ps(_vb0, _va6, _sum6); // sum6 = (a00-a07) * k60
364 _sum7 = _mm256_fmadd_ps(_vb0, _va7, _sum7); // sum7 = (a00-a07) * k70
365
366 va += 8;
367 vb += 8;
368 }
369
370 _mm256_storeu_ps(output0, _sum0);
371 _mm256_storeu_ps(output1, _sum1);
372 _mm256_storeu_ps(output2, _sum2);
373 _mm256_storeu_ps(output3, _sum3);
374 _mm256_storeu_ps(output4, _sum4);
375 _mm256_storeu_ps(output5, _sum5);
376 _mm256_storeu_ps(output6, _sum6);
377 _mm256_storeu_ps(output7, _sum7);
378 #else
379 float sum0[8] = {0};
380 float sum1[8] = {0};
381 float sum2[8] = {0};
382 float sum3[8] = {0};
383 float sum4[8] = {0};
384 float sum5[8] = {0};
385 float sum6[8] = {0};
386 float sum7[8] = {0};
387
388 int k = 0;
389 for (; k + 7 < L; k = k + 8)
390 {
391 for (int n = 0; n < 8; n++)
392 {
393 sum0[n] += va[0] * vb[n];
394 sum1[n] += va[1] * vb[n];
395 sum2[n] += va[2] * vb[n];
396 sum3[n] += va[3] * vb[n];
397 sum4[n] += va[4] * vb[n];
398 sum5[n] += va[5] * vb[n];
399 sum6[n] += va[6] * vb[n];
400 sum7[n] += va[7] * vb[n];
401 va += 8;
402
403 sum0[n] += va[0] * vb[n + 8];
404 sum1[n] += va[1] * vb[n + 8];
405 sum2[n] += va[2] * vb[n + 8];
406 sum3[n] += va[3] * vb[n + 8];
407 sum4[n] += va[4] * vb[n + 8];
408 sum5[n] += va[5] * vb[n + 8];
409 sum6[n] += va[6] * vb[n + 8];
410 sum7[n] += va[7] * vb[n + 8];
411 va += 8;
412
413 sum0[n] += va[0] * vb[n + 16];
414 sum1[n] += va[1] * vb[n + 16];
415 sum2[n] += va[2] * vb[n + 16];
416 sum3[n] += va[3] * vb[n + 16];
417 sum4[n] += va[4] * vb[n + 16];
418 sum5[n] += va[5] * vb[n + 16];
419 sum6[n] += va[6] * vb[n + 16];
420 sum7[n] += va[7] * vb[n + 16];
421 va += 8;
422
423 sum0[n] += va[0] * vb[n + 24];
424 sum1[n] += va[1] * vb[n + 24];
425 sum2[n] += va[2] * vb[n + 24];
426 sum3[n] += va[3] * vb[n + 24];
427 sum4[n] += va[4] * vb[n + 24];
428 sum5[n] += va[5] * vb[n + 24];
429 sum6[n] += va[6] * vb[n + 24];
430 sum7[n] += va[7] * vb[n + 24];
431 va += 8;
432
433 sum0[n] += va[0] * vb[n + 32];
434 sum1[n] += va[1] * vb[n + 32];
435 sum2[n] += va[2] * vb[n + 32];
436 sum3[n] += va[3] * vb[n + 32];
437 sum4[n] += va[4] * vb[n + 32];
438 sum5[n] += va[5] * vb[n + 32];
439 sum6[n] += va[6] * vb[n + 32];
440 sum7[n] += va[7] * vb[n + 32];
441 va += 8;
442
443 sum0[n] += va[0] * vb[n + 40];
444 sum1[n] += va[1] * vb[n + 40];
445 sum2[n] += va[2] * vb[n + 40];
446 sum3[n] += va[3] * vb[n + 40];
447 sum4[n] += va[4] * vb[n + 40];
448 sum5[n] += va[5] * vb[n + 40];
449 sum6[n] += va[6] * vb[n + 40];
450 sum7[n] += va[7] * vb[n + 40];
451 va += 8;
452
453 sum0[n] += va[0] * vb[n + 48];
454 sum1[n] += va[1] * vb[n + 48];
455 sum2[n] += va[2] * vb[n + 48];
456 sum3[n] += va[3] * vb[n + 48];
457 sum4[n] += va[4] * vb[n + 48];
458 sum5[n] += va[5] * vb[n + 48];
459 sum6[n] += va[6] * vb[n + 48];
460 sum7[n] += va[7] * vb[n + 48];
461 va += 8;
462
463 sum0[n] += va[0] * vb[n + 56];
464 sum1[n] += va[1] * vb[n + 56];
465 sum2[n] += va[2] * vb[n + 56];
466 sum3[n] += va[3] * vb[n + 56];
467 sum4[n] += va[4] * vb[n + 56];
468 sum5[n] += va[5] * vb[n + 56];
469 sum6[n] += va[6] * vb[n + 56];
470 sum7[n] += va[7] * vb[n + 56];
471 va -= 56;
472 }
473
474 va += 64;
475 vb += 64;
476 }
477
478 for (; k < L; k++)
479 {
480 for (int n = 0; n < 8; n++)
481 {
482 sum0[n] += va[0] * vb[n];
483 sum1[n] += va[1] * vb[n];
484 sum2[n] += va[2] * vb[n];
485 sum3[n] += va[3] * vb[n];
486 sum4[n] += va[4] * vb[n];
487 sum5[n] += va[5] * vb[n];
488 sum6[n] += va[6] * vb[n];
489 sum7[n] += va[7] * vb[n];
490 }
491
492 va += 8;
493 vb += 8;
494 }
495
496 for (int n = 0; n < 8; n++)
497 {
498 output0[n] = sum0[n] + biasptr[0];
499 output1[n] = sum1[n] + biasptr[1];
500 output2[n] = sum2[n] + biasptr[2];
501 output3[n] = sum3[n] + biasptr[3];
502 output4[n] = sum4[n] + biasptr[4];
503 output5[n] = sum5[n] + biasptr[5];
504 output6[n] = sum6[n] + biasptr[6];
505 output7[n] = sum7[n] + biasptr[7];
506 }
507 #endif // __AVX__
508 output0 += 8;
509 output1 += 8;
510 output2 += 8;
511 output3 += 8;
512 output4 += 8;
513 output5 += 8;
514 output6 += 8;
515 output7 += 8;
516 }
517
518 for (; j < N; j++)
519 {
520 const float* vb = bottom_tm.channel(j / 8 + j % 8);
521 const float* va = kernel_tm.channel(i / 8);
522
523 #if __AVX__
524 __m256 _sum0_7 = _mm256_loadu_ps(biasptr);
525 __m256 _sum0 = _mm256_set1_ps(0.0);
526 __m256 _sum1 = _mm256_set1_ps(0.0);
527 __m256 _sum2 = _mm256_set1_ps(0.0);
528 __m256 _sum3 = _mm256_set1_ps(0.0);
529
530 int k = 0;
531 for (; k + 3 < L; k = k + 4)
532 {
533 __m256 _vb0 = _mm256_broadcast_ss(vb);
534 __m256 _vb1 = _mm256_broadcast_ss(vb + 1);
535 __m256 _vb2 = _mm256_broadcast_ss(vb + 2);
536 __m256 _vb3 = _mm256_broadcast_ss(vb + 3);
537 __m256 _va0 = _mm256_loadu_ps(va);
538 __m256 _va1 = _mm256_loadu_ps(va + 8);
539 __m256 _va2 = _mm256_loadu_ps(va + 16);
540 __m256 _va3 = _mm256_loadu_ps(va + 24);
541
542 _sum0 = _mm256_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k70) * a00
543 _sum1 = _mm256_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k71) * a10
544 _sum2 = _mm256_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k72) * a20
545 _sum3 = _mm256_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k73) * a30
546
547 va += 32;
548 vb += 4;
549 }
550
551 _sum0 = _mm256_add_ps(_sum0, _sum1);
552 _sum2 = _mm256_add_ps(_sum2, _sum3);
553 _sum0_7 = _mm256_add_ps(_sum0_7, _sum0);
554 _sum0_7 = _mm256_add_ps(_sum0_7, _sum2);
555
556 for (; k < L; k++)
557 {
558 __m256 _vb0 = _mm256_broadcast_ss(vb);
559 __m256 _va = _mm256_loadu_ps(va);
560
561 _sum0_7 = _mm256_fmadd_ps(_va, _vb0, _sum0_7); // sum0 += (k00-k70) * a00
562
563 va += 8;
564 vb += 1;
565 }
566
567 float output_sum0_7[8] = {0.f};
568 _mm256_storeu_ps(output_sum0_7, _sum0_7);
569
570 output0[0] = output_sum0_7[0];
571 output1[0] = output_sum0_7[1];
572 output2[0] = output_sum0_7[2];
573 output3[0] = output_sum0_7[3];
574 output4[0] = output_sum0_7[4];
575 output5[0] = output_sum0_7[5];
576 output6[0] = output_sum0_7[6];
577 output7[0] = output_sum0_7[7];
578 #else
579 float sum0 = biasptr[0];
580 float sum1 = biasptr[1];
581 float sum2 = biasptr[2];
582 float sum3 = biasptr[3];
583 float sum4 = biasptr[4];
584 float sum5 = biasptr[5];
585 float sum6 = biasptr[6];
586 float sum7 = biasptr[7];
587
588 for (int k = 0; k < L; k++)
589 {
590 sum0 += va[0] * vb[0];
591 sum1 += va[1] * vb[0];
592 sum2 += va[2] * vb[0];
593 sum3 += va[3] * vb[0];
594 sum4 += va[4] * vb[0];
595 sum5 += va[5] * vb[0];
596 sum6 += va[6] * vb[0];
597 sum7 += va[7] * vb[0];
598
599 va += 8;
600 vb += 1;
601 }
602
603 output0[0] = sum0;
604 output1[0] = sum1;
605 output2[0] = sum2;
606 output3[0] = sum3;
607 output4[0] = sum4;
608 output5[0] = sum5;
609 output6[0] = sum6;
610 output7[0] = sum7;
611 #endif // __AVX__
612 output0++;
613 output1++;
614 output2++;
615 output3++;
616 output4++;
617 output5++;
618 output6++;
619 output7++;
620 }
621 }
622
623 nn_outch = (outch - remain_outch_start) >> 2;
624
625 #pragma omp parallel for num_threads(opt.num_threads)
626 for (int pp = 0; pp < nn_outch; pp++)
627 {
628 int i = remain_outch_start + pp * 4;
629
630 float* output0 = top_blob.channel(i);
631 float* output1 = top_blob.channel(i + 1);
632 float* output2 = top_blob.channel(i + 2);
633 float* output3 = top_blob.channel(i + 3);
634
635 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
636 const float* biasptr = bias ? bias + i : zeros;
637
638 int j = 0;
639 for (; j + 7 < N; j = j + 8)
640 {
641 const float* vb = bottom_tm.channel(j / 8);
642 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
643 #if __AVX__
644 __m256 _sum0 = _mm256_broadcast_ss(biasptr);
645 __m256 _sum1 = _mm256_broadcast_ss(biasptr + 1);
646 __m256 _sum2 = _mm256_broadcast_ss(biasptr + 2);
647 __m256 _sum3 = _mm256_broadcast_ss(biasptr + 3);
648
649 int k = 0;
650 for (; k + 3 < L; k = k + 4)
651 {
652 // k0
653 __m256 _va0 = _mm256_broadcast_ss(va);
654 __m256 _va1 = _mm256_broadcast_ss(va + 1);
655 __m256 _va2 = _mm256_broadcast_ss(va + 2);
656 __m256 _va3 = _mm256_broadcast_ss(va + 3);
657 __m256 _vb0 = _mm256_loadu_ps(vb);
658 __m256 _vb1 = _mm256_loadu_ps(vb + 8);
659 __m256 _vb2 = _mm256_loadu_ps(vb + 16);
660 __m256 _vb3 = _mm256_loadu_ps(vb + 24);
661 _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
662 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
663 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
664 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
665
666 va += 4;
667
668 // k1
669 _va0 = _mm256_broadcast_ss(va);
670 _va1 = _mm256_broadcast_ss(va + 1);
671 _va2 = _mm256_broadcast_ss(va + 2);
672 _va3 = _mm256_broadcast_ss(va + 3);
673 _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
674 _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
675 _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
676 _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
677
678 va += 4;
679
680 // k2
681 _va0 = _mm256_broadcast_ss(va);
682 _va1 = _mm256_broadcast_ss(va + 1);
683 _va2 = _mm256_broadcast_ss(va + 2);
684 _va3 = _mm256_broadcast_ss(va + 3);
685 _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
686 _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
687 _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
688 _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
689
690 va += 4;
691
692 // k3
693 _va0 = _mm256_broadcast_ss(va);
694 _va1 = _mm256_broadcast_ss(va + 1);
695 _va2 = _mm256_broadcast_ss(va + 2);
696 _va3 = _mm256_broadcast_ss(va + 3);
697 _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
698 _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
699 _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
700 _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
701
702 va += 4;
703 vb += 32;
704 }
705
706 for (; k < L; k++)
707 {
708 // k0
709 __m256 _va0 = _mm256_broadcast_ss(va);
710 __m256 _va1 = _mm256_broadcast_ss(va + 1);
711 __m256 _va2 = _mm256_broadcast_ss(va + 2);
712 __m256 _va3 = _mm256_broadcast_ss(va + 3);
713 __m256 _vb0 = _mm256_loadu_ps(vb);
714 _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
715 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
716 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
717 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
718
719 va += 4;
720 vb += 8;
721 }
722
723 _mm256_storeu_ps(output0, _sum0);
724 _mm256_storeu_ps(output1, _sum1);
725 _mm256_storeu_ps(output2, _sum2);
726 _mm256_storeu_ps(output3, _sum3);
727 #else
728 float sum0[8] = {0};
729 float sum1[8] = {0};
730 float sum2[8] = {0};
731 float sum3[8] = {0};
732
733 int k = 0;
734 for (; k + 7 < L; k = k + 8)
735 {
736 for (int n = 0; n < 8; n++)
737 {
738 sum0[n] += va[0] * vb[n];
739 sum1[n] += va[1] * vb[n];
740 sum2[n] += va[2] * vb[n];
741 sum3[n] += va[3] * vb[n];
742 va += 4;
743
744 sum0[n] += va[0] * vb[n + 8];
745 sum1[n] += va[1] * vb[n + 8];
746 sum2[n] += va[2] * vb[n + 8];
747 sum3[n] += va[3] * vb[n + 8];
748 va += 4;
749
750 sum0[n] += va[0] * vb[n + 16];
751 sum1[n] += va[1] * vb[n + 16];
752 sum2[n] += va[2] * vb[n + 16];
753 sum3[n] += va[3] * vb[n + 16];
754 va += 4;
755
756 sum0[n] += va[0] * vb[n + 24];
757 sum1[n] += va[1] * vb[n + 24];
758 sum2[n] += va[2] * vb[n + 24];
759 sum3[n] += va[3] * vb[n + 24];
760 va += 4;
761
762 sum0[n] += va[0] * vb[n + 32];
763 sum1[n] += va[1] * vb[n + 32];
764 sum2[n] += va[2] * vb[n + 32];
765 sum3[n] += va[3] * vb[n + 32];
766 va += 4;
767
768 sum0[n] += va[0] * vb[n + 40];
769 sum1[n] += va[1] * vb[n + 40];
770 sum2[n] += va[2] * vb[n + 40];
771 sum3[n] += va[3] * vb[n + 40];
772 va += 4;
773
774 sum0[n] += va[0] * vb[n + 48];
775 sum1[n] += va[1] * vb[n + 48];
776 sum2[n] += va[2] * vb[n + 48];
777 sum3[n] += va[3] * vb[n + 48];
778 va += 4;
779
780 sum0[n] += va[0] * vb[n + 56];
781 sum1[n] += va[1] * vb[n + 56];
782 sum2[n] += va[2] * vb[n + 56];
783 sum3[n] += va[3] * vb[n + 56];
784 va -= 28;
785 }
786
787 va += 32;
788 vb += 64;
789 }
790
791 for (; k < L; k++)
792 {
793 for (int n = 0; n < 8; n++)
794 {
795 sum0[n] += va[0] * vb[n];
796 sum1[n] += va[1] * vb[n];
797 sum2[n] += va[2] * vb[n];
798 sum3[n] += va[3] * vb[n];
799 }
800
801 va += 4;
802 vb += 8;
803 }
804
805 for (int n = 0; n < 8; n++)
806 {
807 output0[n] = sum0[n] + biasptr[0];
808 output1[n] = sum1[n] + biasptr[1];
809 output2[n] = sum2[n] + biasptr[2];
810 output3[n] = sum3[n] + biasptr[3];
811 }
812 #endif // __AVX__
813 output0 += 8;
814 output1 += 8;
815 output2 += 8;
816 output3 += 8;
817 }
818
819 for (; j < N; j++)
820 {
821 const float* vb = bottom_tm.channel(j / 8 + j % 8);
822 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
823 #if __AVX__
824 __m128 _sum0_3 = _mm_loadu_ps(biasptr);
825 __m128 _sum0 = _mm_set1_ps(0.0);
826 __m128 _sum1 = _mm_set1_ps(0.0);
827 __m128 _sum2 = _mm_set1_ps(0.0);
828 __m128 _sum3 = _mm_set1_ps(0.0);
829
830 int k = 0;
831 for (; k + 3 < L; k = k + 4)
832 {
833 __m128 _vb0 = _mm_set1_ps(vb[0]);
834 __m128 _vb1 = _mm_set1_ps(vb[1]);
835 __m128 _vb2 = _mm_set1_ps(vb[2]);
836 __m128 _vb3 = _mm_set1_ps(vb[3]);
837 __m128 _va0 = _mm_loadu_ps(va);
838 __m128 _va1 = _mm_loadu_ps(va + 4);
839 __m128 _va2 = _mm_loadu_ps(va + 8);
840 __m128 _va3 = _mm_loadu_ps(va + 12);
841
842 _sum0 = _mm_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k30) * a00
843 _sum1 = _mm_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k31) * a10
844 _sum2 = _mm_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k32) * a20
845 _sum3 = _mm_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k33) * a30
846
847 va += 16;
848 vb += 4;
849 }
850
851 _sum0 = _mm_add_ps(_sum0, _sum1);
852 _sum2 = _mm_add_ps(_sum2, _sum3);
853 _sum0_3 = _mm_add_ps(_sum0_3, _sum0);
854 _sum0_3 = _mm_add_ps(_sum0_3, _sum2);
855
856 for (; k < L; k++)
857 {
858 __m128 _vb0 = _mm_set1_ps(vb[0]);
859 __m128 _va = _mm_loadu_ps(va);
860
861 _sum0_3 = _mm_fmadd_ps(_va, _vb0, _sum0_3); // sum0 += (k00-k30) * a00
862
863 va += 4;
864 vb += 1;
865 }
866
867 float output_sum0_3[4] = {0.f};
868 _mm_storeu_ps(output_sum0_3, _sum0_3);
869 output0[0] = output_sum0_3[0];
870 output1[0] = output_sum0_3[1];
871 output2[0] = output_sum0_3[2];
872 output3[0] = output_sum0_3[3];
873 #else
874 float sum0 = biasptr[0];
875 float sum1 = biasptr[1];
876 float sum2 = biasptr[2];
877 float sum3 = biasptr[3];
878
879 for (int k = 0; k < L; k++)
880 {
881 sum0 += va[0] * vb[0];
882 sum1 += va[1] * vb[0];
883 sum2 += va[2] * vb[0];
884 sum3 += va[3] * vb[0];
885
886 va += 4;
887 vb += 1;
888 }
889
890 output0[0] = sum0;
891 output1[0] = sum1;
892 output2[0] = sum2;
893 output3[0] = sum3;
894 #endif // __AVX__
895 output0++;
896 output1++;
897 output2++;
898 output3++;
899 }
900 }
901
902 remain_outch_start += nn_outch << 2;
903
904 #pragma omp parallel for num_threads(opt.num_threads)
905 for (int i = remain_outch_start; i < outch; i++)
906 {
907 float* output = top_blob.channel(i);
908
909 const float bias0 = bias ? bias[i] : 0.f;
910
911 int j = 0;
912 for (; j + 7 < N; j = j + 8)
913 {
914 const float* vb = bottom_tm.channel(j / 8);
915 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
916 #if __AVX__
917 __m256 _sum0 = _mm256_broadcast_ss(&bias0);
918
919 int k = 0;
920 for (; k + 3 < L; k = k + 4)
921 {
922 // k0
923 __m256 _va0 = _mm256_broadcast_ss(va);
924 __m256 _va1 = _mm256_broadcast_ss(va + 1);
925 __m256 _va2 = _mm256_broadcast_ss(va + 2);
926 __m256 _va3 = _mm256_broadcast_ss(va + 3);
927 __m256 _vb0 = _mm256_loadu_ps(vb);
928 __m256 _vb1 = _mm256_loadu_ps(vb + 8);
929 __m256 _vb2 = _mm256_loadu_ps(vb + 16);
930 __m256 _vb3 = _mm256_loadu_ps(vb + 24);
931
932 _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
933 _sum0 = _mm256_fmadd_ps(_vb1, _va1, _sum0); // sum0 += (a10-a17) * k01
934 _sum0 = _mm256_fmadd_ps(_vb2, _va2, _sum0); // sum0 += (a20-a27) * k02
935 _sum0 = _mm256_fmadd_ps(_vb3, _va3, _sum0); // sum0 += (a30-a37) * k03
936
937 va += 4;
938 vb += 32;
939 }
940
941 for (; k < L; k++)
942 {
943 // k0
944 __m256 _va0 = _mm256_broadcast_ss(va);
945 __m256 _vb0 = _mm256_loadu_ps(vb);
946
947 _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
948
949 va += 1;
950 vb += 8;
951 }
952
953 _mm256_storeu_ps(output, _sum0);
954 #else
955 float sum[8] = {0};
956
957 int k = 0;
958 for (; k + 7 < L; k = k + 8)
959 {
960 for (int n = 0; n < 8; n++)
961 {
962 sum[n] += va[0] * vb[n];
963 sum[n] += va[1] * vb[n + 8];
964 sum[n] += va[2] * vb[n + 16];
965 sum[n] += va[3] * vb[n + 24];
966 sum[n] += va[4] * vb[n + 32];
967 sum[n] += va[5] * vb[n + 40];
968 sum[n] += va[6] * vb[n + 48];
969 sum[n] += va[7] * vb[n + 56];
970 }
971
972 va += 8;
973 vb += 64;
974 }
975
976 for (; k < L; k++)
977 {
978 for (int n = 0; n < 8; n++)
979 {
980 sum[n] += va[0] * vb[n];
981 }
982
983 va += 1;
984 vb += 8;
985 }
986
987 for (int n = 0; n < 8; n++)
988 {
989 output[n] = sum[n] + bias0;
990 }
991 #endif // __AVX__
992 output += 8;
993 }
994
995 for (; j < N; j++)
996 {
997 const float* vb = bottom_tm.channel(j / 8 + j % 8);
998 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
999
1000 int k = 0;
1001 #if __AVX__
1002 __m128 _sum0 = _mm_set1_ps(0.f);
1003
1004 for (; k + 3 < L; k += 4)
1005 {
1006 __m128 _p0 = _mm_loadu_ps(vb);
1007 vb += 4;
1008
1009 __m128 _k0 = _mm_loadu_ps(va);
1010 va += 4;
1011
1012 _sum0 = _mm_fmadd_ps(_p0, _k0, _sum0);
1013 }
1014
1015 float output_sum0[4] = {0.f};
1016 _mm_storeu_ps(output_sum0, _sum0);
1017
1018 float sum0 = bias0 + output_sum0[0] + output_sum0[1] + output_sum0[2] + output_sum0[3];
1019
1020 #else
1021 float sum0 = bias0;
1022 #endif // __AVX__
1023 for (; k < L; k++)
1024 {
1025 sum0 += va[0] * vb[0];
1026
1027 va += 1;
1028 vb += 1;
1029 }
1030 output[0] = sum0;
1031
1032 output++;
1033 }
1034 }
1035 }
1036 }
1037 #else
conv_im2col_sgemm_transform_kernel_sse(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)1038 static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
1039 {
1040 const float* kernel = _kernel;
1041
1042 // kernel memory packed 4 x 4
1043 kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4);
1044
1045 int nn_outch = 0;
1046 int remain_outch_start = 0;
1047
1048 nn_outch = outch >> 2;
1049 remain_outch_start = nn_outch << 2;
1050
1051 for (int pp = 0; pp < nn_outch; pp++)
1052 {
1053 int p = pp * 4;
1054
1055 const float* k0 = kernel + (p + 0) * inch * kernel_size;
1056 const float* k1 = kernel + (p + 1) * inch * kernel_size;
1057 const float* k2 = kernel + (p + 2) * inch * kernel_size;
1058 const float* k3 = kernel + (p + 3) * inch * kernel_size;
1059
1060 float* ktmp = kernel_tm.channel(p / 4);
1061
1062 for (int q = 0; q < inch * kernel_size; q++)
1063 {
1064 ktmp[0] = k0[0];
1065 ktmp[1] = k1[0];
1066 ktmp[2] = k2[0];
1067 ktmp[3] = k3[0];
1068 ktmp += 4;
1069
1070 k0 += 1;
1071 k1 += 1;
1072 k2 += 1;
1073 k3 += 1;
1074 }
1075 }
1076
1077 for (int p = remain_outch_start; p < outch; p++)
1078 {
1079 const float* k0 = kernel + (p + 0) * inch * kernel_size;
1080
1081 float* ktmp = kernel_tm.channel(p / 4 + p % 4);
1082
1083 for (int q = 0; q < inch * kernel_size; q++)
1084 {
1085 ktmp[0] = k0[0];
1086 ktmp++;
1087 k0++;
1088 }
1089 }
1090 }
1091
conv_im2col_sgemm_sse(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)1092 static void conv_im2col_sgemm_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias,
1093 const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
1094 {
1095 int w = bottom_blob.w;
1096 int inch = bottom_blob.c;
1097 size_t elemsize = bottom_blob.elemsize;
1098
1099 int outw = top_blob.w;
1100 int outh = top_blob.h;
1101 int outch = top_blob.c;
1102
1103 const float* bias = _bias;
1104
1105 // im2col
1106 Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator);
1107 {
1108 const int stride = kernel_h * kernel_w * outw * outh;
1109 float* ret = (float*)bottom_im2col;
1110
1111 #pragma omp parallel for num_threads(opt.num_threads)
1112 for (int p = 0; p < inch; p++)
1113 {
1114 const float* input = bottom_blob.channel(p);
1115 int retID = stride * p;
1116 for (int u = 0; u < kernel_h; u++)
1117 {
1118 for (int v = 0; v < kernel_w; v++)
1119 {
1120 for (int i = 0; i < outh; i++)
1121 {
1122 for (int j = 0; j < outw; j++)
1123 {
1124 int row = u + i * stride_h;
1125 int col = v + j * stride_w;
1126 int index = row * w + col;
1127 ret[retID] = input[index];
1128 retID++;
1129 }
1130 }
1131 }
1132 }
1133 }
1134 }
1135
1136 int kernel_size = kernel_w * kernel_h;
1137 int out_size = outw * outh;
1138
1139 // bottom_im2col memory packed 4 x 4
1140 Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, elemsize, opt.workspace_allocator);
1141 {
1142 int nn_size = out_size >> 2;
1143 int remain_size_start = nn_size << 2;
1144
1145 #pragma omp parallel for num_threads(opt.num_threads)
1146 for (int ii = 0; ii < nn_size; ii++)
1147 {
1148 int i = ii * 4;
1149
1150 const float* img0 = bottom_im2col.channel(0);
1151 img0 += i;
1152
1153 float* tmpptr = bottom_tm.channel(i / 4);
1154
1155 for (int q = 0; q < inch * kernel_size; q++)
1156 {
1157 #if __SSE__
1158 _mm_storeu_ps(tmpptr, _mm_loadu_ps(img0));
1159 #else
1160 tmpptr[0] = img0[0];
1161 tmpptr[1] = img0[1];
1162 tmpptr[2] = img0[2];
1163 tmpptr[3] = img0[3];
1164 #endif // __SSE__
1165 tmpptr += 4;
1166 img0 += out_size;
1167 }
1168 }
1169
1170 #pragma omp parallel for num_threads(opt.num_threads)
1171 for (int i = remain_size_start; i < out_size; i++)
1172 {
1173 const float* img0 = bottom_im2col.channel(0);
1174 img0 += i;
1175
1176 float* tmpptr = bottom_tm.channel(i / 4 + i % 4);
1177
1178 for (int q = 0; q < inch * kernel_size; q++)
1179 {
1180 tmpptr[0] = img0[0];
1181
1182 tmpptr += 1;
1183 img0 += out_size;
1184 }
1185 }
1186 }
1187
1188 // sgemm(int M, int N, int L, float* A, float* B, float* C)
1189 {
1190 //int M = outch; // outch
1191 int N = outw * outh; // outsize or out stride
1192 int L = kernel_w * kernel_h * inch; // ksize * inch
1193
1194 int nn_outch = 0;
1195 int remain_outch_start = 0;
1196
1197 nn_outch = outch >> 2;
1198 remain_outch_start = nn_outch << 2;
1199
1200 #pragma omp parallel for num_threads(opt.num_threads)
1201 for (int pp = 0; pp < nn_outch; pp++)
1202 {
1203 int i = pp * 4;
1204
1205 float* output0 = top_blob.channel(i);
1206 float* output1 = top_blob.channel(i + 1);
1207 float* output2 = top_blob.channel(i + 2);
1208 float* output3 = top_blob.channel(i + 3);
1209
1210 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
1211 const float* biasptr = bias ? bias + i : zeros;
1212
1213 int j = 0;
1214 for (; j + 3 < N; j = j + 4)
1215 {
1216 const float* vb = bottom_tm.channel(j / 4);
1217 const float* va = kernel_tm.channel(i / 4);
1218 #if __SSE__
1219 __m128 _sum0 = _mm_set1_ps(biasptr[0]);
1220 __m128 _sum1 = _mm_set1_ps(biasptr[1]);
1221 __m128 _sum2 = _mm_set1_ps(biasptr[2]);
1222 __m128 _sum3 = _mm_set1_ps(biasptr[3]);
1223
1224 int k = 0;
1225 for (; k + 3 < L; k = k + 4)
1226 {
1227 // k0
1228 __m128 _vb = _mm_loadu_ps(vb);
1229 __m128 _va0 = _mm_set1_ps(va[0]);
1230 __m128 _va1 = _mm_set1_ps(va[1]);
1231 __m128 _va2 = _mm_set1_ps(va[2]);
1232 __m128 _va3 = _mm_set1_ps(va[3]);
1233 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a00-a03) * k00
1234 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a00-a03) * k10
1235 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a00-a03) * k20
1236 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a00-a03) * k30
1237
1238 // k1
1239 _vb = _mm_loadu_ps(vb + 4);
1240 _va0 = _mm_set1_ps(va[4]);
1241 _va1 = _mm_set1_ps(va[5]);
1242 _va2 = _mm_set1_ps(va[6]);
1243 _va3 = _mm_set1_ps(va[7]);
1244 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a10-a13) * k01
1245 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a10-a13) * k11
1246 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a10-a13) * k21
1247 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a10-a13) * k31
1248
1249 // k2
1250 _vb = _mm_loadu_ps(vb + 8);
1251 _va0 = _mm_set1_ps(va[8]);
1252 _va1 = _mm_set1_ps(va[9]);
1253 _va2 = _mm_set1_ps(va[10]);
1254 _va3 = _mm_set1_ps(va[11]);
1255 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a20-a23) * k02
1256 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a20-a23) * k12
1257 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a20-a23) * k22
1258 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a20-a23) * k32
1259
1260 // k3
1261 _vb = _mm_loadu_ps(vb + 12);
1262 _va0 = _mm_set1_ps(va[12]);
1263 _va1 = _mm_set1_ps(va[13]);
1264 _va2 = _mm_set1_ps(va[14]);
1265 _va3 = _mm_set1_ps(va[15]);
1266 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a30-a33) * k03
1267 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a30-a33) * k13
1268 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a30-a33) * k23
1269 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a30-a33) * k33
1270
1271 va += 16;
1272 vb += 16;
1273 }
1274
1275 for (; k < L; k++)
1276 {
1277 // k0
1278 __m128 _vb = _mm_loadu_ps(vb);
1279 __m128 _va0 = _mm_set1_ps(va[0]);
1280 __m128 _va1 = _mm_set1_ps(va[1]);
1281 __m128 _va2 = _mm_set1_ps(va[2]);
1282 __m128 _va3 = _mm_set1_ps(va[3]);
1283 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a00-a03) * k00
1284 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a00-a03) * k10
1285 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a00-a03) * k20
1286 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a00-a03) * k30
1287
1288 va += 4;
1289 vb += 4;
1290 }
1291 _mm_storeu_ps(output0, _sum0);
1292 _mm_storeu_ps(output1, _sum1);
1293 _mm_storeu_ps(output2, _sum2);
1294 _mm_storeu_ps(output3, _sum3);
1295 #else
1296 float sum0[4] = {0};
1297 float sum1[4] = {0};
1298 float sum2[4] = {0};
1299 float sum3[4] = {0};
1300
1301 int k = 0;
1302 for (; k + 7 < L; k = k + 8)
1303 {
1304 for (int n = 0; n < 4; n++)
1305 {
1306 sum0[n] += va[0] * vb[n];
1307 sum1[n] += va[1] * vb[n];
1308 sum2[n] += va[2] * vb[n];
1309 sum3[n] += va[3] * vb[n];
1310 va += 4;
1311
1312 sum0[n] += va[0] * vb[n + 4];
1313 sum1[n] += va[1] * vb[n + 4];
1314 sum2[n] += va[2] * vb[n + 4];
1315 sum3[n] += va[3] * vb[n + 4];
1316 va += 4;
1317
1318 sum0[n] += va[0] * vb[n + 8];
1319 sum1[n] += va[1] * vb[n + 8];
1320 sum2[n] += va[2] * vb[n + 8];
1321 sum3[n] += va[3] * vb[n + 8];
1322 va += 4;
1323
1324 sum0[n] += va[0] * vb[n + 12];
1325 sum1[n] += va[1] * vb[n + 12];
1326 sum2[n] += va[2] * vb[n + 12];
1327 sum3[n] += va[3] * vb[n + 12];
1328 va += 4;
1329
1330 sum0[n] += va[0] * vb[n + 16];
1331 sum1[n] += va[1] * vb[n + 16];
1332 sum2[n] += va[2] * vb[n + 16];
1333 sum3[n] += va[3] * vb[n + 16];
1334 va += 4;
1335
1336 sum0[n] += va[0] * vb[n + 20];
1337 sum1[n] += va[1] * vb[n + 20];
1338 sum2[n] += va[2] * vb[n + 20];
1339 sum3[n] += va[3] * vb[n + 20];
1340 va += 4;
1341
1342 sum0[n] += va[0] * vb[n + 24];
1343 sum1[n] += va[1] * vb[n + 24];
1344 sum2[n] += va[2] * vb[n + 24];
1345 sum3[n] += va[3] * vb[n + 24];
1346 va += 4;
1347
1348 sum0[n] += va[0] * vb[n + 28];
1349 sum1[n] += va[1] * vb[n + 28];
1350 sum2[n] += va[2] * vb[n + 28];
1351 sum3[n] += va[3] * vb[n + 28];
1352 va -= 28;
1353 }
1354
1355 va += 32;
1356 vb += 32;
1357 }
1358
1359 for (; k < L; k++)
1360 {
1361 for (int n = 0; n < 4; n++)
1362 {
1363 sum0[n] += va[0] * vb[n];
1364 sum1[n] += va[1] * vb[n];
1365 sum2[n] += va[2] * vb[n];
1366 sum3[n] += va[3] * vb[n];
1367 }
1368
1369 va += 4;
1370 vb += 4;
1371 }
1372
1373 for (int n = 0; n < 4; n++)
1374 {
1375 output0[n] = sum0[n] + biasptr[0];
1376 output1[n] = sum1[n] + biasptr[1];
1377 output2[n] = sum2[n] + biasptr[2];
1378 output3[n] = sum3[n] + biasptr[3];
1379 }
1380 #endif // __SSE__
1381 output0 += 4;
1382 output1 += 4;
1383 output2 += 4;
1384 output3 += 4;
1385 }
1386
1387 for (; j < N; j++)
1388 {
1389 const float* vb = bottom_tm.channel(j / 4 + j % 4);
1390 const float* va = kernel_tm.channel(i / 4);
1391 #if __SSE__
1392 __m128 _sum0_3 = _mm_loadu_ps(biasptr);
1393 __m128 _sum0 = _mm_set1_ps(0.0);
1394 __m128 _sum1 = _mm_set1_ps(0.0);
1395 __m128 _sum2 = _mm_set1_ps(0.0);
1396 __m128 _sum3 = _mm_set1_ps(0.0);
1397
1398 int k = 0;
1399 for (; k + 3 < L; k = k + 4)
1400 {
1401 __m128 _vb0 = _mm_set1_ps(vb[0]);
1402 __m128 _vb1 = _mm_set1_ps(vb[1]);
1403 __m128 _vb2 = _mm_set1_ps(vb[2]);
1404 __m128 _vb3 = _mm_set1_ps(vb[3]);
1405 __m128 _va0 = _mm_loadu_ps(va);
1406 __m128 _va1 = _mm_loadu_ps(va + 4);
1407 __m128 _va2 = _mm_loadu_ps(va + 8);
1408 __m128 _va3 = _mm_loadu_ps(va + 12);
1409
1410 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_va0, _vb0)); // sum0 += (k00-k30) * a00
1411 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_va1, _vb1)); // sum1 += (k01-k31) * a10
1412 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_va2, _vb2)); // sum2 += (k02-k32) * a20
1413 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_va3, _vb3)); // sum3 += (k03-k33) * a30
1414
1415 va += 16;
1416 vb += 4;
1417 }
1418
1419 _sum0 = _mm_add_ps(_sum0, _sum1);
1420 _sum2 = _mm_add_ps(_sum2, _sum3);
1421 _sum0_3 = _mm_add_ps(_sum0_3, _sum0);
1422 _sum0_3 = _mm_add_ps(_sum0_3, _sum2);
1423
1424 for (; k < L; k++)
1425 {
1426 __m128 _vb0 = _mm_set1_ps(vb[0]);
1427 __m128 _va = _mm_loadu_ps(va);
1428
1429 _sum0_3 = _mm_add_ps(_sum0_3, _mm_mul_ps(_va, _vb0)); // sum0 += (k00-k30) * a00
1430
1431 va += 4;
1432 vb += 1;
1433 }
1434
1435 float sum0_3_tmp[4];
1436 _mm_storeu_ps(sum0_3_tmp, _sum0_3);
1437 output0[0] = sum0_3_tmp[0];
1438 output1[0] = sum0_3_tmp[1];
1439 output2[0] = sum0_3_tmp[2];
1440 output3[0] = sum0_3_tmp[3];
1441 #else
1442 float sum0 = biasptr[0];
1443 float sum1 = biasptr[1];
1444 float sum2 = biasptr[2];
1445 float sum3 = biasptr[3];
1446
1447 for (int k = 0; k < L; k++)
1448 {
1449 sum0 += va[0] * vb[0];
1450 sum1 += va[1] * vb[0];
1451 sum2 += va[2] * vb[0];
1452 sum3 += va[3] * vb[0];
1453
1454 va += 4;
1455 vb += 1;
1456 }
1457
1458 output0[0] = sum0;
1459 output1[0] = sum1;
1460 output2[0] = sum2;
1461 output3[0] = sum3;
1462 #endif // __SSE__
1463 output0++;
1464 output1++;
1465 output2++;
1466 output3++;
1467 }
1468 }
1469
1470 #pragma omp parallel for num_threads(opt.num_threads)
1471 for (int i = remain_outch_start; i < outch; i++)
1472 {
1473 float* output = top_blob.channel(i);
1474
1475 const float bias0 = bias ? bias[i] : 0.f;
1476
1477 int j = 0;
1478 for (; j + 3 < N; j = j + 4)
1479 {
1480 const float* vb = bottom_tm.channel(j / 4);
1481 const float* va = kernel_tm.channel(i / 4 + i % 4);
1482 #if __SSE__
1483 __m128 _sum0 = _mm_set1_ps(bias0);
1484
1485 int k = 0;
1486 for (; k + 3 < L; k = k + 4)
1487 {
1488 // k0
1489 __m128 _va0 = _mm_set1_ps(va[0]);
1490 __m128 _va1 = _mm_set1_ps(va[1]);
1491 __m128 _va2 = _mm_set1_ps(va[2]);
1492 __m128 _va3 = _mm_set1_ps(va[3]);
1493 __m128 _vb0 = _mm_loadu_ps(vb);
1494 __m128 _vb1 = _mm_loadu_ps(vb + 4);
1495 __m128 _vb2 = _mm_loadu_ps(vb + 8);
1496 __m128 _vb3 = _mm_loadu_ps(vb + 12);
1497
1498 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00
1499 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb1, _va1)); // sum0 += (a10-a13) * k01
1500 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb2, _va2)); // sum0 += (a20-a23) * k02
1501 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb3, _va3)); // sum0 += (a30-a33) * k03
1502
1503 va += 4;
1504 vb += 16;
1505 }
1506
1507 for (; k < L; k++)
1508 {
1509 // k0
1510 __m128 _va0 = _mm_set1_ps(va[0]);
1511 __m128 _vb0 = _mm_loadu_ps(vb);
1512
1513 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00
1514
1515 va += 1;
1516 vb += 4;
1517 }
1518 _mm_storeu_ps(output, _sum0);
1519 #else
1520 float sum[4] = {0};
1521
1522 int k = 0;
1523 for (; k + 3 < L; k = k + 4)
1524 {
1525 for (int n = 0; n < 4; n++)
1526 {
1527 sum[n] += va[0] * vb[n];
1528 sum[n] += va[1] * vb[n + 4];
1529 sum[n] += va[2] * vb[n + 8];
1530 sum[n] += va[3] * vb[n + 12];
1531 //sum[n] += va[4] * vb[n+16];
1532 //sum[n] += va[5] * vb[n+20];
1533 //sum[n] += va[6] * vb[n+24];
1534 //sum[n] += va[7] * vb[n+28];
1535 }
1536
1537 va += 4;
1538 vb += 16;
1539 }
1540
1541 for (; k < L; k++)
1542 {
1543 for (int n = 0; n < 4; n++)
1544 {
1545 sum[n] += va[0] * vb[n];
1546 }
1547
1548 va += 1;
1549 vb += 4;
1550 }
1551
1552 for (int n = 0; n < 4; n++)
1553 {
1554 output[n] = sum[n] + bias0;
1555 }
1556 #endif // __SSE__
1557 output += 4;
1558 }
1559
1560 for (; j < N; j++)
1561 {
1562 const float* vb = bottom_tm.channel(j / 4 + j % 4);
1563 const float* va = kernel_tm.channel(i / 4 + i % 4);
1564
1565 int k = 0;
1566 #if __SSE__
1567 __m128 _sum0 = _mm_set1_ps(0.f);
1568
1569 for (; k + 3 < L; k += 4)
1570 {
1571 __m128 _p0 = _mm_loadu_ps(vb);
1572 __m128 _k0 = _mm_loadu_ps(va);
1573 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_p0, _k0));
1574
1575 va += 4;
1576 vb += 4;
1577 }
1578 float sum0_tmp[4];
1579 _mm_storeu_ps(sum0_tmp, _sum0);
1580 float sum0 = bias0 + sum0_tmp[0] + sum0_tmp[1] + sum0_tmp[2] + sum0_tmp[3];
1581 #else
1582 float sum0 = bias0;
1583 #endif // __SSE__
1584 for (; k < L; k++)
1585 {
1586 sum0 += va[0] * vb[0];
1587
1588 va += 1;
1589 vb += 1;
1590 }
1591 output[0] = sum0;
1592
1593 output++;
1594 }
1595 }
1596 }
1597 }
1598 #endif
1599