1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
convdw5x5s1_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void convdw5x5s1_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 int outw = top_blob.w;
18 int outh = top_blob.h;
19
20 const int group = bottom_blob.c;
21
22 const float* bias = _bias;
23 #pragma omp parallel for num_threads(opt.num_threads)
24 for (int g = 0; g < group; g++)
25 {
26 Mat out = top_blob.channel(g);
27
28 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + g * 8) : _mm256_set1_ps(0.f);
29
30 const float* k0 = kernel.row(g);
31
32 float* outptr0 = out.row(0);
33
34 const Mat img0 = bottom_blob.channel(g);
35
36 const float* r0 = img0.row(0);
37 const float* r1 = img0.row(1);
38 const float* r2 = img0.row(2);
39 const float* r3 = img0.row(3);
40 const float* r4 = img0.row(4);
41
42 int i = 0;
43 for (; i < outh; i++)
44 {
45 int j = 0;
46
47 for (; j < outw; j++)
48 {
49 __m256 _sum0 = _bias0;
50
51 __m256 _r00 = _mm256_loadu_ps(r0);
52 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
53 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
54 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
55 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
56
57 __m256 _k00 = _mm256_loadu_ps(k0);
58 __m256 _k01 = _mm256_loadu_ps(k0 + 8);
59 __m256 _k02 = _mm256_loadu_ps(k0 + 16);
60 __m256 _k03 = _mm256_loadu_ps(k0 + 24);
61 __m256 _k04 = _mm256_loadu_ps(k0 + 32);
62 k0 += 40;
63
64 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
65 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
66 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
67 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
68 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
69
70 __m256 _r10 = _mm256_loadu_ps(r1);
71 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
72 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
73 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
74 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
75
76 __m256 _k10 = _mm256_loadu_ps(k0);
77 __m256 _k11 = _mm256_loadu_ps(k0 + 8);
78 __m256 _k12 = _mm256_loadu_ps(k0 + 16);
79 __m256 _k13 = _mm256_loadu_ps(k0 + 24);
80 __m256 _k14 = _mm256_loadu_ps(k0 + 32);
81 k0 += 40;
82
83 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
84 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
85 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
86 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
87 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
88
89 __m256 _r20 = _mm256_loadu_ps(r2);
90 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
91 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
92 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
93 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
94
95 __m256 _k20 = _mm256_loadu_ps(k0);
96 __m256 _k21 = _mm256_loadu_ps(k0 + 8);
97 __m256 _k22 = _mm256_loadu_ps(k0 + 16);
98 __m256 _k23 = _mm256_loadu_ps(k0 + 24);
99 __m256 _k24 = _mm256_loadu_ps(k0 + 32);
100 k0 += 40;
101
102 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
103 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
104 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
105 _sum0 = _mm256_fmadd_ps(_k23, _r23, _sum0);
106 _sum0 = _mm256_fmadd_ps(_k24, _r24, _sum0);
107
108 __m256 _r30 = _mm256_loadu_ps(r3);
109 __m256 _r31 = _mm256_loadu_ps(r3 + 8);
110 __m256 _r32 = _mm256_loadu_ps(r3 + 16);
111 __m256 _r33 = _mm256_loadu_ps(r3 + 24);
112 __m256 _r34 = _mm256_loadu_ps(r3 + 32);
113
114 __m256 _k30 = _mm256_loadu_ps(k0);
115 __m256 _k31 = _mm256_loadu_ps(k0 + 8);
116 __m256 _k32 = _mm256_loadu_ps(k0 + 16);
117 __m256 _k33 = _mm256_loadu_ps(k0 + 24);
118 __m256 _k34 = _mm256_loadu_ps(k0 + 32);
119 k0 += 40;
120
121 _sum0 = _mm256_fmadd_ps(_k30, _r30, _sum0);
122 _sum0 = _mm256_fmadd_ps(_k31, _r31, _sum0);
123 _sum0 = _mm256_fmadd_ps(_k32, _r32, _sum0);
124 _sum0 = _mm256_fmadd_ps(_k33, _r33, _sum0);
125 _sum0 = _mm256_fmadd_ps(_k34, _r34, _sum0);
126
127 __m256 _r40 = _mm256_loadu_ps(r4);
128 __m256 _r41 = _mm256_loadu_ps(r4 + 8);
129 __m256 _r42 = _mm256_loadu_ps(r4 + 16);
130 __m256 _r43 = _mm256_loadu_ps(r4 + 24);
131 __m256 _r44 = _mm256_loadu_ps(r4 + 32);
132
133 __m256 _k40 = _mm256_loadu_ps(k0);
134 __m256 _k41 = _mm256_loadu_ps(k0 + 8);
135 __m256 _k42 = _mm256_loadu_ps(k0 + 16);
136 __m256 _k43 = _mm256_loadu_ps(k0 + 24);
137 __m256 _k44 = _mm256_loadu_ps(k0 + 32);
138 k0 -= 160;
139
140 _sum0 = _mm256_fmadd_ps(_k40, _r40, _sum0);
141 _sum0 = _mm256_fmadd_ps(_k41, _r41, _sum0);
142 _sum0 = _mm256_fmadd_ps(_k42, _r42, _sum0);
143 _sum0 = _mm256_fmadd_ps(_k43, _r43, _sum0);
144 _sum0 = _mm256_fmadd_ps(_k44, _r44, _sum0);
145
146 _mm256_storeu_ps(outptr0, _sum0);
147
148 r0 += 8;
149 r1 += 8;
150 r2 += 8;
151 r3 += 8;
152 r4 += 8;
153 outptr0 += 8;
154 }
155
156 r0 += 4 * 8;
157 r1 += 4 * 8;
158 r2 += 4 * 8;
159 r3 += 4 * 8;
160 r4 += 4 * 8;
161 }
162 }
163 }
164
convdw5x5s2_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)165 static void convdw5x5s2_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
166 {
167 int w = bottom_blob.w;
168
169 int outw = top_blob.w;
170 int outh = top_blob.h;
171
172 const int group = bottom_blob.c;
173
174 const int tailstep = (w - 2 * outw + w) * 8;
175
176 const float* bias = _bias;
177 #pragma omp parallel for num_threads(opt.num_threads)
178 for (int g = 0; g < group; g++)
179 {
180 Mat out = top_blob.channel(g);
181
182 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + g * 8) : _mm256_set1_ps(0.f);
183
184 const float* k0 = kernel.row(g);
185
186 float* outptr0 = out.row(0);
187
188 const Mat img0 = bottom_blob.channel(g);
189
190 const float* r0 = img0.row(0);
191 const float* r1 = img0.row(1);
192 const float* r2 = img0.row(2);
193 const float* r3 = img0.row(3);
194 const float* r4 = img0.row(4);
195
196 int i = 0;
197 for (; i < outh; i++)
198 {
199 int j = 0;
200
201 for (; j < outw; j++)
202 {
203 __m256 _sum0 = _bias0;
204
205 __m256 _r00 = _mm256_loadu_ps(r0);
206 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
207 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
208 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
209 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
210
211 __m256 _k00 = _mm256_loadu_ps(k0);
212 __m256 _k01 = _mm256_loadu_ps(k0 + 8);
213 __m256 _k02 = _mm256_loadu_ps(k0 + 16);
214 __m256 _k03 = _mm256_loadu_ps(k0 + 24);
215 __m256 _k04 = _mm256_loadu_ps(k0 + 32);
216 k0 += 40;
217
218 _sum0 = _mm256_fmadd_ps(_k00, _r00, _sum0);
219 _sum0 = _mm256_fmadd_ps(_k01, _r01, _sum0);
220 _sum0 = _mm256_fmadd_ps(_k02, _r02, _sum0);
221 _sum0 = _mm256_fmadd_ps(_k03, _r03, _sum0);
222 _sum0 = _mm256_fmadd_ps(_k04, _r04, _sum0);
223
224 __m256 _r10 = _mm256_loadu_ps(r1);
225 __m256 _r11 = _mm256_loadu_ps(r1 + 8);
226 __m256 _r12 = _mm256_loadu_ps(r1 + 16);
227 __m256 _r13 = _mm256_loadu_ps(r1 + 24);
228 __m256 _r14 = _mm256_loadu_ps(r1 + 32);
229
230 __m256 _k10 = _mm256_loadu_ps(k0);
231 __m256 _k11 = _mm256_loadu_ps(k0 + 8);
232 __m256 _k12 = _mm256_loadu_ps(k0 + 16);
233 __m256 _k13 = _mm256_loadu_ps(k0 + 24);
234 __m256 _k14 = _mm256_loadu_ps(k0 + 32);
235 k0 += 40;
236
237 _sum0 = _mm256_fmadd_ps(_k10, _r10, _sum0);
238 _sum0 = _mm256_fmadd_ps(_k11, _r11, _sum0);
239 _sum0 = _mm256_fmadd_ps(_k12, _r12, _sum0);
240 _sum0 = _mm256_fmadd_ps(_k13, _r13, _sum0);
241 _sum0 = _mm256_fmadd_ps(_k14, _r14, _sum0);
242
243 __m256 _r20 = _mm256_loadu_ps(r2);
244 __m256 _r21 = _mm256_loadu_ps(r2 + 8);
245 __m256 _r22 = _mm256_loadu_ps(r2 + 16);
246 __m256 _r23 = _mm256_loadu_ps(r2 + 24);
247 __m256 _r24 = _mm256_loadu_ps(r2 + 32);
248
249 __m256 _k20 = _mm256_loadu_ps(k0);
250 __m256 _k21 = _mm256_loadu_ps(k0 + 8);
251 __m256 _k22 = _mm256_loadu_ps(k0 + 16);
252 __m256 _k23 = _mm256_loadu_ps(k0 + 24);
253 __m256 _k24 = _mm256_loadu_ps(k0 + 32);
254 k0 += 40;
255
256 _sum0 = _mm256_fmadd_ps(_k20, _r20, _sum0);
257 _sum0 = _mm256_fmadd_ps(_k21, _r21, _sum0);
258 _sum0 = _mm256_fmadd_ps(_k22, _r22, _sum0);
259 _sum0 = _mm256_fmadd_ps(_k23, _r23, _sum0);
260 _sum0 = _mm256_fmadd_ps(_k24, _r24, _sum0);
261
262 __m256 _r30 = _mm256_loadu_ps(r3);
263 __m256 _r31 = _mm256_loadu_ps(r3 + 8);
264 __m256 _r32 = _mm256_loadu_ps(r3 + 16);
265 __m256 _r33 = _mm256_loadu_ps(r3 + 24);
266 __m256 _r34 = _mm256_loadu_ps(r3 + 32);
267
268 __m256 _k30 = _mm256_loadu_ps(k0);
269 __m256 _k31 = _mm256_loadu_ps(k0 + 8);
270 __m256 _k32 = _mm256_loadu_ps(k0 + 16);
271 __m256 _k33 = _mm256_loadu_ps(k0 + 24);
272 __m256 _k34 = _mm256_loadu_ps(k0 + 32);
273 k0 += 40;
274
275 _sum0 = _mm256_fmadd_ps(_k30, _r30, _sum0);
276 _sum0 = _mm256_fmadd_ps(_k31, _r31, _sum0);
277 _sum0 = _mm256_fmadd_ps(_k32, _r32, _sum0);
278 _sum0 = _mm256_fmadd_ps(_k33, _r33, _sum0);
279 _sum0 = _mm256_fmadd_ps(_k34, _r34, _sum0);
280
281 __m256 _r40 = _mm256_loadu_ps(r4);
282 __m256 _r41 = _mm256_loadu_ps(r4 + 8);
283 __m256 _r42 = _mm256_loadu_ps(r4 + 16);
284 __m256 _r43 = _mm256_loadu_ps(r4 + 24);
285 __m256 _r44 = _mm256_loadu_ps(r4 + 32);
286
287 __m256 _k40 = _mm256_loadu_ps(k0);
288 __m256 _k41 = _mm256_loadu_ps(k0 + 8);
289 __m256 _k42 = _mm256_loadu_ps(k0 + 16);
290 __m256 _k43 = _mm256_loadu_ps(k0 + 24);
291 __m256 _k44 = _mm256_loadu_ps(k0 + 32);
292 k0 -= 160;
293
294 _sum0 = _mm256_fmadd_ps(_k40, _r40, _sum0);
295 _sum0 = _mm256_fmadd_ps(_k41, _r41, _sum0);
296 _sum0 = _mm256_fmadd_ps(_k42, _r42, _sum0);
297 _sum0 = _mm256_fmadd_ps(_k43, _r43, _sum0);
298 _sum0 = _mm256_fmadd_ps(_k44, _r44, _sum0);
299
300 _mm256_storeu_ps(outptr0, _sum0);
301
302 r0 += 16;
303 r1 += 16;
304 r2 += 16;
305 r3 += 16;
306 r4 += 16;
307 outptr0 += 8;
308 }
309
310 r0 += tailstep;
311 r1 += tailstep;
312 r2 += tailstep;
313 r3 += tailstep;
314 r4 += tailstep;
315 }
316 }
317 }
318