1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv1x1s1_sgemm_transform_kernel_pack8_avx(const Mat & kernel,Mat & weight_data_pack8,int num_input,int num_output)15 static void conv1x1s1_sgemm_transform_kernel_pack8_avx(const Mat& kernel, Mat& weight_data_pack8, int num_input, int num_output)
16 {
17 // src = kw-kh-inch-outch
18 // dst = 8b-8a-kw-kh-inch/8a-outch/8b
19 Mat weight_data_r2 = kernel.reshape(1, num_input, num_output);
20
21 weight_data_pack8.create(1, num_input / 8, num_output / 8, (size_t)4 * 64, 64);
22
23 for (int q = 0; q + 7 < num_output; q += 8)
24 {
25 const Mat k0 = weight_data_r2.channel(q);
26 const Mat k1 = weight_data_r2.channel(q + 1);
27 const Mat k2 = weight_data_r2.channel(q + 2);
28 const Mat k3 = weight_data_r2.channel(q + 3);
29 const Mat k4 = weight_data_r2.channel(q + 4);
30 const Mat k5 = weight_data_r2.channel(q + 5);
31 const Mat k6 = weight_data_r2.channel(q + 6);
32 const Mat k7 = weight_data_r2.channel(q + 7);
33
34 Mat g0 = weight_data_pack8.channel(q / 8);
35
36 for (int p = 0; p + 7 < num_input; p += 8)
37 {
38 const float* k00 = k0.row(p);
39 const float* k01 = k0.row(p + 1);
40 const float* k02 = k0.row(p + 2);
41 const float* k03 = k0.row(p + 3);
42 const float* k04 = k0.row(p + 4);
43 const float* k05 = k0.row(p + 5);
44 const float* k06 = k0.row(p + 6);
45 const float* k07 = k0.row(p + 7);
46
47 const float* k10 = k1.row(p);
48 const float* k11 = k1.row(p + 1);
49 const float* k12 = k1.row(p + 2);
50 const float* k13 = k1.row(p + 3);
51 const float* k14 = k1.row(p + 4);
52 const float* k15 = k1.row(p + 5);
53 const float* k16 = k1.row(p + 6);
54 const float* k17 = k1.row(p + 7);
55
56 const float* k20 = k2.row(p);
57 const float* k21 = k2.row(p + 1);
58 const float* k22 = k2.row(p + 2);
59 const float* k23 = k2.row(p + 3);
60 const float* k24 = k2.row(p + 4);
61 const float* k25 = k2.row(p + 5);
62 const float* k26 = k2.row(p + 6);
63 const float* k27 = k2.row(p + 7);
64
65 const float* k30 = k3.row(p);
66 const float* k31 = k3.row(p + 1);
67 const float* k32 = k3.row(p + 2);
68 const float* k33 = k3.row(p + 3);
69 const float* k34 = k3.row(p + 4);
70 const float* k35 = k3.row(p + 5);
71 const float* k36 = k3.row(p + 6);
72 const float* k37 = k3.row(p + 7);
73
74 const float* k40 = k4.row(p);
75 const float* k41 = k4.row(p + 1);
76 const float* k42 = k4.row(p + 2);
77 const float* k43 = k4.row(p + 3);
78 const float* k44 = k4.row(p + 4);
79 const float* k45 = k4.row(p + 5);
80 const float* k46 = k4.row(p + 6);
81 const float* k47 = k4.row(p + 7);
82
83 const float* k50 = k5.row(p);
84 const float* k51 = k5.row(p + 1);
85 const float* k52 = k5.row(p + 2);
86 const float* k53 = k5.row(p + 3);
87 const float* k54 = k5.row(p + 4);
88 const float* k55 = k5.row(p + 5);
89 const float* k56 = k5.row(p + 6);
90 const float* k57 = k5.row(p + 7);
91
92 const float* k60 = k6.row(p);
93 const float* k61 = k6.row(p + 1);
94 const float* k62 = k6.row(p + 2);
95 const float* k63 = k6.row(p + 3);
96 const float* k64 = k6.row(p + 4);
97 const float* k65 = k6.row(p + 5);
98 const float* k66 = k6.row(p + 6);
99 const float* k67 = k6.row(p + 7);
100
101 const float* k70 = k7.row(p);
102 const float* k71 = k7.row(p + 1);
103 const float* k72 = k7.row(p + 2);
104 const float* k73 = k7.row(p + 3);
105 const float* k74 = k7.row(p + 4);
106 const float* k75 = k7.row(p + 5);
107 const float* k76 = k7.row(p + 6);
108 const float* k77 = k7.row(p + 7);
109
110 float* g00 = g0.row(p / 8);
111 g00[0] = k00[0];
112 g00[1] = k10[0];
113 g00[2] = k20[0];
114 g00[3] = k30[0];
115 g00[4] = k40[0];
116 g00[5] = k50[0];
117 g00[6] = k60[0];
118 g00[7] = k70[0];
119 g00 += 8;
120 g00[0] = k01[0];
121 g00[1] = k11[0];
122 g00[2] = k21[0];
123 g00[3] = k31[0];
124 g00[4] = k41[0];
125 g00[5] = k51[0];
126 g00[6] = k61[0];
127 g00[7] = k71[0];
128
129 g00 += 8;
130 g00[0] = k02[0];
131 g00[1] = k12[0];
132 g00[2] = k22[0];
133 g00[3] = k32[0];
134 g00[4] = k42[0];
135 g00[5] = k52[0];
136 g00[6] = k62[0];
137 g00[7] = k72[0];
138
139 g00 += 8;
140 g00[0] = k03[0];
141 g00[1] = k13[0];
142 g00[2] = k23[0];
143 g00[3] = k33[0];
144 g00[4] = k43[0];
145 g00[5] = k53[0];
146 g00[6] = k63[0];
147 g00[7] = k73[0];
148
149 g00 += 8;
150 g00[0] = k04[0];
151 g00[1] = k14[0];
152 g00[2] = k24[0];
153 g00[3] = k34[0];
154 g00[4] = k44[0];
155 g00[5] = k54[0];
156 g00[6] = k64[0];
157 g00[7] = k74[0];
158
159 g00 += 8;
160 g00[0] = k05[0];
161 g00[1] = k15[0];
162 g00[2] = k25[0];
163 g00[3] = k35[0];
164 g00[4] = k45[0];
165 g00[5] = k55[0];
166 g00[6] = k65[0];
167 g00[7] = k75[0];
168
169 g00 += 8;
170 g00[0] = k06[0];
171 g00[1] = k16[0];
172 g00[2] = k26[0];
173 g00[3] = k36[0];
174 g00[4] = k46[0];
175 g00[5] = k56[0];
176 g00[6] = k66[0];
177 g00[7] = k76[0];
178
179 g00 += 8;
180 g00[0] = k07[0];
181 g00[1] = k17[0];
182 g00[2] = k27[0];
183 g00[3] = k37[0];
184 g00[4] = k47[0];
185 g00[5] = k57[0];
186 g00[6] = k67[0];
187 g00[7] = k77[0];
188
189 g00 += 8;
190 }
191 }
192 }
193
conv1x1s1_sgemm_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)194 static void conv1x1s1_sgemm_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
195 {
196 int w = bottom_blob.w;
197 int h = bottom_blob.h;
198 int inch = bottom_blob.c;
199 int outch = top_blob.c;
200 size_t elemsize = bottom_blob.elemsize;
201 int elempack = bottom_blob.elempack;
202
203 const int size = w * h;
204
205 const float* bias = _bias;
206 // interleave
207 Mat tmp(12, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + (size % 12 % 4) / 2 + size % 12 % 2, elemsize, elempack, opt.workspace_allocator);
208 {
209 int nn_size = size / 12;
210 int remain_size_start = nn_size * 12;
211 #pragma omp parallel for num_threads(opt.num_threads)
212 for (int ii = 0; ii < nn_size; ii++)
213 {
214 int i = ii * 12;
215 const float* img0 = bottom_blob.channel(0);
216 img0 += i * 8;
217
218 float* tmpptr = tmp.channel(i / 12);
219
220 for (int q = 0; q < inch; q++)
221 {
222 __m256 _r0 = _mm256_loadu_ps(img0);
223 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
224 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
225 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
226 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
227 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
228 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
229 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
230 __m256 _r8 = _mm256_loadu_ps(img0 + 64);
231 __m256 _r9 = _mm256_loadu_ps(img0 + 72);
232 __m256 _r10 = _mm256_loadu_ps(img0 + 80);
233 __m256 _r11 = _mm256_loadu_ps(img0 + 88);
234 _mm256_storeu_ps(tmpptr, _r0);
235 _mm256_storeu_ps(tmpptr + 8, _r1);
236 _mm256_storeu_ps(tmpptr + 16, _r2);
237 _mm256_storeu_ps(tmpptr + 24, _r3);
238 _mm256_storeu_ps(tmpptr + 32, _r4);
239 _mm256_storeu_ps(tmpptr + 40, _r5);
240 _mm256_storeu_ps(tmpptr + 48, _r6);
241 _mm256_storeu_ps(tmpptr + 56, _r7);
242 _mm256_storeu_ps(tmpptr + 64, _r8);
243 _mm256_storeu_ps(tmpptr + 72, _r9);
244 _mm256_storeu_ps(tmpptr + 80, _r10);
245 _mm256_storeu_ps(tmpptr + 88, _r11);
246
247 tmpptr += 96;
248 img0 += bottom_blob.cstep * 8;
249 }
250 }
251 nn_size = (size - remain_size_start) >> 3;
252 #pragma omp parallel for num_threads(opt.num_threads)
253 for (int ii = 0; ii < nn_size; ii++)
254 {
255 int i = remain_size_start + ii * 8;
256
257 const float* img0 = bottom_blob.channel(0);
258 img0 += i * 8;
259
260 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
261
262 for (int q = 0; q < inch; q++)
263 {
264 __m256 _r0 = _mm256_loadu_ps(img0);
265 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
266 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
267 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
268 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
269 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
270 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
271 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
272 _mm256_storeu_ps(tmpptr, _r0);
273 _mm256_storeu_ps(tmpptr + 8, _r1);
274 _mm256_storeu_ps(tmpptr + 16, _r2);
275 _mm256_storeu_ps(tmpptr + 24, _r3);
276 _mm256_storeu_ps(tmpptr + 32, _r4);
277 _mm256_storeu_ps(tmpptr + 40, _r5);
278 _mm256_storeu_ps(tmpptr + 48, _r6);
279 _mm256_storeu_ps(tmpptr + 56, _r7);
280
281 tmpptr += 64;
282 img0 += bottom_blob.cstep * 8;
283 }
284 }
285
286 remain_size_start += nn_size << 3;
287 nn_size = (size - remain_size_start) >> 2;
288
289 #pragma omp parallel for num_threads(opt.num_threads)
290 for (int ii = 0; ii < nn_size; ii++)
291 {
292 int i = remain_size_start + ii * 4;
293
294 const float* img0 = bottom_blob.channel(0);
295 img0 += i * 8;
296 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
297
298 for (int q = 0; q < inch; q++)
299 {
300 __m256 _r0 = _mm256_loadu_ps(img0);
301 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
302 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
303 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
304 _mm256_storeu_ps(tmpptr, _r0);
305 _mm256_storeu_ps(tmpptr + 8, _r1);
306 _mm256_storeu_ps(tmpptr + 16, _r2);
307 _mm256_storeu_ps(tmpptr + 24, _r3);
308
309 tmpptr += 32;
310 img0 += bottom_blob.cstep * 8;
311 }
312 }
313
314 remain_size_start += nn_size << 2;
315 nn_size = (size - remain_size_start) >> 1;
316 #pragma omp parallel for num_threads(opt.num_threads)
317 for (int ii = 0; ii < nn_size; ii++)
318 {
319 int i = remain_size_start + ii * 2;
320
321 const float* img0 = bottom_blob.channel(0);
322 img0 += i * 8;
323 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
324
325 for (int q = 0; q < inch; q++)
326 {
327 __m256 _r0 = _mm256_loadu_ps(img0);
328 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
329 _mm256_storeu_ps(tmpptr, _r0);
330 _mm256_storeu_ps(tmpptr + 8, _r1);
331
332 tmpptr += 16;
333 img0 += bottom_blob.cstep * 8;
334 }
335 }
336
337 remain_size_start += nn_size << 1;
338 #pragma omp parallel for num_threads(opt.num_threads)
339 for (int i = remain_size_start; i < size; i++)
340 {
341 const float* img0 = bottom_blob.channel(0);
342 img0 += i * 8;
343 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
344 for (int q = 0; q < inch; q++)
345 {
346 __m256 _r0 = _mm256_loadu_ps(img0);
347 _mm256_storeu_ps(tmpptr, _r0);
348
349 tmpptr += 8;
350 img0 += bottom_blob.cstep * 8;
351 }
352 }
353 }
354 #pragma omp parallel for num_threads(opt.num_threads)
355 for (int p = 0; p < outch; p++)
356 {
357 Mat out = top_blob.channel(p);
358
359 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
360
361 float* outptr = out;
362 int i = 0;
363 for (; i + 11 < size; i += 12)
364 {
365 const float* tmpptr = tmp.channel(i / 12);
366
367 __m256 _sum0 = _bias0;
368 __m256 _sum1 = _bias0;
369 __m256 _sum2 = _bias0;
370 __m256 _sum3 = _bias0;
371 __m256 _sum4 = _bias0;
372 __m256 _sum5 = _bias0;
373 __m256 _sum6 = _bias0;
374 __m256 _sum7 = _bias0;
375 __m256 _sum8 = _bias0;
376 __m256 _sum9 = _bias0;
377 __m256 _sum10 = _bias0;
378 __m256 _sum11 = _bias0;
379
380 const float* kptr = (const float*)kernel + p * inch * 64;
381 for (int q = 0; q < inch; q++)
382 {
383 __m256 _w0 = _mm256_loadu_ps(kptr);
384 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
385 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
386 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
387
388 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
389 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
390 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
391 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
392
393 _mm256_comp_fmadd_ps4(_sum0, _w0, _w1, _w2, _w3, _val00, _val01, _val02, _val03);
394
395 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
396 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
397 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
398 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
399
400 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
401 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
402 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
403 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
404
405 _mm256_comp_fmadd_ps4(_sum0, _w4, _w5, _w6, _w7, _val04, _val05, _val06, _val07);
406
407 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
408 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
409 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
410 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
411
412 _mm256_comp_fmadd_ps4(_sum1, _w0, _w1, _w2, _w3, _val10, _val11, _val12, _val13);
413
414 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
415 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
416 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
417 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
418
419 _mm256_comp_fmadd_ps4(_sum1, _w4, _w5, _w6, _w7, _val14, _val15, _val16, _val17);
420
421 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
422 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
423 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
424 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
425
426 _mm256_comp_fmadd_ps4(_sum2, _w0, _w1, _w2, _w3, _val20, _val21, _val22, _val23);
427
428 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
429 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
430 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
431 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
432
433 _mm256_comp_fmadd_ps4(_sum2, _w4, _w5, _w6, _w7, _val24, _val25, _val26, _val27);
434
435 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
436 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
437 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
438 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
439
440 _mm256_comp_fmadd_ps4(_sum3, _w0, _w1, _w2, _w3, _val30, _val31, _val32, _val33);
441
442 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
443 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
444 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
445 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
446
447 _mm256_comp_fmadd_ps4(_sum3, _w4, _w5, _w6, _w7, _val34, _val35, _val36, _val37);
448
449 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
450 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
451 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
452 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
453
454 _mm256_comp_fmadd_ps4(_sum4, _w0, _w1, _w2, _w3, _val40, _val41, _val42, _val43);
455
456 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
457 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
458 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
459 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
460
461 _mm256_comp_fmadd_ps4(_sum4, _w4, _w5, _w6, _w7, _val44, _val45, _val46, _val47);
462
463 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
464 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
465 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
466 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
467
468 _mm256_comp_fmadd_ps4(_sum5, _w0, _w1, _w2, _w3, _val50, _val51, _val52, _val53);
469
470 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
471 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
472 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
473 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
474
475 _mm256_comp_fmadd_ps4(_sum5, _w4, _w5, _w6, _w7, _val54, _val55, _val56, _val57);
476
477 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
478 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
479 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
480 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
481
482 _mm256_comp_fmadd_ps4(_sum6, _w0, _w1, _w2, _w3, _val60, _val61, _val62, _val63);
483
484 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
485 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
486 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
487 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
488
489 _mm256_comp_fmadd_ps4(_sum6, _w4, _w5, _w6, _w7, _val64, _val65, _val66, _val67);
490
491 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
492 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
493 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
494 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
495
496 _mm256_comp_fmadd_ps4(_sum7, _w0, _w1, _w2, _w3, _val70, _val71, _val72, _val73);
497
498 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
499 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
500 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
501 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
502
503 _mm256_comp_fmadd_ps4(_sum7, _w4, _w5, _w6, _w7, _val74, _val75, _val76, _val77);
504
505 __m256 _val80 = _mm256_broadcast_ss(tmpptr + 64);
506 __m256 _val81 = _mm256_broadcast_ss(tmpptr + 65);
507 __m256 _val82 = _mm256_broadcast_ss(tmpptr + 66);
508 __m256 _val83 = _mm256_broadcast_ss(tmpptr + 67);
509
510 _mm256_comp_fmadd_ps4(_sum8, _w0, _w1, _w2, _w3, _val80, _val81, _val82, _val83);
511
512 __m256 _val84 = _mm256_broadcast_ss(tmpptr + 68);
513 __m256 _val85 = _mm256_broadcast_ss(tmpptr + 69);
514 __m256 _val86 = _mm256_broadcast_ss(tmpptr + 70);
515 __m256 _val87 = _mm256_broadcast_ss(tmpptr + 71);
516
517 _mm256_comp_fmadd_ps4(_sum8, _w4, _w5, _w6, _w7, _val84, _val85, _val86, _val87);
518
519 __m256 _val90 = _mm256_broadcast_ss(tmpptr + 72);
520 __m256 _val91 = _mm256_broadcast_ss(tmpptr + 73);
521 __m256 _val92 = _mm256_broadcast_ss(tmpptr + 74);
522 __m256 _val93 = _mm256_broadcast_ss(tmpptr + 75);
523
524 _mm256_comp_fmadd_ps4(_sum9, _w0, _w1, _w2, _w3, _val90, _val91, _val92, _val93);
525
526 __m256 _val94 = _mm256_broadcast_ss(tmpptr + 76);
527 __m256 _val95 = _mm256_broadcast_ss(tmpptr + 77);
528 __m256 _val96 = _mm256_broadcast_ss(tmpptr + 78);
529 __m256 _val97 = _mm256_broadcast_ss(tmpptr + 79);
530
531 _mm256_comp_fmadd_ps4(_sum9, _w4, _w5, _w6, _w7, _val94, _val95, _val96, _val97);
532
533 __m256 _val100 = _mm256_broadcast_ss(tmpptr + 80);
534 __m256 _val101 = _mm256_broadcast_ss(tmpptr + 81);
535 __m256 _val102 = _mm256_broadcast_ss(tmpptr + 82);
536 __m256 _val103 = _mm256_broadcast_ss(tmpptr + 83);
537
538 _mm256_comp_fmadd_ps4(_sum10, _w0, _w1, _w2, _w3, _val100, _val101, _val102, _val103);
539
540 __m256 _val104 = _mm256_broadcast_ss(tmpptr + 84);
541 __m256 _val105 = _mm256_broadcast_ss(tmpptr + 85);
542 __m256 _val106 = _mm256_broadcast_ss(tmpptr + 86);
543 __m256 _val107 = _mm256_broadcast_ss(tmpptr + 87);
544
545 _mm256_comp_fmadd_ps4(_sum10, _w4, _w5, _w6, _w7, _val104, _val105, _val106, _val107);
546
547 __m256 _val110 = _mm256_broadcast_ss(tmpptr + 88);
548 __m256 _val111 = _mm256_broadcast_ss(tmpptr + 89);
549 __m256 _val112 = _mm256_broadcast_ss(tmpptr + 90);
550 __m256 _val113 = _mm256_broadcast_ss(tmpptr + 91);
551
552 _mm256_comp_fmadd_ps4(_sum11, _w0, _w1, _w2, _w3, _val110, _val111, _val112, _val113);
553
554 __m256 _val114 = _mm256_broadcast_ss(tmpptr + 92);
555 __m256 _val115 = _mm256_broadcast_ss(tmpptr + 93);
556 __m256 _val116 = _mm256_broadcast_ss(tmpptr + 94);
557 __m256 _val117 = _mm256_broadcast_ss(tmpptr + 95);
558
559 _mm256_comp_fmadd_ps4(_sum11, _w4, _w5, _w6, _w7, _val114, _val115, _val116, _val117);
560
561 tmpptr += 96;
562
563 kptr += 64;
564 }
565 _mm256_storeu_ps(outptr, _sum0);
566 _mm256_storeu_ps(outptr + 8, _sum1);
567 _mm256_storeu_ps(outptr + 16, _sum2);
568 _mm256_storeu_ps(outptr + 24, _sum3);
569 _mm256_storeu_ps(outptr + 32, _sum4);
570 _mm256_storeu_ps(outptr + 40, _sum5);
571 _mm256_storeu_ps(outptr + 48, _sum6);
572 _mm256_storeu_ps(outptr + 56, _sum7);
573 _mm256_storeu_ps(outptr + 64, _sum8);
574 _mm256_storeu_ps(outptr + 72, _sum9);
575 _mm256_storeu_ps(outptr + 80, _sum10);
576 _mm256_storeu_ps(outptr + 88, _sum11);
577
578 outptr += 96;
579 }
580 for (; i + 7 < size; i += 8)
581 {
582 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
583
584 __m256 _sum0 = _bias0;
585 __m256 _sum1 = _bias0;
586 __m256 _sum2 = _bias0;
587 __m256 _sum3 = _bias0;
588 __m256 _sum4 = _bias0;
589 __m256 _sum5 = _bias0;
590 __m256 _sum6 = _bias0;
591 __m256 _sum7 = _bias0;
592
593 const float* kptr = (const float*)kernel + p * inch * 64;
594 for (int q = 0; q < inch; q++)
595 {
596 __m256 _w0 = _mm256_loadu_ps(kptr);
597 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
598 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
599 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
600 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
601 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
602 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
603 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
604
605 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
606 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
607 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
608 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
609 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
610 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
611 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
612 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
613 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
614 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
615 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
616 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
617 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
618 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
619 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
620 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
621
622 _sum0 = _mm256_comp_fmadd_ps(_w0, _val00, _sum0);
623 _sum0 = _mm256_comp_fmadd_ps(_w1, _val01, _sum0);
624 _sum0 = _mm256_comp_fmadd_ps(_w2, _val02, _sum0);
625 _sum0 = _mm256_comp_fmadd_ps(_w3, _val03, _sum0);
626 _sum0 = _mm256_comp_fmadd_ps(_w4, _val04, _sum0);
627 _sum0 = _mm256_comp_fmadd_ps(_w5, _val05, _sum0);
628 _sum0 = _mm256_comp_fmadd_ps(_w6, _val06, _sum0);
629 _sum0 = _mm256_comp_fmadd_ps(_w7, _val07, _sum0);
630 _sum1 = _mm256_comp_fmadd_ps(_w0, _val10, _sum1);
631 _sum1 = _mm256_comp_fmadd_ps(_w1, _val11, _sum1);
632 _sum1 = _mm256_comp_fmadd_ps(_w2, _val12, _sum1);
633 _sum1 = _mm256_comp_fmadd_ps(_w3, _val13, _sum1);
634 _sum1 = _mm256_comp_fmadd_ps(_w4, _val14, _sum1);
635 _sum1 = _mm256_comp_fmadd_ps(_w5, _val15, _sum1);
636 _sum1 = _mm256_comp_fmadd_ps(_w6, _val16, _sum1);
637 _sum1 = _mm256_comp_fmadd_ps(_w7, _val17, _sum1);
638
639 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
640 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
641 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
642 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
643 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
644 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
645 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
646 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
647 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
648 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
649 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
650 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
651 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
652 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
653 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
654 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
655
656 _sum2 = _mm256_comp_fmadd_ps(_w0, _val20, _sum2);
657 _sum2 = _mm256_comp_fmadd_ps(_w1, _val21, _sum2);
658 _sum2 = _mm256_comp_fmadd_ps(_w2, _val22, _sum2);
659 _sum2 = _mm256_comp_fmadd_ps(_w3, _val23, _sum2);
660 _sum2 = _mm256_comp_fmadd_ps(_w4, _val24, _sum2);
661 _sum2 = _mm256_comp_fmadd_ps(_w5, _val25, _sum2);
662 _sum2 = _mm256_comp_fmadd_ps(_w6, _val26, _sum2);
663 _sum2 = _mm256_comp_fmadd_ps(_w7, _val27, _sum2);
664 _sum3 = _mm256_comp_fmadd_ps(_w0, _val30, _sum3);
665 _sum3 = _mm256_comp_fmadd_ps(_w1, _val31, _sum3);
666 _sum3 = _mm256_comp_fmadd_ps(_w2, _val32, _sum3);
667 _sum3 = _mm256_comp_fmadd_ps(_w3, _val33, _sum3);
668 _sum3 = _mm256_comp_fmadd_ps(_w4, _val34, _sum3);
669 _sum3 = _mm256_comp_fmadd_ps(_w5, _val35, _sum3);
670 _sum3 = _mm256_comp_fmadd_ps(_w6, _val36, _sum3);
671 _sum3 = _mm256_comp_fmadd_ps(_w7, _val37, _sum3);
672
673 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
674 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
675 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
676 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
677 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
678 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
679 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
680 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
681 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
682 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
683 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
684 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
685 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
686 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
687 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
688 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
689
690 _sum4 = _mm256_comp_fmadd_ps(_w0, _val40, _sum4);
691 _sum4 = _mm256_comp_fmadd_ps(_w1, _val41, _sum4);
692 _sum4 = _mm256_comp_fmadd_ps(_w2, _val42, _sum4);
693 _sum4 = _mm256_comp_fmadd_ps(_w3, _val43, _sum4);
694 _sum4 = _mm256_comp_fmadd_ps(_w4, _val44, _sum4);
695 _sum4 = _mm256_comp_fmadd_ps(_w5, _val45, _sum4);
696 _sum4 = _mm256_comp_fmadd_ps(_w6, _val46, _sum4);
697 _sum4 = _mm256_comp_fmadd_ps(_w7, _val47, _sum4);
698 _sum5 = _mm256_comp_fmadd_ps(_w0, _val50, _sum5);
699 _sum5 = _mm256_comp_fmadd_ps(_w1, _val51, _sum5);
700 _sum5 = _mm256_comp_fmadd_ps(_w2, _val52, _sum5);
701 _sum5 = _mm256_comp_fmadd_ps(_w3, _val53, _sum5);
702 _sum5 = _mm256_comp_fmadd_ps(_w4, _val54, _sum5);
703 _sum5 = _mm256_comp_fmadd_ps(_w5, _val55, _sum5);
704 _sum5 = _mm256_comp_fmadd_ps(_w6, _val56, _sum5);
705 _sum5 = _mm256_comp_fmadd_ps(_w7, _val57, _sum5);
706
707 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
708 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
709 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
710 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
711 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
712 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
713 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
714 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
715 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
716 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
717 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
718 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
719 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
720 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
721 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
722 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
723
724 _sum6 = _mm256_comp_fmadd_ps(_w0, _val60, _sum6);
725 _sum6 = _mm256_comp_fmadd_ps(_w1, _val61, _sum6);
726 _sum6 = _mm256_comp_fmadd_ps(_w2, _val62, _sum6);
727 _sum6 = _mm256_comp_fmadd_ps(_w3, _val63, _sum6);
728 _sum6 = _mm256_comp_fmadd_ps(_w4, _val64, _sum6);
729 _sum6 = _mm256_comp_fmadd_ps(_w5, _val65, _sum6);
730 _sum6 = _mm256_comp_fmadd_ps(_w6, _val66, _sum6);
731 _sum6 = _mm256_comp_fmadd_ps(_w7, _val67, _sum6);
732 _sum7 = _mm256_comp_fmadd_ps(_w0, _val70, _sum7);
733 _sum7 = _mm256_comp_fmadd_ps(_w1, _val71, _sum7);
734 _sum7 = _mm256_comp_fmadd_ps(_w2, _val72, _sum7);
735 _sum7 = _mm256_comp_fmadd_ps(_w3, _val73, _sum7);
736 _sum7 = _mm256_comp_fmadd_ps(_w4, _val74, _sum7);
737 _sum7 = _mm256_comp_fmadd_ps(_w5, _val75, _sum7);
738 _sum7 = _mm256_comp_fmadd_ps(_w6, _val76, _sum7);
739 _sum7 = _mm256_comp_fmadd_ps(_w7, _val77, _sum7);
740
741 tmpptr += 64;
742
743 kptr += 64;
744 }
745 _mm256_storeu_ps(outptr, _sum0);
746 _mm256_storeu_ps(outptr + 8, _sum1);
747 _mm256_storeu_ps(outptr + 16, _sum2);
748 _mm256_storeu_ps(outptr + 24, _sum3);
749 _mm256_storeu_ps(outptr + 32, _sum4);
750 _mm256_storeu_ps(outptr + 40, _sum5);
751 _mm256_storeu_ps(outptr + 48, _sum6);
752 _mm256_storeu_ps(outptr + 56, _sum7);
753
754 outptr += 64;
755 }
756 for (; i + 3 < size; i += 4)
757 {
758 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
759
760 __m256 _sum0 = _bias0;
761 __m256 _sum1 = _bias0;
762 __m256 _sum2 = _bias0;
763 __m256 _sum3 = _bias0;
764
765 const float* kptr = (const float*)kernel + p * inch * 64;
766 for (int q = 0; q < inch; q++)
767 {
768 __m256 _w0 = _mm256_loadu_ps(kptr);
769 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
770 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
771 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
772 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
773 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
774 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
775 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
776
777 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
778 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
779 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
780 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
781 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
782 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
783 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
784 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
785 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
786 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
787 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
788 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
789 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
790 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
791 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
792 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
793
794 _sum0 = _mm256_comp_fmadd_ps(_w0, _val00, _sum0);
795 _sum0 = _mm256_comp_fmadd_ps(_w1, _val01, _sum0);
796 _sum0 = _mm256_comp_fmadd_ps(_w2, _val02, _sum0);
797 _sum0 = _mm256_comp_fmadd_ps(_w3, _val03, _sum0);
798 _sum0 = _mm256_comp_fmadd_ps(_w4, _val04, _sum0);
799 _sum0 = _mm256_comp_fmadd_ps(_w5, _val05, _sum0);
800 _sum0 = _mm256_comp_fmadd_ps(_w6, _val06, _sum0);
801 _sum0 = _mm256_comp_fmadd_ps(_w7, _val07, _sum0);
802 _sum1 = _mm256_comp_fmadd_ps(_w0, _val10, _sum1);
803 _sum1 = _mm256_comp_fmadd_ps(_w1, _val11, _sum1);
804 _sum1 = _mm256_comp_fmadd_ps(_w2, _val12, _sum1);
805 _sum1 = _mm256_comp_fmadd_ps(_w3, _val13, _sum1);
806 _sum1 = _mm256_comp_fmadd_ps(_w4, _val14, _sum1);
807 _sum1 = _mm256_comp_fmadd_ps(_w5, _val15, _sum1);
808 _sum1 = _mm256_comp_fmadd_ps(_w6, _val16, _sum1);
809 _sum1 = _mm256_comp_fmadd_ps(_w7, _val17, _sum1);
810
811 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
812 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
813 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
814 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
815 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
816 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
817 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
818 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
819 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
820 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
821 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
822 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
823 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
824 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
825 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
826 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
827
828 _sum2 = _mm256_comp_fmadd_ps(_w0, _val20, _sum2);
829 _sum2 = _mm256_comp_fmadd_ps(_w1, _val21, _sum2);
830 _sum2 = _mm256_comp_fmadd_ps(_w2, _val22, _sum2);
831 _sum2 = _mm256_comp_fmadd_ps(_w3, _val23, _sum2);
832 _sum2 = _mm256_comp_fmadd_ps(_w4, _val24, _sum2);
833 _sum2 = _mm256_comp_fmadd_ps(_w5, _val25, _sum2);
834 _sum2 = _mm256_comp_fmadd_ps(_w6, _val26, _sum2);
835 _sum2 = _mm256_comp_fmadd_ps(_w7, _val27, _sum2);
836 _sum3 = _mm256_comp_fmadd_ps(_w0, _val30, _sum3);
837 _sum3 = _mm256_comp_fmadd_ps(_w1, _val31, _sum3);
838 _sum3 = _mm256_comp_fmadd_ps(_w2, _val32, _sum3);
839 _sum3 = _mm256_comp_fmadd_ps(_w3, _val33, _sum3);
840 _sum3 = _mm256_comp_fmadd_ps(_w4, _val34, _sum3);
841 _sum3 = _mm256_comp_fmadd_ps(_w5, _val35, _sum3);
842 _sum3 = _mm256_comp_fmadd_ps(_w6, _val36, _sum3);
843 _sum3 = _mm256_comp_fmadd_ps(_w7, _val37, _sum3);
844
845 tmpptr += 32;
846
847 kptr += 64;
848 }
849 _mm256_storeu_ps(outptr, _sum0);
850 _mm256_storeu_ps(outptr + 8, _sum1);
851 _mm256_storeu_ps(outptr + 16, _sum2);
852 _mm256_storeu_ps(outptr + 24, _sum3);
853
854 outptr += 32;
855 }
856 for (; i + 1 < size; i += 2)
857 {
858 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
859
860 __m256 _sum0 = _bias0;
861 __m256 _sum1 = _bias0;
862
863 const float* kptr = (const float*)kernel + p * inch * 64;
864 for (int q = 0; q < inch; q++)
865 {
866 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
867 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
868 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
869 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
870 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
871 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
872 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
873 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
874 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
875 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
876 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
877 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
878 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
879 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
880 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
881 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
882
883 __m256 _w0 = _mm256_loadu_ps(kptr);
884 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
885 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
886 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
887 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
888 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
889 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
890 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
891
892 _sum0 = _mm256_comp_fmadd_ps(_w0, _val00, _sum0);
893 _sum0 = _mm256_comp_fmadd_ps(_w1, _val01, _sum0);
894 _sum0 = _mm256_comp_fmadd_ps(_w2, _val02, _sum0);
895 _sum0 = _mm256_comp_fmadd_ps(_w3, _val03, _sum0);
896 _sum0 = _mm256_comp_fmadd_ps(_w4, _val04, _sum0);
897 _sum0 = _mm256_comp_fmadd_ps(_w5, _val05, _sum0);
898 _sum0 = _mm256_comp_fmadd_ps(_w6, _val06, _sum0);
899 _sum0 = _mm256_comp_fmadd_ps(_w7, _val07, _sum0);
900 _sum1 = _mm256_comp_fmadd_ps(_w0, _val10, _sum1);
901 _sum1 = _mm256_comp_fmadd_ps(_w1, _val11, _sum1);
902 _sum1 = _mm256_comp_fmadd_ps(_w2, _val12, _sum1);
903 _sum1 = _mm256_comp_fmadd_ps(_w3, _val13, _sum1);
904 _sum1 = _mm256_comp_fmadd_ps(_w4, _val14, _sum1);
905 _sum1 = _mm256_comp_fmadd_ps(_w5, _val15, _sum1);
906 _sum1 = _mm256_comp_fmadd_ps(_w6, _val16, _sum1);
907 _sum1 = _mm256_comp_fmadd_ps(_w7, _val17, _sum1);
908
909 tmpptr += 16;
910
911 kptr += 64;
912 }
913 _mm256_storeu_ps(outptr, _sum0);
914 _mm256_storeu_ps(outptr + 8, _sum1);
915
916 outptr += 16;
917 }
918
919 for (; i < size; i++)
920 {
921 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
922 __m256 _sum = _bias0;
923
924 const float* kptr = (const float*)kernel + p * inch * 64;
925 for (int q = 0; q < inch; q++)
926 {
927 __m256 _val0 = _mm256_broadcast_ss(tmpptr);
928 __m256 _val1 = _mm256_broadcast_ss(tmpptr + 1);
929 __m256 _val2 = _mm256_broadcast_ss(tmpptr + 2);
930 __m256 _val3 = _mm256_broadcast_ss(tmpptr + 3);
931 __m256 _val4 = _mm256_broadcast_ss(tmpptr + 4);
932 __m256 _val5 = _mm256_broadcast_ss(tmpptr + 5);
933 __m256 _val6 = _mm256_broadcast_ss(tmpptr + 6);
934 __m256 _val7 = _mm256_broadcast_ss(tmpptr + 7);
935
936 __m256 _w0 = _mm256_loadu_ps(kptr);
937 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
938 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
939 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
940 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
941 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
942 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
943 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
944
945 _sum = _mm256_comp_fmadd_ps(_w0, _val0, _sum);
946 _sum = _mm256_comp_fmadd_ps(_w1, _val1, _sum);
947 _sum = _mm256_comp_fmadd_ps(_w2, _val2, _sum);
948 _sum = _mm256_comp_fmadd_ps(_w3, _val3, _sum);
949 _sum = _mm256_comp_fmadd_ps(_w4, _val4, _sum);
950 _sum = _mm256_comp_fmadd_ps(_w5, _val5, _sum);
951 _sum = _mm256_comp_fmadd_ps(_w6, _val6, _sum);
952 _sum = _mm256_comp_fmadd_ps(_w7, _val7, _sum);
953
954 tmpptr += 8;
955
956 kptr += 64;
957 }
958 _mm256_storeu_ps(outptr, _sum);
959
960 outptr += 8;
961 }
962 }
963 }
964
conv1x1s2_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)965 static void conv1x1s2_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
966 {
967 int w = bottom_blob.w;
968 int channels = bottom_blob.c;
969 size_t elemsize = bottom_blob.elemsize;
970 int elempack = bottom_blob.elempack;
971
972 int outw = top_blob.w;
973 int outh = top_blob.h;
974
975 const int tailstep = (w - 2 * outw + w) * 8;
976
977 Mat bottom_blob_shrinked;
978 bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
979
980 #pragma omp parallel for num_threads(opt.num_threads)
981 for (int p = 0; p < channels; p++)
982 {
983 const float* r0 = bottom_blob.channel(p);
984 float* outptr = bottom_blob_shrinked.channel(p);
985
986 for (int i = 0; i < outh; i++)
987 {
988 for (int j = 0; j < outw; j++)
989 {
990 __m256 _v = _mm256_loadu_ps(r0);
991 _mm256_storeu_ps(outptr, _v);
992
993 r0 += 16;
994 outptr += 8;
995 }
996
997 r0 += tailstep;
998 }
999 }
1000
1001 conv1x1s1_sgemm_pack8_avx(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
1002 }
1003