1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
conv2x2s1_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void conv2x2s1_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17     int inch = bottom_blob.c;
18     int outw = top_blob.w;
19     int outh = top_blob.h;
20     int outch = top_blob.c;
21     const float* bias = _bias;
22 
23     #pragma omp parallel for num_threads(opt.num_threads)
24     for (int p = 0; p < outch; p++)
25     {
26         Mat out0 = top_blob.channel(p);
27 
28         __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
29         out0.fill(_bias0);
30 
31         for (int q = 0; q < inch; q++)
32         {
33             float* outptr0 = out0.row(0);
34 
35             const Mat img0 = bottom_blob.channel(q);
36 
37             const float* r0 = img0.row(0);
38             const float* r1 = img0.row(1);
39 
40             const float* kptr = (const float*)kernel.channel(p).row(q);
41             // const float* kptr = (const float*)kernel + 4 * inch * p * 64;
42 
43             int i = 0;
44             for (; i < outh; i++)
45             {
46                 int j = 0;
47 
48                 for (; j + 1 < outw; j += 2)
49                 {
50                     __m256 _sum0 = _mm256_loadu_ps(outptr0);
51                     __m256 _sum1 = _mm256_loadu_ps(outptr0 + 8);
52 
53                     __m256 _r00 = _mm256_broadcast_ss(r0);
54                     __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
55                     __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
56                     __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
57                     __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
58                     __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
59                     __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
60                     __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
61                     r0 += 8;
62 
63                     __m256 _k00 = _mm256_loadu_ps(kptr);
64                     __m256 _k01 = _mm256_loadu_ps(kptr + 8);
65                     __m256 _k02 = _mm256_loadu_ps(kptr + 16);
66                     __m256 _k03 = _mm256_loadu_ps(kptr + 24);
67                     kptr += 32;
68 
69                     _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
70                     _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
71                     _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
72                     _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
73 
74                     __m256 _k04 = _mm256_loadu_ps(kptr);
75                     __m256 _k05 = _mm256_loadu_ps(kptr + 8);
76                     __m256 _k06 = _mm256_loadu_ps(kptr + 16);
77                     __m256 _k07 = _mm256_loadu_ps(kptr + 24);
78                     kptr += 32;
79 
80                     _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
81                     _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
82                     _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
83                     _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
84 
85                     //========================================
86 
87                     _r00 = _mm256_broadcast_ss(r0);
88                     _r01 = _mm256_broadcast_ss(r0 + 1);
89                     _r02 = _mm256_broadcast_ss(r0 + 2);
90                     _r03 = _mm256_broadcast_ss(r0 + 3);
91                     _r04 = _mm256_broadcast_ss(r0 + 4);
92                     _r05 = _mm256_broadcast_ss(r0 + 5);
93                     _r06 = _mm256_broadcast_ss(r0 + 6);
94                     _r07 = _mm256_broadcast_ss(r0 + 7);
95                     r0 += 8;
96 
97                     _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
98                     _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
99                     _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
100                     _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
101                     _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
102                     _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
103                     _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
104                     _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
105 
106                     _k00 = _mm256_loadu_ps(kptr);
107                     _k01 = _mm256_loadu_ps(kptr + 8);
108                     _k02 = _mm256_loadu_ps(kptr + 16);
109                     _k03 = _mm256_loadu_ps(kptr + 24);
110                     kptr += 32;
111 
112                     _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
113                     _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
114                     _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
115                     _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
116 
117                     _k04 = _mm256_loadu_ps(kptr);
118                     _k05 = _mm256_loadu_ps(kptr + 8);
119                     _k06 = _mm256_loadu_ps(kptr + 16);
120                     _k07 = _mm256_loadu_ps(kptr + 24);
121                     kptr += 32;
122 
123                     _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
124                     _sum0 = _mm256_fmadd_ps(_k05, _r05, _sum0);
125                     _sum0 = _mm256_fmadd_ps(_k06, _r06, _sum0);
126                     _sum0 = _mm256_fmadd_ps(_k07, _r07, _sum0);
127 
128                     _r00 = _mm256_broadcast_ss(r0);
129                     _r01 = _mm256_broadcast_ss(r0 + 1);
130                     _r02 = _mm256_broadcast_ss(r0 + 2);
131                     _r03 = _mm256_broadcast_ss(r0 + 3);
132                     _r04 = _mm256_broadcast_ss(r0 + 4);
133                     _r05 = _mm256_broadcast_ss(r0 + 5);
134                     _r06 = _mm256_broadcast_ss(r0 + 6);
135                     _r07 = _mm256_broadcast_ss(r0 + 7);
136 
137                     _sum1 = _mm256_fmadd_ps(_k00, _r00, _sum1);
138                     _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
139                     _sum1 = _mm256_fmadd_ps(_k02, _r02, _sum1);
140                     _sum1 = _mm256_fmadd_ps(_k03, _r03, _sum1);
141                     _sum1 = _mm256_fmadd_ps(_k04, _r04, _sum1);
142                     _sum1 = _mm256_fmadd_ps(_k05, _r05, _sum1);
143                     _sum1 = _mm256_fmadd_ps(_k06, _r06, _sum1);
144                     _sum1 = _mm256_fmadd_ps(_k07, _r07, _sum1);
145                     //===============
146 
147                     __m256 _r10 = _mm256_broadcast_ss(r1);
148                     __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
149                     __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
150                     __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
151                     __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
152                     __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
153                     __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
154                     __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
155 
156                     __m256 _k10 = _mm256_loadu_ps(kptr);
157                     __m256 _k11 = _mm256_loadu_ps(kptr + 8);
158                     __m256 _k12 = _mm256_loadu_ps(kptr + 16);
159                     __m256 _k13 = _mm256_loadu_ps(kptr + 24);
160                     kptr += 32;
161 
162                     _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
163                     _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
164                     _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
165                     _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
166 
167                     __m256 _k14 = _mm256_loadu_ps(kptr);
168                     __m256 _k15 = _mm256_loadu_ps(kptr + 8);
169                     __m256 _k16 = _mm256_loadu_ps(kptr + 16);
170                     __m256 _k17 = _mm256_loadu_ps(kptr + 24);
171                     kptr += 32;
172 
173                     _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
174                     _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
175                     _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
176                     _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
177 
178                     //=======================================
179                     r1 += 8;
180                     _r10 = _mm256_broadcast_ss(r1);
181                     _r11 = _mm256_broadcast_ss(r1 + 1);
182                     _r12 = _mm256_broadcast_ss(r1 + 2);
183                     _r13 = _mm256_broadcast_ss(r1 + 3);
184                     _r14 = _mm256_broadcast_ss(r1 + 4);
185                     _r15 = _mm256_broadcast_ss(r1 + 5);
186                     _r16 = _mm256_broadcast_ss(r1 + 6);
187                     _r17 = _mm256_broadcast_ss(r1 + 7);
188 
189                     _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
190                     _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
191                     _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
192                     _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
193                     _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
194                     _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
195                     _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
196                     _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
197 
198                     _k10 = _mm256_loadu_ps(kptr);
199                     _k11 = _mm256_loadu_ps(kptr + 8);
200                     _k12 = _mm256_loadu_ps(kptr + 16);
201                     _k13 = _mm256_loadu_ps(kptr + 24);
202                     kptr += 32;
203 
204                     _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
205                     _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
206                     _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
207                     _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
208 
209                     _k14 = _mm256_loadu_ps(kptr);
210                     _k15 = _mm256_loadu_ps(kptr + 8);
211                     _k16 = _mm256_loadu_ps(kptr + 16);
212                     _k17 = _mm256_loadu_ps(kptr + 24);
213                     _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
214                     _sum0 = _mm256_fmadd_ps(_k15, _r15, _sum0);
215                     _sum0 = _mm256_fmadd_ps(_k16, _r16, _sum0);
216                     _sum0 = _mm256_fmadd_ps(_k17, _r17, _sum0);
217 
218                     r1 += 8;
219                     _r10 = _mm256_broadcast_ss(r1);
220                     _r11 = _mm256_broadcast_ss(r1 + 1);
221                     _r12 = _mm256_broadcast_ss(r1 + 2);
222                     _r13 = _mm256_broadcast_ss(r1 + 3);
223                     _r14 = _mm256_broadcast_ss(r1 + 4);
224                     _r15 = _mm256_broadcast_ss(r1 + 5);
225                     _r16 = _mm256_broadcast_ss(r1 + 6);
226                     _r17 = _mm256_broadcast_ss(r1 + 7);
227 
228                     _sum1 = _mm256_fmadd_ps(_k10, _r10, _sum1);
229                     _sum1 = _mm256_fmadd_ps(_k11, _r11, _sum1);
230                     _sum1 = _mm256_fmadd_ps(_k12, _r12, _sum1);
231                     _sum1 = _mm256_fmadd_ps(_k13, _r13, _sum1);
232                     _sum1 = _mm256_fmadd_ps(_k14, _r14, _sum1);
233                     _sum1 = _mm256_fmadd_ps(_k15, _r15, _sum1);
234                     _sum1 = _mm256_fmadd_ps(_k16, _r16, _sum1);
235                     _sum1 = _mm256_fmadd_ps(_k17, _r17, _sum1);
236 
237                     kptr -= 224;
238                     _mm256_storeu_ps(outptr0, _sum0);
239                     _mm256_storeu_ps(outptr0 + 8, _sum1);
240                     outptr0 += 16;
241                 }
242 
243                 for (; j < outw; j++)
244                 {
245                     __m256 _sum = _mm256_loadu_ps(outptr0);
246 
247                     __m256 _r00 = _mm256_broadcast_ss(r0);
248                     __m256 _r01 = _mm256_broadcast_ss(r0 + 1);
249                     __m256 _r02 = _mm256_broadcast_ss(r0 + 2);
250                     __m256 _r03 = _mm256_broadcast_ss(r0 + 3);
251                     __m256 _r04 = _mm256_broadcast_ss(r0 + 4);
252                     __m256 _r05 = _mm256_broadcast_ss(r0 + 5);
253                     __m256 _r06 = _mm256_broadcast_ss(r0 + 6);
254                     __m256 _r07 = _mm256_broadcast_ss(r0 + 7);
255 
256                     __m256 _k00 = _mm256_loadu_ps(kptr);
257                     __m256 _k01 = _mm256_loadu_ps(kptr + 8);
258                     __m256 _k02 = _mm256_loadu_ps(kptr + 16);
259                     __m256 _k03 = _mm256_loadu_ps(kptr + 24);
260                     kptr += 32;
261 
262                     _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
263                     _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
264                     _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
265                     _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
266 
267                     __m256 _k04 = _mm256_loadu_ps(kptr);
268                     __m256 _k05 = _mm256_loadu_ps(kptr + 8);
269                     __m256 _k06 = _mm256_loadu_ps(kptr + 16);
270                     __m256 _k07 = _mm256_loadu_ps(kptr + 24);
271                     kptr += 32;
272 
273                     _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
274                     _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
275                     _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
276                     _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
277 
278                     //========================================
279                     r0 += 8;
280                     _r00 = _mm256_broadcast_ss(r0);
281                     _r01 = _mm256_broadcast_ss(r0 + 1);
282                     _r02 = _mm256_broadcast_ss(r0 + 2);
283                     _r03 = _mm256_broadcast_ss(r0 + 3);
284                     _r04 = _mm256_broadcast_ss(r0 + 4);
285                     _r05 = _mm256_broadcast_ss(r0 + 5);
286                     _r06 = _mm256_broadcast_ss(r0 + 6);
287                     _r07 = _mm256_broadcast_ss(r0 + 7);
288 
289                     _k00 = _mm256_loadu_ps(kptr);
290                     _k01 = _mm256_loadu_ps(kptr + 8);
291                     _k02 = _mm256_loadu_ps(kptr + 16);
292                     _k03 = _mm256_loadu_ps(kptr + 24);
293                     kptr += 32;
294 
295                     _sum = _mm256_fmadd_ps(_k00, _r00, _sum);
296                     _sum = _mm256_fmadd_ps(_k01, _r01, _sum);
297                     _sum = _mm256_fmadd_ps(_k02, _r02, _sum);
298                     _sum = _mm256_fmadd_ps(_k03, _r03, _sum);
299 
300                     _k04 = _mm256_loadu_ps(kptr);
301                     _k05 = _mm256_loadu_ps(kptr + 8);
302                     _k06 = _mm256_loadu_ps(kptr + 16);
303                     _k07 = _mm256_loadu_ps(kptr + 24);
304                     kptr += 32;
305 
306                     _sum = _mm256_fmadd_ps(_k04, _r04, _sum);
307                     _sum = _mm256_fmadd_ps(_k05, _r05, _sum);
308                     _sum = _mm256_fmadd_ps(_k06, _r06, _sum);
309                     _sum = _mm256_fmadd_ps(_k07, _r07, _sum);
310                     //===============
311 
312                     __m256 _r10 = _mm256_broadcast_ss(r1);
313                     __m256 _r11 = _mm256_broadcast_ss(r1 + 1);
314                     __m256 _r12 = _mm256_broadcast_ss(r1 + 2);
315                     __m256 _r13 = _mm256_broadcast_ss(r1 + 3);
316                     __m256 _r14 = _mm256_broadcast_ss(r1 + 4);
317                     __m256 _r15 = _mm256_broadcast_ss(r1 + 5);
318                     __m256 _r16 = _mm256_broadcast_ss(r1 + 6);
319                     __m256 _r17 = _mm256_broadcast_ss(r1 + 7);
320 
321                     __m256 _k10 = _mm256_loadu_ps(kptr);
322                     __m256 _k11 = _mm256_loadu_ps(kptr + 8);
323                     __m256 _k12 = _mm256_loadu_ps(kptr + 16);
324                     __m256 _k13 = _mm256_loadu_ps(kptr + 24);
325                     kptr += 32;
326 
327                     _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
328                     _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
329                     _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
330                     _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
331 
332                     __m256 _k14 = _mm256_loadu_ps(kptr);
333                     __m256 _k15 = _mm256_loadu_ps(kptr + 8);
334                     __m256 _k16 = _mm256_loadu_ps(kptr + 16);
335                     __m256 _k17 = _mm256_loadu_ps(kptr + 24);
336                     kptr += 32;
337 
338                     _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
339                     _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
340                     _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
341                     _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
342 
343                     //=======================================
344                     r1 += 8;
345                     _r10 = _mm256_broadcast_ss(r1);
346                     _r11 = _mm256_broadcast_ss(r1 + 1);
347                     _r12 = _mm256_broadcast_ss(r1 + 2);
348                     _r13 = _mm256_broadcast_ss(r1 + 3);
349                     _r14 = _mm256_broadcast_ss(r1 + 4);
350                     _r15 = _mm256_broadcast_ss(r1 + 5);
351                     _r16 = _mm256_broadcast_ss(r1 + 6);
352                     _r17 = _mm256_broadcast_ss(r1 + 7);
353 
354                     _k10 = _mm256_loadu_ps(kptr);
355                     _k11 = _mm256_loadu_ps(kptr + 8);
356                     _k12 = _mm256_loadu_ps(kptr + 16);
357                     _k13 = _mm256_loadu_ps(kptr + 24);
358                     kptr += 32;
359 
360                     _sum = _mm256_fmadd_ps(_k10, _r10, _sum);
361                     _sum = _mm256_fmadd_ps(_k11, _r11, _sum);
362                     _sum = _mm256_fmadd_ps(_k12, _r12, _sum);
363                     _sum = _mm256_fmadd_ps(_k13, _r13, _sum);
364 
365                     _k14 = _mm256_loadu_ps(kptr);
366                     _k15 = _mm256_loadu_ps(kptr + 8);
367                     _k16 = _mm256_loadu_ps(kptr + 16);
368                     _k17 = _mm256_loadu_ps(kptr + 24);
369                     _sum = _mm256_fmadd_ps(_k14, _r14, _sum);
370                     _sum = _mm256_fmadd_ps(_k15, _r15, _sum);
371                     _sum = _mm256_fmadd_ps(_k16, _r16, _sum);
372                     _sum = _mm256_fmadd_ps(_k17, _r17, _sum);
373 
374                     kptr -= 224;
375                     _mm256_storeu_ps(outptr0, _sum);
376                     outptr0 += 8;
377                 }
378 
379                 r0 += 8;
380                 r1 += 8;
381             }
382         }
383     }
384 }
385