1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv2x2s1_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void conv2x2s1_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 int inch = bottom_blob.c;
18 int outw = top_blob.w;
19 int outh = top_blob.h;
20 int outch = top_blob.c;
21 const float* bias = _bias;
22
23 #pragma omp parallel for num_threads(opt.num_threads)
24 for (int p = 0; p < outch; p++)
25 {
26 Mat out0 = top_blob.channel(p);
27
28 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
29 out0.fill(_bias0);
30
31 for (int q = 0; q < inch; q++)
32 {
33 float* outptr0 = out0.row(0);
34
35 const Mat img0 = bottom_blob.channel(q);
36
37 const float* r0 = img0.row(0);
38 const float* r1 = img0.row(1);
39
40 const float* kptr = (const float*)kernel.channel(p).row(q);
41 // const float* kptr = (const float*)kernel + 4 * inch * p * 64;
42
43 int i = 0;
44 for (; i < outh; i++)
45 {
46 int j = 0;
47
48 for (; j + 1 < outw; j += 2)
49 {
50 __m256 _sum0 = _mm256_loadu_ps(outptr0);
51 __m256 _sum1 = _mm256_loadu_ps(outptr0 + 8);
52
53 __m256 _r00 = _mm256_broadcast_ss(r0);
54 __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
55 __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
56 __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
57 __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
58 __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
59 __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
60 __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
61 r0 += 8;
62
63 __m256 _k00 = _mm256_loadu_ps(kptr);
64 __m256 _k01 = _mm256_loadu_ps(kptr + 8);
65 __m256 _k02 = _mm256_loadu_ps(kptr + 16);
66 __m256 _k03 = _mm256_loadu_ps(kptr + 24);
67 kptr += 32;
68
69 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
70 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
71 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
72 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
73
74 __m256 _k04 = _mm256_loadu_ps(kptr);
75 __m256 _k05 = _mm256_loadu_ps(kptr + 8);
76 __m256 _k06 = _mm256_loadu_ps(kptr + 16);
77 __m256 _k07 = _mm256_loadu_ps(kptr + 24);
78 kptr += 32;
79
80 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
81 _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
82 _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
83 _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
84
85 //========================================
86
87 _r00 = _mm256_broadcast_ss(r0);
88 _r01 = _mm256_broadcast_ss(r0 + 1);
89 _r02 = _mm256_broadcast_ss(r0 + 2);
90 _r03 = _mm256_broadcast_ss(r0 + 3);
91 _r04 = _mm256_broadcast_ss(r0 + 4);
92 _r05 = _mm256_broadcast_ss(r0 + 5);
93 _r06 = _mm256_broadcast_ss(r0 + 6);
94 _r07 = _mm256_broadcast_ss(r0 + 7);
95 r0 += 8;
96
97 _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
98 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
99 _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
100 _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
101 _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
102 _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
103 _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
104 _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
105
106 _k00 = _mm256_loadu_ps(kptr);
107 _k01 = _mm256_loadu_ps(kptr + 8);
108 _k02 = _mm256_loadu_ps(kptr + 16);
109 _k03 = _mm256_loadu_ps(kptr + 24);
110 kptr += 32;
111
112 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
113 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
114 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
115 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
116
117 _k04 = _mm256_loadu_ps(kptr);
118 _k05 = _mm256_loadu_ps(kptr + 8);
119 _k06 = _mm256_loadu_ps(kptr + 16);
120 _k07 = _mm256_loadu_ps(kptr + 24);
121 kptr += 32;
122
123 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
124 _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
125 _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
126 _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
127
128 _r00 = _mm256_broadcast_ss(r0);
129 _r01 = _mm256_broadcast_ss(r0 + 1);
130 _r02 = _mm256_broadcast_ss(r0 + 2);
131 _r03 = _mm256_broadcast_ss(r0 + 3);
132 _r04 = _mm256_broadcast_ss(r0 + 4);
133 _r05 = _mm256_broadcast_ss(r0 + 5);
134 _r06 = _mm256_broadcast_ss(r0 + 6);
135 _r07 = _mm256_broadcast_ss(r0 + 7);
136
137 _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
138 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
139 _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
140 _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
141 _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
142 _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
143 _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
144 _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
145 //===============
146
147 __m256 _r10 = _mm256_broadcast_ss(r1);
148 __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
149 __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
150 __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
151 __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
152 __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
153 __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
154 __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
155
156 __m256 _k10 = _mm256_loadu_ps(kptr);
157 __m256 _k11 = _mm256_loadu_ps(kptr + 8);
158 __m256 _k12 = _mm256_loadu_ps(kptr + 16);
159 __m256 _k13 = _mm256_loadu_ps(kptr + 24);
160 kptr += 32;
161
162 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
163 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
164 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
165 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
166
167 __m256 _k14 = _mm256_loadu_ps(kptr);
168 __m256 _k15 = _mm256_loadu_ps(kptr + 8);
169 __m256 _k16 = _mm256_loadu_ps(kptr + 16);
170 __m256 _k17 = _mm256_loadu_ps(kptr + 24);
171 kptr += 32;
172
173 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
174 _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
175 _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
176 _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
177
178 //=======================================
179 r1 += 8;
180 _r10 = _mm256_broadcast_ss(r1);
181 _r11 = _mm256_broadcast_ss(r1 + 1);
182 _r12 = _mm256_broadcast_ss(r1 + 2);
183 _r13 = _mm256_broadcast_ss(r1 + 3);
184 _r14 = _mm256_broadcast_ss(r1 + 4);
185 _r15 = _mm256_broadcast_ss(r1 + 5);
186 _r16 = _mm256_broadcast_ss(r1 + 6);
187 _r17 = _mm256_broadcast_ss(r1 + 7);
188
189 _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
190 _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
191 _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
192 _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
193 _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
194 _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
195 _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
196 _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
197
198 _k10 = _mm256_loadu_ps(kptr);
199 _k11 = _mm256_loadu_ps(kptr + 8);
200 _k12 = _mm256_loadu_ps(kptr + 16);
201 _k13 = _mm256_loadu_ps(kptr + 24);
202 kptr += 32;
203
204 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
205 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
206 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
207 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
208
209 _k14 = _mm256_loadu_ps(kptr);
210 _k15 = _mm256_loadu_ps(kptr + 8);
211 _k16 = _mm256_loadu_ps(kptr + 16);
212 _k17 = _mm256_loadu_ps(kptr + 24);
213 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
214 _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
215 _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
216 _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
217
218 r1 += 8;
219 _r10 = _mm256_broadcast_ss(r1);
220 _r11 = _mm256_broadcast_ss(r1 + 1);
221 _r12 = _mm256_broadcast_ss(r1 + 2);
222 _r13 = _mm256_broadcast_ss(r1 + 3);
223 _r14 = _mm256_broadcast_ss(r1 + 4);
224 _r15 = _mm256_broadcast_ss(r1 + 5);
225 _r16 = _mm256_broadcast_ss(r1 + 6);
226 _r17 = _mm256_broadcast_ss(r1 + 7);
227
228 _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
229 _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
230 _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
231 _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
232 _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
233 _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
234 _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
235 _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
236
237 kptr -= 224;
238 _mm256_storeu_ps(outptr0, _sum0);
239 _mm256_storeu_ps(outptr0 + 8, _sum1);
240 outptr0 += 16;
241 }
242
243 for (; j < outw; j++)
244 {
245 __m256 _sum = _mm256_loadu_ps(outptr0);
246
247 __m256 _r00 = _mm256_broadcast_ss(r0);
248 __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
249 __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
250 __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
251 __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
252 __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
253 __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
254 __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
255
256 __m256 _k00 = _mm256_loadu_ps(kptr);
257 __m256 _k01 = _mm256_loadu_ps(kptr + 8);
258 __m256 _k02 = _mm256_loadu_ps(kptr + 16);
259 __m256 _k03 = _mm256_loadu_ps(kptr + 24);
260 kptr += 32;
261
262 _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
263 _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
264 _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
265 _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
266
267 __m256 _k04 = _mm256_loadu_ps(kptr);
268 __m256 _k05 = _mm256_loadu_ps(kptr + 8);
269 __m256 _k06 = _mm256_loadu_ps(kptr + 16);
270 __m256 _k07 = _mm256_loadu_ps(kptr + 24);
271 kptr += 32;
272
273 _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
274 _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
275 _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
276 _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
277
278 //========================================
279 r0 += 8;
280 _r00 = _mm256_broadcast_ss(r0);
281 _r01 = _mm256_broadcast_ss(r0 + 1);
282 _r02 = _mm256_broadcast_ss(r0 + 2);
283 _r03 = _mm256_broadcast_ss(r0 + 3);
284 _r04 = _mm256_broadcast_ss(r0 + 4);
285 _r05 = _mm256_broadcast_ss(r0 + 5);
286 _r06 = _mm256_broadcast_ss(r0 + 6);
287 _r07 = _mm256_broadcast_ss(r0 + 7);
288
289 _k00 = _mm256_loadu_ps(kptr);
290 _k01 = _mm256_loadu_ps(kptr + 8);
291 _k02 = _mm256_loadu_ps(kptr + 16);
292 _k03 = _mm256_loadu_ps(kptr + 24);
293 kptr += 32;
294
295 _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
296 _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
297 _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
298 _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
299
300 _k04 = _mm256_loadu_ps(kptr);
301 _k05 = _mm256_loadu_ps(kptr + 8);
302 _k06 = _mm256_loadu_ps(kptr + 16);
303 _k07 = _mm256_loadu_ps(kptr + 24);
304 kptr += 32;
305
306 _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
307 _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
308 _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
309 _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
310 //===============
311
312 __m256 _r10 = _mm256_broadcast_ss(r1);
313 __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
314 __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
315 __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
316 __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
317 __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
318 __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
319 __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
320
321 __m256 _k10 = _mm256_loadu_ps(kptr);
322 __m256 _k11 = _mm256_loadu_ps(kptr + 8);
323 __m256 _k12 = _mm256_loadu_ps(kptr + 16);
324 __m256 _k13 = _mm256_loadu_ps(kptr + 24);
325 kptr += 32;
326
327 _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
328 _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
329 _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
330 _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
331
332 __m256 _k14 = _mm256_loadu_ps(kptr);
333 __m256 _k15 = _mm256_loadu_ps(kptr + 8);
334 __m256 _k16 = _mm256_loadu_ps(kptr + 16);
335 __m256 _k17 = _mm256_loadu_ps(kptr + 24);
336 kptr += 32;
337
338 _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
339 _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
340 _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
341 _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
342
343 //=======================================
344 r1 += 8;
345 _r10 = _mm256_broadcast_ss(r1);
346 _r11 = _mm256_broadcast_ss(r1 + 1);
347 _r12 = _mm256_broadcast_ss(r1 + 2);
348 _r13 = _mm256_broadcast_ss(r1 + 3);
349 _r14 = _mm256_broadcast_ss(r1 + 4);
350 _r15 = _mm256_broadcast_ss(r1 + 5);
351 _r16 = _mm256_broadcast_ss(r1 + 6);
352 _r17 = _mm256_broadcast_ss(r1 + 7);
353
354 _k10 = _mm256_loadu_ps(kptr);
355 _k11 = _mm256_loadu_ps(kptr + 8);
356 _k12 = _mm256_loadu_ps(kptr + 16);
357 _k13 = _mm256_loadu_ps(kptr + 24);
358 kptr += 32;
359
360 _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
361 _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
362 _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
363 _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
364
365 _k14 = _mm256_loadu_ps(kptr);
366 _k15 = _mm256_loadu_ps(kptr + 8);
367 _k16 = _mm256_loadu_ps(kptr + 16);
368 _k17 = _mm256_loadu_ps(kptr + 24);
369 _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
370 _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
371 _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
372 _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
373
374 kptr -= 224;
375 _mm256_storeu_ps(outptr0, _sum);
376 outptr0 += 8;
377 }
378
379 r0 += 8;
380 r1 += 8;
381 }
382 }
383 }
384 }
385