1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
convdw3x3s1_fp16_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void convdw3x3s1_fp16_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 int outw = top_blob.w;
18 int outh = top_blob.h;
19
20 const int group = bottom_blob.c;
21
22 const float* bias = _bias;
23
24 #pragma omp parallel for num_threads(opt.num_threads)
25 for (int g = 0; g < group; g++)
26 {
27 Mat out = top_blob.channel(g);
28
29 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + g * 8) : _mm256_set1_ps(0.f);
30
31 const unsigned short* k0 = (const unsigned short*)kernel.row(g);
32
33 float* outptr0 = out.row(0);
34
35 const Mat img0 = bottom_blob.channel(g);
36
37 const float* r0 = img0.row(0);
38 const float* r1 = img0.row(1);
39 const float* r2 = img0.row(2);
40
41 __m256 _k00 = loadfp16(k0);
42 __m256 _k01 = loadfp16(k0 + 8);
43 __m256 _k02 = loadfp16(k0 + 16);
44 __m256 _k10 = loadfp16(k0 + 24);
45 __m256 _k11 = loadfp16(k0 + 32);
46 __m256 _k12 = loadfp16(k0 + 40);
47 __m256 _k20 = loadfp16(k0 + 48);
48 __m256 _k21 = loadfp16(k0 + 56);
49 __m256 _k22 = loadfp16(k0 + 64);
50
51 int i = 0;
52
53 for (; i < outh; i++)
54 {
55 int j = 0;
56 for (; j + 7 < outw; j += 8)
57 {
58 __m256 _sum0 = _bias0;
59
60 __m256 _r00 = _mm256_loadu_ps(r0);
61 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
62 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
63 __m256 _r10 = _mm256_loadu_ps(r1);
64 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
65 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
66 __m256 _r20 = _mm256_loadu_ps(r2);
67 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
68 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
69
70 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
71 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
72 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
73 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
74 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
75 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
76 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
77 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
78 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
79
80 __m256 _sum1 = _bias0;
81 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
82 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
83 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
84 _mm256_storeu_ps(outptr0, _sum0);
85
86 _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1);
87 _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1);
88 _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1);
89 _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1);
90 _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1);
91 _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1);
92 _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1);
93 _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1);
94 _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1);
95
96 __m256 _sum2 = _bias0;
97 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
98 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
99 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
100 _mm256_storeu_ps(outptr0 + 8, _sum1);
101
102 _sum2 = _mm256_fmadd_ps(_k00, _r02, _sum2);
103 _sum2 = _mm256_fmadd_ps(_k01, _r03, _sum2);
104 _sum2 = _mm256_fmadd_ps(_k02, _r04, _sum2);
105 _sum2 = _mm256_fmadd_ps(_k10, _r12, _sum2);
106 _sum2 = _mm256_fmadd_ps(_k11, _r13, _sum2);
107 _sum2 = _mm256_fmadd_ps(_k12, _r14, _sum2);
108 _sum2 = _mm256_fmadd_ps(_k20, _r22, _sum2);
109 _sum2 = _mm256_fmadd_ps(_k21, _r23, _sum2);
110 _sum2 = _mm256_fmadd_ps(_k22, _r24, _sum2);
111
112 __m256 _sum3 = _bias0;
113 __m256 _r05 = _mm256_loadu_ps(r0 + 40);
114 __m256 _r15 = _mm256_loadu_ps(r1 + 40);
115 __m256 _r25 = _mm256_loadu_ps(r2 + 40);
116 _mm256_storeu_ps(outptr0 + 16, _sum2);
117
118 _sum3 = _mm256_fmadd_ps(_k00, _r03, _sum3);
119 _sum3 = _mm256_fmadd_ps(_k01, _r04, _sum3);
120 _sum3 = _mm256_fmadd_ps(_k02, _r05, _sum3);
121 _sum3 = _mm256_fmadd_ps(_k10, _r13, _sum3);
122 _sum3 = _mm256_fmadd_ps(_k11, _r14, _sum3);
123 _sum3 = _mm256_fmadd_ps(_k12, _r15, _sum3);
124 _sum3 = _mm256_fmadd_ps(_k20, _r23, _sum3);
125 _sum3 = _mm256_fmadd_ps(_k21, _r24, _sum3);
126 _sum3 = _mm256_fmadd_ps(_k22, _r25, _sum3);
127
128 __m256 _sum4 = _bias0;
129 __m256 _r06 = _mm256_loadu_ps(r0 + 48);
130 __m256 _r16 = _mm256_loadu_ps(r1 + 48);
131 __m256 _r26 = _mm256_loadu_ps(r2 + 48);
132 _mm256_storeu_ps(outptr0 + 24, _sum3);
133
134 _sum4 = _mm256_fmadd_ps(_k00, _r04, _sum4);
135 _sum4 = _mm256_fmadd_ps(_k01, _r05, _sum4);
136 _sum4 = _mm256_fmadd_ps(_k02, _r06, _sum4);
137 _sum4 = _mm256_fmadd_ps(_k10, _r14, _sum4);
138 _sum4 = _mm256_fmadd_ps(_k11, _r15, _sum4);
139 _sum4 = _mm256_fmadd_ps(_k12, _r16, _sum4);
140 _sum4 = _mm256_fmadd_ps(_k20, _r24, _sum4);
141 _sum4 = _mm256_fmadd_ps(_k21, _r25, _sum4);
142 _sum4 = _mm256_fmadd_ps(_k22, _r26, _sum4);
143
144 __m256 _sum5 = _bias0;
145 __m256 _r07 = _mm256_loadu_ps(r0 + 56);
146 __m256 _r17 = _mm256_loadu_ps(r1 + 56);
147 __m256 _r27 = _mm256_loadu_ps(r2 + 56);
148 _mm256_storeu_ps(outptr0 + 32, _sum4);
149
150 _sum5 = _mm256_fmadd_ps(_k00, _r05, _sum5);
151 _sum5 = _mm256_fmadd_ps(_k01, _r06, _sum5);
152 _sum5 = _mm256_fmadd_ps(_k02, _r07, _sum5);
153 _sum5 = _mm256_fmadd_ps(_k10, _r15, _sum5);
154 _sum5 = _mm256_fmadd_ps(_k11, _r16, _sum5);
155 _sum5 = _mm256_fmadd_ps(_k12, _r17, _sum5);
156 _sum5 = _mm256_fmadd_ps(_k20, _r25, _sum5);
157 _sum5 = _mm256_fmadd_ps(_k21, _r26, _sum5);
158 _sum5 = _mm256_fmadd_ps(_k22, _r27, _sum5);
159
160 __m256 _sum6 = _bias0;
161 __m256 _r08 = _mm256_loadu_ps(r0 + 64);
162 __m256 _r18 = _mm256_loadu_ps(r1 + 64);
163 __m256 _r28 = _mm256_loadu_ps(r2 + 64);
164 _mm256_storeu_ps(outptr0 + 40, _sum5);
165
166 _sum6 = _mm256_fmadd_ps(_k00, _r06, _sum6);
167 _sum6 = _mm256_fmadd_ps(_k01, _r07, _sum6);
168 _sum6 = _mm256_fmadd_ps(_k02, _r08, _sum6);
169 _sum6 = _mm256_fmadd_ps(_k10, _r16, _sum6);
170 _sum6 = _mm256_fmadd_ps(_k11, _r17, _sum6);
171 _sum6 = _mm256_fmadd_ps(_k12, _r18, _sum6);
172 _sum6 = _mm256_fmadd_ps(_k20, _r26, _sum6);
173 _sum6 = _mm256_fmadd_ps(_k21, _r27, _sum6);
174 _sum6 = _mm256_fmadd_ps(_k22, _r28, _sum6);
175
176 __m256 _sum7 = _bias0;
177 __m256 _r09 = _mm256_loadu_ps(r0 + 72);
178 __m256 _r19 = _mm256_loadu_ps(r1 + 72);
179 __m256 _r29 = _mm256_loadu_ps(r2 + 72);
180 _mm256_storeu_ps(outptr0 + 48, _sum6);
181
182 _sum7 = _mm256_fmadd_ps(_k00, _r07, _sum7);
183 _sum7 = _mm256_fmadd_ps(_k01, _r08, _sum7);
184 _sum7 = _mm256_fmadd_ps(_k02, _r09, _sum7);
185 _sum7 = _mm256_fmadd_ps(_k10, _r17, _sum7);
186 _sum7 = _mm256_fmadd_ps(_k11, _r18, _sum7);
187 _sum7 = _mm256_fmadd_ps(_k12, _r19, _sum7);
188 _sum7 = _mm256_fmadd_ps(_k20, _r27, _sum7);
189 _sum7 = _mm256_fmadd_ps(_k21, _r28, _sum7);
190 _sum7 = _mm256_fmadd_ps(_k22, _r29, _sum7);
191 _mm256_storeu_ps(outptr0 + 56, _sum7);
192
193 r0 += 64;
194 r1 += 64;
195 r2 += 64;
196 outptr0 += 64;
197 }
198 for (; j + 3 < outw; j += 4)
199 {
200 __m256 _sum0 = _bias0;
201
202 __m256 _r00 = _mm256_loadu_ps(r0);
203 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
204 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
205 __m256 _r10 = _mm256_loadu_ps(r1);
206 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
207 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
208 __m256 _r20 = _mm256_loadu_ps(r2);
209 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
210 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
211
212 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
213 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
214 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
215 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
216 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
217 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
218 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
219 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
220 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
221
222 __m256 _sum1 = _bias0;
223 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
224 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
225 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
226 _mm256_storeu_ps(outptr0, _sum0);
227
228 _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1);
229 _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1);
230 _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1);
231 _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1);
232 _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1);
233 _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1);
234 _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1);
235 _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1);
236 _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1);
237
238 __m256 _sum2 = _bias0;
239 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
240 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
241 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
242 _mm256_storeu_ps(outptr0 + 8, _sum1);
243
244 _sum2 = _mm256_fmadd_ps(_k00, _r02, _sum2);
245 _sum2 = _mm256_fmadd_ps(_k01, _r03, _sum2);
246 _sum2 = _mm256_fmadd_ps(_k02, _r04, _sum2);
247 _sum2 = _mm256_fmadd_ps(_k10, _r12, _sum2);
248 _sum2 = _mm256_fmadd_ps(_k11, _r13, _sum2);
249 _sum2 = _mm256_fmadd_ps(_k12, _r14, _sum2);
250 _sum2 = _mm256_fmadd_ps(_k20, _r22, _sum2);
251 _sum2 = _mm256_fmadd_ps(_k21, _r23, _sum2);
252 _sum2 = _mm256_fmadd_ps(_k22, _r24, _sum2);
253
254 __m256 _sum3 = _bias0;
255 __m256 _r05 = _mm256_loadu_ps(r0 + 40);
256 __m256 _r15 = _mm256_loadu_ps(r1 + 40);
257 __m256 _r25 = _mm256_loadu_ps(r2 + 40);
258 _mm256_storeu_ps(outptr0 + 16, _sum2);
259
260 _sum3 = _mm256_fmadd_ps(_k00, _r03, _sum3);
261 _sum3 = _mm256_fmadd_ps(_k01, _r04, _sum3);
262 _sum3 = _mm256_fmadd_ps(_k02, _r05, _sum3);
263 _sum3 = _mm256_fmadd_ps(_k10, _r13, _sum3);
264 _sum3 = _mm256_fmadd_ps(_k11, _r14, _sum3);
265 _sum3 = _mm256_fmadd_ps(_k12, _r15, _sum3);
266 _sum3 = _mm256_fmadd_ps(_k20, _r23, _sum3);
267 _sum3 = _mm256_fmadd_ps(_k21, _r24, _sum3);
268 _sum3 = _mm256_fmadd_ps(_k22, _r25, _sum3);
269
270 _mm256_storeu_ps(outptr0 + 24, _sum3);
271
272 r0 += 32;
273 r1 += 32;
274 r2 += 32;
275 outptr0 += 32;
276 }
277 for (; j + 1 < outw; j += 2)
278 {
279 __m256 _sum0 = _bias0;
280
281 __m256 _r00 = _mm256_loadu_ps(r0);
282 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
283 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
284 __m256 _r10 = _mm256_loadu_ps(r1);
285 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
286 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
287 __m256 _r20 = _mm256_loadu_ps(r2);
288 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
289 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
290
291 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
292 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
293 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
294 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
295 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
296 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
297 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
298 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
299 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
300
301 __m256 _sum1 = _bias0;
302 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
303 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
304 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
305 _mm256_storeu_ps(outptr0, _sum0);
306
307 _sum1 = _mm256_fmadd_ps(_k00, _r01, _sum1);
308 _sum1 = _mm256_fmadd_ps(_k01, _r02, _sum1);
309 _sum1 = _mm256_fmadd_ps(_k02, _r03, _sum1);
310 _sum1 = _mm256_fmadd_ps(_k10, _r11, _sum1);
311 _sum1 = _mm256_fmadd_ps(_k11, _r12, _sum1);
312 _sum1 = _mm256_fmadd_ps(_k12, _r13, _sum1);
313 _sum1 = _mm256_fmadd_ps(_k20, _r21, _sum1);
314 _sum1 = _mm256_fmadd_ps(_k21, _r22, _sum1);
315 _sum1 = _mm256_fmadd_ps(_k22, _r23, _sum1);
316
317 _mm256_storeu_ps(outptr0 + 8, _sum1);
318
319 r0 += 16;
320 r1 += 16;
321 r2 += 16;
322 outptr0 += 16;
323 }
324 for (; j < outw; j++)
325 {
326 __m256 _sum0 = _bias0;
327
328 __m256 _r00 = _mm256_loadu_ps(r0);
329 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
330 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
331 __m256 _r10 = _mm256_loadu_ps(r1);
332 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
333 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
334 __m256 _r20 = _mm256_loadu_ps(r2);
335 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
336 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
337
338 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
339 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
340 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
341 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
342 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
343 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
344 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
345 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
346 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
347
348 _mm256_storeu_ps(outptr0, _sum0);
349
350 r0 += 8;
351 r1 += 8;
352 r2 += 8;
353 outptr0 += 8;
354 }
355
356 r0 += 2 * 8;
357 r1 += 2 * 8;
358 r2 += 2 * 8;
359 }
360 }
361 }
362
convdw3x3s2_fp16_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)363 static void convdw3x3s2_fp16_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
364 {
365 int w = bottom_blob.w;
366
367 int outw = top_blob.w;
368 int outh = top_blob.h;
369
370 const int group = bottom_blob.c;
371
372 const int tailstep = (w - 2 * outw + w) * 8;
373
374 const float* bias = _bias;
375
376 #pragma omp parallel for num_threads(opt.num_threads)
377 for (int g = 0; g < group; g++)
378 {
379 Mat out = top_blob.channel(g);
380
381 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + g * 8) : _mm256_set1_ps(0.f);
382
383 const unsigned short* k0 = (const unsigned short*)kernel.row(g);
384
385 float* outptr0 = out.row(0);
386
387 const Mat img0 = bottom_blob.channel(g);
388
389 const float* r0 = img0.row(0);
390 const float* r1 = img0.row(1);
391 const float* r2 = img0.row(2);
392
393 __m256 _k00 = loadfp16(k0);
394 __m256 _k01 = loadfp16(k0 + 8);
395 __m256 _k02 = loadfp16(k0 + 16);
396 __m256 _k10 = loadfp16(k0 + 24);
397 __m256 _k11 = loadfp16(k0 + 32);
398 __m256 _k12 = loadfp16(k0 + 40);
399 __m256 _k20 = loadfp16(k0 + 48);
400 __m256 _k21 = loadfp16(k0 + 56);
401 __m256 _k22 = loadfp16(k0 + 64);
402
403 int i = 0;
404
405 for (; i < outh; i++)
406 {
407 int j = 0;
408 for (; j + 3 < outw; j += 4)
409 {
410 __m256 _sum0 = _bias0;
411
412 __m256 _r00 = _mm256_loadu_ps(r0);
413 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
414 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
415 __m256 _r10 = _mm256_loadu_ps(r1);
416 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
417 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
418 __m256 _r20 = _mm256_loadu_ps(r2);
419 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
420 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
421
422 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
423 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
424 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
425 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
426 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
427 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
428 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
429 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
430 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
431
432 __m256 _sum1 = _bias0;
433 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
434 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
435 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
436 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
437 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
438 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
439 _mm256_storeu_ps(outptr0, _sum0);
440
441 _sum1 = _mm256_fmadd_ps(_k00, _r02, _sum1);
442 _sum1 = _mm256_fmadd_ps(_k01, _r03, _sum1);
443 _sum1 = _mm256_fmadd_ps(_k02, _r04, _sum1);
444 _sum1 = _mm256_fmadd_ps(_k10, _r12, _sum1);
445 _sum1 = _mm256_fmadd_ps(_k11, _r13, _sum1);
446 _sum1 = _mm256_fmadd_ps(_k12, _r14, _sum1);
447 _sum1 = _mm256_fmadd_ps(_k20, _r22, _sum1);
448 _sum1 = _mm256_fmadd_ps(_k21, _r23, _sum1);
449 _sum1 = _mm256_fmadd_ps(_k22, _r24, _sum1);
450
451 __m256 _sum2 = _bias0;
452 __m256 _r05 = _mm256_loadu_ps(r0 + 40);
453 __m256 _r15 = _mm256_loadu_ps(r1 + 40);
454 __m256 _r25 = _mm256_loadu_ps(r2 + 40);
455 __m256 _r06 = _mm256_loadu_ps(r0 + 48);
456 __m256 _r16 = _mm256_loadu_ps(r1 + 48);
457 __m256 _r26 = _mm256_loadu_ps(r2 + 48);
458 _mm256_storeu_ps(outptr0 + 8, _sum1);
459
460 _sum2 = _mm256_fmadd_ps(_k00, _r04, _sum2);
461 _sum2 = _mm256_fmadd_ps(_k01, _r05, _sum2);
462 _sum2 = _mm256_fmadd_ps(_k02, _r06, _sum2);
463 _sum2 = _mm256_fmadd_ps(_k10, _r14, _sum2);
464 _sum2 = _mm256_fmadd_ps(_k11, _r15, _sum2);
465 _sum2 = _mm256_fmadd_ps(_k12, _r16, _sum2);
466 _sum2 = _mm256_fmadd_ps(_k20, _r24, _sum2);
467 _sum2 = _mm256_fmadd_ps(_k21, _r25, _sum2);
468 _sum2 = _mm256_fmadd_ps(_k22, _r26, _sum2);
469
470 __m256 _sum3 = _bias0;
471 __m256 _r07 = _mm256_loadu_ps(r0 + 56);
472 __m256 _r17 = _mm256_loadu_ps(r1 + 56);
473 __m256 _r27 = _mm256_loadu_ps(r2 + 56);
474 __m256 _r08 = _mm256_loadu_ps(r0 + 64);
475 __m256 _r18 = _mm256_loadu_ps(r1 + 64);
476 __m256 _r28 = _mm256_loadu_ps(r2 + 64);
477 _mm256_storeu_ps(outptr0 + 16, _sum2);
478
479 _sum3 = _mm256_fmadd_ps(_k00, _r06, _sum3);
480 _sum3 = _mm256_fmadd_ps(_k01, _r07, _sum3);
481 _sum3 = _mm256_fmadd_ps(_k02, _r08, _sum3);
482 _sum3 = _mm256_fmadd_ps(_k10, _r16, _sum3);
483 _sum3 = _mm256_fmadd_ps(_k11, _r17, _sum3);
484 _sum3 = _mm256_fmadd_ps(_k12, _r18, _sum3);
485 _sum3 = _mm256_fmadd_ps(_k20, _r26, _sum3);
486 _sum3 = _mm256_fmadd_ps(_k21, _r27, _sum3);
487 _sum3 = _mm256_fmadd_ps(_k22, _r28, _sum3);
488
489 _mm256_storeu_ps(outptr0 + 24, _sum3);
490
491 r0 += 2 * 32;
492 r1 += 2 * 32;
493 r2 += 2 * 32;
494 outptr0 += 32;
495 }
496 for (; j + 1 < outw; j += 2)
497 {
498 __m256 _sum0 = _bias0;
499
500 __m256 _r00 = _mm256_loadu_ps(r0);
501 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
502 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
503 __m256 _r10 = _mm256_loadu_ps(r1);
504 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
505 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
506 __m256 _r20 = _mm256_loadu_ps(r2);
507 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
508 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
509
510 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
511 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
512 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
513 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
514 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
515 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
516 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
517 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
518 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
519
520 __m256 _sum1 = _bias0;
521 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
522 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
523 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
524 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
525 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
526 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
527 _mm256_storeu_ps(outptr0, _sum0);
528
529 _sum1 = _mm256_fmadd_ps(_k00, _r02, _sum1);
530 _sum1 = _mm256_fmadd_ps(_k01, _r03, _sum1);
531 _sum1 = _mm256_fmadd_ps(_k02, _r04, _sum1);
532 _sum1 = _mm256_fmadd_ps(_k10, _r12, _sum1);
533 _sum1 = _mm256_fmadd_ps(_k11, _r13, _sum1);
534 _sum1 = _mm256_fmadd_ps(_k12, _r14, _sum1);
535 _sum1 = _mm256_fmadd_ps(_k20, _r22, _sum1);
536 _sum1 = _mm256_fmadd_ps(_k21, _r23, _sum1);
537 _sum1 = _mm256_fmadd_ps(_k22, _r24, _sum1);
538
539 _mm256_storeu_ps(outptr0 + 8, _sum1);
540
541 r0 += 2 * 16;
542 r1 += 2 * 16;
543 r2 += 2 * 16;
544 outptr0 += 16;
545 }
546 for (; j < outw; j++)
547 {
548 __m256 _sum0 = _bias0;
549
550 __m256 _r00 = _mm256_loadu_ps(r0);
551 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
552 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
553 __m256 _r10 = _mm256_loadu_ps(r1);
554 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
555 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
556 __m256 _r20 = _mm256_loadu_ps(r2);
557 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
558 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
559
560 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
561 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
562 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
563 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
564 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
565 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
566 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
567 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
568 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
569 _mm256_storeu_ps(outptr0, _sum0);
570 r0 += 2 * 8;
571 r1 += 2 * 8;
572 r2 += 2 * 8;
573 outptr0 += 8;
574 }
575
576 r0 += tailstep;
577 r1 += tailstep;
578 r2 += tailstep;
579 }
580 }
581 }
582