1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
convdw5x5s1_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void convdw5x5s1_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17     int outw = top_blob.w;
18     int outh = top_blob.h;
19 
20     const int group = bottom_blob.c;
21 
22     const float* bias = _bias;
23     #pragma omp parallel for num_threads(opt.num_threads)
24     for (int g = 0; g < group; g++)
25     {
26         Mat out = top_blob.channel(g);
27 
28         __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + g * 8) : _mm256_set1_ps(0.f);
29 
30         const float* k0 = kernel.row(g);
31 
32         float* outptr0 = out.row(0);
33 
34         const Mat img0 = bottom_blob.channel(g);
35 
36         const float* r0 = img0.row(0);
37         const float* r1 = img0.row(1);
38         const float* r2 = img0.row(2);
39         const float* r3 = img0.row(3);
40         const float* r4 = img0.row(4);
41 
42         int i = 0;
43         for (; i < outh; i++)
44         {
45             int j = 0;
46 
47             for (; j < outw; j++)
48             {
49                 __m256 _sum0 = _bias0;
50 
51                 __m256 _r00 = _mm256_loadu_ps(r0);
52                 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
53                 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
54                 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
55                 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
56 
57                 __m256 _k00 = _mm256_loadu_ps(k0);
58                 __m256 _k01 = _mm256_loadu_ps(k0 + 8);
59                 __m256 _k02 = _mm256_loadu_ps(k0 + 16);
60                 __m256 _k03 = _mm256_loadu_ps(k0 + 24);
61                 __m256 _k04 = _mm256_loadu_ps(k0 + 32);
62                 k0 += 40;
63 
64                 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
65                 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
66                 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
67                 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
68                 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
69 
70                 __m256 _r10 = _mm256_loadu_ps(r1);
71                 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
72                 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
73                 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
74                 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
75 
76                 __m256 _k10 = _mm256_loadu_ps(k0);
77                 __m256 _k11 = _mm256_loadu_ps(k0 + 8);
78                 __m256 _k12 = _mm256_loadu_ps(k0 + 16);
79                 __m256 _k13 = _mm256_loadu_ps(k0 + 24);
80                 __m256 _k14 = _mm256_loadu_ps(k0 + 32);
81                 k0 += 40;
82 
83                 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
84                 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
85                 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
86                 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
87                 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
88 
89                 __m256 _r20 = _mm256_loadu_ps(r2);
90                 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
91                 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
92                 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
93                 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
94 
95                 __m256 _k20 = _mm256_loadu_ps(k0);
96                 __m256 _k21 = _mm256_loadu_ps(k0 + 8);
97                 __m256 _k22 = _mm256_loadu_ps(k0 + 16);
98                 __m256 _k23 = _mm256_loadu_ps(k0 + 24);
99                 __m256 _k24 = _mm256_loadu_ps(k0 + 32);
100                 k0 += 40;
101 
102                 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
103                 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
104                 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
105                 _sum0 = _mm256_fmadd_ps(_k23, _r23, _sum0);
106                 _sum0 = _mm256_fmadd_ps(_k24, _r24, _sum0);
107 
108                 __m256 _r30 = _mm256_loadu_ps(r3);
109                 __m256 _r31 = _mm256_loadu_ps(r3 + 8);
110                 __m256 _r32 = _mm256_loadu_ps(r3 + 16);
111                 __m256 _r33 = _mm256_loadu_ps(r3 + 24);
112                 __m256 _r34 = _mm256_loadu_ps(r3 + 32);
113 
114                 __m256 _k30 = _mm256_loadu_ps(k0);
115                 __m256 _k31 = _mm256_loadu_ps(k0 + 8);
116                 __m256 _k32 = _mm256_loadu_ps(k0 + 16);
117                 __m256 _k33 = _mm256_loadu_ps(k0 + 24);
118                 __m256 _k34 = _mm256_loadu_ps(k0 + 32);
119                 k0 += 40;
120 
121                 _sum0 = _mm256_fmadd_ps(_k30, _r30, _sum0);
122                 _sum0 = _mm256_fmadd_ps(_k31, _r31, _sum0);
123                 _sum0 = _mm256_fmadd_ps(_k32, _r32, _sum0);
124                 _sum0 = _mm256_fmadd_ps(_k33, _r33, _sum0);
125                 _sum0 = _mm256_fmadd_ps(_k34, _r34, _sum0);
126 
127                 __m256 _r40 = _mm256_loadu_ps(r4);
128                 __m256 _r41 = _mm256_loadu_ps(r4 + 8);
129                 __m256 _r42 = _mm256_loadu_ps(r4 + 16);
130                 __m256 _r43 = _mm256_loadu_ps(r4 + 24);
131                 __m256 _r44 = _mm256_loadu_ps(r4 + 32);
132 
133                 __m256 _k40 = _mm256_loadu_ps(k0);
134                 __m256 _k41 = _mm256_loadu_ps(k0 + 8);
135                 __m256 _k42 = _mm256_loadu_ps(k0 + 16);
136                 __m256 _k43 = _mm256_loadu_ps(k0 + 24);
137                 __m256 _k44 = _mm256_loadu_ps(k0 + 32);
138                 k0 -= 160;
139 
140                 _sum0 = _mm256_fmadd_ps(_k40, _r40, _sum0);
141                 _sum0 = _mm256_fmadd_ps(_k41, _r41, _sum0);
142                 _sum0 = _mm256_fmadd_ps(_k42, _r42, _sum0);
143                 _sum0 = _mm256_fmadd_ps(_k43, _r43, _sum0);
144                 _sum0 = _mm256_fmadd_ps(_k44, _r44, _sum0);
145 
146                 _mm256_storeu_ps(outptr0, _sum0);
147 
148                 r0 += 8;
149                 r1 += 8;
150                 r2 += 8;
151                 r3 += 8;
152                 r4 += 8;
153                 outptr0 += 8;
154             }
155 
156             r0 += 4 * 8;
157             r1 += 4 * 8;
158             r2 += 4 * 8;
159             r3 += 4 * 8;
160             r4 += 4 * 8;
161         }
162     }
163 }
164 
convdw5x5s2_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)165 static void convdw5x5s2_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
166 {
167     int w = bottom_blob.w;
168 
169     int outw = top_blob.w;
170     int outh = top_blob.h;
171 
172     const int group = bottom_blob.c;
173 
174     const int tailstep = (w - 2 * outw + w) * 8;
175 
176     const float* bias = _bias;
177     #pragma omp parallel for num_threads(opt.num_threads)
178     for (int g = 0; g < group; g++)
179     {
180         Mat out = top_blob.channel(g);
181 
182         __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + g * 8) : _mm256_set1_ps(0.f);
183 
184         const float* k0 = kernel.row(g);
185 
186         float* outptr0 = out.row(0);
187 
188         const Mat img0 = bottom_blob.channel(g);
189 
190         const float* r0 = img0.row(0);
191         const float* r1 = img0.row(1);
192         const float* r2 = img0.row(2);
193         const float* r3 = img0.row(3);
194         const float* r4 = img0.row(4);
195 
196         int i = 0;
197         for (; i < outh; i++)
198         {
199             int j = 0;
200 
201             for (; j < outw; j++)
202             {
203                 __m256 _sum0 = _bias0;
204 
205                 __m256 _r00 = _mm256_loadu_ps(r0);
206                 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
207                 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
208                 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
209                 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
210 
211                 __m256 _k00 = _mm256_loadu_ps(k0);
212                 __m256 _k01 = _mm256_loadu_ps(k0 + 8);
213                 __m256 _k02 = _mm256_loadu_ps(k0 + 16);
214                 __m256 _k03 = _mm256_loadu_ps(k0 + 24);
215                 __m256 _k04 = _mm256_loadu_ps(k0 + 32);
216                 k0 += 40;
217 
218                 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
219                 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
220                 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
221                 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
222                 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
223 
224                 __m256 _r10 = _mm256_loadu_ps(r1);
225                 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
226                 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
227                 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
228                 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
229 
230                 __m256 _k10 = _mm256_loadu_ps(k0);
231                 __m256 _k11 = _mm256_loadu_ps(k0 + 8);
232                 __m256 _k12 = _mm256_loadu_ps(k0 + 16);
233                 __m256 _k13 = _mm256_loadu_ps(k0 + 24);
234                 __m256 _k14 = _mm256_loadu_ps(k0 + 32);
235                 k0 += 40;
236 
237                 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
238                 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
239                 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
240                 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
241                 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
242 
243                 __m256 _r20 = _mm256_loadu_ps(r2);
244                 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
245                 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
246                 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
247                 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
248 
249                 __m256 _k20 = _mm256_loadu_ps(k0);
250                 __m256 _k21 = _mm256_loadu_ps(k0 + 8);
251                 __m256 _k22 = _mm256_loadu_ps(k0 + 16);
252                 __m256 _k23 = _mm256_loadu_ps(k0 + 24);
253                 __m256 _k24 = _mm256_loadu_ps(k0 + 32);
254                 k0 += 40;
255 
256                 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
257                 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
258                 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
259                 _sum0 = _mm256_fmadd_ps(_k23, _r23, _sum0);
260                 _sum0 = _mm256_fmadd_ps(_k24, _r24, _sum0);
261 
262                 __m256 _r30 = _mm256_loadu_ps(r3);
263                 __m256 _r31 = _mm256_loadu_ps(r3 + 8);
264                 __m256 _r32 = _mm256_loadu_ps(r3 + 16);
265                 __m256 _r33 = _mm256_loadu_ps(r3 + 24);
266                 __m256 _r34 = _mm256_loadu_ps(r3 + 32);
267 
268                 __m256 _k30 = _mm256_loadu_ps(k0);
269                 __m256 _k31 = _mm256_loadu_ps(k0 + 8);
270                 __m256 _k32 = _mm256_loadu_ps(k0 + 16);
271                 __m256 _k33 = _mm256_loadu_ps(k0 + 24);
272                 __m256 _k34 = _mm256_loadu_ps(k0 + 32);
273                 k0 += 40;
274 
275                 _sum0 = _mm256_fmadd_ps(_k30, _r30, _sum0);
276                 _sum0 = _mm256_fmadd_ps(_k31, _r31, _sum0);
277                 _sum0 = _mm256_fmadd_ps(_k32, _r32, _sum0);
278                 _sum0 = _mm256_fmadd_ps(_k33, _r33, _sum0);
279                 _sum0 = _mm256_fmadd_ps(_k34, _r34, _sum0);
280 
281                 __m256 _r40 = _mm256_loadu_ps(r4);
282                 __m256 _r41 = _mm256_loadu_ps(r4 + 8);
283                 __m256 _r42 = _mm256_loadu_ps(r4 + 16);
284                 __m256 _r43 = _mm256_loadu_ps(r4 + 24);
285                 __m256 _r44 = _mm256_loadu_ps(r4 + 32);
286 
287                 __m256 _k40 = _mm256_loadu_ps(k0);
288                 __m256 _k41 = _mm256_loadu_ps(k0 + 8);
289                 __m256 _k42 = _mm256_loadu_ps(k0 + 16);
290                 __m256 _k43 = _mm256_loadu_ps(k0 + 24);
291                 __m256 _k44 = _mm256_loadu_ps(k0 + 32);
292                 k0 -= 160;
293 
294                 _sum0 = _mm256_fmadd_ps(_k40, _r40, _sum0);
295                 _sum0 = _mm256_fmadd_ps(_k41, _r41, _sum0);
296                 _sum0 = _mm256_fmadd_ps(_k42, _r42, _sum0);
297                 _sum0 = _mm256_fmadd_ps(_k43, _r43, _sum0);
298                 _sum0 = _mm256_fmadd_ps(_k44, _r44, _sum0);
299 
300                 _mm256_storeu_ps(outptr0, _sum0);
301 
302                 r0 += 16;
303                 r1 += 16;
304                 r2 += 16;
305                 r3 += 16;
306                 r4 += 16;
307                 outptr0 += 8;
308             }
309 
310             r0 += tailstep;
311             r1 += tailstep;
312             r2 += tailstep;
313             r3 += tailstep;
314             r4 += tailstep;
315         }
316     }
317 }
318