1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
conv2x2s1_weight_fp16_pack8_avx(const Mat & kernel,Mat & kernel_tm_pack8,int num_input,int num_output)14 static void conv2x2s1_weight_fp16_pack8_avx(const Mat& kernel, Mat& kernel_tm_pack8, int num_input, int num_output)
15 {
16 // src = kw-kh-inch-outch
17 // dst = 8b-8a-kw-kh-inch/8a-outch/8b
18 Mat weight_data_r2 = kernel.reshape(4, num_input, num_output);
19
20 kernel_tm_pack8.create(4, num_input / 8, num_output / 8, (size_t)2 * 64, 64);
21
22 for (int q = 0; q + 7 < num_output; q += 8)
23 {
24 const Mat k0 = weight_data_r2.channel(q);
25 const Mat k1 = weight_data_r2.channel(q + 1);
26 const Mat k2 = weight_data_r2.channel(q + 2);
27 const Mat k3 = weight_data_r2.channel(q + 3);
28 const Mat k4 = weight_data_r2.channel(q + 4);
29 const Mat k5 = weight_data_r2.channel(q + 5);
30 const Mat k6 = weight_data_r2.channel(q + 6);
31 const Mat k7 = weight_data_r2.channel(q + 7);
32
33 Mat g0 = kernel_tm_pack8.channel(q / 8);
34
35 for (int p = 0; p + 7 < num_input; p += 8)
36 {
37 const float* k00 = k0.row(p);
38 const float* k01 = k0.row(p + 1);
39 const float* k02 = k0.row(p + 2);
40 const float* k03 = k0.row(p + 3);
41 const float* k04 = k0.row(p + 4);
42 const float* k05 = k0.row(p + 5);
43 const float* k06 = k0.row(p + 6);
44 const float* k07 = k0.row(p + 7);
45
46 const float* k10 = k1.row(p);
47 const float* k11 = k1.row(p + 1);
48 const float* k12 = k1.row(p + 2);
49 const float* k13 = k1.row(p + 3);
50 const float* k14 = k1.row(p + 4);
51 const float* k15 = k1.row(p + 5);
52 const float* k16 = k1.row(p + 6);
53 const float* k17 = k1.row(p + 7);
54
55 const float* k20 = k2.row(p);
56 const float* k21 = k2.row(p + 1);
57 const float* k22 = k2.row(p + 2);
58 const float* k23 = k2.row(p + 3);
59 const float* k24 = k2.row(p + 4);
60 const float* k25 = k2.row(p + 5);
61 const float* k26 = k2.row(p + 6);
62 const float* k27 = k2.row(p + 7);
63
64 const float* k30 = k3.row(p);
65 const float* k31 = k3.row(p + 1);
66 const float* k32 = k3.row(p + 2);
67 const float* k33 = k3.row(p + 3);
68 const float* k34 = k3.row(p + 4);
69 const float* k35 = k3.row(p + 5);
70 const float* k36 = k3.row(p + 6);
71 const float* k37 = k3.row(p + 7);
72
73 const float* k40 = k4.row(p);
74 const float* k41 = k4.row(p + 1);
75 const float* k42 = k4.row(p + 2);
76 const float* k43 = k4.row(p + 3);
77 const float* k44 = k4.row(p + 4);
78 const float* k45 = k4.row(p + 5);
79 const float* k46 = k4.row(p + 6);
80 const float* k47 = k4.row(p + 7);
81
82 const float* k50 = k5.row(p);
83 const float* k51 = k5.row(p + 1);
84 const float* k52 = k5.row(p + 2);
85 const float* k53 = k5.row(p + 3);
86 const float* k54 = k5.row(p + 4);
87 const float* k55 = k5.row(p + 5);
88 const float* k56 = k5.row(p + 6);
89 const float* k57 = k5.row(p + 7);
90
91 const float* k60 = k6.row(p);
92 const float* k61 = k6.row(p + 1);
93 const float* k62 = k6.row(p + 2);
94 const float* k63 = k6.row(p + 3);
95 const float* k64 = k6.row(p + 4);
96 const float* k65 = k6.row(p + 5);
97 const float* k66 = k6.row(p + 6);
98 const float* k67 = k6.row(p + 7);
99
100 const float* k70 = k7.row(p);
101 const float* k71 = k7.row(p + 1);
102 const float* k72 = k7.row(p + 2);
103 const float* k73 = k7.row(p + 3);
104 const float* k74 = k7.row(p + 4);
105 const float* k75 = k7.row(p + 5);
106 const float* k76 = k7.row(p + 6);
107 const float* k77 = k7.row(p + 7);
108
109 unsigned short* g00 = (unsigned short*)g0.row(p / 8);
110
111 for (int k = 0; k < 4; k++)
112 {
113 g00[0] = float32_to_float16(k00[k]);
114 g00[1] = float32_to_float16(k10[k]);
115 g00[2] = float32_to_float16(k20[k]);
116 g00[3] = float32_to_float16(k30[k]);
117 g00[4] = float32_to_float16(k40[k]);
118 g00[5] = float32_to_float16(k50[k]);
119 g00[6] = float32_to_float16(k60[k]);
120 g00[7] = float32_to_float16(k70[k]);
121 g00 += 8;
122 g00[0] = float32_to_float16(k01[k]);
123 g00[1] = float32_to_float16(k11[k]);
124 g00[2] = float32_to_float16(k21[k]);
125 g00[3] = float32_to_float16(k31[k]);
126 g00[4] = float32_to_float16(k41[k]);
127 g00[5] = float32_to_float16(k51[k]);
128 g00[6] = float32_to_float16(k61[k]);
129 g00[7] = float32_to_float16(k71[k]);
130
131 g00 += 8;
132 g00[0] = float32_to_float16(k02[k]);
133 g00[1] = float32_to_float16(k12[k]);
134 g00[2] = float32_to_float16(k22[k]);
135 g00[3] = float32_to_float16(k32[k]);
136 g00[4] = float32_to_float16(k42[k]);
137 g00[5] = float32_to_float16(k52[k]);
138 g00[6] = float32_to_float16(k62[k]);
139 g00[7] = float32_to_float16(k72[k]);
140
141 g00 += 8;
142 g00[0] = float32_to_float16(k03[k]);
143 g00[1] = float32_to_float16(k13[k]);
144 g00[2] = float32_to_float16(k23[k]);
145 g00[3] = float32_to_float16(k33[k]);
146 g00[4] = float32_to_float16(k43[k]);
147 g00[5] = float32_to_float16(k53[k]);
148 g00[6] = float32_to_float16(k63[k]);
149 g00[7] = float32_to_float16(k73[k]);
150
151 g00 += 8;
152 g00[0] = float32_to_float16(k04[k]);
153 g00[1] = float32_to_float16(k14[k]);
154 g00[2] = float32_to_float16(k24[k]);
155 g00[3] = float32_to_float16(k34[k]);
156 g00[4] = float32_to_float16(k44[k]);
157 g00[5] = float32_to_float16(k54[k]);
158 g00[6] = float32_to_float16(k64[k]);
159 g00[7] = float32_to_float16(k74[k]);
160
161 g00 += 8;
162 g00[0] = float32_to_float16(k05[k]);
163 g00[1] = float32_to_float16(k15[k]);
164 g00[2] = float32_to_float16(k25[k]);
165 g00[3] = float32_to_float16(k35[k]);
166 g00[4] = float32_to_float16(k45[k]);
167 g00[5] = float32_to_float16(k55[k]);
168 g00[6] = float32_to_float16(k65[k]);
169 g00[7] = float32_to_float16(k75[k]);
170
171 g00 += 8;
172 g00[0] = float32_to_float16(k06[k]);
173 g00[1] = float32_to_float16(k16[k]);
174 g00[2] = float32_to_float16(k26[k]);
175 g00[3] = float32_to_float16(k36[k]);
176 g00[4] = float32_to_float16(k46[k]);
177 g00[5] = float32_to_float16(k56[k]);
178 g00[6] = float32_to_float16(k66[k]);
179 g00[7] = float32_to_float16(k76[k]);
180
181 g00 += 8;
182 g00[0] = float32_to_float16(k07[k]);
183 g00[1] = float32_to_float16(k17[k]);
184 g00[2] = float32_to_float16(k27[k]);
185 g00[3] = float32_to_float16(k37[k]);
186 g00[4] = float32_to_float16(k47[k]);
187 g00[5] = float32_to_float16(k57[k]);
188 g00[6] = float32_to_float16(k67[k]);
189 g00[7] = float32_to_float16(k77[k]);
190
191 g00 += 8;
192 }
193 }
194 }
195 }
conv2x2s1_fp16_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)196 static void conv2x2s1_fp16_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
197 {
198 int inch = bottom_blob.c;
199 int outw = top_blob.w;
200 int outh = top_blob.h;
201 int outch = top_blob.c;
202 const float* bias = _bias;
203
204 #pragma omp parallel for num_threads(opt.num_threads)
205 for (int p = 0; p < outch; p++)
206 {
207 Mat out0 = top_blob.channel(p);
208
209 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
210 out0.fill(_bias0);
211
212 for (int q = 0; q < inch; q++)
213 {
214 float* outptr0 = out0.row(0);
215
216 const Mat img0 = bottom_blob.channel(q);
217
218 const float* r0 = img0.row(0);
219 const float* r1 = img0.row(1);
220
221 const unsigned short* kptr = (const unsigned short*)kernel.channel(p).row(q);
222 // const float* kptr = (const float*)kernel + 4 * inch * p * 64;
223
224 int i = 0;
225 for (; i < outh; i++)
226 {
227 int j = 0;
228
229 for (; j + 1 < outw; j += 2)
230 {
231 __m256 _sum0 = _mm256_loadu_ps(outptr0);
232 __m256 _sum1 = _mm256_loadu_ps(outptr0 + 8);
233
234 __m256 _r00 = _mm256_broadcast_ss(r0);
235 __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
236 __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
237 __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
238 __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
239 __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
240 __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
241 __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
242 r0 += 8;
243
244 __m256 _k00 = loadfp16(kptr);
245 __m256 _k01 = loadfp16(kptr + 8);
246 __m256 _k02 = loadfp16(kptr + 16);
247 __m256 _k03 = loadfp16(kptr + 24);
248 kptr += 32;
249
250 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
251 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
252 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
253 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
254
255 __m256 _k04 = loadfp16(kptr);
256 __m256 _k05 = loadfp16(kptr + 8);
257 __m256 _k06 = loadfp16(kptr + 16);
258 __m256 _k07 = loadfp16(kptr + 24);
259 kptr += 32;
260
261 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
262 _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
263 _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
264 _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
265
266 //========================================
267
268 _r00 = _mm256_broadcast_ss(r0);
269 _r01 = _mm256_broadcast_ss(r0 + 1);
270 _r02 = _mm256_broadcast_ss(r0 + 2);
271 _r03 = _mm256_broadcast_ss(r0 + 3);
272 _r04 = _mm256_broadcast_ss(r0 + 4);
273 _r05 = _mm256_broadcast_ss(r0 + 5);
274 _r06 = _mm256_broadcast_ss(r0 + 6);
275 _r07 = _mm256_broadcast_ss(r0 + 7);
276 r0 += 8;
277
278 _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
279 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
280 _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
281 _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
282 _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
283 _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
284 _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
285 _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
286
287 _k00 = loadfp16(kptr);
288 _k01 = loadfp16(kptr + 8);
289 _k02 = loadfp16(kptr + 16);
290 _k03 = loadfp16(kptr + 24);
291 kptr += 32;
292
293 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
294 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
295 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
296 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
297
298 _k04 = loadfp16(kptr);
299 _k05 = loadfp16(kptr + 8);
300 _k06 = loadfp16(kptr + 16);
301 _k07 = loadfp16(kptr + 24);
302 kptr += 32;
303
304 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
305 _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
306 _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
307 _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
308
309 _r00 = _mm256_broadcast_ss(r0);
310 _r01 = _mm256_broadcast_ss(r0 + 1);
311 _r02 = _mm256_broadcast_ss(r0 + 2);
312 _r03 = _mm256_broadcast_ss(r0 + 3);
313 _r04 = _mm256_broadcast_ss(r0 + 4);
314 _r05 = _mm256_broadcast_ss(r0 + 5);
315 _r06 = _mm256_broadcast_ss(r0 + 6);
316 _r07 = _mm256_broadcast_ss(r0 + 7);
317
318 _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
319 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
320 _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
321 _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
322 _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
323 _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
324 _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
325 _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
326 //===============
327
328 __m256 _r10 = _mm256_broadcast_ss(r1);
329 __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
330 __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
331 __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
332 __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
333 __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
334 __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
335 __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
336
337 __m256 _k10 = loadfp16(kptr);
338 __m256 _k11 = loadfp16(kptr + 8);
339 __m256 _k12 = loadfp16(kptr + 16);
340 __m256 _k13 = loadfp16(kptr + 24);
341 kptr += 32;
342
343 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
344 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
345 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
346 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
347
348 __m256 _k14 = loadfp16(kptr);
349 __m256 _k15 = loadfp16(kptr + 8);
350 __m256 _k16 = loadfp16(kptr + 16);
351 __m256 _k17 = loadfp16(kptr + 24);
352 kptr += 32;
353
354 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
355 _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
356 _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
357 _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
358
359 //=======================================
360 r1 += 8;
361 _r10 = _mm256_broadcast_ss(r1);
362 _r11 = _mm256_broadcast_ss(r1 + 1);
363 _r12 = _mm256_broadcast_ss(r1 + 2);
364 _r13 = _mm256_broadcast_ss(r1 + 3);
365 _r14 = _mm256_broadcast_ss(r1 + 4);
366 _r15 = _mm256_broadcast_ss(r1 + 5);
367 _r16 = _mm256_broadcast_ss(r1 + 6);
368 _r17 = _mm256_broadcast_ss(r1 + 7);
369
370 _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
371 _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
372 _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
373 _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
374 _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
375 _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
376 _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
377 _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
378
379 _k10 = loadfp16(kptr);
380 _k11 = loadfp16(kptr + 8);
381 _k12 = loadfp16(kptr + 16);
382 _k13 = loadfp16(kptr + 24);
383 kptr += 32;
384
385 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
386 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
387 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
388 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
389
390 _k14 = loadfp16(kptr);
391 _k15 = loadfp16(kptr + 8);
392 _k16 = loadfp16(kptr + 16);
393 _k17 = loadfp16(kptr + 24);
394 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
395 _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
396 _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
397 _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
398
399 r1 += 8;
400 _r10 = _mm256_broadcast_ss(r1);
401 _r11 = _mm256_broadcast_ss(r1 + 1);
402 _r12 = _mm256_broadcast_ss(r1 + 2);
403 _r13 = _mm256_broadcast_ss(r1 + 3);
404 _r14 = _mm256_broadcast_ss(r1 + 4);
405 _r15 = _mm256_broadcast_ss(r1 + 5);
406 _r16 = _mm256_broadcast_ss(r1 + 6);
407 _r17 = _mm256_broadcast_ss(r1 + 7);
408
409 _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
410 _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
411 _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
412 _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
413 _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
414 _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
415 _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
416 _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
417
418 kptr -= 224;
419 _mm256_storeu_ps(outptr0, _sum0);
420 _mm256_storeu_ps(outptr0 + 8, _sum1);
421 outptr0 += 16;
422 }
423
424 for (; j < outw; j++)
425 {
426 __m256 _sum = _mm256_loadu_ps(outptr0);
427
428 __m256 _r00 = _mm256_broadcast_ss(r0);
429 __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
430 __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
431 __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
432 __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
433 __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
434 __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
435 __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
436
437 __m256 _k00 = loadfp16(kptr);
438 __m256 _k01 = loadfp16(kptr + 8);
439 __m256 _k02 = loadfp16(kptr + 16);
440 __m256 _k03 = loadfp16(kptr + 24);
441 kptr += 32;
442
443 _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
444 _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
445 _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
446 _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
447
448 __m256 _k04 = loadfp16(kptr);
449 __m256 _k05 = loadfp16(kptr + 8);
450 __m256 _k06 = loadfp16(kptr + 16);
451 __m256 _k07 = loadfp16(kptr + 24);
452 kptr += 32;
453
454 _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
455 _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
456 _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
457 _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
458
459 //========================================
460 r0 += 8;
461 _r00 = _mm256_broadcast_ss(r0);
462 _r01 = _mm256_broadcast_ss(r0 + 1);
463 _r02 = _mm256_broadcast_ss(r0 + 2);
464 _r03 = _mm256_broadcast_ss(r0 + 3);
465 _r04 = _mm256_broadcast_ss(r0 + 4);
466 _r05 = _mm256_broadcast_ss(r0 + 5);
467 _r06 = _mm256_broadcast_ss(r0 + 6);
468 _r07 = _mm256_broadcast_ss(r0 + 7);
469
470 _k00 = loadfp16(kptr);
471 _k01 = loadfp16(kptr + 8);
472 _k02 = loadfp16(kptr + 16);
473 _k03 = loadfp16(kptr + 24);
474 kptr += 32;
475
476 _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
477 _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
478 _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
479 _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
480
481 _k04 = loadfp16(kptr);
482 _k05 = loadfp16(kptr + 8);
483 _k06 = loadfp16(kptr + 16);
484 _k07 = loadfp16(kptr + 24);
485 kptr += 32;
486
487 _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
488 _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
489 _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
490 _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
491 //===============
492
493 __m256 _r10 = _mm256_broadcast_ss(r1);
494 __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
495 __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
496 __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
497 __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
498 __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
499 __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
500 __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
501
502 __m256 _k10 = loadfp16(kptr);
503 __m256 _k11 = loadfp16(kptr + 8);
504 __m256 _k12 = loadfp16(kptr + 16);
505 __m256 _k13 = loadfp16(kptr + 24);
506 kptr += 32;
507
508 _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
509 _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
510 _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
511 _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
512
513 __m256 _k14 = loadfp16(kptr);
514 __m256 _k15 = loadfp16(kptr + 8);
515 __m256 _k16 = loadfp16(kptr + 16);
516 __m256 _k17 = loadfp16(kptr + 24);
517 kptr += 32;
518
519 _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
520 _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
521 _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
522 _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
523
524 //=======================================
525 r1 += 8;
526 _r10 = _mm256_broadcast_ss(r1);
527 _r11 = _mm256_broadcast_ss(r1 + 1);
528 _r12 = _mm256_broadcast_ss(r1 + 2);
529 _r13 = _mm256_broadcast_ss(r1 + 3);
530 _r14 = _mm256_broadcast_ss(r1 + 4);
531 _r15 = _mm256_broadcast_ss(r1 + 5);
532 _r16 = _mm256_broadcast_ss(r1 + 6);
533 _r17 = _mm256_broadcast_ss(r1 + 7);
534
535 _k10 = loadfp16(kptr);
536 _k11 = loadfp16(kptr + 8);
537 _k12 = loadfp16(kptr + 16);
538 _k13 = loadfp16(kptr + 24);
539 kptr += 32;
540
541 _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
542 _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
543 _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
544 _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
545
546 _k14 = loadfp16(kptr);
547 _k15 = loadfp16(kptr + 8);
548 _k16 = loadfp16(kptr + 16);
549 _k17 = loadfp16(kptr + 24);
550 _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
551 _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
552 _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
553 _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
554
555 kptr -= 224;
556 _mm256_storeu_ps(outptr0, _sum);
557 outptr0 += 8;
558 }
559
560 r0 += 8;
561 r1 += 8;
562 }
563 }
564 }
565 }
566