1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
conv1x1s1_sgemm_transform_kernel_fp16_pack8_avx(const Mat & kernel,Mat & weight_data_pack8,int num_input,int num_output)14 static void conv1x1s1_sgemm_transform_kernel_fp16_pack8_avx(const Mat& kernel, Mat& weight_data_pack8, int num_input, int num_output)
15 {
16     // src = kw-kh-inch-outch
17     // dst = 8b-8a-kw-kh-inch/8a-outch/8b
18     Mat weight_data_r2 = kernel.reshape(1, num_input, num_output);
19 
20     weight_data_pack8.create(1, num_input / 8, num_output / 8, (size_t)2 * 64, 64);
21 
22     for (int q = 0; q + 7 < num_output; q += 8)
23     {
24         const Mat k0 = weight_data_r2.channel(q);
25         const Mat k1 = weight_data_r2.channel(q + 1);
26         const Mat k2 = weight_data_r2.channel(q + 2);
27         const Mat k3 = weight_data_r2.channel(q + 3);
28         const Mat k4 = weight_data_r2.channel(q + 4);
29         const Mat k5 = weight_data_r2.channel(q + 5);
30         const Mat k6 = weight_data_r2.channel(q + 6);
31         const Mat k7 = weight_data_r2.channel(q + 7);
32 
33         Mat g0 = weight_data_pack8.channel(q / 8);
34 
35         for (int p = 0; p + 7 < num_input; p += 8)
36         {
37             const float* k00 = k0.row(p);
38             const float* k01 = k0.row(p + 1);
39             const float* k02 = k0.row(p + 2);
40             const float* k03 = k0.row(p + 3);
41             const float* k04 = k0.row(p + 4);
42             const float* k05 = k0.row(p + 5);
43             const float* k06 = k0.row(p + 6);
44             const float* k07 = k0.row(p + 7);
45 
46             const float* k10 = k1.row(p);
47             const float* k11 = k1.row(p + 1);
48             const float* k12 = k1.row(p + 2);
49             const float* k13 = k1.row(p + 3);
50             const float* k14 = k1.row(p + 4);
51             const float* k15 = k1.row(p + 5);
52             const float* k16 = k1.row(p + 6);
53             const float* k17 = k1.row(p + 7);
54 
55             const float* k20 = k2.row(p);
56             const float* k21 = k2.row(p + 1);
57             const float* k22 = k2.row(p + 2);
58             const float* k23 = k2.row(p + 3);
59             const float* k24 = k2.row(p + 4);
60             const float* k25 = k2.row(p + 5);
61             const float* k26 = k2.row(p + 6);
62             const float* k27 = k2.row(p + 7);
63 
64             const float* k30 = k3.row(p);
65             const float* k31 = k3.row(p + 1);
66             const float* k32 = k3.row(p + 2);
67             const float* k33 = k3.row(p + 3);
68             const float* k34 = k3.row(p + 4);
69             const float* k35 = k3.row(p + 5);
70             const float* k36 = k3.row(p + 6);
71             const float* k37 = k3.row(p + 7);
72 
73             const float* k40 = k4.row(p);
74             const float* k41 = k4.row(p + 1);
75             const float* k42 = k4.row(p + 2);
76             const float* k43 = k4.row(p + 3);
77             const float* k44 = k4.row(p + 4);
78             const float* k45 = k4.row(p + 5);
79             const float* k46 = k4.row(p + 6);
80             const float* k47 = k4.row(p + 7);
81 
82             const float* k50 = k5.row(p);
83             const float* k51 = k5.row(p + 1);
84             const float* k52 = k5.row(p + 2);
85             const float* k53 = k5.row(p + 3);
86             const float* k54 = k5.row(p + 4);
87             const float* k55 = k5.row(p + 5);
88             const float* k56 = k5.row(p + 6);
89             const float* k57 = k5.row(p + 7);
90 
91             const float* k60 = k6.row(p);
92             const float* k61 = k6.row(p + 1);
93             const float* k62 = k6.row(p + 2);
94             const float* k63 = k6.row(p + 3);
95             const float* k64 = k6.row(p + 4);
96             const float* k65 = k6.row(p + 5);
97             const float* k66 = k6.row(p + 6);
98             const float* k67 = k6.row(p + 7);
99 
100             const float* k70 = k7.row(p);
101             const float* k71 = k7.row(p + 1);
102             const float* k72 = k7.row(p + 2);
103             const float* k73 = k7.row(p + 3);
104             const float* k74 = k7.row(p + 4);
105             const float* k75 = k7.row(p + 5);
106             const float* k76 = k7.row(p + 6);
107             const float* k77 = k7.row(p + 7);
108 
109             unsigned short* g00 = (unsigned short*)g0.row(p / 8);
110             g00[0] = float32_to_float16(k00[0]);
111             g00[1] = float32_to_float16(k10[0]);
112             g00[2] = float32_to_float16(k20[0]);
113             g00[3] = float32_to_float16(k30[0]);
114             g00[4] = float32_to_float16(k40[0]);
115             g00[5] = float32_to_float16(k50[0]);
116             g00[6] = float32_to_float16(k60[0]);
117             g00[7] = float32_to_float16(k70[0]);
118             g00 += 8;
119             g00[0] = float32_to_float16(k01[0]);
120             g00[1] = float32_to_float16(k11[0]);
121             g00[2] = float32_to_float16(k21[0]);
122             g00[3] = float32_to_float16(k31[0]);
123             g00[4] = float32_to_float16(k41[0]);
124             g00[5] = float32_to_float16(k51[0]);
125             g00[6] = float32_to_float16(k61[0]);
126             g00[7] = float32_to_float16(k71[0]);
127 
128             g00 += 8;
129             g00[0] = float32_to_float16(k02[0]);
130             g00[1] = float32_to_float16(k12[0]);
131             g00[2] = float32_to_float16(k22[0]);
132             g00[3] = float32_to_float16(k32[0]);
133             g00[4] = float32_to_float16(k42[0]);
134             g00[5] = float32_to_float16(k52[0]);
135             g00[6] = float32_to_float16(k62[0]);
136             g00[7] = float32_to_float16(k72[0]);
137 
138             g00 += 8;
139             g00[0] = float32_to_float16(k03[0]);
140             g00[1] = float32_to_float16(k13[0]);
141             g00[2] = float32_to_float16(k23[0]);
142             g00[3] = float32_to_float16(k33[0]);
143             g00[4] = float32_to_float16(k43[0]);
144             g00[5] = float32_to_float16(k53[0]);
145             g00[6] = float32_to_float16(k63[0]);
146             g00[7] = float32_to_float16(k73[0]);
147 
148             g00 += 8;
149             g00[0] = float32_to_float16(k04[0]);
150             g00[1] = float32_to_float16(k14[0]);
151             g00[2] = float32_to_float16(k24[0]);
152             g00[3] = float32_to_float16(k34[0]);
153             g00[4] = float32_to_float16(k44[0]);
154             g00[5] = float32_to_float16(k54[0]);
155             g00[6] = float32_to_float16(k64[0]);
156             g00[7] = float32_to_float16(k74[0]);
157 
158             g00 += 8;
159             g00[0] = float32_to_float16(k05[0]);
160             g00[1] = float32_to_float16(k15[0]);
161             g00[2] = float32_to_float16(k25[0]);
162             g00[3] = float32_to_float16(k35[0]);
163             g00[4] = float32_to_float16(k45[0]);
164             g00[5] = float32_to_float16(k55[0]);
165             g00[6] = float32_to_float16(k65[0]);
166             g00[7] = float32_to_float16(k75[0]);
167 
168             g00 += 8;
169             g00[0] = float32_to_float16(k06[0]);
170             g00[1] = float32_to_float16(k16[0]);
171             g00[2] = float32_to_float16(k26[0]);
172             g00[3] = float32_to_float16(k36[0]);
173             g00[4] = float32_to_float16(k46[0]);
174             g00[5] = float32_to_float16(k56[0]);
175             g00[6] = float32_to_float16(k66[0]);
176             g00[7] = float32_to_float16(k76[0]);
177 
178             g00 += 8;
179             g00[0] = float32_to_float16(k07[0]);
180             g00[1] = float32_to_float16(k17[0]);
181             g00[2] = float32_to_float16(k27[0]);
182             g00[3] = float32_to_float16(k37[0]);
183             g00[4] = float32_to_float16(k47[0]);
184             g00[5] = float32_to_float16(k57[0]);
185             g00[6] = float32_to_float16(k67[0]);
186             g00[7] = float32_to_float16(k77[0]);
187 
188             g00 += 8;
189         }
190     }
191 }
192 
conv1x1s1_sgemm_fp16_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)193 static void conv1x1s1_sgemm_fp16_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
194 {
195     int w = bottom_blob.w;
196     int h = bottom_blob.h;
197     int inch = bottom_blob.c;
198     int outch = top_blob.c;
199 
200     size_t elemsize = bottom_blob.elemsize;
201     int elempack = bottom_blob.elempack;
202 
203     const int size = w * h;
204 
205     const float* bias = _bias;
206     // interleave
207     Mat tmp(12, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + (size % 12 % 4) / 2 + size % 12 % 2, elemsize, elempack, opt.workspace_allocator);
208     {
209         int nn_size = size / 12;
210         int remain_size_start = nn_size * 12;
211         #pragma omp parallel for num_threads(opt.num_threads)
212         for (int ii = 0; ii < nn_size; ii++)
213         {
214             int i = ii * 12;
215             const float* img0 = bottom_blob.channel(0);
216             img0 += i * 8;
217 
218             float* tmpptr = tmp.channel(i / 12);
219 
220             for (int q = 0; q < inch; q++)
221             {
222                 __m256 _r0 = _mm256_loadu_ps(img0);
223                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
224                 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
225                 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
226                 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
227                 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
228                 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
229                 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
230                 __m256 _r8 = _mm256_loadu_ps(img0 + 64);
231                 __m256 _r9 = _mm256_loadu_ps(img0 + 72);
232                 __m256 _r10 = _mm256_loadu_ps(img0 + 80);
233                 __m256 _r11 = _mm256_loadu_ps(img0 + 88);
234                 _mm256_storeu_ps(tmpptr, _r0);
235                 _mm256_storeu_ps(tmpptr + 8, _r1);
236                 _mm256_storeu_ps(tmpptr + 16, _r2);
237                 _mm256_storeu_ps(tmpptr + 24, _r3);
238                 _mm256_storeu_ps(tmpptr + 32, _r4);
239                 _mm256_storeu_ps(tmpptr + 40, _r5);
240                 _mm256_storeu_ps(tmpptr + 48, _r6);
241                 _mm256_storeu_ps(tmpptr + 56, _r7);
242                 _mm256_storeu_ps(tmpptr + 64, _r8);
243                 _mm256_storeu_ps(tmpptr + 72, _r9);
244                 _mm256_storeu_ps(tmpptr + 80, _r10);
245                 _mm256_storeu_ps(tmpptr + 88, _r11);
246 
247                 tmpptr += 96;
248                 img0 += bottom_blob.cstep * 8;
249             }
250         }
251         nn_size = (size - remain_size_start) >> 3;
252         #pragma omp parallel for num_threads(opt.num_threads)
253         for (int ii = 0; ii < nn_size; ii++)
254         {
255             int i = remain_size_start + ii * 8;
256 
257             const float* img0 = bottom_blob.channel(0);
258             img0 += i * 8;
259 
260             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
261 
262             for (int q = 0; q < inch; q++)
263             {
264                 __m256 _r0 = _mm256_loadu_ps(img0);
265                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
266                 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
267                 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
268                 __m256 _r4 = _mm256_loadu_ps(img0 + 32);
269                 __m256 _r5 = _mm256_loadu_ps(img0 + 40);
270                 __m256 _r6 = _mm256_loadu_ps(img0 + 48);
271                 __m256 _r7 = _mm256_loadu_ps(img0 + 56);
272                 _mm256_storeu_ps(tmpptr, _r0);
273                 _mm256_storeu_ps(tmpptr + 8, _r1);
274                 _mm256_storeu_ps(tmpptr + 16, _r2);
275                 _mm256_storeu_ps(tmpptr + 24, _r3);
276                 _mm256_storeu_ps(tmpptr + 32, _r4);
277                 _mm256_storeu_ps(tmpptr + 40, _r5);
278                 _mm256_storeu_ps(tmpptr + 48, _r6);
279                 _mm256_storeu_ps(tmpptr + 56, _r7);
280 
281                 tmpptr += 64;
282                 img0 += bottom_blob.cstep * 8;
283             }
284         }
285 
286         remain_size_start += nn_size << 3;
287         nn_size = (size - remain_size_start) >> 2;
288 
289         #pragma omp parallel for num_threads(opt.num_threads)
290         for (int ii = 0; ii < nn_size; ii++)
291         {
292             int i = remain_size_start + ii * 4;
293 
294             const float* img0 = bottom_blob.channel(0);
295             img0 += i * 8;
296             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
297 
298             for (int q = 0; q < inch; q++)
299             {
300                 __m256 _r0 = _mm256_loadu_ps(img0);
301                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
302                 __m256 _r2 = _mm256_loadu_ps(img0 + 16);
303                 __m256 _r3 = _mm256_loadu_ps(img0 + 24);
304                 _mm256_storeu_ps(tmpptr, _r0);
305                 _mm256_storeu_ps(tmpptr + 8, _r1);
306                 _mm256_storeu_ps(tmpptr + 16, _r2);
307                 _mm256_storeu_ps(tmpptr + 24, _r3);
308 
309                 tmpptr += 32;
310                 img0 += bottom_blob.cstep * 8;
311             }
312         }
313 
314         remain_size_start += nn_size << 2;
315         nn_size = (size - remain_size_start) >> 1;
316         #pragma omp parallel for num_threads(opt.num_threads)
317         for (int ii = 0; ii < nn_size; ii++)
318         {
319             int i = remain_size_start + ii * 2;
320 
321             const float* img0 = bottom_blob.channel(0);
322             img0 += i * 8;
323             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
324 
325             for (int q = 0; q < inch; q++)
326             {
327                 __m256 _r0 = _mm256_loadu_ps(img0);
328                 __m256 _r1 = _mm256_loadu_ps(img0 + 8);
329                 _mm256_storeu_ps(tmpptr, _r0);
330                 _mm256_storeu_ps(tmpptr + 8, _r1);
331 
332                 tmpptr += 16;
333                 img0 += bottom_blob.cstep * 8;
334             }
335         }
336 
337         remain_size_start += nn_size << 1;
338         #pragma omp parallel for num_threads(opt.num_threads)
339         for (int i = remain_size_start; i < size; i++)
340         {
341             const float* img0 = bottom_blob.channel(0);
342             img0 += i * 8;
343             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
344             for (int q = 0; q < inch; q++)
345             {
346                 __m256 _r0 = _mm256_loadu_ps(img0);
347                 _mm256_storeu_ps(tmpptr, _r0);
348 
349                 tmpptr += 8;
350                 img0 += bottom_blob.cstep * 8;
351             }
352         }
353     }
354     #pragma omp parallel for num_threads(opt.num_threads)
355     for (int p = 0; p < outch; p++)
356     {
357         Mat out = top_blob.channel(p);
358 
359         __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
360 
361         float* outptr = out;
362         int i = 0;
363         for (; i + 11 < size; i += 12)
364         {
365             const float* tmpptr = tmp.channel(i / 12);
366 
367             __m256 _sum0 = _bias0;
368             __m256 _sum1 = _bias0;
369             __m256 _sum2 = _bias0;
370             __m256 _sum3 = _bias0;
371             __m256 _sum4 = _bias0;
372             __m256 _sum5 = _bias0;
373             __m256 _sum6 = _bias0;
374             __m256 _sum7 = _bias0;
375             __m256 _sum8 = _bias0;
376             __m256 _sum9 = _bias0;
377             __m256 _sum10 = _bias0;
378             __m256 _sum11 = _bias0;
379 
380             const unsigned short* kptr = (const unsigned short*)kernel + p * inch * 64;
381             for (int q = 0; q < inch; q++)
382             {
383                 __m256 _w0 = loadfp16(kptr);
384                 __m256 _w1 = loadfp16(kptr + 8);
385                 __m256 _w2 = loadfp16(kptr + 16);
386                 __m256 _w3 = loadfp16(kptr + 24);
387                 __m256 _w4 = loadfp16(kptr + 32);
388                 __m256 _w5 = loadfp16(kptr + 40);
389                 __m256 _w6 = loadfp16(kptr + 48);
390                 __m256 _w7 = loadfp16(kptr + 56);
391 
392                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
393                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
394                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
395                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
396                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
397                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
398                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
399                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
400                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
401                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
402                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
403                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
404                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
405                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
406                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
407                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
408 
409                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
410                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
411                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
412                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
413                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
414                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
415                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
416                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
417                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
418                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
419                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
420                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
421                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
422                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
423                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
424                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
425 
426                 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
427                 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
428                 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
429                 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
430                 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
431                 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
432                 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
433                 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
434                 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
435                 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
436                 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
437                 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
438                 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
439                 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
440                 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
441                 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
442 
443                 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
444                 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
445                 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
446                 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
447                 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
448                 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
449                 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
450                 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
451                 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
452                 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
453                 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
454                 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
455                 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
456                 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
457                 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
458                 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
459 
460                 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
461                 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
462                 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
463                 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
464                 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
465                 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
466                 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
467                 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
468                 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
469                 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
470                 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
471                 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
472                 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
473                 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
474                 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
475                 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
476 
477                 _sum4 = _mm256_fmadd_ps(_w0, _val40, _sum4);
478                 _sum4 = _mm256_fmadd_ps(_w1, _val41, _sum4);
479                 _sum4 = _mm256_fmadd_ps(_w2, _val42, _sum4);
480                 _sum4 = _mm256_fmadd_ps(_w3, _val43, _sum4);
481                 _sum4 = _mm256_fmadd_ps(_w4, _val44, _sum4);
482                 _sum4 = _mm256_fmadd_ps(_w5, _val45, _sum4);
483                 _sum4 = _mm256_fmadd_ps(_w6, _val46, _sum4);
484                 _sum4 = _mm256_fmadd_ps(_w7, _val47, _sum4);
485                 _sum5 = _mm256_fmadd_ps(_w0, _val50, _sum5);
486                 _sum5 = _mm256_fmadd_ps(_w1, _val51, _sum5);
487                 _sum5 = _mm256_fmadd_ps(_w2, _val52, _sum5);
488                 _sum5 = _mm256_fmadd_ps(_w3, _val53, _sum5);
489                 _sum5 = _mm256_fmadd_ps(_w4, _val54, _sum5);
490                 _sum5 = _mm256_fmadd_ps(_w5, _val55, _sum5);
491                 _sum5 = _mm256_fmadd_ps(_w6, _val56, _sum5);
492                 _sum5 = _mm256_fmadd_ps(_w7, _val57, _sum5);
493 
494                 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
495                 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
496                 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
497                 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
498                 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
499                 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
500                 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
501                 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
502                 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
503                 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
504                 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
505                 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
506                 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
507                 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
508                 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
509                 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
510 
511                 _sum6 = _mm256_fmadd_ps(_w0, _val60, _sum6);
512                 _sum6 = _mm256_fmadd_ps(_w1, _val61, _sum6);
513                 _sum6 = _mm256_fmadd_ps(_w2, _val62, _sum6);
514                 _sum6 = _mm256_fmadd_ps(_w3, _val63, _sum6);
515                 _sum6 = _mm256_fmadd_ps(_w4, _val64, _sum6);
516                 _sum6 = _mm256_fmadd_ps(_w5, _val65, _sum6);
517                 _sum6 = _mm256_fmadd_ps(_w6, _val66, _sum6);
518                 _sum6 = _mm256_fmadd_ps(_w7, _val67, _sum6);
519                 _sum7 = _mm256_fmadd_ps(_w0, _val70, _sum7);
520                 _sum7 = _mm256_fmadd_ps(_w1, _val71, _sum7);
521                 _sum7 = _mm256_fmadd_ps(_w2, _val72, _sum7);
522                 _sum7 = _mm256_fmadd_ps(_w3, _val73, _sum7);
523                 _sum7 = _mm256_fmadd_ps(_w4, _val74, _sum7);
524                 _sum7 = _mm256_fmadd_ps(_w5, _val75, _sum7);
525                 _sum7 = _mm256_fmadd_ps(_w6, _val76, _sum7);
526                 _sum7 = _mm256_fmadd_ps(_w7, _val77, _sum7);
527 
528                 __m256 _val80 = _mm256_broadcast_ss(tmpptr + 64);
529                 __m256 _val81 = _mm256_broadcast_ss(tmpptr + 65);
530                 __m256 _val82 = _mm256_broadcast_ss(tmpptr + 66);
531                 __m256 _val83 = _mm256_broadcast_ss(tmpptr + 67);
532                 __m256 _val84 = _mm256_broadcast_ss(tmpptr + 68);
533                 __m256 _val85 = _mm256_broadcast_ss(tmpptr + 69);
534                 __m256 _val86 = _mm256_broadcast_ss(tmpptr + 70);
535                 __m256 _val87 = _mm256_broadcast_ss(tmpptr + 71);
536                 __m256 _val90 = _mm256_broadcast_ss(tmpptr + 72);
537                 __m256 _val91 = _mm256_broadcast_ss(tmpptr + 73);
538                 __m256 _val92 = _mm256_broadcast_ss(tmpptr + 74);
539                 __m256 _val93 = _mm256_broadcast_ss(tmpptr + 75);
540                 __m256 _val94 = _mm256_broadcast_ss(tmpptr + 76);
541                 __m256 _val95 = _mm256_broadcast_ss(tmpptr + 77);
542                 __m256 _val96 = _mm256_broadcast_ss(tmpptr + 78);
543                 __m256 _val97 = _mm256_broadcast_ss(tmpptr + 79);
544 
545                 _sum8 = _mm256_fmadd_ps(_w0, _val80, _sum8);
546                 _sum8 = _mm256_fmadd_ps(_w1, _val81, _sum8);
547                 _sum8 = _mm256_fmadd_ps(_w2, _val82, _sum8);
548                 _sum8 = _mm256_fmadd_ps(_w3, _val83, _sum8);
549                 _sum8 = _mm256_fmadd_ps(_w4, _val84, _sum8);
550                 _sum8 = _mm256_fmadd_ps(_w5, _val85, _sum8);
551                 _sum8 = _mm256_fmadd_ps(_w6, _val86, _sum8);
552                 _sum8 = _mm256_fmadd_ps(_w7, _val87, _sum8);
553                 _sum9 = _mm256_fmadd_ps(_w0, _val90, _sum9);
554                 _sum9 = _mm256_fmadd_ps(_w1, _val91, _sum9);
555                 _sum9 = _mm256_fmadd_ps(_w2, _val92, _sum9);
556                 _sum9 = _mm256_fmadd_ps(_w3, _val93, _sum9);
557                 _sum9 = _mm256_fmadd_ps(_w4, _val94, _sum9);
558                 _sum9 = _mm256_fmadd_ps(_w5, _val95, _sum9);
559                 _sum9 = _mm256_fmadd_ps(_w6, _val96, _sum9);
560                 _sum9 = _mm256_fmadd_ps(_w7, _val97, _sum9);
561 
562                 __m256 _val100 = _mm256_broadcast_ss(tmpptr + 80);
563                 __m256 _val101 = _mm256_broadcast_ss(tmpptr + 81);
564                 __m256 _val102 = _mm256_broadcast_ss(tmpptr + 82);
565                 __m256 _val103 = _mm256_broadcast_ss(tmpptr + 83);
566                 __m256 _val104 = _mm256_broadcast_ss(tmpptr + 84);
567                 __m256 _val105 = _mm256_broadcast_ss(tmpptr + 85);
568                 __m256 _val106 = _mm256_broadcast_ss(tmpptr + 86);
569                 __m256 _val107 = _mm256_broadcast_ss(tmpptr + 87);
570                 __m256 _val110 = _mm256_broadcast_ss(tmpptr + 88);
571                 __m256 _val111 = _mm256_broadcast_ss(tmpptr + 89);
572                 __m256 _val112 = _mm256_broadcast_ss(tmpptr + 90);
573                 __m256 _val113 = _mm256_broadcast_ss(tmpptr + 91);
574                 __m256 _val114 = _mm256_broadcast_ss(tmpptr + 92);
575                 __m256 _val115 = _mm256_broadcast_ss(tmpptr + 93);
576                 __m256 _val116 = _mm256_broadcast_ss(tmpptr + 94);
577                 __m256 _val117 = _mm256_broadcast_ss(tmpptr + 95);
578 
579                 _sum10 = _mm256_fmadd_ps(_w0, _val100, _sum10);
580                 _sum10 = _mm256_fmadd_ps(_w1, _val101, _sum10);
581                 _sum10 = _mm256_fmadd_ps(_w2, _val102, _sum10);
582                 _sum10 = _mm256_fmadd_ps(_w3, _val103, _sum10);
583                 _sum10 = _mm256_fmadd_ps(_w4, _val104, _sum10);
584                 _sum10 = _mm256_fmadd_ps(_w5, _val105, _sum10);
585                 _sum10 = _mm256_fmadd_ps(_w6, _val106, _sum10);
586                 _sum10 = _mm256_fmadd_ps(_w7, _val107, _sum10);
587                 _sum11 = _mm256_fmadd_ps(_w0, _val110, _sum11);
588                 _sum11 = _mm256_fmadd_ps(_w1, _val111, _sum11);
589                 _sum11 = _mm256_fmadd_ps(_w2, _val112, _sum11);
590                 _sum11 = _mm256_fmadd_ps(_w3, _val113, _sum11);
591                 _sum11 = _mm256_fmadd_ps(_w4, _val114, _sum11);
592                 _sum11 = _mm256_fmadd_ps(_w5, _val115, _sum11);
593                 _sum11 = _mm256_fmadd_ps(_w6, _val116, _sum11);
594                 _sum11 = _mm256_fmadd_ps(_w7, _val117, _sum11);
595 
596                 tmpptr += 96;
597 
598                 kptr += 64;
599             }
600             _mm256_storeu_ps(outptr, _sum0);
601             _mm256_storeu_ps(outptr + 8, _sum1);
602             _mm256_storeu_ps(outptr + 16, _sum2);
603             _mm256_storeu_ps(outptr + 24, _sum3);
604             _mm256_storeu_ps(outptr + 32, _sum4);
605             _mm256_storeu_ps(outptr + 40, _sum5);
606             _mm256_storeu_ps(outptr + 48, _sum6);
607             _mm256_storeu_ps(outptr + 56, _sum7);
608             _mm256_storeu_ps(outptr + 64, _sum8);
609             _mm256_storeu_ps(outptr + 72, _sum9);
610             _mm256_storeu_ps(outptr + 80, _sum10);
611             _mm256_storeu_ps(outptr + 88, _sum11);
612 
613             outptr += 96;
614         }
615         for (; i + 7 < size; i += 8)
616         {
617             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
618 
619             __m256 _sum0 = _bias0;
620             __m256 _sum1 = _bias0;
621             __m256 _sum2 = _bias0;
622             __m256 _sum3 = _bias0;
623             __m256 _sum4 = _bias0;
624             __m256 _sum5 = _bias0;
625             __m256 _sum6 = _bias0;
626             __m256 _sum7 = _bias0;
627 
628             const unsigned short* kptr = (const unsigned short*)kernel + p * inch * 64;
629             for (int q = 0; q < inch; q++)
630             {
631                 __m256 _w0 = loadfp16(kptr);
632                 __m256 _w1 = loadfp16(kptr + 8);
633                 __m256 _w2 = loadfp16(kptr + 16);
634                 __m256 _w3 = loadfp16(kptr + 24);
635                 __m256 _w4 = loadfp16(kptr + 32);
636                 __m256 _w5 = loadfp16(kptr + 40);
637                 __m256 _w6 = loadfp16(kptr + 48);
638                 __m256 _w7 = loadfp16(kptr + 56);
639 
640                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
641                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
642                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
643                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
644                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
645                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
646                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
647                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
648                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
649                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
650                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
651                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
652                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
653                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
654                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
655                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
656 
657                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
658                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
659                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
660                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
661                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
662                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
663                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
664                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
665                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
666                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
667                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
668                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
669                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
670                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
671                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
672                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
673 
674                 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
675                 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
676                 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
677                 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
678                 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
679                 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
680                 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
681                 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
682                 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
683                 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
684                 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
685                 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
686                 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
687                 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
688                 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
689                 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
690 
691                 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
692                 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
693                 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
694                 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
695                 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
696                 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
697                 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
698                 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
699                 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
700                 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
701                 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
702                 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
703                 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
704                 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
705                 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
706                 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
707 
708                 __m256 _val40 = _mm256_broadcast_ss(tmpptr + 32);
709                 __m256 _val41 = _mm256_broadcast_ss(tmpptr + 33);
710                 __m256 _val42 = _mm256_broadcast_ss(tmpptr + 34);
711                 __m256 _val43 = _mm256_broadcast_ss(tmpptr + 35);
712                 __m256 _val44 = _mm256_broadcast_ss(tmpptr + 36);
713                 __m256 _val45 = _mm256_broadcast_ss(tmpptr + 37);
714                 __m256 _val46 = _mm256_broadcast_ss(tmpptr + 38);
715                 __m256 _val47 = _mm256_broadcast_ss(tmpptr + 39);
716                 __m256 _val50 = _mm256_broadcast_ss(tmpptr + 40);
717                 __m256 _val51 = _mm256_broadcast_ss(tmpptr + 41);
718                 __m256 _val52 = _mm256_broadcast_ss(tmpptr + 42);
719                 __m256 _val53 = _mm256_broadcast_ss(tmpptr + 43);
720                 __m256 _val54 = _mm256_broadcast_ss(tmpptr + 44);
721                 __m256 _val55 = _mm256_broadcast_ss(tmpptr + 45);
722                 __m256 _val56 = _mm256_broadcast_ss(tmpptr + 46);
723                 __m256 _val57 = _mm256_broadcast_ss(tmpptr + 47);
724 
725                 _sum4 = _mm256_fmadd_ps(_w0, _val40, _sum4);
726                 _sum4 = _mm256_fmadd_ps(_w1, _val41, _sum4);
727                 _sum4 = _mm256_fmadd_ps(_w2, _val42, _sum4);
728                 _sum4 = _mm256_fmadd_ps(_w3, _val43, _sum4);
729                 _sum4 = _mm256_fmadd_ps(_w4, _val44, _sum4);
730                 _sum4 = _mm256_fmadd_ps(_w5, _val45, _sum4);
731                 _sum4 = _mm256_fmadd_ps(_w6, _val46, _sum4);
732                 _sum4 = _mm256_fmadd_ps(_w7, _val47, _sum4);
733                 _sum5 = _mm256_fmadd_ps(_w0, _val50, _sum5);
734                 _sum5 = _mm256_fmadd_ps(_w1, _val51, _sum5);
735                 _sum5 = _mm256_fmadd_ps(_w2, _val52, _sum5);
736                 _sum5 = _mm256_fmadd_ps(_w3, _val53, _sum5);
737                 _sum5 = _mm256_fmadd_ps(_w4, _val54, _sum5);
738                 _sum5 = _mm256_fmadd_ps(_w5, _val55, _sum5);
739                 _sum5 = _mm256_fmadd_ps(_w6, _val56, _sum5);
740                 _sum5 = _mm256_fmadd_ps(_w7, _val57, _sum5);
741 
742                 __m256 _val60 = _mm256_broadcast_ss(tmpptr + 48);
743                 __m256 _val61 = _mm256_broadcast_ss(tmpptr + 49);
744                 __m256 _val62 = _mm256_broadcast_ss(tmpptr + 50);
745                 __m256 _val63 = _mm256_broadcast_ss(tmpptr + 51);
746                 __m256 _val64 = _mm256_broadcast_ss(tmpptr + 52);
747                 __m256 _val65 = _mm256_broadcast_ss(tmpptr + 53);
748                 __m256 _val66 = _mm256_broadcast_ss(tmpptr + 54);
749                 __m256 _val67 = _mm256_broadcast_ss(tmpptr + 55);
750                 __m256 _val70 = _mm256_broadcast_ss(tmpptr + 56);
751                 __m256 _val71 = _mm256_broadcast_ss(tmpptr + 57);
752                 __m256 _val72 = _mm256_broadcast_ss(tmpptr + 58);
753                 __m256 _val73 = _mm256_broadcast_ss(tmpptr + 59);
754                 __m256 _val74 = _mm256_broadcast_ss(tmpptr + 60);
755                 __m256 _val75 = _mm256_broadcast_ss(tmpptr + 61);
756                 __m256 _val76 = _mm256_broadcast_ss(tmpptr + 62);
757                 __m256 _val77 = _mm256_broadcast_ss(tmpptr + 63);
758 
759                 _sum6 = _mm256_fmadd_ps(_w0, _val60, _sum6);
760                 _sum6 = _mm256_fmadd_ps(_w1, _val61, _sum6);
761                 _sum6 = _mm256_fmadd_ps(_w2, _val62, _sum6);
762                 _sum6 = _mm256_fmadd_ps(_w3, _val63, _sum6);
763                 _sum6 = _mm256_fmadd_ps(_w4, _val64, _sum6);
764                 _sum6 = _mm256_fmadd_ps(_w5, _val65, _sum6);
765                 _sum6 = _mm256_fmadd_ps(_w6, _val66, _sum6);
766                 _sum6 = _mm256_fmadd_ps(_w7, _val67, _sum6);
767                 _sum7 = _mm256_fmadd_ps(_w0, _val70, _sum7);
768                 _sum7 = _mm256_fmadd_ps(_w1, _val71, _sum7);
769                 _sum7 = _mm256_fmadd_ps(_w2, _val72, _sum7);
770                 _sum7 = _mm256_fmadd_ps(_w3, _val73, _sum7);
771                 _sum7 = _mm256_fmadd_ps(_w4, _val74, _sum7);
772                 _sum7 = _mm256_fmadd_ps(_w5, _val75, _sum7);
773                 _sum7 = _mm256_fmadd_ps(_w6, _val76, _sum7);
774                 _sum7 = _mm256_fmadd_ps(_w7, _val77, _sum7);
775 
776                 tmpptr += 64;
777 
778                 kptr += 64;
779             }
780             _mm256_storeu_ps(outptr, _sum0);
781             _mm256_storeu_ps(outptr + 8, _sum1);
782             _mm256_storeu_ps(outptr + 16, _sum2);
783             _mm256_storeu_ps(outptr + 24, _sum3);
784             _mm256_storeu_ps(outptr + 32, _sum4);
785             _mm256_storeu_ps(outptr + 40, _sum5);
786             _mm256_storeu_ps(outptr + 48, _sum6);
787             _mm256_storeu_ps(outptr + 56, _sum7);
788 
789             outptr += 64;
790         }
791         for (; i + 3 < size; i += 4)
792         {
793             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
794 
795             __m256 _sum0 = _bias0;
796             __m256 _sum1 = _bias0;
797             __m256 _sum2 = _bias0;
798             __m256 _sum3 = _bias0;
799 
800             const unsigned short* kptr = (const unsigned short*)kernel + p * inch * 64;
801             for (int q = 0; q < inch; q++)
802             {
803                 __m256 _w0 = loadfp16(kptr);
804                 __m256 _w1 = loadfp16(kptr + 8);
805                 __m256 _w2 = loadfp16(kptr + 16);
806                 __m256 _w3 = loadfp16(kptr + 24);
807                 __m256 _w4 = loadfp16(kptr + 32);
808                 __m256 _w5 = loadfp16(kptr + 40);
809                 __m256 _w6 = loadfp16(kptr + 48);
810                 __m256 _w7 = loadfp16(kptr + 56);
811 
812                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
813                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
814                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
815                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
816                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
817                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
818                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
819                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
820                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
821                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
822                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
823                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
824                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
825                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
826                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
827                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
828 
829                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
830                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
831                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
832                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
833                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
834                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
835                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
836                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
837                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
838                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
839                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
840                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
841                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
842                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
843                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
844                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
845 
846                 __m256 _val20 = _mm256_broadcast_ss(tmpptr + 16);
847                 __m256 _val21 = _mm256_broadcast_ss(tmpptr + 17);
848                 __m256 _val22 = _mm256_broadcast_ss(tmpptr + 18);
849                 __m256 _val23 = _mm256_broadcast_ss(tmpptr + 19);
850                 __m256 _val24 = _mm256_broadcast_ss(tmpptr + 20);
851                 __m256 _val25 = _mm256_broadcast_ss(tmpptr + 21);
852                 __m256 _val26 = _mm256_broadcast_ss(tmpptr + 22);
853                 __m256 _val27 = _mm256_broadcast_ss(tmpptr + 23);
854                 __m256 _val30 = _mm256_broadcast_ss(tmpptr + 24);
855                 __m256 _val31 = _mm256_broadcast_ss(tmpptr + 25);
856                 __m256 _val32 = _mm256_broadcast_ss(tmpptr + 26);
857                 __m256 _val33 = _mm256_broadcast_ss(tmpptr + 27);
858                 __m256 _val34 = _mm256_broadcast_ss(tmpptr + 28);
859                 __m256 _val35 = _mm256_broadcast_ss(tmpptr + 29);
860                 __m256 _val36 = _mm256_broadcast_ss(tmpptr + 30);
861                 __m256 _val37 = _mm256_broadcast_ss(tmpptr + 31);
862 
863                 _sum2 = _mm256_fmadd_ps(_w0, _val20, _sum2);
864                 _sum2 = _mm256_fmadd_ps(_w1, _val21, _sum2);
865                 _sum2 = _mm256_fmadd_ps(_w2, _val22, _sum2);
866                 _sum2 = _mm256_fmadd_ps(_w3, _val23, _sum2);
867                 _sum2 = _mm256_fmadd_ps(_w4, _val24, _sum2);
868                 _sum2 = _mm256_fmadd_ps(_w5, _val25, _sum2);
869                 _sum2 = _mm256_fmadd_ps(_w6, _val26, _sum2);
870                 _sum2 = _mm256_fmadd_ps(_w7, _val27, _sum2);
871                 _sum3 = _mm256_fmadd_ps(_w0, _val30, _sum3);
872                 _sum3 = _mm256_fmadd_ps(_w1, _val31, _sum3);
873                 _sum3 = _mm256_fmadd_ps(_w2, _val32, _sum3);
874                 _sum3 = _mm256_fmadd_ps(_w3, _val33, _sum3);
875                 _sum3 = _mm256_fmadd_ps(_w4, _val34, _sum3);
876                 _sum3 = _mm256_fmadd_ps(_w5, _val35, _sum3);
877                 _sum3 = _mm256_fmadd_ps(_w6, _val36, _sum3);
878                 _sum3 = _mm256_fmadd_ps(_w7, _val37, _sum3);
879 
880                 tmpptr += 32;
881 
882                 kptr += 64;
883             }
884             _mm256_storeu_ps(outptr, _sum0);
885             _mm256_storeu_ps(outptr + 8, _sum1);
886             _mm256_storeu_ps(outptr + 16, _sum2);
887             _mm256_storeu_ps(outptr + 24, _sum3);
888 
889             outptr += 32;
890         }
891         for (; i + 1 < size; i += 2)
892         {
893             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
894 
895             __m256 _sum0 = _bias0;
896             __m256 _sum1 = _bias0;
897 
898             const unsigned short* kptr = (const unsigned short*)kernel + p * inch * 64;
899             for (int q = 0; q < inch; q++)
900             {
901                 __m256 _val00 = _mm256_broadcast_ss(tmpptr);
902                 __m256 _val01 = _mm256_broadcast_ss(tmpptr + 1);
903                 __m256 _val02 = _mm256_broadcast_ss(tmpptr + 2);
904                 __m256 _val03 = _mm256_broadcast_ss(tmpptr + 3);
905                 __m256 _val04 = _mm256_broadcast_ss(tmpptr + 4);
906                 __m256 _val05 = _mm256_broadcast_ss(tmpptr + 5);
907                 __m256 _val06 = _mm256_broadcast_ss(tmpptr + 6);
908                 __m256 _val07 = _mm256_broadcast_ss(tmpptr + 7);
909                 __m256 _val10 = _mm256_broadcast_ss(tmpptr + 8);
910                 __m256 _val11 = _mm256_broadcast_ss(tmpptr + 9);
911                 __m256 _val12 = _mm256_broadcast_ss(tmpptr + 10);
912                 __m256 _val13 = _mm256_broadcast_ss(tmpptr + 11);
913                 __m256 _val14 = _mm256_broadcast_ss(tmpptr + 12);
914                 __m256 _val15 = _mm256_broadcast_ss(tmpptr + 13);
915                 __m256 _val16 = _mm256_broadcast_ss(tmpptr + 14);
916                 __m256 _val17 = _mm256_broadcast_ss(tmpptr + 15);
917 
918                 __m256 _w0 = loadfp16(kptr);
919                 __m256 _w1 = loadfp16(kptr + 8);
920                 __m256 _w2 = loadfp16(kptr + 16);
921                 __m256 _w3 = loadfp16(kptr + 24);
922                 __m256 _w4 = loadfp16(kptr + 32);
923                 __m256 _w5 = loadfp16(kptr + 40);
924                 __m256 _w6 = loadfp16(kptr + 48);
925                 __m256 _w7 = loadfp16(kptr + 56);
926 
927                 _sum0 = _mm256_fmadd_ps(_w0, _val00, _sum0);
928                 _sum0 = _mm256_fmadd_ps(_w1, _val01, _sum0);
929                 _sum0 = _mm256_fmadd_ps(_w2, _val02, _sum0);
930                 _sum0 = _mm256_fmadd_ps(_w3, _val03, _sum0);
931                 _sum0 = _mm256_fmadd_ps(_w4, _val04, _sum0);
932                 _sum0 = _mm256_fmadd_ps(_w5, _val05, _sum0);
933                 _sum0 = _mm256_fmadd_ps(_w6, _val06, _sum0);
934                 _sum0 = _mm256_fmadd_ps(_w7, _val07, _sum0);
935                 _sum1 = _mm256_fmadd_ps(_w0, _val10, _sum1);
936                 _sum1 = _mm256_fmadd_ps(_w1, _val11, _sum1);
937                 _sum1 = _mm256_fmadd_ps(_w2, _val12, _sum1);
938                 _sum1 = _mm256_fmadd_ps(_w3, _val13, _sum1);
939                 _sum1 = _mm256_fmadd_ps(_w4, _val14, _sum1);
940                 _sum1 = _mm256_fmadd_ps(_w5, _val15, _sum1);
941                 _sum1 = _mm256_fmadd_ps(_w6, _val16, _sum1);
942                 _sum1 = _mm256_fmadd_ps(_w7, _val17, _sum1);
943 
944                 tmpptr += 16;
945 
946                 kptr += 64;
947             }
948             _mm256_storeu_ps(outptr, _sum0);
949             _mm256_storeu_ps(outptr + 8, _sum1);
950 
951             outptr += 16;
952         }
953 
954         for (; i < size; i++)
955         {
956             float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
957             __m256 _sum = _bias0;
958 
959             const unsigned short* kptr = (const unsigned short*)kernel + p * inch * 64;
960             for (int q = 0; q < inch; q++)
961             {
962                 __m256 _val0 = _mm256_broadcast_ss(tmpptr);
963                 __m256 _val1 = _mm256_broadcast_ss(tmpptr + 1);
964                 __m256 _val2 = _mm256_broadcast_ss(tmpptr + 2);
965                 __m256 _val3 = _mm256_broadcast_ss(tmpptr + 3);
966                 __m256 _val4 = _mm256_broadcast_ss(tmpptr + 4);
967                 __m256 _val5 = _mm256_broadcast_ss(tmpptr + 5);
968                 __m256 _val6 = _mm256_broadcast_ss(tmpptr + 6);
969                 __m256 _val7 = _mm256_broadcast_ss(tmpptr + 7);
970 
971                 __m256 _w0 = loadfp16(kptr);
972                 __m256 _w1 = loadfp16(kptr + 8);
973                 __m256 _w2 = loadfp16(kptr + 16);
974                 __m256 _w3 = loadfp16(kptr + 24);
975                 __m256 _w4 = loadfp16(kptr + 32);
976                 __m256 _w5 = loadfp16(kptr + 40);
977                 __m256 _w6 = loadfp16(kptr + 48);
978                 __m256 _w7 = loadfp16(kptr + 56);
979 
980                 _sum = _mm256_fmadd_ps(_w0, _val0, _sum);
981                 _sum = _mm256_fmadd_ps(_w1, _val1, _sum);
982                 _sum = _mm256_fmadd_ps(_w2, _val2, _sum);
983                 _sum = _mm256_fmadd_ps(_w3, _val3, _sum);
984                 _sum = _mm256_fmadd_ps(_w4, _val4, _sum);
985                 _sum = _mm256_fmadd_ps(_w5, _val5, _sum);
986                 _sum = _mm256_fmadd_ps(_w6, _val6, _sum);
987                 _sum = _mm256_fmadd_ps(_w7, _val7, _sum);
988 
989                 tmpptr += 8;
990 
991                 kptr += 64;
992             }
993             _mm256_storeu_ps(outptr, _sum);
994 
995             outptr += 8;
996         }
997     }
998 }
999 
conv1x1s2_fp16_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)1000 static void conv1x1s2_fp16_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
1001 {
1002     int w = bottom_blob.w;
1003     int channels = bottom_blob.c;
1004     size_t elemsize = bottom_blob.elemsize;
1005     int elempack = bottom_blob.elempack;
1006 
1007     int outw = top_blob.w;
1008     int outh = top_blob.h;
1009 
1010     const int tailstep = (w - 2 * outw + w) * 8;
1011 
1012     Mat bottom_blob_shrinked;
1013     bottom_blob_shrinked.create(outw, outh, channels, elemsize, elempack, opt.workspace_allocator);
1014 
1015     #pragma omp parallel for num_threads(opt.num_threads)
1016     for (int p = 0; p < channels; p++)
1017     {
1018         const float* r0 = bottom_blob.channel(p);
1019         float* outptr = bottom_blob_shrinked.channel(p);
1020 
1021         for (int i = 0; i < outh; i++)
1022         {
1023             for (int j = 0; j < outw; j++)
1024             {
1025                 __m256 _v = _mm256_loadu_ps(r0);
1026                 _mm256_storeu_ps(outptr, _v);
1027 
1028                 r0 += 16;
1029                 outptr += 8;
1030             }
1031 
1032             r0 += tailstep;
1033         }
1034     }
1035 
1036     conv1x1s1_sgemm_fp16_pack8_avx(bottom_blob_shrinked, top_blob, kernel, _bias, opt);
1037 }
1038