1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
conv2x2s1_weight_fp16_pack8_avx(const Mat & kernel,Mat & kernel_tm_pack8,int num_input,int num_output)14 static void conv2x2s1_weight_fp16_pack8_avx(const Mat& kernel, Mat& kernel_tm_pack8, int num_input, int num_output)
15 {
16     // src = kw-kh-inch-outch
17     // dst = 8b-8a-kw-kh-inch/8a-outch/8b
18     Mat weight_data_r2 = kernel.reshape(4, num_input, num_output);
19 
20     kernel_tm_pack8.create(4, num_input / 8, num_output / 8, (size_t)2 * 64, 64);
21 
22     for (int q = 0; q + 7 < num_output; q += 8)
23     {
24         const Mat k0 = weight_data_r2.channel(q);
25         const Mat k1 = weight_data_r2.channel(q + 1);
26         const Mat k2 = weight_data_r2.channel(q + 2);
27         const Mat k3 = weight_data_r2.channel(q + 3);
28         const Mat k4 = weight_data_r2.channel(q + 4);
29         const Mat k5 = weight_data_r2.channel(q + 5);
30         const Mat k6 = weight_data_r2.channel(q + 6);
31         const Mat k7 = weight_data_r2.channel(q + 7);
32 
33         Mat g0 = kernel_tm_pack8.channel(q / 8);
34 
35         for (int p = 0; p + 7 < num_input; p += 8)
36         {
37             const float* k00 = k0.row(p);
38             const float* k01 = k0.row(p + 1);
39             const float* k02 = k0.row(p + 2);
40             const float* k03 = k0.row(p + 3);
41             const float* k04 = k0.row(p + 4);
42             const float* k05 = k0.row(p + 5);
43             const float* k06 = k0.row(p + 6);
44             const float* k07 = k0.row(p + 7);
45 
46             const float* k10 = k1.row(p);
47             const float* k11 = k1.row(p + 1);
48             const float* k12 = k1.row(p + 2);
49             const float* k13 = k1.row(p + 3);
50             const float* k14 = k1.row(p + 4);
51             const float* k15 = k1.row(p + 5);
52             const float* k16 = k1.row(p + 6);
53             const float* k17 = k1.row(p + 7);
54 
55             const float* k20 = k2.row(p);
56             const float* k21 = k2.row(p + 1);
57             const float* k22 = k2.row(p + 2);
58             const float* k23 = k2.row(p + 3);
59             const float* k24 = k2.row(p + 4);
60             const float* k25 = k2.row(p + 5);
61             const float* k26 = k2.row(p + 6);
62             const float* k27 = k2.row(p + 7);
63 
64             const float* k30 = k3.row(p);
65             const float* k31 = k3.row(p + 1);
66             const float* k32 = k3.row(p + 2);
67             const float* k33 = k3.row(p + 3);
68             const float* k34 = k3.row(p + 4);
69             const float* k35 = k3.row(p + 5);
70             const float* k36 = k3.row(p + 6);
71             const float* k37 = k3.row(p + 7);
72 
73             const float* k40 = k4.row(p);
74             const float* k41 = k4.row(p + 1);
75             const float* k42 = k4.row(p + 2);
76             const float* k43 = k4.row(p + 3);
77             const float* k44 = k4.row(p + 4);
78             const float* k45 = k4.row(p + 5);
79             const float* k46 = k4.row(p + 6);
80             const float* k47 = k4.row(p + 7);
81 
82             const float* k50 = k5.row(p);
83             const float* k51 = k5.row(p + 1);
84             const float* k52 = k5.row(p + 2);
85             const float* k53 = k5.row(p + 3);
86             const float* k54 = k5.row(p + 4);
87             const float* k55 = k5.row(p + 5);
88             const float* k56 = k5.row(p + 6);
89             const float* k57 = k5.row(p + 7);
90 
91             const float* k60 = k6.row(p);
92             const float* k61 = k6.row(p + 1);
93             const float* k62 = k6.row(p + 2);
94             const float* k63 = k6.row(p + 3);
95             const float* k64 = k6.row(p + 4);
96             const float* k65 = k6.row(p + 5);
97             const float* k66 = k6.row(p + 6);
98             const float* k67 = k6.row(p + 7);
99 
100             const float* k70 = k7.row(p);
101             const float* k71 = k7.row(p + 1);
102             const float* k72 = k7.row(p + 2);
103             const float* k73 = k7.row(p + 3);
104             const float* k74 = k7.row(p + 4);
105             const float* k75 = k7.row(p + 5);
106             const float* k76 = k7.row(p + 6);
107             const float* k77 = k7.row(p + 7);
108 
109             unsigned short* g00 = (unsigned short*)g0.row(p / 8);
110 
111             for (int k = 0; k < 4; k++)
112             {
113                 g00[0] = float32_to_float16(k00[k]);
114                 g00[1] = float32_to_float16(k10[k]);
115                 g00[2] = float32_to_float16(k20[k]);
116                 g00[3] = float32_to_float16(k30[k]);
117                 g00[4] = float32_to_float16(k40[k]);
118                 g00[5] = float32_to_float16(k50[k]);
119                 g00[6] = float32_to_float16(k60[k]);
120                 g00[7] = float32_to_float16(k70[k]);
121                 g00 += 8;
122                 g00[0] = float32_to_float16(k01[k]);
123                 g00[1] = float32_to_float16(k11[k]);
124                 g00[2] = float32_to_float16(k21[k]);
125                 g00[3] = float32_to_float16(k31[k]);
126                 g00[4] = float32_to_float16(k41[k]);
127                 g00[5] = float32_to_float16(k51[k]);
128                 g00[6] = float32_to_float16(k61[k]);
129                 g00[7] = float32_to_float16(k71[k]);
130 
131                 g00 += 8;
132                 g00[0] = float32_to_float16(k02[k]);
133                 g00[1] = float32_to_float16(k12[k]);
134                 g00[2] = float32_to_float16(k22[k]);
135                 g00[3] = float32_to_float16(k32[k]);
136                 g00[4] = float32_to_float16(k42[k]);
137                 g00[5] = float32_to_float16(k52[k]);
138                 g00[6] = float32_to_float16(k62[k]);
139                 g00[7] = float32_to_float16(k72[k]);
140 
141                 g00 += 8;
142                 g00[0] = float32_to_float16(k03[k]);
143                 g00[1] = float32_to_float16(k13[k]);
144                 g00[2] = float32_to_float16(k23[k]);
145                 g00[3] = float32_to_float16(k33[k]);
146                 g00[4] = float32_to_float16(k43[k]);
147                 g00[5] = float32_to_float16(k53[k]);
148                 g00[6] = float32_to_float16(k63[k]);
149                 g00[7] = float32_to_float16(k73[k]);
150 
151                 g00 += 8;
152                 g00[0] = float32_to_float16(k04[k]);
153                 g00[1] = float32_to_float16(k14[k]);
154                 g00[2] = float32_to_float16(k24[k]);
155                 g00[3] = float32_to_float16(k34[k]);
156                 g00[4] = float32_to_float16(k44[k]);
157                 g00[5] = float32_to_float16(k54[k]);
158                 g00[6] = float32_to_float16(k64[k]);
159                 g00[7] = float32_to_float16(k74[k]);
160 
161                 g00 += 8;
162                 g00[0] = float32_to_float16(k05[k]);
163                 g00[1] = float32_to_float16(k15[k]);
164                 g00[2] = float32_to_float16(k25[k]);
165                 g00[3] = float32_to_float16(k35[k]);
166                 g00[4] = float32_to_float16(k45[k]);
167                 g00[5] = float32_to_float16(k55[k]);
168                 g00[6] = float32_to_float16(k65[k]);
169                 g00[7] = float32_to_float16(k75[k]);
170 
171                 g00 += 8;
172                 g00[0] = float32_to_float16(k06[k]);
173                 g00[1] = float32_to_float16(k16[k]);
174                 g00[2] = float32_to_float16(k26[k]);
175                 g00[3] = float32_to_float16(k36[k]);
176                 g00[4] = float32_to_float16(k46[k]);
177                 g00[5] = float32_to_float16(k56[k]);
178                 g00[6] = float32_to_float16(k66[k]);
179                 g00[7] = float32_to_float16(k76[k]);
180 
181                 g00 += 8;
182                 g00[0] = float32_to_float16(k07[k]);
183                 g00[1] = float32_to_float16(k17[k]);
184                 g00[2] = float32_to_float16(k27[k]);
185                 g00[3] = float32_to_float16(k37[k]);
186                 g00[4] = float32_to_float16(k47[k]);
187                 g00[5] = float32_to_float16(k57[k]);
188                 g00[6] = float32_to_float16(k67[k]);
189                 g00[7] = float32_to_float16(k77[k]);
190 
191                 g00 += 8;
192             }
193         }
194     }
195 }
conv2x2s1_fp16_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)196 static void conv2x2s1_fp16_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
197 {
198     int inch = bottom_blob.c;
199     int outw = top_blob.w;
200     int outh = top_blob.h;
201     int outch = top_blob.c;
202     const float* bias = _bias;
203 
204     #pragma omp parallel for num_threads(opt.num_threads)
205     for (int p = 0; p < outch; p++)
206     {
207         Mat out0 = top_blob.channel(p);
208 
209         __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
210         out0.fill(_bias0);
211 
212         for (int q = 0; q < inch; q++)
213         {
214             float* outptr0 = out0.row(0);
215 
216             const Mat img0 = bottom_blob.channel(q);
217 
218             const float* r0 = img0.row(0);
219             const float* r1 = img0.row(1);
220 
221             const unsigned short* kptr = (const unsigned short*)kernel.channel(p).row(q);
222             // const float* kptr = (const float*)kernel + 4 * inch * p * 64;
223 
224             int i = 0;
225             for (; i < outh; i++)
226             {
227                 int j = 0;
228 
229                 for (; j + 1 < outw; j += 2)
230                 {
231                     __m256 _sum0 = _mm256_loadu_ps(outptr0);
232                     __m256 _sum1 = _mm256_loadu_ps(outptr0 + 8);
233 
234                     __m256 _r00 = _mm256_broadcast_ss(r0);
235                     __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
236                     __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
237                     __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
238                     __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
239                     __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
240                     __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
241                     __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
242                     r0 += 8;
243 
244                     __m256 _k00 = loadfp16(kptr);
245                     __m256 _k01 = loadfp16(kptr + 8);
246                     __m256 _k02 = loadfp16(kptr + 16);
247                     __m256 _k03 = loadfp16(kptr + 24);
248                     kptr += 32;
249 
250                     _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
251                     _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
252                     _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
253                     _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
254 
255                     __m256 _k04 = loadfp16(kptr);
256                     __m256 _k05 = loadfp16(kptr + 8);
257                     __m256 _k06 = loadfp16(kptr + 16);
258                     __m256 _k07 = loadfp16(kptr + 24);
259                     kptr += 32;
260 
261                     _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
262                     _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
263                     _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
264                     _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
265 
266                     //========================================
267 
268                     _r00 = _mm256_broadcast_ss(r0);
269                     _r01 = _mm256_broadcast_ss(r0 + 1);
270                     _r02 = _mm256_broadcast_ss(r0 + 2);
271                     _r03 = _mm256_broadcast_ss(r0 + 3);
272                     _r04 = _mm256_broadcast_ss(r0 + 4);
273                     _r05 = _mm256_broadcast_ss(r0 + 5);
274                     _r06 = _mm256_broadcast_ss(r0 + 6);
275                     _r07 = _mm256_broadcast_ss(r0 + 7);
276                     r0 += 8;
277 
278                     _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
279                     _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
280                     _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
281                     _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
282                     _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
283                     _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
284                     _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
285                     _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
286 
287                     _k00 = loadfp16(kptr);
288                     _k01 = loadfp16(kptr + 8);
289                     _k02 = loadfp16(kptr + 16);
290                     _k03 = loadfp16(kptr + 24);
291                     kptr += 32;
292 
293                     _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
294                     _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
295                     _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
296                     _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
297 
298                     _k04 = loadfp16(kptr);
299                     _k05 = loadfp16(kptr + 8);
300                     _k06 = loadfp16(kptr + 16);
301                     _k07 = loadfp16(kptr + 24);
302                     kptr += 32;
303 
304                     _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
305                     _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
306                     _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
307                     _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
308 
309                     _r00 = _mm256_broadcast_ss(r0);
310                     _r01 = _mm256_broadcast_ss(r0 + 1);
311                     _r02 = _mm256_broadcast_ss(r0 + 2);
312                     _r03 = _mm256_broadcast_ss(r0 + 3);
313                     _r04 = _mm256_broadcast_ss(r0 + 4);
314                     _r05 = _mm256_broadcast_ss(r0 + 5);
315                     _r06 = _mm256_broadcast_ss(r0 + 6);
316                     _r07 = _mm256_broadcast_ss(r0 + 7);
317 
318                     _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
319                     _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
320                     _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
321                     _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
322                     _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
323                     _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
324                     _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
325                     _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
326                     //===============
327 
328                     __m256 _r10 = _mm256_broadcast_ss(r1);
329                     __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
330                     __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
331                     __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
332                     __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
333                     __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
334                     __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
335                     __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
336 
337                     __m256 _k10 = loadfp16(kptr);
338                     __m256 _k11 = loadfp16(kptr + 8);
339                     __m256 _k12 = loadfp16(kptr + 16);
340                     __m256 _k13 = loadfp16(kptr + 24);
341                     kptr += 32;
342 
343                     _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
344                     _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
345                     _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
346                     _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
347 
348                     __m256 _k14 = loadfp16(kptr);
349                     __m256 _k15 = loadfp16(kptr + 8);
350                     __m256 _k16 = loadfp16(kptr + 16);
351                     __m256 _k17 = loadfp16(kptr + 24);
352                     kptr += 32;
353 
354                     _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
355                     _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
356                     _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
357                     _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
358 
359                     //=======================================
360                     r1 += 8;
361                     _r10 = _mm256_broadcast_ss(r1);
362                     _r11 = _mm256_broadcast_ss(r1 + 1);
363                     _r12 = _mm256_broadcast_ss(r1 + 2);
364                     _r13 = _mm256_broadcast_ss(r1 + 3);
365                     _r14 = _mm256_broadcast_ss(r1 + 4);
366                     _r15 = _mm256_broadcast_ss(r1 + 5);
367                     _r16 = _mm256_broadcast_ss(r1 + 6);
368                     _r17 = _mm256_broadcast_ss(r1 + 7);
369 
370                     _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
371                     _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
372                     _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
373                     _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
374                     _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
375                     _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
376                     _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
377                     _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
378 
379                     _k10 = loadfp16(kptr);
380                     _k11 = loadfp16(kptr + 8);
381                     _k12 = loadfp16(kptr + 16);
382                     _k13 = loadfp16(kptr + 24);
383                     kptr += 32;
384 
385                     _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
386                     _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
387                     _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
388                     _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
389 
390                     _k14 = loadfp16(kptr);
391                     _k15 = loadfp16(kptr + 8);
392                     _k16 = loadfp16(kptr + 16);
393                     _k17 = loadfp16(kptr + 24);
394                     _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
395                     _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
396                     _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
397                     _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
398 
399                     r1 += 8;
400                     _r10 = _mm256_broadcast_ss(r1);
401                     _r11 = _mm256_broadcast_ss(r1 + 1);
402                     _r12 = _mm256_broadcast_ss(r1 + 2);
403                     _r13 = _mm256_broadcast_ss(r1 + 3);
404                     _r14 = _mm256_broadcast_ss(r1 + 4);
405                     _r15 = _mm256_broadcast_ss(r1 + 5);
406                     _r16 = _mm256_broadcast_ss(r1 + 6);
407                     _r17 = _mm256_broadcast_ss(r1 + 7);
408 
409                     _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
410                     _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
411                     _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
412                     _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
413                     _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
414                     _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
415                     _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
416                     _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
417 
418                     kptr -= 224;
419                     _mm256_storeu_ps(outptr0, _sum0);
420                     _mm256_storeu_ps(outptr0 + 8, _sum1);
421                     outptr0 += 16;
422                 }
423 
424                 for (; j < outw; j++)
425                 {
426                     __m256 _sum = _mm256_loadu_ps(outptr0);
427 
428                     __m256 _r00 = _mm256_broadcast_ss(r0);
429                     __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
430                     __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
431                     __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
432                     __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
433                     __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
434                     __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
435                     __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
436 
437                     __m256 _k00 = loadfp16(kptr);
438                     __m256 _k01 = loadfp16(kptr + 8);
439                     __m256 _k02 = loadfp16(kptr + 16);
440                     __m256 _k03 = loadfp16(kptr + 24);
441                     kptr += 32;
442 
443                     _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
444                     _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
445                     _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
446                     _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
447 
448                     __m256 _k04 = loadfp16(kptr);
449                     __m256 _k05 = loadfp16(kptr + 8);
450                     __m256 _k06 = loadfp16(kptr + 16);
451                     __m256 _k07 = loadfp16(kptr + 24);
452                     kptr += 32;
453 
454                     _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
455                     _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
456                     _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
457                     _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
458 
459                     //========================================
460                     r0 += 8;
461                     _r00 = _mm256_broadcast_ss(r0);
462                     _r01 = _mm256_broadcast_ss(r0 + 1);
463                     _r02 = _mm256_broadcast_ss(r0 + 2);
464                     _r03 = _mm256_broadcast_ss(r0 + 3);
465                     _r04 = _mm256_broadcast_ss(r0 + 4);
466                     _r05 = _mm256_broadcast_ss(r0 + 5);
467                     _r06 = _mm256_broadcast_ss(r0 + 6);
468                     _r07 = _mm256_broadcast_ss(r0 + 7);
469 
470                     _k00 = loadfp16(kptr);
471                     _k01 = loadfp16(kptr + 8);
472                     _k02 = loadfp16(kptr + 16);
473                     _k03 = loadfp16(kptr + 24);
474                     kptr += 32;
475 
476                     _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
477                     _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
478                     _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
479                     _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
480 
481                     _k04 = loadfp16(kptr);
482                     _k05 = loadfp16(kptr + 8);
483                     _k06 = loadfp16(kptr + 16);
484                     _k07 = loadfp16(kptr + 24);
485                     kptr += 32;
486 
487                     _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
488                     _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
489                     _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
490                     _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
491                     //===============
492 
493                     __m256 _r10 = _mm256_broadcast_ss(r1);
494                     __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
495                     __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
496                     __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
497                     __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
498                     __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
499                     __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
500                     __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
501 
502                     __m256 _k10 = loadfp16(kptr);
503                     __m256 _k11 = loadfp16(kptr + 8);
504                     __m256 _k12 = loadfp16(kptr + 16);
505                     __m256 _k13 = loadfp16(kptr + 24);
506                     kptr += 32;
507 
508                     _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
509                     _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
510                     _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
511                     _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
512 
513                     __m256 _k14 = loadfp16(kptr);
514                     __m256 _k15 = loadfp16(kptr + 8);
515                     __m256 _k16 = loadfp16(kptr + 16);
516                     __m256 _k17 = loadfp16(kptr + 24);
517                     kptr += 32;
518 
519                     _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
520                     _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
521                     _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
522                     _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
523 
524                     //=======================================
525                     r1 += 8;
526                     _r10 = _mm256_broadcast_ss(r1);
527                     _r11 = _mm256_broadcast_ss(r1 + 1);
528                     _r12 = _mm256_broadcast_ss(r1 + 2);
529                     _r13 = _mm256_broadcast_ss(r1 + 3);
530                     _r14 = _mm256_broadcast_ss(r1 + 4);
531                     _r15 = _mm256_broadcast_ss(r1 + 5);
532                     _r16 = _mm256_broadcast_ss(r1 + 6);
533                     _r17 = _mm256_broadcast_ss(r1 + 7);
534 
535                     _k10 = loadfp16(kptr);
536                     _k11 = loadfp16(kptr + 8);
537                     _k12 = loadfp16(kptr + 16);
538                     _k13 = loadfp16(kptr + 24);
539                     kptr += 32;
540 
541                     _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
542                     _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
543                     _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
544                     _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
545 
546                     _k14 = loadfp16(kptr);
547                     _k15 = loadfp16(kptr + 8);
548                     _k16 = loadfp16(kptr + 16);
549                     _k17 = loadfp16(kptr + 24);
550                     _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
551                     _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
552                     _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
553                     _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
554 
555                     kptr -= 224;
556                     _mm256_storeu_ps(outptr0, _sum);
557                     outptr0 += 8;
558                 }
559 
560                 r0 += 8;
561                 r1 += 8;
562             }
563         }
564     }
565 }
566