1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv1x1s1_sgemm_transform_kernel_pack8_avx(const Mat & kernel,Mat & weight_data_pack8,int num_input,int num_output)15 static void conv1x1s1_sgemm_transform_kernel_pack8_avx(const Mat& kernel, Mat& weight_data_pack8, int num_input, int num_output)
16 {
17     // src = kw-kh-inch-outch
18     // dst = 8b-8a-kw-kh-inch/8a-outch/8b
19     Mat weight_data_r2 = kernel.reshape(1, num_input, num_output);
20 
21     weight_data_pack8.create(1, num_input / 8, num_output / 8, (size_t)4 * 64, 64);
22 
23     for (int q = 0; q + 7 < num_output; q += 8)
24     {
25         const Mat k0 = weight_data_r2.channel(q);
26         const Mat k1 = weight_data_r2.channel(q + 1);
27         const Mat k2 = weight_data_r2.channel(q + 2);
28         const Mat k3 = weight_data_r2.channel(q + 3);
29         const Mat k4 = weight_data_r2.channel(q + 4);
30         const Mat k5 = weight_data_r2.channel(q + 5);
31         const Mat k6 = weight_data_r2.channel(q + 6);
32         const Mat k7 = weight_data_r2.channel(q + 7);
33 
34         Mat g0 = weight_data_pack8.channel(q / 8);
35 
36         for (int p = 0; p + 7 < num_input; p += 8)
37         {
38             const float* k00 = k0.row(p);
39             const float* k01 = k0.row(p + 1);
40             const float* k02 = k0.row(p + 2);
41             const float* k03 = k0.row(p + 3);
42             const float* k04 = k0.row(p + 4);
43             const float* k05 = k0.row(p + 5);
44             const float* k06 = k0.row(p + 6);
45             const float* k07 = k0.row(p + 7);
46 
47             const float* k10 = k1.row(p);
48             const float* k11 = k1.row(p + 1);
49             const float* k12 = k1.row(p + 2);
50             const float* k13 = k1.row(p + 3);
51             const float* k14 = k1.row(p + 4);
52             const float* k15 = k1.row(p + 5);
53             const float* k16 = k1.row(p + 6);
54             const float* k17 = k1.row(p + 7);
55 
56             const float* k20 = k2.row(p);
57             const float* k21 = k2.row(p + 1);
58             const float* k22 = k2.row(p + 2);
59             const float* k23 = k2.row(p + 3);
60             const float* k24 = k2.row(p + 4);
61             const float* k25 = k2.row(p + 5);
62             const float* k26 = k2.row(p + 6);
63             const float* k27 = k2.row(p + 7);
64 
65             const float* k30 = k3.row(p);
66             const float* k31 = k3.row(p + 1);
67             const float* k32 = k3.row(p + 2);
68             const float* k33 = k3.row(p + 3);
69             const float* k34 = k3.row(p + 4);
70             const float* k35 = k3.row(p + 5);
71             const float* k36 = k3.row(p + 6);
72             const float* k37 = k3.row(p + 7);
73 
74             const float* k40 = k4.row(p);
75             const float* k41 = k4.row(p + 1);
76             const float* k42 = k4.row(p + 2);
77             const float* k43 = k4.row(p + 3);
78             const float* k44 = k4.row(p + 4);
79             const float* k45 = k4.row(p + 5);
80             const float* k46 = k4.row(p + 6);
81             const float* k47 = k4.row(p + 7);
82 
83             const float* k50 = k5.row(p);
84             const float* k51 = k5.row(p + 1);
85             const float* k52 = k5.row(p + 2);
86             const float* k53 = k5.row(p + 3);
87             const float* k54 = k5.row(p + 4);
88             const float* k55 = k5.row(p + 5);
89             const float* k56 = k5.row(p + 6);
90             const float* k57 = k5.row(p + 7);
91 
92             const float* k60 = k6.row(p);
93             const float* k61 = k6.row(p + 1);
94             const float* k62 = k6.row(p + 2);
95             const float* k63 = k6.row(p + 3);
96             const float* k64 = k6.row(p + 4);
97             const float* k65 = k6.row(p + 5);
98             const float* k66 = k6.row(p + 6);
99             const float* k67 = k6.row(p + 7);
100 
101             const float* k70 = k7.row(p);
102             const float* k71 = k7.row(p + 1);
103             const float* k72 = k7.row(p + 2);
104             const float* k73 = k7.row(p + 3);
105             const float* k74 = k7.row(p + 4);
106             const float* k75 = k7.row(p + 5);
107             const float* k76 = k7.row(p + 6);
108             const float* k77 = k7.row(p + 7);
109 
110             float* g00 = g0.row(p / 8);
111             g00[0] = k00[0];
112             g00[1] = k10[0];
113             g00[2] = k20[0];
114             g00[3] = k30[0];
115             g00[4] = k40[0];
116             g00[5] = k50[0];
117             g00[6] = k60[0];
118             g00[7] = k70[0];
119             g00 += 8;
120             g00[0] = k01[0];
121             g00[1] = k11[0];
122             g00[2] = k21[0];
123             g00[3] = k31[0];
124             g00[4] = k41[0];
125             g00[5] = k51[0];
126             g00[6] = k61[0];
127             g00[7] = k71[0];
128 
129             g00 += 8;
130             g00[0] = k02[0];
131             g00[1] = k12[0];
132             g00[2] = k22[0];
133             g00[3] = k32[0];
134             g00[4] = k42[0];
135             g00[5] = k52[0];
136             g00[6] = k62[0];
137             g00[7] = k72[0];
138 
139             g00 += 8;
140             g00[0] = k03[0];
141             g00[1] = k13[0];
142             g00[2] = k23[0];
143             g00[3] = k33[0];
144             g00[4] = k43[0];
145             g00[5] = k53[0];
146             g00[6] = k63[0];
147             g00[7] = k73[0];
148 
149             g00 += 8;
150             g00[0] = k04[0];
151             g00[1] = k14[0];
152             g00[2] = k24[0];
153             g00[3] = k34[0];
154             g00[4] = k44[0];
155             g00[5] = k54[0];
156             g00[6] = k64[0];
157             g00[7] = k74[0];
158 
159             g00 += 8;
160             g00[0] = k05[0];
161             g00[1] = k15[0];
162             g00[2] = k25[0];
163             g00[3] = k35[0];
164             g00[4] = k45[0];
165             g00[5] = k55[0];
166             g00[6] = k65[0];
167             g00[7] = k75[0];
168 
169             g00 += 8;
170             g00[0] = k06[0];
171             g00[1] = k16[0];
172             g00[2] = k26[0];
173             g00[3] = k36[0];
174             g00[4] = k46[0];
175             g00[5] = k56[0];
176             g00[6] = k66[0];
177             g00[7] = k76[0];
178 
179             g00 += 8;
180             g00[0] = k07[0];
181             g00[1] = k17[0];
182             g00[2] = k27[0];
183             g00[3] = k37[0];
184             g00[4] = k47[0];
185             g00[5] = k57[0];
186             g00[6] = k67[0];
187             g00[7] = k77[0];
188 
189             g00 += 8;
190         }
191     }
192 }
193 
conv1x1s1_sgemm_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)194 static void conv1x1s1_sgemm_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
195 {
196     int w = bottom_blob.w;
197     int h = bottom_blob.h;
198     int inch = bottom_blob.c;
199     int outch = top_blob.c;
200     size_t elemsize = bottom_blob.elemsize;
201     int elempack = bottom_blob.elempack;
202 
203     const int size = w * h;
204 
205     const float* bias = _bias;
206     // interleave
207     Mat tmp(12, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + (size % 12 % 4) / 2 + size % 12 % 2, elemsize, elempack, opt.workspace_allocator);
208     {
209         int nn_size = size / 12;
210         int remain_size_start = nn_size * 12;
211         #pragma omp parallel for num_threads(opt.num_threads)
212         for (int ii = 0; ii < nn_size; ii++)
213         {
214             int i = ii * 12;
215             const float* img0 = bottom_blob.channel(0);
216             img0 += i * 8;
217 
218             float* tmpptr = tmp.channel(i / 12);
219 
220             for (int q = 0; q < inch; q++)
221             {
222                 __m256 _r0 = _mm256_loadu_ps(img0);
223                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
224                 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
225                 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
226                 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
227                 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
228                 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
229                 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
230                 __m256 _r8 = _mm256_loadu_ps(img0 + 64);
231                 __m256 _r9 = _mm256_loadu_ps(img0 + 72);
232                 __m256 _r10 = _mm256_loadu_ps(img0 + 80);
233                 __m256 _r11 = _mm256_loadu_ps(img0 + 88);
234                 _mm256_storeu_ps(tmpptr, _r0);
235                 _mm256_storeu_ps(tmpptr + 8, _r1);
236                 _mm256_storeu_ps(tmpptr + 16, _r2);
237                 _mm256_storeu_ps(tmpptr + 24, _r3);
238                 _mm256_storeu_ps(tmpptr + 32, _r4);
239                 _mm256_storeu_ps(tmpptr + 40, _r5);
240                 _mm256_storeu_ps(tmpptr + 48, _r6);
241                 _mm256_storeu_ps(tmpptr + 56, _r7);
242                 _mm256_storeu_ps(tmpptr + 64, _r8);
243                 _mm256_storeu_ps(tmpptr + 72, _r9);
244                 _mm256_storeu_ps(tmpptr + 80, _r10);
245                 _mm256_storeu_ps(tmpptr + 88, _r11);
246 
247                 tmpptr += 96;
248                 img0 += bottom_blob.cstep * 8;
249             }
250         }
251         nn_size = (size - remain_size_start) >> 3;
252         #pragma omp parallel for num_threads(opt.num_threads)
253         for (int ii = 0; ii < nn_size; ii++)
254         {
255             int i = remain_size_start + ii * 8;
256 
257             const float* img0 = bottom_blob.channel(0);
258             img0 += i * 8;
259 
260             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
261 
262             for (int q = 0; q < inch; q++)
263             {
264                 __m256 _r0 = _mm256_loadu_ps(img0);
265                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
266                 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
267                 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
268                 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
269                 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
270                 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
271                 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
272                 _mm256_storeu_ps(tmpptr, _r0);
273                 _mm256_storeu_ps(tmpptr + 8, _r1);
274                 _mm256_storeu_ps(tmpptr + 16, _r2);
275                 _mm256_storeu_ps(tmpptr + 24, _r3);
276                 _mm256_storeu_ps(tmpptr + 32, _r4);
277                 _mm256_storeu_ps(tmpptr + 40, _r5);
278                 _mm256_storeu_ps(tmpptr + 48, _r6);
279                 _mm256_storeu_ps(tmpptr + 56, _r7);
280 
281                 tmpptr += 64;
282                 img0 += bottom_blob.cstep * 8;
283             }
284         }
285 
286         remain_size_start += nn_size << 3;
287         nn_size = (size - remain_size_start) >> 2;
288 
289         #pragma omp parallel for num_threads(opt.num_threads)
290         for (int ii = 0; ii < nn_size; ii++)
291         {
292             int i = remain_size_start + ii * 4;
293 
294             const float* img0 = bottom_blob.channel(0);
295             img0 += i * 8;
296             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
297 
298             for (int q = 0; q < inch; q++)
299             {
300                 __m256 _r0 = _mm256_loadu_ps(img0);
301                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
302                 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
303                 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
304                 _mm256_storeu_ps(tmpptr, _r0);
305                 _mm256_storeu_ps(tmpptr + 8, _r1);
306                 _mm256_storeu_ps(tmpptr + 16, _r2);
307                 _mm256_storeu_ps(tmpptr + 24, _r3);
308 
309                 tmpptr += 32;
310                 img0 += bottom_blob.cstep * 8;
311             }
312         }
313 
314         remain_size_start += nn_size << 2;
315         nn_size = (size - remain_size_start) >> 1;
316         #pragma omp parallel for num_threads(opt.num_threads)
317         for (int ii = 0; ii < nn_size; ii++)
318         {
319             int i = remain_size_start + ii * 2;
320 
321             const float* img0 = bottom_blob.channel(0);
322             img0 += i * 8;
323             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
324 
325             for (int q = 0; q < inch; q++)
326             {
327                 __m256 _r0 = _mm256_loadu_ps(img0);
328                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
329                 _mm256_storeu_ps(tmpptr, _r0);
330                 _mm256_storeu_ps(tmpptr + 8, _r1);
331 
332                 tmpptr += 16;
333                 img0 += bottom_blob.cstep * 8;
334             }
335         }
336 
337         remain_size_start += nn_size << 1;
338         #pragma omp parallel for num_threads(opt.num_threads)
339         for (int i = remain_size_start; i < size; i++)
340         {
341             const float* img0 = bottom_blob.channel(0);
342             img0 += i * 8;
343             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
344             for (int q = 0; q < inch; q++)
345             {
346                 __m256 _r0 = _mm256_loadu_ps(img0);
347                 _mm256_storeu_ps(tmpptr, _r0);
348 
349                 tmpptr += 8;
350                 img0 += bottom_blob.cstep * 8;
351             }
352         }
353     }
354     #pragma omp parallel for num_threads(opt.num_threads)
355     for (int p = 0; p < outch; p++)
356     {
357         Mat out = top_blob.channel(p);
358 
359         __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
360 
361         float* outptr = out;
362         int i = 0;
363         for (; i + 11 < size; i += 12)
364         {
365             const float* tmpptr = tmp.channel(i / 12);
366 
367             __m256 _sum0 = _bias0;
368             __m256 _sum1 = _bias0;
369             __m256 _sum2 = _bias0;
370             __m256 _sum3 = _bias0;
371             __m256 _sum4 = _bias0;
372             __m256 _sum5 = _bias0;
373             __m256 _sum6 = _bias0;
374             __m256 _sum7 = _bias0;
375             __m256 _sum8 = _bias0;
376             __m256 _sum9 = _bias0;
377             __m256 _sum10 = _bias0;
378             __m256 _sum11 = _bias0;
379 
380             const float* kptr = (const float*)kernel + p * inch * 64;
381             for (int q = 0; q < inch; q++)
382             {
383                 __m256 _w0 = _mm256_loadu_ps(kptr);
384                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
385                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
386                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
387 
388                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
389                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
390                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
391                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
392 
393                 _mm256_fmadd_ps4(_sum0, _w0, _w1, _w2, _w3, _val00, _val01, _val02, _val03);
394 
395                 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
396                 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
397                 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
398                 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
399 
400                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
401                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
402                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
403                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
404 
405                 _mm256_fmadd_ps4(_sum0, _w4, _w5, _w6, _w7, _val04, _val05, _val06, _val07);
406 
407                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
408                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
409                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
410                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
411 
412                 _mm256_fmadd_ps4(_sum1, _w0, _w1, _w2, _w3, _val10, _val11, _val12, _val13);
413 
414                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
415                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
416                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
417                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
418 
419                 _mm256_fmadd_ps4(_sum1, _w4, _w5, _w6, _w7, _val14, _val15, _val16, _val17);
420 
421                 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
422                 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
423                 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
424                 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
425 
426                 _mm256_fmadd_ps4(_sum2, _w0, _w1, _w2, _w3, _val20, _val21, _val22, _val23);
427 
428                 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
429                 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
430                 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
431                 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
432 
433                 _mm256_fmadd_ps4(_sum2, _w4, _w5, _w6, _w7, _val24, _val25, _val26, _val27);
434 
435                 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
436                 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
437                 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
438                 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
439 
440                 _mm256_fmadd_ps4(_sum3, _w0, _w1, _w2, _w3, _val30, _val31, _val32, _val33);
441 
442                 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
443                 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
444                 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
445                 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
446 
447                 _mm256_fmadd_ps4(_sum3, _w4, _w5, _w6, _w7, _val34, _val35, _val36, _val37);
448 
449                 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
450                 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
451                 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
452                 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
453 
454                 _mm256_fmadd_ps4(_sum4, _w0, _w1, _w2, _w3, _val40, _val41, _val42, _val43);
455 
456                 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
457                 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
458                 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
459                 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
460 
461                 _mm256_fmadd_ps4(_sum4, _w4, _w5, _w6, _w7, _val44, _val45, _val46, _val47);
462 
463                 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
464                 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
465                 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
466                 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
467 
468                 _mm256_fmadd_ps4(_sum5, _w0, _w1, _w2, _w3, _val50, _val51, _val52, _val53);
469 
470                 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
471                 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
472                 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
473                 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
474 
475                 _mm256_fmadd_ps4(_sum5, _w4, _w5, _w6, _w7, _val54, _val55, _val56, _val57);
476 
477                 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
478                 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
479                 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
480                 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
481 
482                 _mm256_fmadd_ps4(_sum6, _w0, _w1, _w2, _w3, _val60, _val61, _val62, _val63);
483 
484                 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
485                 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
486                 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
487                 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
488 
489                 _mm256_fmadd_ps4(_sum6, _w4, _w5, _w6, _w7, _val64, _val65, _val66, _val67);
490 
491                 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
492                 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
493                 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
494                 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
495 
496                 _mm256_fmadd_ps4(_sum7, _w0, _w1, _w2, _w3, _val70, _val71, _val72, _val73);
497 
498                 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
499                 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
500                 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
501                 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
502 
503                 _mm256_fmadd_ps4(_sum7, _w4, _w5, _w6, _w7, _val74, _val75, _val76, _val77);
504 
505                 __m256 _val80 = _mm256_broadcast_ss(tmpptr + 64);
506                 __m256 _val81 = _mm256_broadcast_ss(tmpptr + 65);
507                 __m256 _val82 = _mm256_broadcast_ss(tmpptr + 66);
508                 __m256 _val83 = _mm256_broadcast_ss(tmpptr + 67);
509 
510                 _mm256_fmadd_ps4(_sum8, _w0, _w1, _w2, _w3, _val80, _val81, _val82, _val83);
511 
512                 __m256 _val84 = _mm256_broadcast_ss(tmpptr + 68);
513                 __m256 _val85 = _mm256_broadcast_ss(tmpptr + 69);
514                 __m256 _val86 = _mm256_broadcast_ss(tmpptr + 70);
515                 __m256 _val87 = _mm256_broadcast_ss(tmpptr + 71);
516 
517                 _mm256_fmadd_ps4(_sum8, _w4, _w5, _w6, _w7, _val84, _val85, _val86, _val87);
518 
519                 __m256 _val90 = _mm256_broadcast_ss(tmpptr + 72);
520                 __m256 _val91 = _mm256_broadcast_ss(tmpptr + 73);
521                 __m256 _val92 = _mm256_broadcast_ss(tmpptr + 74);
522                 __m256 _val93 = _mm256_broadcast_ss(tmpptr + 75);
523 
524                 _mm256_fmadd_ps4(_sum9, _w0, _w1, _w2, _w3, _val90, _val91, _val92, _val93);
525 
526                 __m256 _val94 = _mm256_broadcast_ss(tmpptr + 76);
527                 __m256 _val95 = _mm256_broadcast_ss(tmpptr + 77);
528                 __m256 _val96 = _mm256_broadcast_ss(tmpptr + 78);
529                 __m256 _val97 = _mm256_broadcast_ss(tmpptr + 79);
530 
531                 _mm256_fmadd_ps4(_sum9, _w4, _w5, _w6, _w7, _val94, _val95, _val96, _val97);
532 
533                 __m256 _val100 = _mm256_broadcast_ss(tmpptr + 80);
534                 __m256 _val101 = _mm256_broadcast_ss(tmpptr + 81);
535                 __m256 _val102 = _mm256_broadcast_ss(tmpptr + 82);
536                 __m256 _val103 = _mm256_broadcast_ss(tmpptr + 83);
537 
538                 _mm256_fmadd_ps4(_sum10, _w0, _w1, _w2, _w3, _val100, _val101, _val102, _val103);
539 
540                 __m256 _val104 = _mm256_broadcast_ss(tmpptr + 84);
541                 __m256 _val105 = _mm256_broadcast_ss(tmpptr + 85);
542                 __m256 _val106 = _mm256_broadcast_ss(tmpptr + 86);
543                 __m256 _val107 = _mm256_broadcast_ss(tmpptr + 87);
544 
545                 _mm256_fmadd_ps4(_sum10, _w4, _w5, _w6, _w7, _val104, _val105, _val106, _val107);
546 
547                 __m256 _val110 = _mm256_broadcast_ss(tmpptr + 88);
548                 __m256 _val111 = _mm256_broadcast_ss(tmpptr + 89);
549                 __m256 _val112 = _mm256_broadcast_ss(tmpptr + 90);
550                 __m256 _val113 = _mm256_broadcast_ss(tmpptr + 91);
551 
552                 _mm256_fmadd_ps4(_sum11, _w0, _w1, _w2, _w3, _val110, _val111, _val112, _val113);
553 
554                 __m256 _val114 = _mm256_broadcast_ss(tmpptr + 92);
555                 __m256 _val115 = _mm256_broadcast_ss(tmpptr + 93);
556                 __m256 _val116 = _mm256_broadcast_ss(tmpptr + 94);
557                 __m256 _val117 = _mm256_broadcast_ss(tmpptr + 95);
558 
559                 _mm256_fmadd_ps4(_sum11, _w4, _w5, _w6, _w7, _val114, _val115, _val116, _val117);
560 
561                 tmpptr += 96;
562 
563                 kptr += 64;
564             }
565             _mm256_storeu_ps(outptr, _sum0);
566             _mm256_storeu_ps(outptr + 8, _sum1);
567             _mm256_storeu_ps(outptr + 16, _sum2);
568             _mm256_storeu_ps(outptr + 24, _sum3);
569             _mm256_storeu_ps(outptr + 32, _sum4);
570             _mm256_storeu_ps(outptr + 40, _sum5);
571             _mm256_storeu_ps(outptr + 48, _sum6);
572             _mm256_storeu_ps(outptr + 56, _sum7);
573             _mm256_storeu_ps(outptr + 64, _sum8);
574             _mm256_storeu_ps(outptr + 72, _sum9);
575             _mm256_storeu_ps(outptr + 80, _sum10);
576             _mm256_storeu_ps(outptr + 88, _sum11);
577 
578             outptr += 96;
579         }
580         for (; i + 7 < size; i += 8)
581         {
582             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
583 
584             __m256 _sum0 = _bias0;
585             __m256 _sum1 = _bias0;
586             __m256 _sum2 = _bias0;
587             __m256 _sum3 = _bias0;
588             __m256 _sum4 = _bias0;
589             __m256 _sum5 = _bias0;
590             __m256 _sum6 = _bias0;
591             __m256 _sum7 = _bias0;
592 
593             const float* kptr = (const float*)kernel + p * inch * 64;
594             for (int q = 0; q < inch; q++)
595             {
596                 __m256 _w0 = _mm256_loadu_ps(kptr);
597                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
598                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
599                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
600                 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
601                 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
602                 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
603                 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
604 
605                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
606                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
607                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
608                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
609                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
610                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
611                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
612                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
613                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
614                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
615                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
616                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
617                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
618                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
619                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
620                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
621 
622                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
623                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
624                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
625                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
626                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
627                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
628                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
629                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
630                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
631                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
632                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
633                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
634                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
635                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
636                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
637                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
638 
639                 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
640                 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
641                 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
642                 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
643                 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
644                 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
645                 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
646                 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
647                 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
648                 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
649                 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
650                 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
651                 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
652                 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
653                 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
654                 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
655 
656                 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
657                 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
658                 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
659                 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
660                 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
661                 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
662                 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
663                 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
664                 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
665                 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
666                 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
667                 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
668                 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
669                 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
670                 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
671                 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
672 
673                 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
674                 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
675                 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
676                 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
677                 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
678                 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
679                 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
680                 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
681                 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
682                 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
683                 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
684                 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
685                 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
686                 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
687                 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
688                 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
689 
690                 _sum4 = _mm256_fmadd_ps(_w0, _val40, _sum4);
691                 _sum4 = _mm256_fmadd_ps(_w1, _val41, _sum4);
692                 _sum4 = _mm256_fmadd_ps(_w2, _val42, _sum4);
693                 _sum4 = _mm256_fmadd_ps(_w3, _val43, _sum4);
694                 _sum4 = _mm256_fmadd_ps(_w4, _val44, _sum4);
695                 _sum4 = _mm256_fmadd_ps(_w5, _val45, _sum4);
696                 _sum4 = _mm256_fmadd_ps(_w6, _val46, _sum4);
697                 _sum4 = _mm256_fmadd_ps(_w7, _val47, _sum4);
698                 _sum5 = _mm256_fmadd_ps(_w0, _val50, _sum5);
699                 _sum5 = _mm256_fmadd_ps(_w1, _val51, _sum5);
700                 _sum5 = _mm256_fmadd_ps(_w2, _val52, _sum5);
701                 _sum5 = _mm256_fmadd_ps(_w3, _val53, _sum5);
702                 _sum5 = _mm256_fmadd_ps(_w4, _val54, _sum5);
703                 _sum5 = _mm256_fmadd_ps(_w5, _val55, _sum5);
704                 _sum5 = _mm256_fmadd_ps(_w6, _val56, _sum5);
705                 _sum5 = _mm256_fmadd_ps(_w7, _val57, _sum5);
706 
707                 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
708                 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
709                 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
710                 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
711                 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
712                 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
713                 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
714                 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
715                 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
716                 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
717                 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
718                 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
719                 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
720                 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
721                 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
722                 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
723 
724                 _sum6 = _mm256_fmadd_ps(_w0, _val60, _sum6);
725                 _sum6 = _mm256_fmadd_ps(_w1, _val61, _sum6);
726                 _sum6 = _mm256_fmadd_ps(_w2, _val62, _sum6);
727                 _sum6 = _mm256_fmadd_ps(_w3, _val63, _sum6);
728                 _sum6 = _mm256_fmadd_ps(_w4, _val64, _sum6);
729                 _sum6 = _mm256_fmadd_ps(_w5, _val65, _sum6);
730                 _sum6 = _mm256_fmadd_ps(_w6, _val66, _sum6);
731                 _sum6 = _mm256_fmadd_ps(_w7, _val67, _sum6);
732                 _sum7 = _mm256_fmadd_ps(_w0, _val70, _sum7);
733                 _sum7 = _mm256_fmadd_ps(_w1, _val71, _sum7);
734                 _sum7 = _mm256_fmadd_ps(_w2, _val72, _sum7);
735                 _sum7 = _mm256_fmadd_ps(_w3, _val73, _sum7);
736                 _sum7 = _mm256_fmadd_ps(_w4, _val74, _sum7);
737                 _sum7 = _mm256_fmadd_ps(_w5, _val75, _sum7);
738                 _sum7 = _mm256_fmadd_ps(_w6, _val76, _sum7);
739                 _sum7 = _mm256_fmadd_ps(_w7, _val77, _sum7);
740 
741                 tmpptr += 64;
742 
743                 kptr += 64;
744             }
745             _mm256_storeu_ps(outptr, _sum0);
746             _mm256_storeu_ps(outptr + 8, _sum1);
747             _mm256_storeu_ps(outptr + 16, _sum2);
748             _mm256_storeu_ps(outptr + 24, _sum3);
749             _mm256_storeu_ps(outptr + 32, _sum4);
750             _mm256_storeu_ps(outptr + 40, _sum5);
751             _mm256_storeu_ps(outptr + 48, _sum6);
752             _mm256_storeu_ps(outptr + 56, _sum7);
753 
754             outptr += 64;
755         }
756         for (; i + 3 < size; i += 4)
757         {
758             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
759 
760             __m256 _sum0 = _bias0;
761             __m256 _sum1 = _bias0;
762             __m256 _sum2 = _bias0;
763             __m256 _sum3 = _bias0;
764 
765             const float* kptr = (const float*)kernel + p * inch * 64;
766             for (int q = 0; q < inch; q++)
767             {
768                 __m256 _w0 = _mm256_loadu_ps(kptr);
769                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
770                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
771                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
772                 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
773                 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
774                 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
775                 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
776 
777                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
778                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
779                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
780                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
781                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
782                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
783                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
784                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
785                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
786                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
787                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
788                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
789                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
790                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
791                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
792                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
793 
794                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
795                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
796                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
797                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
798                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
799                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
800                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
801                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
802                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
803                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
804                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
805                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
806                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
807                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
808                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
809                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
810 
811                 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
812                 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
813                 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
814                 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
815                 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
816                 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
817                 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
818                 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
819                 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
820                 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
821                 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
822                 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
823                 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
824                 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
825                 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
826                 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
827 
828                 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
829                 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
830                 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
831                 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
832                 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
833                 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
834                 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
835                 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
836                 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
837                 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
838                 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
839                 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
840                 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
841                 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
842                 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
843                 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
844 
845                 tmpptr += 32;
846 
847                 kptr += 64;
848             }
849             _mm256_storeu_ps(outptr, _sum0);
850             _mm256_storeu_ps(outptr + 8, _sum1);
851             _mm256_storeu_ps(outptr + 16, _sum2);
852             _mm256_storeu_ps(outptr + 24, _sum3);
853 
854             outptr += 32;
855         }
856         for (; i + 1 < size; i += 2)
857         {
858             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
859 
860             __m256 _sum0 = _bias0;
861             __m256 _sum1 = _bias0;
862 
863             const float* kptr = (const float*)kernel + p * inch * 64;
864             for (int q = 0; q < inch; q++)
865             {
866                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
867                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
868                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
869                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
870                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
871                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
872                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
873                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
874                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
875                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
876                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
877                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
878                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
879                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
880                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
881                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
882 
883                 __m256 _w0 = _mm256_loadu_ps(kptr);
884                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
885                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
886                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
887                 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
888                 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
889                 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
890                 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
891 
892                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
893                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
894                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
895                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
896                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
897                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
898                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
899                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
900                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
901                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
902                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
903                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
904                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
905                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
906                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
907                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
908 
909                 tmpptr += 16;
910 
911                 kptr += 64;
912             }
913             _mm256_storeu_ps(outptr, _sum0);
914             _mm256_storeu_ps(outptr + 8, _sum1);
915 
916             outptr += 16;
917         }
918 
919         for (; i < size; i++)
920         {
921             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
922             __m256 _sum = _bias0;
923 
924             const float* kptr = (const float*)kernel + p * inch * 64;
925             for (int q = 0; q < inch; q++)
926             {
927                 __m256 _val0 = _mm256_broadcast_ss(tmpptr);
928                 __m256 _val1 = _mm256_broadcast_ss(tmpptr + 1);
929                 __m256 _val2 = _mm256_broadcast_ss(tmpptr + 2);
930                 __m256 _val3 = _mm256_broadcast_ss(tmpptr + 3);
931                 __m256 _val4 = _mm256_broadcast_ss(tmpptr + 4);
932                 __m256 _val5 = _mm256_broadcast_ss(tmpptr + 5);
933                 __m256 _val6 = _mm256_broadcast_ss(tmpptr + 6);
934                 __m256 _val7 = _mm256_broadcast_ss(tmpptr + 7);
935 
936                 __m256 _w0 = _mm256_loadu_ps(kptr);
937                 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
938                 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
939                 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
940                 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
941                 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
942                 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
943                 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
944 
945                 _sum = _mm256_fmadd_ps(_w0, _val0, _sum);
946                 _sum = _mm256_fmadd_ps(_w1, _val1, _sum);
947                 _sum = _mm256_fmadd_ps(_w2, _val2, _sum);
948                 _sum = _mm256_fmadd_ps(_w3, _val3, _sum);
949                 _sum = _mm256_fmadd_ps(_w4, _val4, _sum);
950                 _sum = _mm256_fmadd_ps(_w5, _val5, _sum);
951                 _sum = _mm256_fmadd_ps(_w6, _val6, _sum);
952                 _sum = _mm256_fmadd_ps(_w7, _val7, _sum);
953 
954                 tmpptr += 8;
955 
956                 kptr += 64;
957             }
958             _mm256_storeu_ps(outptr, _sum);
959 
960             outptr += 8;
961         }
962     }
963 }
964 
conv1x1s2_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)965 static void conv1x1s2_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
966 {
967     int w = bottom_blob.w;
968     int channels = bottom_blob.c;
969     size_t elemsize = bottom_blob.elemsize;
970     int elempack = bottom_blob.elempack;
971 
972     int outw = top_blob.w;
973     int outh = top_blob.h;
974 
975     const int tailstep = (w - 2 * outw + w) * 8;
976 
977     Mat bottom_blob_shrinked;
978     bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
979 
980     #pragma omp parallel for num_threads(opt.num_threads)
981     for (int p = 0; p < channels; p++)
982     {
983         const float* r0 = bottom_blob.channel(p);
984         float* outptr = bottom_blob_shrinked.channel(p);
985 
986         for (int i = 0; i < outh; i++)
987         {
988             for (int j = 0; j < outw; j++)
989             {
990                 __m256 _v = _mm256_loadu_ps(r0);
991                 _mm256_storeu_ps(outptr, _v);
992 
993                 r0 += 16;
994                 outptr += 8;
995             }
996 
997             r0 += tailstep;
998         }
999     }
1000 
1001     conv1x1s1_sgemm_pack8_avx(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
1002 }
1003