1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv1x1s1_sgemm_transform_kernel_pack8_avx(const Mat & kernel,Mat & weight_data_pack8,int num_input,int num_output)15 static void conv1x1s1_sgemm_transform_kernel_pack8_avx(const Mat& kernel, Mat& weight_data_pack8, int num_input, int num_output)
16 {
17 // src = kw-kh-inch-outch
18 // dst = 8b-8a-kw-kh-inch/8a-outch/8b
19 Mat weight_data_r2 = kernel.reshape(1, num_input, num_output);
20
21 weight_data_pack8.create(1, num_input / 8, num_output / 8, (size_t)4 * 64, 64);
22
23 for (int q = 0; q + 7 < num_output; q += 8)
24 {
25 const Mat k0 = weight_data_r2.channel(q);
26 const Mat k1 = weight_data_r2.channel(q + 1);
27 const Mat k2 = weight_data_r2.channel(q + 2);
28 const Mat k3 = weight_data_r2.channel(q + 3);
29 const Mat k4 = weight_data_r2.channel(q + 4);
30 const Mat k5 = weight_data_r2.channel(q + 5);
31 const Mat k6 = weight_data_r2.channel(q + 6);
32 const Mat k7 = weight_data_r2.channel(q + 7);
33
34 Mat g0 = weight_data_pack8.channel(q / 8);
35
36 for (int p = 0; p + 7 < num_input; p += 8)
37 {
38 const float* k00 = k0.row(p);
39 const float* k01 = k0.row(p + 1);
40 const float* k02 = k0.row(p + 2);
41 const float* k03 = k0.row(p + 3);
42 const float* k04 = k0.row(p + 4);
43 const float* k05 = k0.row(p + 5);
44 const float* k06 = k0.row(p + 6);
45 const float* k07 = k0.row(p + 7);
46
47 const float* k10 = k1.row(p);
48 const float* k11 = k1.row(p + 1);
49 const float* k12 = k1.row(p + 2);
50 const float* k13 = k1.row(p + 3);
51 const float* k14 = k1.row(p + 4);
52 const float* k15 = k1.row(p + 5);
53 const float* k16 = k1.row(p + 6);
54 const float* k17 = k1.row(p + 7);
55
56 const float* k20 = k2.row(p);
57 const float* k21 = k2.row(p + 1);
58 const float* k22 = k2.row(p + 2);
59 const float* k23 = k2.row(p + 3);
60 const float* k24 = k2.row(p + 4);
61 const float* k25 = k2.row(p + 5);
62 const float* k26 = k2.row(p + 6);
63 const float* k27 = k2.row(p + 7);
64
65 const float* k30 = k3.row(p);
66 const float* k31 = k3.row(p + 1);
67 const float* k32 = k3.row(p + 2);
68 const float* k33 = k3.row(p + 3);
69 const float* k34 = k3.row(p + 4);
70 const float* k35 = k3.row(p + 5);
71 const float* k36 = k3.row(p + 6);
72 const float* k37 = k3.row(p + 7);
73
74 const float* k40 = k4.row(p);
75 const float* k41 = k4.row(p + 1);
76 const float* k42 = k4.row(p + 2);
77 const float* k43 = k4.row(p + 3);
78 const float* k44 = k4.row(p + 4);
79 const float* k45 = k4.row(p + 5);
80 const float* k46 = k4.row(p + 6);
81 const float* k47 = k4.row(p + 7);
82
83 const float* k50 = k5.row(p);
84 const float* k51 = k5.row(p + 1);
85 const float* k52 = k5.row(p + 2);
86 const float* k53 = k5.row(p + 3);
87 const float* k54 = k5.row(p + 4);
88 const float* k55 = k5.row(p + 5);
89 const float* k56 = k5.row(p + 6);
90 const float* k57 = k5.row(p + 7);
91
92 const float* k60 = k6.row(p);
93 const float* k61 = k6.row(p + 1);
94 const float* k62 = k6.row(p + 2);
95 const float* k63 = k6.row(p + 3);
96 const float* k64 = k6.row(p + 4);
97 const float* k65 = k6.row(p + 5);
98 const float* k66 = k6.row(p + 6);
99 const float* k67 = k6.row(p + 7);
100
101 const float* k70 = k7.row(p);
102 const float* k71 = k7.row(p + 1);
103 const float* k72 = k7.row(p + 2);
104 const float* k73 = k7.row(p + 3);
105 const float* k74 = k7.row(p + 4);
106 const float* k75 = k7.row(p + 5);
107 const float* k76 = k7.row(p + 6);
108 const float* k77 = k7.row(p + 7);
109
110 float* g00 = g0.row(p / 8);
111 g00[0] = k00[0];
112 g00[1] = k10[0];
113 g00[2] = k20[0];
114 g00[3] = k30[0];
115 g00[4] = k40[0];
116 g00[5] = k50[0];
117 g00[6] = k60[0];
118 g00[7] = k70[0];
119 g00 += 8;
120 g00[0] = k01[0];
121 g00[1] = k11[0];
122 g00[2] = k21[0];
123 g00[3] = k31[0];
124 g00[4] = k41[0];
125 g00[5] = k51[0];
126 g00[6] = k61[0];
127 g00[7] = k71[0];
128
129 g00 += 8;
130 g00[0] = k02[0];
131 g00[1] = k12[0];
132 g00[2] = k22[0];
133 g00[3] = k32[0];
134 g00[4] = k42[0];
135 g00[5] = k52[0];
136 g00[6] = k62[0];
137 g00[7] = k72[0];
138
139 g00 += 8;
140 g00[0] = k03[0];
141 g00[1] = k13[0];
142 g00[2] = k23[0];
143 g00[3] = k33[0];
144 g00[4] = k43[0];
145 g00[5] = k53[0];
146 g00[6] = k63[0];
147 g00[7] = k73[0];
148
149 g00 += 8;
150 g00[0] = k04[0];
151 g00[1] = k14[0];
152 g00[2] = k24[0];
153 g00[3] = k34[0];
154 g00[4] = k44[0];
155 g00[5] = k54[0];
156 g00[6] = k64[0];
157 g00[7] = k74[0];
158
159 g00 += 8;
160 g00[0] = k05[0];
161 g00[1] = k15[0];
162 g00[2] = k25[0];
163 g00[3] = k35[0];
164 g00[4] = k45[0];
165 g00[5] = k55[0];
166 g00[6] = k65[0];
167 g00[7] = k75[0];
168
169 g00 += 8;
170 g00[0] = k06[0];
171 g00[1] = k16[0];
172 g00[2] = k26[0];
173 g00[3] = k36[0];
174 g00[4] = k46[0];
175 g00[5] = k56[0];
176 g00[6] = k66[0];
177 g00[7] = k76[0];
178
179 g00 += 8;
180 g00[0] = k07[0];
181 g00[1] = k17[0];
182 g00[2] = k27[0];
183 g00[3] = k37[0];
184 g00[4] = k47[0];
185 g00[5] = k57[0];
186 g00[6] = k67[0];
187 g00[7] = k77[0];
188
189 g00 += 8;
190 }
191 }
192 }
193
conv1x1s1_sgemm_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)194 static void conv1x1s1_sgemm_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
195 {
196 int w = bottom_blob.w;
197 int h = bottom_blob.h;
198 int inch = bottom_blob.c;
199 int outch = top_blob.c;
200 size_t elemsize = bottom_blob.elemsize;
201 int elempack = bottom_blob.elempack;
202
203 const int size = w * h;
204
205 const float* bias = _bias;
206 // interleave
207 Mat tmp(12, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + (size % 12 % 4) / 2 + size % 12 % 2, elemsize, elempack, opt.workspace_allocator);
208 {
209 int nn_size = size / 12;
210 int remain_size_start = nn_size * 12;
211 #pragma omp parallel for num_threads(opt.num_threads)
212 for (int ii = 0; ii < nn_size; ii++)
213 {
214 int i = ii * 12;
215 const float* img0 = bottom_blob.channel(0);
216 img0 += i * 8;
217
218 float* tmpptr = tmp.channel(i / 12);
219
220 for (int q = 0; q < inch; q++)
221 {
222 __m256 _r0 = _mm256_loadu_ps(img0);
223 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
224 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
225 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
226 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
227 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
228 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
229 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
230 __m256 _r8 = _mm256_loadu_ps(img0 + 64);
231 __m256 _r9 = _mm256_loadu_ps(img0 + 72);
232 __m256 _r10 = _mm256_loadu_ps(img0 + 80);
233 __m256 _r11 = _mm256_loadu_ps(img0 + 88);
234 _mm256_storeu_ps(tmpptr, _r0);
235 _mm256_storeu_ps(tmpptr + 8, _r1);
236 _mm256_storeu_ps(tmpptr + 16, _r2);
237 _mm256_storeu_ps(tmpptr + 24, _r3);
238 _mm256_storeu_ps(tmpptr + 32, _r4);
239 _mm256_storeu_ps(tmpptr + 40, _r5);
240 _mm256_storeu_ps(tmpptr + 48, _r6);
241 _mm256_storeu_ps(tmpptr + 56, _r7);
242 _mm256_storeu_ps(tmpptr + 64, _r8);
243 _mm256_storeu_ps(tmpptr + 72, _r9);
244 _mm256_storeu_ps(tmpptr + 80, _r10);
245 _mm256_storeu_ps(tmpptr + 88, _r11);
246
247 tmpptr += 96;
248 img0 += bottom_blob.cstep * 8;
249 }
250 }
251 nn_size = (size - remain_size_start) >> 3;
252 #pragma omp parallel for num_threads(opt.num_threads)
253 for (int ii = 0; ii < nn_size; ii++)
254 {
255 int i = remain_size_start + ii * 8;
256
257 const float* img0 = bottom_blob.channel(0);
258 img0 += i * 8;
259
260 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
261
262 for (int q = 0; q < inch; q++)
263 {
264 __m256 _r0 = _mm256_loadu_ps(img0);
265 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
266 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
267 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
268 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
269 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
270 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
271 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
272 _mm256_storeu_ps(tmpptr, _r0);
273 _mm256_storeu_ps(tmpptr + 8, _r1);
274 _mm256_storeu_ps(tmpptr + 16, _r2);
275 _mm256_storeu_ps(tmpptr + 24, _r3);
276 _mm256_storeu_ps(tmpptr + 32, _r4);
277 _mm256_storeu_ps(tmpptr + 40, _r5);
278 _mm256_storeu_ps(tmpptr + 48, _r6);
279 _mm256_storeu_ps(tmpptr + 56, _r7);
280
281 tmpptr += 64;
282 img0 += bottom_blob.cstep * 8;
283 }
284 }
285
286 remain_size_start += nn_size << 3;
287 nn_size = (size - remain_size_start) >> 2;
288
289 #pragma omp parallel for num_threads(opt.num_threads)
290 for (int ii = 0; ii < nn_size; ii++)
291 {
292 int i = remain_size_start + ii * 4;
293
294 const float* img0 = bottom_blob.channel(0);
295 img0 += i * 8;
296 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
297
298 for (int q = 0; q < inch; q++)
299 {
300 __m256 _r0 = _mm256_loadu_ps(img0);
301 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
302 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
303 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
304 _mm256_storeu_ps(tmpptr, _r0);
305 _mm256_storeu_ps(tmpptr + 8, _r1);
306 _mm256_storeu_ps(tmpptr + 16, _r2);
307 _mm256_storeu_ps(tmpptr + 24, _r3);
308
309 tmpptr += 32;
310 img0 += bottom_blob.cstep * 8;
311 }
312 }
313
314 remain_size_start += nn_size << 2;
315 nn_size = (size - remain_size_start) >> 1;
316 #pragma omp parallel for num_threads(opt.num_threads)
317 for (int ii = 0; ii < nn_size; ii++)
318 {
319 int i = remain_size_start + ii * 2;
320
321 const float* img0 = bottom_blob.channel(0);
322 img0 += i * 8;
323 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
324
325 for (int q = 0; q < inch; q++)
326 {
327 __m256 _r0 = _mm256_loadu_ps(img0);
328 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
329 _mm256_storeu_ps(tmpptr, _r0);
330 _mm256_storeu_ps(tmpptr + 8, _r1);
331
332 tmpptr += 16;
333 img0 += bottom_blob.cstep * 8;
334 }
335 }
336
337 remain_size_start += nn_size << 1;
338 #pragma omp parallel for num_threads(opt.num_threads)
339 for (int i = remain_size_start; i < size; i++)
340 {
341 const float* img0 = bottom_blob.channel(0);
342 img0 += i * 8;
343 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
344 for (int q = 0; q < inch; q++)
345 {
346 __m256 _r0 = _mm256_loadu_ps(img0);
347 _mm256_storeu_ps(tmpptr, _r0);
348
349 tmpptr += 8;
350 img0 += bottom_blob.cstep * 8;
351 }
352 }
353 }
354 #pragma omp parallel for num_threads(opt.num_threads)
355 for (int p = 0; p < outch; p++)
356 {
357 Mat out = top_blob.channel(p);
358
359 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
360
361 float* outptr = out;
362 int i = 0;
363 for (; i + 11 < size; i += 12)
364 {
365 const float* tmpptr = tmp.channel(i / 12);
366
367 __m256 _sum0 = _bias0;
368 __m256 _sum1 = _bias0;
369 __m256 _sum2 = _bias0;
370 __m256 _sum3 = _bias0;
371 __m256 _sum4 = _bias0;
372 __m256 _sum5 = _bias0;
373 __m256 _sum6 = _bias0;
374 __m256 _sum7 = _bias0;
375 __m256 _sum8 = _bias0;
376 __m256 _sum9 = _bias0;
377 __m256 _sum10 = _bias0;
378 __m256 _sum11 = _bias0;
379
380 const float* kptr = (const float*)kernel + p * inch * 64;
381 for (int q = 0; q < inch; q++)
382 {
383 __m256 _w0 = _mm256_loadu_ps(kptr);
384 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
385 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
386 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
387 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
388 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
389 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
390 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
391
392 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
393 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
394 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
395 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
396 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
397 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
398 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
399 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
400 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
401 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
402 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
403 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
404 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
405 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
406 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
407 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
408
409 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
410 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
411 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
412 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
413 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
414 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
415 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
416 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
417 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
418 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
419 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
420 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
421 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
422 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
423 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
424 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
425
426 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
427 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
428 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
429 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
430 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
431 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
432 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
433 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
434 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
435 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
436 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
437 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
438 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
439 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
440 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
441 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
442
443 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
444 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
445 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
446 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
447 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
448 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
449 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
450 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
451 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
452 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
453 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
454 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
455 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
456 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
457 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
458 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
459
460 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
461 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
462 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
463 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
464 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
465 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
466 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
467 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
468 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
469 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
470 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
471 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
472 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
473 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
474 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
475 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
476
477 _sum4 = _mm256_fmadd_ps(_w0, _val40, _sum4);
478 _sum4 = _mm256_fmadd_ps(_w1, _val41, _sum4);
479 _sum4 = _mm256_fmadd_ps(_w2, _val42, _sum4);
480 _sum4 = _mm256_fmadd_ps(_w3, _val43, _sum4);
481 _sum4 = _mm256_fmadd_ps(_w4, _val44, _sum4);
482 _sum4 = _mm256_fmadd_ps(_w5, _val45, _sum4);
483 _sum4 = _mm256_fmadd_ps(_w6, _val46, _sum4);
484 _sum4 = _mm256_fmadd_ps(_w7, _val47, _sum4);
485 _sum5 = _mm256_fmadd_ps(_w0, _val50, _sum5);
486 _sum5 = _mm256_fmadd_ps(_w1, _val51, _sum5);
487 _sum5 = _mm256_fmadd_ps(_w2, _val52, _sum5);
488 _sum5 = _mm256_fmadd_ps(_w3, _val53, _sum5);
489 _sum5 = _mm256_fmadd_ps(_w4, _val54, _sum5);
490 _sum5 = _mm256_fmadd_ps(_w5, _val55, _sum5);
491 _sum5 = _mm256_fmadd_ps(_w6, _val56, _sum5);
492 _sum5 = _mm256_fmadd_ps(_w7, _val57, _sum5);
493
494 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
495 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
496 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
497 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
498 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
499 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
500 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
501 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
502 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
503 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
504 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
505 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
506 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
507 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
508 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
509 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
510
511 _sum6 = _mm256_fmadd_ps(_w0, _val60, _sum6);
512 _sum6 = _mm256_fmadd_ps(_w1, _val61, _sum6);
513 _sum6 = _mm256_fmadd_ps(_w2, _val62, _sum6);
514 _sum6 = _mm256_fmadd_ps(_w3, _val63, _sum6);
515 _sum6 = _mm256_fmadd_ps(_w4, _val64, _sum6);
516 _sum6 = _mm256_fmadd_ps(_w5, _val65, _sum6);
517 _sum6 = _mm256_fmadd_ps(_w6, _val66, _sum6);
518 _sum6 = _mm256_fmadd_ps(_w7, _val67, _sum6);
519 _sum7 = _mm256_fmadd_ps(_w0, _val70, _sum7);
520 _sum7 = _mm256_fmadd_ps(_w1, _val71, _sum7);
521 _sum7 = _mm256_fmadd_ps(_w2, _val72, _sum7);
522 _sum7 = _mm256_fmadd_ps(_w3, _val73, _sum7);
523 _sum7 = _mm256_fmadd_ps(_w4, _val74, _sum7);
524 _sum7 = _mm256_fmadd_ps(_w5, _val75, _sum7);
525 _sum7 = _mm256_fmadd_ps(_w6, _val76, _sum7);
526 _sum7 = _mm256_fmadd_ps(_w7, _val77, _sum7);
527
528 __m256 _val80 = _mm256_broadcast_ss(tmpptr + 64);
529 __m256 _val81 = _mm256_broadcast_ss(tmpptr + 65);
530 __m256 _val82 = _mm256_broadcast_ss(tmpptr + 66);
531 __m256 _val83 = _mm256_broadcast_ss(tmpptr + 67);
532 __m256 _val84 = _mm256_broadcast_ss(tmpptr + 68);
533 __m256 _val85 = _mm256_broadcast_ss(tmpptr + 69);
534 __m256 _val86 = _mm256_broadcast_ss(tmpptr + 70);
535 __m256 _val87 = _mm256_broadcast_ss(tmpptr + 71);
536 __m256 _val90 = _mm256_broadcast_ss(tmpptr + 72);
537 __m256 _val91 = _mm256_broadcast_ss(tmpptr + 73);
538 __m256 _val92 = _mm256_broadcast_ss(tmpptr + 74);
539 __m256 _val93 = _mm256_broadcast_ss(tmpptr + 75);
540 __m256 _val94 = _mm256_broadcast_ss(tmpptr + 76);
541 __m256 _val95 = _mm256_broadcast_ss(tmpptr + 77);
542 __m256 _val96 = _mm256_broadcast_ss(tmpptr + 78);
543 __m256 _val97 = _mm256_broadcast_ss(tmpptr + 79);
544
545 _sum8 = _mm256_fmadd_ps(_w0, _val80, _sum8);
546 _sum8 = _mm256_fmadd_ps(_w1, _val81, _sum8);
547 _sum8 = _mm256_fmadd_ps(_w2, _val82, _sum8);
548 _sum8 = _mm256_fmadd_ps(_w3, _val83, _sum8);
549 _sum8 = _mm256_fmadd_ps(_w4, _val84, _sum8);
550 _sum8 = _mm256_fmadd_ps(_w5, _val85, _sum8);
551 _sum8 = _mm256_fmadd_ps(_w6, _val86, _sum8);
552 _sum8 = _mm256_fmadd_ps(_w7, _val87, _sum8);
553 _sum9 = _mm256_fmadd_ps(_w0, _val90, _sum9);
554 _sum9 = _mm256_fmadd_ps(_w1, _val91, _sum9);
555 _sum9 = _mm256_fmadd_ps(_w2, _val92, _sum9);
556 _sum9 = _mm256_fmadd_ps(_w3, _val93, _sum9);
557 _sum9 = _mm256_fmadd_ps(_w4, _val94, _sum9);
558 _sum9 = _mm256_fmadd_ps(_w5, _val95, _sum9);
559 _sum9 = _mm256_fmadd_ps(_w6, _val96, _sum9);
560 _sum9 = _mm256_fmadd_ps(_w7, _val97, _sum9);
561
562 __m256 _val100 = _mm256_broadcast_ss(tmpptr + 80);
563 __m256 _val101 = _mm256_broadcast_ss(tmpptr + 81);
564 __m256 _val102 = _mm256_broadcast_ss(tmpptr + 82);
565 __m256 _val103 = _mm256_broadcast_ss(tmpptr + 83);
566 __m256 _val104 = _mm256_broadcast_ss(tmpptr + 84);
567 __m256 _val105 = _mm256_broadcast_ss(tmpptr + 85);
568 __m256 _val106 = _mm256_broadcast_ss(tmpptr + 86);
569 __m256 _val107 = _mm256_broadcast_ss(tmpptr + 87);
570 __m256 _val110 = _mm256_broadcast_ss(tmpptr + 88);
571 __m256 _val111 = _mm256_broadcast_ss(tmpptr + 89);
572 __m256 _val112 = _mm256_broadcast_ss(tmpptr + 90);
573 __m256 _val113 = _mm256_broadcast_ss(tmpptr + 91);
574 __m256 _val114 = _mm256_broadcast_ss(tmpptr + 92);
575 __m256 _val115 = _mm256_broadcast_ss(tmpptr + 93);
576 __m256 _val116 = _mm256_broadcast_ss(tmpptr + 94);
577 __m256 _val117 = _mm256_broadcast_ss(tmpptr + 95);
578
579 _sum10 = _mm256_fmadd_ps(_w0, _val100, _sum10);
580 _sum10 = _mm256_fmadd_ps(_w1, _val101, _sum10);
581 _sum10 = _mm256_fmadd_ps(_w2, _val102, _sum10);
582 _sum10 = _mm256_fmadd_ps(_w3, _val103, _sum10);
583 _sum10 = _mm256_fmadd_ps(_w4, _val104, _sum10);
584 _sum10 = _mm256_fmadd_ps(_w5, _val105, _sum10);
585 _sum10 = _mm256_fmadd_ps(_w6, _val106, _sum10);
586 _sum10 = _mm256_fmadd_ps(_w7, _val107, _sum10);
587 _sum11 = _mm256_fmadd_ps(_w0, _val110, _sum11);
588 _sum11 = _mm256_fmadd_ps(_w1, _val111, _sum11);
589 _sum11 = _mm256_fmadd_ps(_w2, _val112, _sum11);
590 _sum11 = _mm256_fmadd_ps(_w3, _val113, _sum11);
591 _sum11 = _mm256_fmadd_ps(_w4, _val114, _sum11);
592 _sum11 = _mm256_fmadd_ps(_w5, _val115, _sum11);
593 _sum11 = _mm256_fmadd_ps(_w6, _val116, _sum11);
594 _sum11 = _mm256_fmadd_ps(_w7, _val117, _sum11);
595
596 tmpptr += 96;
597
598 kptr += 64;
599 }
600 _mm256_storeu_ps(outptr, _sum0);
601 _mm256_storeu_ps(outptr + 8, _sum1);
602 _mm256_storeu_ps(outptr + 16, _sum2);
603 _mm256_storeu_ps(outptr + 24, _sum3);
604 _mm256_storeu_ps(outptr + 32, _sum4);
605 _mm256_storeu_ps(outptr + 40, _sum5);
606 _mm256_storeu_ps(outptr + 48, _sum6);
607 _mm256_storeu_ps(outptr + 56, _sum7);
608 _mm256_storeu_ps(outptr + 64, _sum8);
609 _mm256_storeu_ps(outptr + 72, _sum9);
610 _mm256_storeu_ps(outptr + 80, _sum10);
611 _mm256_storeu_ps(outptr + 88, _sum11);
612
613 outptr += 96;
614 }
615 for (; i + 7 < size; i += 8)
616 {
617 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
618
619 __m256 _sum0 = _bias0;
620 __m256 _sum1 = _bias0;
621 __m256 _sum2 = _bias0;
622 __m256 _sum3 = _bias0;
623 __m256 _sum4 = _bias0;
624 __m256 _sum5 = _bias0;
625 __m256 _sum6 = _bias0;
626 __m256 _sum7 = _bias0;
627
628 const float* kptr = (const float*)kernel + p * inch * 64;
629 for (int q = 0; q < inch; q++)
630 {
631 __m256 _w0 = _mm256_loadu_ps(kptr);
632 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
633 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
634 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
635 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
636 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
637 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
638 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
639
640 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
641 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
642 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
643 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
644 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
645 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
646 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
647 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
648 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
649 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
650 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
651 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
652 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
653 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
654 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
655 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
656
657 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
658 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
659 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
660 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
661 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
662 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
663 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
664 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
665 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
666 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
667 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
668 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
669 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
670 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
671 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
672 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
673
674 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
675 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
676 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
677 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
678 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
679 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
680 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
681 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
682 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
683 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
684 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
685 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
686 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
687 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
688 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
689 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
690
691 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
692 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
693 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
694 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
695 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
696 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
697 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
698 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
699 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
700 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
701 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
702 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
703 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
704 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
705 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
706 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
707
708 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
709 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
710 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
711 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
712 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
713 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
714 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
715 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
716 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
717 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
718 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
719 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
720 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
721 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
722 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
723 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
724
725 _sum4 = _mm256_fmadd_ps(_w0, _val40, _sum4);
726 _sum4 = _mm256_fmadd_ps(_w1, _val41, _sum4);
727 _sum4 = _mm256_fmadd_ps(_w2, _val42, _sum4);
728 _sum4 = _mm256_fmadd_ps(_w3, _val43, _sum4);
729 _sum4 = _mm256_fmadd_ps(_w4, _val44, _sum4);
730 _sum4 = _mm256_fmadd_ps(_w5, _val45, _sum4);
731 _sum4 = _mm256_fmadd_ps(_w6, _val46, _sum4);
732 _sum4 = _mm256_fmadd_ps(_w7, _val47, _sum4);
733 _sum5 = _mm256_fmadd_ps(_w0, _val50, _sum5);
734 _sum5 = _mm256_fmadd_ps(_w1, _val51, _sum5);
735 _sum5 = _mm256_fmadd_ps(_w2, _val52, _sum5);
736 _sum5 = _mm256_fmadd_ps(_w3, _val53, _sum5);
737 _sum5 = _mm256_fmadd_ps(_w4, _val54, _sum5);
738 _sum5 = _mm256_fmadd_ps(_w5, _val55, _sum5);
739 _sum5 = _mm256_fmadd_ps(_w6, _val56, _sum5);
740 _sum5 = _mm256_fmadd_ps(_w7, _val57, _sum5);
741
742 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
743 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
744 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
745 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
746 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
747 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
748 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
749 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
750 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
751 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
752 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
753 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
754 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
755 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
756 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
757 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
758
759 _sum6 = _mm256_fmadd_ps(_w0, _val60, _sum6);
760 _sum6 = _mm256_fmadd_ps(_w1, _val61, _sum6);
761 _sum6 = _mm256_fmadd_ps(_w2, _val62, _sum6);
762 _sum6 = _mm256_fmadd_ps(_w3, _val63, _sum6);
763 _sum6 = _mm256_fmadd_ps(_w4, _val64, _sum6);
764 _sum6 = _mm256_fmadd_ps(_w5, _val65, _sum6);
765 _sum6 = _mm256_fmadd_ps(_w6, _val66, _sum6);
766 _sum6 = _mm256_fmadd_ps(_w7, _val67, _sum6);
767 _sum7 = _mm256_fmadd_ps(_w0, _val70, _sum7);
768 _sum7 = _mm256_fmadd_ps(_w1, _val71, _sum7);
769 _sum7 = _mm256_fmadd_ps(_w2, _val72, _sum7);
770 _sum7 = _mm256_fmadd_ps(_w3, _val73, _sum7);
771 _sum7 = _mm256_fmadd_ps(_w4, _val74, _sum7);
772 _sum7 = _mm256_fmadd_ps(_w5, _val75, _sum7);
773 _sum7 = _mm256_fmadd_ps(_w6, _val76, _sum7);
774 _sum7 = _mm256_fmadd_ps(_w7, _val77, _sum7);
775
776 tmpptr += 64;
777
778 kptr += 64;
779 }
780 _mm256_storeu_ps(outptr, _sum0);
781 _mm256_storeu_ps(outptr + 8, _sum1);
782 _mm256_storeu_ps(outptr + 16, _sum2);
783 _mm256_storeu_ps(outptr + 24, _sum3);
784 _mm256_storeu_ps(outptr + 32, _sum4);
785 _mm256_storeu_ps(outptr + 40, _sum5);
786 _mm256_storeu_ps(outptr + 48, _sum6);
787 _mm256_storeu_ps(outptr + 56, _sum7);
788
789 outptr += 64;
790 }
791 for (; i + 3 < size; i += 4)
792 {
793 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
794
795 __m256 _sum0 = _bias0;
796 __m256 _sum1 = _bias0;
797 __m256 _sum2 = _bias0;
798 __m256 _sum3 = _bias0;
799
800 const float* kptr = (const float*)kernel + p * inch * 64;
801 for (int q = 0; q < inch; q++)
802 {
803 __m256 _w0 = _mm256_loadu_ps(kptr);
804 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
805 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
806 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
807 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
808 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
809 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
810 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
811
812 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
813 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
814 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
815 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
816 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
817 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
818 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
819 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
820 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
821 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
822 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
823 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
824 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
825 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
826 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
827 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
828
829 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
830 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
831 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
832 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
833 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
834 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
835 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
836 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
837 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
838 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
839 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
840 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
841 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
842 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
843 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
844 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
845
846 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
847 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
848 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
849 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
850 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
851 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
852 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
853 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
854 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
855 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
856 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
857 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
858 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
859 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
860 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
861 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
862
863 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
864 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
865 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
866 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
867 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
868 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
869 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
870 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
871 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
872 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
873 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
874 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
875 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
876 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
877 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
878 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
879
880 tmpptr += 32;
881
882 kptr += 64;
883 }
884 _mm256_storeu_ps(outptr, _sum0);
885 _mm256_storeu_ps(outptr + 8, _sum1);
886 _mm256_storeu_ps(outptr + 16, _sum2);
887 _mm256_storeu_ps(outptr + 24, _sum3);
888
889 outptr += 32;
890 }
891 for (; i + 1 < size; i += 2)
892 {
893 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
894
895 __m256 _sum0 = _bias0;
896 __m256 _sum1 = _bias0;
897
898 const float* kptr = (const float*)kernel + p * inch * 64;
899 for (int q = 0; q < inch; q++)
900 {
901 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
902 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
903 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
904 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
905 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
906 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
907 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
908 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
909 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
910 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
911 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
912 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
913 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
914 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
915 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
916 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
917
918 __m256 _w0 = _mm256_loadu_ps(kptr);
919 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
920 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
921 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
922 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
923 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
924 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
925 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
926
927 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
928 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
929 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
930 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
931 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
932 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
933 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
934 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
935 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
936 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
937 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
938 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
939 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
940 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
941 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
942 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
943
944 tmpptr += 16;
945
946 kptr += 64;
947 }
948 _mm256_storeu_ps(outptr, _sum0);
949 _mm256_storeu_ps(outptr + 8, _sum1);
950
951 outptr += 16;
952 }
953
954 for (; i < size; i++)
955 {
956 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
957 __m256 _sum = _bias0;
958
959 const float* kptr = (const float*)kernel + p * inch * 64;
960 for (int q = 0; q < inch; q++)
961 {
962 __m256 _val0 = _mm256_broadcast_ss(tmpptr);
963 __m256 _val1 = _mm256_broadcast_ss(tmpptr + 1);
964 __m256 _val2 = _mm256_broadcast_ss(tmpptr + 2);
965 __m256 _val3 = _mm256_broadcast_ss(tmpptr + 3);
966 __m256 _val4 = _mm256_broadcast_ss(tmpptr + 4);
967 __m256 _val5 = _mm256_broadcast_ss(tmpptr + 5);
968 __m256 _val6 = _mm256_broadcast_ss(tmpptr + 6);
969 __m256 _val7 = _mm256_broadcast_ss(tmpptr + 7);
970
971 __m256 _w0 = _mm256_loadu_ps(kptr);
972 __m256 _w1 = _mm256_loadu_ps(kptr + 8);
973 __m256 _w2 = _mm256_loadu_ps(kptr + 16);
974 __m256 _w3 = _mm256_loadu_ps(kptr + 24);
975 __m256 _w4 = _mm256_loadu_ps(kptr + 32);
976 __m256 _w5 = _mm256_loadu_ps(kptr + 40);
977 __m256 _w6 = _mm256_loadu_ps(kptr + 48);
978 __m256 _w7 = _mm256_loadu_ps(kptr + 56);
979
980 _sum = _mm256_fmadd_ps(_w0, _val0, _sum);
981 _sum = _mm256_fmadd_ps(_w1, _val1, _sum);
982 _sum = _mm256_fmadd_ps(_w2, _val2, _sum);
983 _sum = _mm256_fmadd_ps(_w3, _val3, _sum);
984 _sum = _mm256_fmadd_ps(_w4, _val4, _sum);
985 _sum = _mm256_fmadd_ps(_w5, _val5, _sum);
986 _sum = _mm256_fmadd_ps(_w6, _val6, _sum);
987 _sum = _mm256_fmadd_ps(_w7, _val7, _sum);
988
989 tmpptr += 8;
990
991 kptr += 64;
992 }
993 _mm256_storeu_ps(outptr, _sum);
994
995 outptr += 8;
996 }
997 }
998 }
999
conv1x1s2_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)1000 static void conv1x1s2_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
1001 {
1002 int w = bottom_blob.w;
1003 int channels = bottom_blob.c;
1004 size_t elemsize = bottom_blob.elemsize;
1005 int elempack = bottom_blob.elempack;
1006
1007 int outw = top_blob.w;
1008 int outh = top_blob.h;
1009
1010 const int tailstep = (w - 2 * outw + w) * 8;
1011
1012 Mat bottom_blob_shrinked;
1013 bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
1014
1015 #pragma omp parallel for num_threads(opt.num_threads)
1016 for (int p = 0; p < channels; p++)
1017 {
1018 const float* r0 = bottom_blob.channel(p);
1019 float* outptr = bottom_blob_shrinked.channel(p);
1020
1021 for (int i = 0; i < outh; i++)
1022 {
1023 for (int j = 0; j < outw; j++)
1024 {
1025 __m256 _v = _mm256_loadu_ps(r0);
1026 _mm256_storeu_ps(outptr, _v);
1027
1028 r0 += 16;
1029 outptr += 8;
1030 }
1031
1032 r0 += tailstep;
1033 }
1034 }
1035
1036 conv1x1s1_sgemm_pack8_avx(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
1037 }
1038