1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv3x3s1_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void conv3x3s1_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 int inch = bottom_blob.c;
18
19 int outw = top_blob.w;
20 int outh = top_blob.h;
21 int outch = top_blob.c;
22
23 const float* bias = _bias;
24
25 #pragma omp parallel for num_threads(opt.num_threads)
26 for (int p = 0; p < outch; p++)
27 {
28 Mat out = top_blob.channel(p);
29
30 __m256 _bias0 = bias ? _mm256_loadu_ps(bias + p * 8) : _mm256_set1_ps(0.f);
31 out.fill(_bias0);
32
33 for (int q = 0; q < inch; q++)
34 {
35 float* outptr = out;
36
37 const Mat img0 = bottom_blob.channel(q);
38
39 const float* r0 = img0.row(0);
40 const float* r1 = img0.row(1);
41 const float* r2 = img0.row(2);
42
43 const float* kptr = kernel.channel(p).row(q);
44
45 int i = 0;
46 for (; i < outh; i++)
47 {
48 int j = 0;
49 for (; j + 1 < outw; j += 2)
50 {
51 __m256 _sum00 = _mm256_loadu_ps(outptr);
52 __m256 _sum01 = _mm256_setzero_ps();
53 __m256 _sum10 = _mm256_loadu_ps(outptr + 8);
54 __m256 _sum11 = _mm256_setzero_ps();
55
56 __m256 _r000 = _mm256_broadcast_ss(r0 + 0);
57 __m256 _r001 = _mm256_broadcast_ss(r0 + 1);
58 __m256 _r002 = _mm256_broadcast_ss(r0 + 2);
59 __m256 _r003 = _mm256_broadcast_ss(r0 + 3);
60 __m256 _r004 = _mm256_broadcast_ss(r0 + 4);
61 __m256 _r005 = _mm256_broadcast_ss(r0 + 5);
62 __m256 _r006 = _mm256_broadcast_ss(r0 + 6);
63 __m256 _r007 = _mm256_broadcast_ss(r0 + 7);
64
65 __m256 _k00 = _mm256_loadu_ps(kptr);
66 __m256 _k01 = _mm256_loadu_ps(kptr + 8);
67 __m256 _k02 = _mm256_loadu_ps(kptr + 16);
68 __m256 _k03 = _mm256_loadu_ps(kptr + 24);
69 __m256 _k04 = _mm256_loadu_ps(kptr + 32);
70 __m256 _k05 = _mm256_loadu_ps(kptr + 40);
71 __m256 _k06 = _mm256_loadu_ps(kptr + 48);
72 __m256 _k07 = _mm256_loadu_ps(kptr + 56);
73
74 kptr += 64;
75
76 _sum00 = _mm256_fmadd_ps(_r000, _k00, _sum00);
77 _sum01 = _mm256_fmadd_ps(_r001, _k01, _sum01);
78 _sum00 = _mm256_fmadd_ps(_r002, _k02, _sum00);
79 _sum01 = _mm256_fmadd_ps(_r003, _k03, _sum01);
80 _sum00 = _mm256_fmadd_ps(_r004, _k04, _sum00);
81 _sum01 = _mm256_fmadd_ps(_r005, _k05, _sum01);
82 _sum00 = _mm256_fmadd_ps(_r006, _k06, _sum00);
83 _sum01 = _mm256_fmadd_ps(_r007, _k07, _sum01);
84
85 __m256 _r010 = _mm256_broadcast_ss(r0 + 8);
86 __m256 _r011 = _mm256_broadcast_ss(r0 + 9);
87 __m256 _r012 = _mm256_broadcast_ss(r0 + 10);
88 __m256 _r013 = _mm256_broadcast_ss(r0 + 11);
89 __m256 _r014 = _mm256_broadcast_ss(r0 + 12);
90 __m256 _r015 = _mm256_broadcast_ss(r0 + 13);
91 __m256 _r016 = _mm256_broadcast_ss(r0 + 14);
92 __m256 _r017 = _mm256_broadcast_ss(r0 + 15);
93
94 _sum10 = _mm256_fmadd_ps(_r010, _k00, _sum10);
95 _sum11 = _mm256_fmadd_ps(_r011, _k01, _sum11);
96 _sum10 = _mm256_fmadd_ps(_r012, _k02, _sum10);
97 _sum11 = _mm256_fmadd_ps(_r013, _k03, _sum11);
98 _sum10 = _mm256_fmadd_ps(_r014, _k04, _sum10);
99 _sum11 = _mm256_fmadd_ps(_r015, _k05, _sum11);
100 _sum10 = _mm256_fmadd_ps(_r016, _k06, _sum10);
101 _sum11 = _mm256_fmadd_ps(_r017, _k07, _sum11);
102
103 __m256 _k10 = _mm256_loadu_ps(kptr);
104 __m256 _k11 = _mm256_loadu_ps(kptr + 8);
105 __m256 _k12 = _mm256_loadu_ps(kptr + 16);
106 __m256 _k13 = _mm256_loadu_ps(kptr + 24);
107 __m256 _k14 = _mm256_loadu_ps(kptr + 32);
108 __m256 _k15 = _mm256_loadu_ps(kptr + 40);
109 __m256 _k16 = _mm256_loadu_ps(kptr + 48);
110 __m256 _k17 = _mm256_loadu_ps(kptr + 56);
111
112 kptr += 64;
113
114 _sum00 = _mm256_fmadd_ps(_r010, _k10, _sum00);
115 _sum01 = _mm256_fmadd_ps(_r011, _k11, _sum01);
116 _sum00 = _mm256_fmadd_ps(_r012, _k12, _sum00);
117 _sum01 = _mm256_fmadd_ps(_r013, _k13, _sum01);
118 _sum00 = _mm256_fmadd_ps(_r014, _k14, _sum00);
119 _sum01 = _mm256_fmadd_ps(_r015, _k15, _sum01);
120 _sum00 = _mm256_fmadd_ps(_r016, _k16, _sum00);
121 _sum01 = _mm256_fmadd_ps(_r017, _k17, _sum01);
122
123 __m256 _r020 = _mm256_broadcast_ss(r0 + 16);
124 __m256 _r021 = _mm256_broadcast_ss(r0 + 17);
125 __m256 _r022 = _mm256_broadcast_ss(r0 + 18);
126 __m256 _r023 = _mm256_broadcast_ss(r0 + 19);
127 __m256 _r024 = _mm256_broadcast_ss(r0 + 20);
128 __m256 _r025 = _mm256_broadcast_ss(r0 + 21);
129 __m256 _r026 = _mm256_broadcast_ss(r0 + 22);
130 __m256 _r027 = _mm256_broadcast_ss(r0 + 23);
131
132 _sum10 = _mm256_fmadd_ps(_r020, _k10, _sum10);
133 _sum11 = _mm256_fmadd_ps(_r021, _k11, _sum11);
134 _sum10 = _mm256_fmadd_ps(_r022, _k12, _sum10);
135 _sum11 = _mm256_fmadd_ps(_r023, _k13, _sum11);
136 _sum10 = _mm256_fmadd_ps(_r024, _k14, _sum10);
137 _sum11 = _mm256_fmadd_ps(_r025, _k15, _sum11);
138 _sum10 = _mm256_fmadd_ps(_r026, _k16, _sum10);
139 _sum11 = _mm256_fmadd_ps(_r027, _k17, _sum11);
140
141 __m256 _k20 = _mm256_loadu_ps(kptr);
142 __m256 _k21 = _mm256_loadu_ps(kptr + 8);
143 __m256 _k22 = _mm256_loadu_ps(kptr + 16);
144 __m256 _k23 = _mm256_loadu_ps(kptr + 24);
145 __m256 _k24 = _mm256_loadu_ps(kptr + 32);
146 __m256 _k25 = _mm256_loadu_ps(kptr + 40);
147 __m256 _k26 = _mm256_loadu_ps(kptr + 48);
148 __m256 _k27 = _mm256_loadu_ps(kptr + 56);
149
150 kptr += 64;
151
152 _sum00 = _mm256_fmadd_ps(_r020, _k20, _sum00);
153 _sum01 = _mm256_fmadd_ps(_r021, _k21, _sum01);
154 _sum00 = _mm256_fmadd_ps(_r022, _k22, _sum00);
155 _sum01 = _mm256_fmadd_ps(_r023, _k23, _sum01);
156 _sum00 = _mm256_fmadd_ps(_r024, _k24, _sum00);
157 _sum01 = _mm256_fmadd_ps(_r025, _k25, _sum01);
158 _sum00 = _mm256_fmadd_ps(_r026, _k26, _sum00);
159 _sum01 = _mm256_fmadd_ps(_r027, _k27, _sum01);
160
161 __m256 _r030 = _mm256_broadcast_ss(r0 + 24);
162 __m256 _r031 = _mm256_broadcast_ss(r0 + 25);
163 __m256 _r032 = _mm256_broadcast_ss(r0 + 26);
164 __m256 _r033 = _mm256_broadcast_ss(r0 + 27);
165 __m256 _r034 = _mm256_broadcast_ss(r0 + 28);
166 __m256 _r035 = _mm256_broadcast_ss(r0 + 29);
167 __m256 _r036 = _mm256_broadcast_ss(r0 + 30);
168 __m256 _r037 = _mm256_broadcast_ss(r0 + 31);
169
170 _sum10 = _mm256_fmadd_ps(_r030, _k20, _sum10);
171 _sum11 = _mm256_fmadd_ps(_r031, _k21, _sum11);
172 _sum10 = _mm256_fmadd_ps(_r032, _k22, _sum10);
173 _sum11 = _mm256_fmadd_ps(_r033, _k23, _sum11);
174 _sum10 = _mm256_fmadd_ps(_r034, _k24, _sum10);
175 _sum11 = _mm256_fmadd_ps(_r035, _k25, _sum11);
176 _sum10 = _mm256_fmadd_ps(_r036, _k26, _sum10);
177 _sum11 = _mm256_fmadd_ps(_r037, _k27, _sum11);
178
179 __m256 _r100 = _mm256_broadcast_ss(r1 + 0);
180 __m256 _r101 = _mm256_broadcast_ss(r1 + 1);
181 __m256 _r102 = _mm256_broadcast_ss(r1 + 2);
182 __m256 _r103 = _mm256_broadcast_ss(r1 + 3);
183 __m256 _r104 = _mm256_broadcast_ss(r1 + 4);
184 __m256 _r105 = _mm256_broadcast_ss(r1 + 5);
185 __m256 _r106 = _mm256_broadcast_ss(r1 + 6);
186 __m256 _r107 = _mm256_broadcast_ss(r1 + 7);
187
188 __m256 _k30 = _mm256_loadu_ps(kptr);
189 __m256 _k31 = _mm256_loadu_ps(kptr + 8);
190 __m256 _k32 = _mm256_loadu_ps(kptr + 16);
191 __m256 _k33 = _mm256_loadu_ps(kptr + 24);
192 __m256 _k34 = _mm256_loadu_ps(kptr + 32);
193 __m256 _k35 = _mm256_loadu_ps(kptr + 40);
194 __m256 _k36 = _mm256_loadu_ps(kptr + 48);
195 __m256 _k37 = _mm256_loadu_ps(kptr + 56);
196
197 kptr += 64;
198
199 _sum00 = _mm256_fmadd_ps(_r100, _k30, _sum00);
200 _sum01 = _mm256_fmadd_ps(_r101, _k31, _sum01);
201 _sum00 = _mm256_fmadd_ps(_r102, _k32, _sum00);
202 _sum01 = _mm256_fmadd_ps(_r103, _k33, _sum01);
203 _sum00 = _mm256_fmadd_ps(_r104, _k34, _sum00);
204 _sum01 = _mm256_fmadd_ps(_r105, _k35, _sum01);
205 _sum00 = _mm256_fmadd_ps(_r106, _k36, _sum00);
206 _sum01 = _mm256_fmadd_ps(_r107, _k37, _sum01);
207
208 __m256 _r110 = _mm256_broadcast_ss(r1 + 8);
209 __m256 _r111 = _mm256_broadcast_ss(r1 + 9);
210 __m256 _r112 = _mm256_broadcast_ss(r1 + 10);
211 __m256 _r113 = _mm256_broadcast_ss(r1 + 11);
212 __m256 _r114 = _mm256_broadcast_ss(r1 + 12);
213 __m256 _r115 = _mm256_broadcast_ss(r1 + 13);
214 __m256 _r116 = _mm256_broadcast_ss(r1 + 14);
215 __m256 _r117 = _mm256_broadcast_ss(r1 + 15);
216
217 _sum10 = _mm256_fmadd_ps(_r110, _k30, _sum10);
218 _sum11 = _mm256_fmadd_ps(_r111, _k31, _sum11);
219 _sum10 = _mm256_fmadd_ps(_r112, _k32, _sum10);
220 _sum11 = _mm256_fmadd_ps(_r113, _k33, _sum11);
221 _sum10 = _mm256_fmadd_ps(_r114, _k34, _sum10);
222 _sum11 = _mm256_fmadd_ps(_r115, _k35, _sum11);
223 _sum10 = _mm256_fmadd_ps(_r116, _k36, _sum10);
224 _sum11 = _mm256_fmadd_ps(_r117, _k37, _sum11);
225
226 __m256 _k40 = _mm256_loadu_ps(kptr);
227 __m256 _k41 = _mm256_loadu_ps(kptr + 8);
228 __m256 _k42 = _mm256_loadu_ps(kptr + 16);
229 __m256 _k43 = _mm256_loadu_ps(kptr + 24);
230 __m256 _k44 = _mm256_loadu_ps(kptr + 32);
231 __m256 _k45 = _mm256_loadu_ps(kptr + 40);
232 __m256 _k46 = _mm256_loadu_ps(kptr + 48);
233 __m256 _k47 = _mm256_loadu_ps(kptr + 56);
234
235 kptr += 64;
236
237 _sum00 = _mm256_fmadd_ps(_r110, _k40, _sum00);
238 _sum01 = _mm256_fmadd_ps(_r111, _k41, _sum01);
239 _sum00 = _mm256_fmadd_ps(_r112, _k42, _sum00);
240 _sum01 = _mm256_fmadd_ps(_r113, _k43, _sum01);
241 _sum00 = _mm256_fmadd_ps(_r114, _k44, _sum00);
242 _sum01 = _mm256_fmadd_ps(_r115, _k45, _sum01);
243 _sum00 = _mm256_fmadd_ps(_r116, _k46, _sum00);
244 _sum01 = _mm256_fmadd_ps(_r117, _k47, _sum01);
245
246 __m256 _r120 = _mm256_broadcast_ss(r1 + 16);
247 __m256 _r121 = _mm256_broadcast_ss(r1 + 17);
248 __m256 _r122 = _mm256_broadcast_ss(r1 + 18);
249 __m256 _r123 = _mm256_broadcast_ss(r1 + 19);
250 __m256 _r124 = _mm256_broadcast_ss(r1 + 20);
251 __m256 _r125 = _mm256_broadcast_ss(r1 + 21);
252 __m256 _r126 = _mm256_broadcast_ss(r1 + 22);
253 __m256 _r127 = _mm256_broadcast_ss(r1 + 23);
254
255 _sum10 = _mm256_fmadd_ps(_r120, _k40, _sum10);
256 _sum11 = _mm256_fmadd_ps(_r121, _k41, _sum11);
257 _sum10 = _mm256_fmadd_ps(_r122, _k42, _sum10);
258 _sum11 = _mm256_fmadd_ps(_r123, _k43, _sum11);
259 _sum10 = _mm256_fmadd_ps(_r124, _k44, _sum10);
260 _sum11 = _mm256_fmadd_ps(_r125, _k45, _sum11);
261 _sum10 = _mm256_fmadd_ps(_r126, _k46, _sum10);
262 _sum11 = _mm256_fmadd_ps(_r127, _k47, _sum11);
263
264 __m256 _k50 = _mm256_loadu_ps(kptr);
265 __m256 _k51 = _mm256_loadu_ps(kptr + 8);
266 __m256 _k52 = _mm256_loadu_ps(kptr + 16);
267 __m256 _k53 = _mm256_loadu_ps(kptr + 24);
268 __m256 _k54 = _mm256_loadu_ps(kptr + 32);
269 __m256 _k55 = _mm256_loadu_ps(kptr + 40);
270 __m256 _k56 = _mm256_loadu_ps(kptr + 48);
271 __m256 _k57 = _mm256_loadu_ps(kptr + 56);
272
273 kptr += 64;
274
275 _sum00 = _mm256_fmadd_ps(_r120, _k50, _sum00);
276 _sum01 = _mm256_fmadd_ps(_r121, _k51, _sum01);
277 _sum00 = _mm256_fmadd_ps(_r122, _k52, _sum00);
278 _sum01 = _mm256_fmadd_ps(_r123, _k53, _sum01);
279 _sum00 = _mm256_fmadd_ps(_r124, _k54, _sum00);
280 _sum01 = _mm256_fmadd_ps(_r125, _k55, _sum01);
281 _sum00 = _mm256_fmadd_ps(_r126, _k56, _sum00);
282 _sum01 = _mm256_fmadd_ps(_r127, _k57, _sum01);
283
284 __m256 _r130 = _mm256_broadcast_ss(r1 + 24);
285 __m256 _r131 = _mm256_broadcast_ss(r1 + 25);
286 __m256 _r132 = _mm256_broadcast_ss(r1 + 26);
287 __m256 _r133 = _mm256_broadcast_ss(r1 + 27);
288 __m256 _r134 = _mm256_broadcast_ss(r1 + 28);
289 __m256 _r135 = _mm256_broadcast_ss(r1 + 29);
290 __m256 _r136 = _mm256_broadcast_ss(r1 + 30);
291 __m256 _r137 = _mm256_broadcast_ss(r1 + 31);
292
293 _sum10 = _mm256_fmadd_ps(_r130, _k50, _sum10);
294 _sum11 = _mm256_fmadd_ps(_r131, _k51, _sum11);
295 _sum10 = _mm256_fmadd_ps(_r132, _k52, _sum10);
296 _sum11 = _mm256_fmadd_ps(_r133, _k53, _sum11);
297 _sum10 = _mm256_fmadd_ps(_r134, _k54, _sum10);
298 _sum11 = _mm256_fmadd_ps(_r135, _k55, _sum11);
299 _sum10 = _mm256_fmadd_ps(_r136, _k56, _sum10);
300 _sum11 = _mm256_fmadd_ps(_r137, _k57, _sum11);
301
302 __m256 _r200 = _mm256_broadcast_ss(r2 + 0);
303 __m256 _r201 = _mm256_broadcast_ss(r2 + 1);
304 __m256 _r202 = _mm256_broadcast_ss(r2 + 2);
305 __m256 _r203 = _mm256_broadcast_ss(r2 + 3);
306 __m256 _r204 = _mm256_broadcast_ss(r2 + 4);
307 __m256 _r205 = _mm256_broadcast_ss(r2 + 5);
308 __m256 _r206 = _mm256_broadcast_ss(r2 + 6);
309 __m256 _r207 = _mm256_broadcast_ss(r2 + 7);
310
311 __m256 _k60 = _mm256_loadu_ps(kptr);
312 __m256 _k61 = _mm256_loadu_ps(kptr + 8);
313 __m256 _k62 = _mm256_loadu_ps(kptr + 16);
314 __m256 _k63 = _mm256_loadu_ps(kptr + 24);
315 __m256 _k64 = _mm256_loadu_ps(kptr + 32);
316 __m256 _k65 = _mm256_loadu_ps(kptr + 40);
317 __m256 _k66 = _mm256_loadu_ps(kptr + 48);
318 __m256 _k67 = _mm256_loadu_ps(kptr + 56);
319
320 kptr += 64;
321
322 _sum00 = _mm256_fmadd_ps(_r200, _k60, _sum00);
323 _sum01 = _mm256_fmadd_ps(_r201, _k61, _sum01);
324 _sum00 = _mm256_fmadd_ps(_r202, _k62, _sum00);
325 _sum01 = _mm256_fmadd_ps(_r203, _k63, _sum01);
326 _sum00 = _mm256_fmadd_ps(_r204, _k64, _sum00);
327 _sum01 = _mm256_fmadd_ps(_r205, _k65, _sum01);
328 _sum00 = _mm256_fmadd_ps(_r206, _k66, _sum00);
329 _sum01 = _mm256_fmadd_ps(_r207, _k67, _sum01);
330
331 __m256 _r210 = _mm256_broadcast_ss(r2 + 8);
332 __m256 _r211 = _mm256_broadcast_ss(r2 + 9);
333 __m256 _r212 = _mm256_broadcast_ss(r2 + 10);
334 __m256 _r213 = _mm256_broadcast_ss(r2 + 11);
335 __m256 _r214 = _mm256_broadcast_ss(r2 + 12);
336 __m256 _r215 = _mm256_broadcast_ss(r2 + 13);
337 __m256 _r216 = _mm256_broadcast_ss(r2 + 14);
338 __m256 _r217 = _mm256_broadcast_ss(r2 + 15);
339
340 _sum10 = _mm256_fmadd_ps(_r210, _k60, _sum10);
341 _sum11 = _mm256_fmadd_ps(_r211, _k61, _sum11);
342 _sum10 = _mm256_fmadd_ps(_r212, _k62, _sum10);
343 _sum11 = _mm256_fmadd_ps(_r213, _k63, _sum11);
344 _sum10 = _mm256_fmadd_ps(_r214, _k64, _sum10);
345 _sum11 = _mm256_fmadd_ps(_r215, _k65, _sum11);
346 _sum10 = _mm256_fmadd_ps(_r216, _k66, _sum10);
347 _sum11 = _mm256_fmadd_ps(_r217, _k67, _sum11);
348
349 __m256 _k70 = _mm256_loadu_ps(kptr);
350 __m256 _k71 = _mm256_loadu_ps(kptr + 8);
351 __m256 _k72 = _mm256_loadu_ps(kptr + 16);
352 __m256 _k73 = _mm256_loadu_ps(kptr + 24);
353 __m256 _k74 = _mm256_loadu_ps(kptr + 32);
354 __m256 _k75 = _mm256_loadu_ps(kptr + 40);
355 __m256 _k76 = _mm256_loadu_ps(kptr + 48);
356 __m256 _k77 = _mm256_loadu_ps(kptr + 56);
357
358 kptr += 64;
359
360 _sum00 = _mm256_fmadd_ps(_r210, _k70, _sum00);
361 _sum01 = _mm256_fmadd_ps(_r211, _k71, _sum01);
362 _sum00 = _mm256_fmadd_ps(_r212, _k72, _sum00);
363 _sum01 = _mm256_fmadd_ps(_r213, _k73, _sum01);
364 _sum00 = _mm256_fmadd_ps(_r214, _k74, _sum00);
365 _sum01 = _mm256_fmadd_ps(_r215, _k75, _sum01);
366 _sum00 = _mm256_fmadd_ps(_r216, _k76, _sum00);
367 _sum01 = _mm256_fmadd_ps(_r217, _k77, _sum01);
368
369 __m256 _r220 = _mm256_broadcast_ss(r2 + 16);
370 __m256 _r221 = _mm256_broadcast_ss(r2 + 17);
371 __m256 _r222 = _mm256_broadcast_ss(r2 + 18);
372 __m256 _r223 = _mm256_broadcast_ss(r2 + 19);
373 __m256 _r224 = _mm256_broadcast_ss(r2 + 20);
374 __m256 _r225 = _mm256_broadcast_ss(r2 + 21);
375 __m256 _r226 = _mm256_broadcast_ss(r2 + 22);
376 __m256 _r227 = _mm256_broadcast_ss(r2 + 23);
377
378 _sum10 = _mm256_fmadd_ps(_r220, _k70, _sum10);
379 _sum11 = _mm256_fmadd_ps(_r221, _k71, _sum11);
380 _sum10 = _mm256_fmadd_ps(_r222, _k72, _sum10);
381 _sum11 = _mm256_fmadd_ps(_r223, _k73, _sum11);
382 _sum10 = _mm256_fmadd_ps(_r224, _k74, _sum10);
383 _sum11 = _mm256_fmadd_ps(_r225, _k75, _sum11);
384 _sum10 = _mm256_fmadd_ps(_r226, _k76, _sum10);
385 _sum11 = _mm256_fmadd_ps(_r227, _k77, _sum11);
386
387 __m256 _k80 = _mm256_loadu_ps(kptr);
388 __m256 _k81 = _mm256_loadu_ps(kptr + 8);
389 __m256 _k82 = _mm256_loadu_ps(kptr + 16);
390 __m256 _k83 = _mm256_loadu_ps(kptr + 24);
391 __m256 _k84 = _mm256_loadu_ps(kptr + 32);
392 __m256 _k85 = _mm256_loadu_ps(kptr + 40);
393 __m256 _k86 = _mm256_loadu_ps(kptr + 48);
394 __m256 _k87 = _mm256_loadu_ps(kptr + 56);
395
396 _sum00 = _mm256_fmadd_ps(_r220, _k80, _sum00);
397 _sum01 = _mm256_fmadd_ps(_r221, _k81, _sum01);
398 _sum00 = _mm256_fmadd_ps(_r222, _k82, _sum00);
399 _sum01 = _mm256_fmadd_ps(_r223, _k83, _sum01);
400 _sum00 = _mm256_fmadd_ps(_r224, _k84, _sum00);
401 _sum01 = _mm256_fmadd_ps(_r225, _k85, _sum01);
402 _sum00 = _mm256_fmadd_ps(_r226, _k86, _sum00);
403 _sum01 = _mm256_fmadd_ps(_r227, _k87, _sum01);
404
405 __m256 _r230 = _mm256_broadcast_ss(r2 + 24);
406 __m256 _r231 = _mm256_broadcast_ss(r2 + 25);
407 __m256 _r232 = _mm256_broadcast_ss(r2 + 26);
408 __m256 _r233 = _mm256_broadcast_ss(r2 + 27);
409 __m256 _r234 = _mm256_broadcast_ss(r2 + 28);
410 __m256 _r235 = _mm256_broadcast_ss(r2 + 29);
411 __m256 _r236 = _mm256_broadcast_ss(r2 + 30);
412 __m256 _r237 = _mm256_broadcast_ss(r2 + 31);
413
414 _sum10 = _mm256_fmadd_ps(_r230, _k80, _sum10);
415 _sum11 = _mm256_fmadd_ps(_r231, _k81, _sum11);
416 _sum10 = _mm256_fmadd_ps(_r232, _k82, _sum10);
417 _sum11 = _mm256_fmadd_ps(_r233, _k83, _sum11);
418 _sum10 = _mm256_fmadd_ps(_r234, _k84, _sum10);
419 _sum11 = _mm256_fmadd_ps(_r235, _k85, _sum11);
420 _sum10 = _mm256_fmadd_ps(_r236, _k86, _sum10);
421 _sum11 = _mm256_fmadd_ps(_r237, _k87, _sum11);
422
423 kptr -= 64 * 8;
424
425 _sum00 = _mm256_add_ps(_sum00, _sum01);
426 _sum10 = _mm256_add_ps(_sum10, _sum11);
427
428 _mm256_storeu_ps(outptr, _sum00);
429 _mm256_storeu_ps(outptr + 8, _sum10);
430
431 r0 += 16;
432 r1 += 16;
433 r2 += 16;
434 outptr += 16;
435 }
436 for (; j < outw; j++)
437 {
438 __m256 _sum0 = _mm256_loadu_ps(outptr);
439 __m256 _sum1 = _mm256_setzero_ps();
440
441 __m256 _r000 = _mm256_broadcast_ss(r0 + 0);
442 __m256 _r001 = _mm256_broadcast_ss(r0 + 1);
443 __m256 _r002 = _mm256_broadcast_ss(r0 + 2);
444 __m256 _r003 = _mm256_broadcast_ss(r0 + 3);
445 __m256 _r004 = _mm256_broadcast_ss(r0 + 4);
446 __m256 _r005 = _mm256_broadcast_ss(r0 + 5);
447 __m256 _r006 = _mm256_broadcast_ss(r0 + 6);
448 __m256 _r007 = _mm256_broadcast_ss(r0 + 7);
449
450 __m256 _k00 = _mm256_loadu_ps(kptr);
451 __m256 _k01 = _mm256_loadu_ps(kptr + 8);
452 __m256 _k02 = _mm256_loadu_ps(kptr + 16);
453 __m256 _k03 = _mm256_loadu_ps(kptr + 24);
454 __m256 _k04 = _mm256_loadu_ps(kptr + 32);
455 __m256 _k05 = _mm256_loadu_ps(kptr + 40);
456 __m256 _k06 = _mm256_loadu_ps(kptr + 48);
457 __m256 _k07 = _mm256_loadu_ps(kptr + 56);
458
459 kptr += 64;
460
461 _sum0 = _mm256_fmadd_ps(_r000, _k00, _sum0);
462 _sum1 = _mm256_fmadd_ps(_r001, _k01, _sum1);
463 _sum0 = _mm256_fmadd_ps(_r002, _k02, _sum0);
464 _sum1 = _mm256_fmadd_ps(_r003, _k03, _sum1);
465 _sum0 = _mm256_fmadd_ps(_r004, _k04, _sum0);
466 _sum1 = _mm256_fmadd_ps(_r005, _k05, _sum1);
467 _sum0 = _mm256_fmadd_ps(_r006, _k06, _sum0);
468 _sum1 = _mm256_fmadd_ps(_r007, _k07, _sum1);
469
470 __m256 _r010 = _mm256_broadcast_ss(r0 + 8);
471 __m256 _r011 = _mm256_broadcast_ss(r0 + 9);
472 __m256 _r012 = _mm256_broadcast_ss(r0 + 10);
473 __m256 _r013 = _mm256_broadcast_ss(r0 + 11);
474 __m256 _r014 = _mm256_broadcast_ss(r0 + 12);
475 __m256 _r015 = _mm256_broadcast_ss(r0 + 13);
476 __m256 _r016 = _mm256_broadcast_ss(r0 + 14);
477 __m256 _r017 = _mm256_broadcast_ss(r0 + 15);
478
479 __m256 _k10 = _mm256_loadu_ps(kptr);
480 __m256 _k11 = _mm256_loadu_ps(kptr + 8);
481 __m256 _k12 = _mm256_loadu_ps(kptr + 16);
482 __m256 _k13 = _mm256_loadu_ps(kptr + 24);
483 __m256 _k14 = _mm256_loadu_ps(kptr + 32);
484 __m256 _k15 = _mm256_loadu_ps(kptr + 40);
485 __m256 _k16 = _mm256_loadu_ps(kptr + 48);
486 __m256 _k17 = _mm256_loadu_ps(kptr + 56);
487
488 kptr += 64;
489
490 _sum0 = _mm256_fmadd_ps(_r010, _k10, _sum0);
491 _sum1 = _mm256_fmadd_ps(_r011, _k11, _sum1);
492 _sum0 = _mm256_fmadd_ps(_r012, _k12, _sum0);
493 _sum1 = _mm256_fmadd_ps(_r013, _k13, _sum1);
494 _sum0 = _mm256_fmadd_ps(_r014, _k14, _sum0);
495 _sum1 = _mm256_fmadd_ps(_r015, _k15, _sum1);
496 _sum0 = _mm256_fmadd_ps(_r016, _k16, _sum0);
497 _sum1 = _mm256_fmadd_ps(_r017, _k17, _sum1);
498
499 __m256 _r020 = _mm256_broadcast_ss(r0 + 16);
500 __m256 _r021 = _mm256_broadcast_ss(r0 + 17);
501 __m256 _r022 = _mm256_broadcast_ss(r0 + 18);
502 __m256 _r023 = _mm256_broadcast_ss(r0 + 19);
503 __m256 _r024 = _mm256_broadcast_ss(r0 + 20);
504 __m256 _r025 = _mm256_broadcast_ss(r0 + 21);
505 __m256 _r026 = _mm256_broadcast_ss(r0 + 22);
506 __m256 _r027 = _mm256_broadcast_ss(r0 + 23);
507
508 __m256 _k20 = _mm256_loadu_ps(kptr);
509 __m256 _k21 = _mm256_loadu_ps(kptr + 8);
510 __m256 _k22 = _mm256_loadu_ps(kptr + 16);
511 __m256 _k23 = _mm256_loadu_ps(kptr + 24);
512 __m256 _k24 = _mm256_loadu_ps(kptr + 32);
513 __m256 _k25 = _mm256_loadu_ps(kptr + 40);
514 __m256 _k26 = _mm256_loadu_ps(kptr + 48);
515 __m256 _k27 = _mm256_loadu_ps(kptr + 56);
516
517 kptr += 64;
518
519 _sum0 = _mm256_fmadd_ps(_r020, _k20, _sum0);
520 _sum1 = _mm256_fmadd_ps(_r021, _k21, _sum1);
521 _sum0 = _mm256_fmadd_ps(_r022, _k22, _sum0);
522 _sum1 = _mm256_fmadd_ps(_r023, _k23, _sum1);
523 _sum0 = _mm256_fmadd_ps(_r024, _k24, _sum0);
524 _sum1 = _mm256_fmadd_ps(_r025, _k25, _sum1);
525 _sum0 = _mm256_fmadd_ps(_r026, _k26, _sum0);
526 _sum1 = _mm256_fmadd_ps(_r027, _k27, _sum1);
527
528 __m256 _r100 = _mm256_broadcast_ss(r1 + 0);
529 __m256 _r101 = _mm256_broadcast_ss(r1 + 1);
530 __m256 _r102 = _mm256_broadcast_ss(r1 + 2);
531 __m256 _r103 = _mm256_broadcast_ss(r1 + 3);
532 __m256 _r104 = _mm256_broadcast_ss(r1 + 4);
533 __m256 _r105 = _mm256_broadcast_ss(r1 + 5);
534 __m256 _r106 = _mm256_broadcast_ss(r1 + 6);
535 __m256 _r107 = _mm256_broadcast_ss(r1 + 7);
536
537 __m256 _k30 = _mm256_loadu_ps(kptr);
538 __m256 _k31 = _mm256_loadu_ps(kptr + 8);
539 __m256 _k32 = _mm256_loadu_ps(kptr + 16);
540 __m256 _k33 = _mm256_loadu_ps(kptr + 24);
541 __m256 _k34 = _mm256_loadu_ps(kptr + 32);
542 __m256 _k35 = _mm256_loadu_ps(kptr + 40);
543 __m256 _k36 = _mm256_loadu_ps(kptr + 48);
544 __m256 _k37 = _mm256_loadu_ps(kptr + 56);
545
546 kptr += 64;
547
548 _sum0 = _mm256_fmadd_ps(_r100, _k30, _sum0);
549 _sum1 = _mm256_fmadd_ps(_r101, _k31, _sum1);
550 _sum0 = _mm256_fmadd_ps(_r102, _k32, _sum0);
551 _sum1 = _mm256_fmadd_ps(_r103, _k33, _sum1);
552 _sum0 = _mm256_fmadd_ps(_r104, _k34, _sum0);
553 _sum1 = _mm256_fmadd_ps(_r105, _k35, _sum1);
554 _sum0 = _mm256_fmadd_ps(_r106, _k36, _sum0);
555 _sum1 = _mm256_fmadd_ps(_r107, _k37, _sum1);
556
557 __m256 _r110 = _mm256_broadcast_ss(r1 + 8);
558 __m256 _r111 = _mm256_broadcast_ss(r1 + 9);
559 __m256 _r112 = _mm256_broadcast_ss(r1 + 10);
560 __m256 _r113 = _mm256_broadcast_ss(r1 + 11);
561 __m256 _r114 = _mm256_broadcast_ss(r1 + 12);
562 __m256 _r115 = _mm256_broadcast_ss(r1 + 13);
563 __m256 _r116 = _mm256_broadcast_ss(r1 + 14);
564 __m256 _r117 = _mm256_broadcast_ss(r1 + 15);
565
566 __m256 _k40 = _mm256_loadu_ps(kptr);
567 __m256 _k41 = _mm256_loadu_ps(kptr + 8);
568 __m256 _k42 = _mm256_loadu_ps(kptr + 16);
569 __m256 _k43 = _mm256_loadu_ps(kptr + 24);
570 __m256 _k44 = _mm256_loadu_ps(kptr + 32);
571 __m256 _k45 = _mm256_loadu_ps(kptr + 40);
572 __m256 _k46 = _mm256_loadu_ps(kptr + 48);
573 __m256 _k47 = _mm256_loadu_ps(kptr + 56);
574
575 kptr += 64;
576
577 _sum0 = _mm256_fmadd_ps(_r110, _k40, _sum0);
578 _sum1 = _mm256_fmadd_ps(_r111, _k41, _sum1);
579 _sum0 = _mm256_fmadd_ps(_r112, _k42, _sum0);
580 _sum1 = _mm256_fmadd_ps(_r113, _k43, _sum1);
581 _sum0 = _mm256_fmadd_ps(_r114, _k44, _sum0);
582 _sum1 = _mm256_fmadd_ps(_r115, _k45, _sum1);
583 _sum0 = _mm256_fmadd_ps(_r116, _k46, _sum0);
584 _sum1 = _mm256_fmadd_ps(_r117, _k47, _sum1);
585
586 __m256 _r120 = _mm256_broadcast_ss(r1 + 16);
587 __m256 _r121 = _mm256_broadcast_ss(r1 + 17);
588 __m256 _r122 = _mm256_broadcast_ss(r1 + 18);
589 __m256 _r123 = _mm256_broadcast_ss(r1 + 19);
590 __m256 _r124 = _mm256_broadcast_ss(r1 + 20);
591 __m256 _r125 = _mm256_broadcast_ss(r1 + 21);
592 __m256 _r126 = _mm256_broadcast_ss(r1 + 22);
593 __m256 _r127 = _mm256_broadcast_ss(r1 + 23);
594
595 __m256 _k50 = _mm256_loadu_ps(kptr);
596 __m256 _k51 = _mm256_loadu_ps(kptr + 8);
597 __m256 _k52 = _mm256_loadu_ps(kptr + 16);
598 __m256 _k53 = _mm256_loadu_ps(kptr + 24);
599 __m256 _k54 = _mm256_loadu_ps(kptr + 32);
600 __m256 _k55 = _mm256_loadu_ps(kptr + 40);
601 __m256 _k56 = _mm256_loadu_ps(kptr + 48);
602 __m256 _k57 = _mm256_loadu_ps(kptr + 56);
603
604 kptr += 64;
605
606 _sum0 = _mm256_fmadd_ps(_r120, _k50, _sum0);
607 _sum1 = _mm256_fmadd_ps(_r121, _k51, _sum1);
608 _sum0 = _mm256_fmadd_ps(_r122, _k52, _sum0);
609 _sum1 = _mm256_fmadd_ps(_r123, _k53, _sum1);
610 _sum0 = _mm256_fmadd_ps(_r124, _k54, _sum0);
611 _sum1 = _mm256_fmadd_ps(_r125, _k55, _sum1);
612 _sum0 = _mm256_fmadd_ps(_r126, _k56, _sum0);
613 _sum1 = _mm256_fmadd_ps(_r127, _k57, _sum1);
614
615 __m256 _r200 = _mm256_broadcast_ss(r2 + 0);
616 __m256 _r201 = _mm256_broadcast_ss(r2 + 1);
617 __m256 _r202 = _mm256_broadcast_ss(r2 + 2);
618 __m256 _r203 = _mm256_broadcast_ss(r2 + 3);
619 __m256 _r204 = _mm256_broadcast_ss(r2 + 4);
620 __m256 _r205 = _mm256_broadcast_ss(r2 + 5);
621 __m256 _r206 = _mm256_broadcast_ss(r2 + 6);
622 __m256 _r207 = _mm256_broadcast_ss(r2 + 7);
623
624 __m256 _k60 = _mm256_loadu_ps(kptr);
625 __m256 _k61 = _mm256_loadu_ps(kptr + 8);
626 __m256 _k62 = _mm256_loadu_ps(kptr + 16);
627 __m256 _k63 = _mm256_loadu_ps(kptr + 24);
628 __m256 _k64 = _mm256_loadu_ps(kptr + 32);
629 __m256 _k65 = _mm256_loadu_ps(kptr + 40);
630 __m256 _k66 = _mm256_loadu_ps(kptr + 48);
631 __m256 _k67 = _mm256_loadu_ps(kptr + 56);
632
633 kptr += 64;
634
635 _sum0 = _mm256_fmadd_ps(_r200, _k60, _sum0);
636 _sum1 = _mm256_fmadd_ps(_r201, _k61, _sum1);
637 _sum0 = _mm256_fmadd_ps(_r202, _k62, _sum0);
638 _sum1 = _mm256_fmadd_ps(_r203, _k63, _sum1);
639 _sum0 = _mm256_fmadd_ps(_r204, _k64, _sum0);
640 _sum1 = _mm256_fmadd_ps(_r205, _k65, _sum1);
641 _sum0 = _mm256_fmadd_ps(_r206, _k66, _sum0);
642 _sum1 = _mm256_fmadd_ps(_r207, _k67, _sum1);
643
644 __m256 _r210 = _mm256_broadcast_ss(r2 + 8);
645 __m256 _r211 = _mm256_broadcast_ss(r2 + 9);
646 __m256 _r212 = _mm256_broadcast_ss(r2 + 10);
647 __m256 _r213 = _mm256_broadcast_ss(r2 + 11);
648 __m256 _r214 = _mm256_broadcast_ss(r2 + 12);
649 __m256 _r215 = _mm256_broadcast_ss(r2 + 13);
650 __m256 _r216 = _mm256_broadcast_ss(r2 + 14);
651 __m256 _r217 = _mm256_broadcast_ss(r2 + 15);
652
653 __m256 _k70 = _mm256_loadu_ps(kptr);
654 __m256 _k71 = _mm256_loadu_ps(kptr + 8);
655 __m256 _k72 = _mm256_loadu_ps(kptr + 16);
656 __m256 _k73 = _mm256_loadu_ps(kptr + 24);
657 __m256 _k74 = _mm256_loadu_ps(kptr + 32);
658 __m256 _k75 = _mm256_loadu_ps(kptr + 40);
659 __m256 _k76 = _mm256_loadu_ps(kptr + 48);
660 __m256 _k77 = _mm256_loadu_ps(kptr + 56);
661
662 kptr += 64;
663
664 _sum0 = _mm256_fmadd_ps(_r210, _k70, _sum0);
665 _sum1 = _mm256_fmadd_ps(_r211, _k71, _sum1);
666 _sum0 = _mm256_fmadd_ps(_r212, _k72, _sum0);
667 _sum1 = _mm256_fmadd_ps(_r213, _k73, _sum1);
668 _sum0 = _mm256_fmadd_ps(_r214, _k74, _sum0);
669 _sum1 = _mm256_fmadd_ps(_r215, _k75, _sum1);
670 _sum0 = _mm256_fmadd_ps(_r216, _k76, _sum0);
671 _sum1 = _mm256_fmadd_ps(_r217, _k77, _sum1);
672
673 __m256 _r220 = _mm256_broadcast_ss(r2 + 16);
674 __m256 _r221 = _mm256_broadcast_ss(r2 + 17);
675 __m256 _r222 = _mm256_broadcast_ss(r2 + 18);
676 __m256 _r223 = _mm256_broadcast_ss(r2 + 19);
677 __m256 _r224 = _mm256_broadcast_ss(r2 + 20);
678 __m256 _r225 = _mm256_broadcast_ss(r2 + 21);
679 __m256 _r226 = _mm256_broadcast_ss(r2 + 22);
680 __m256 _r227 = _mm256_broadcast_ss(r2 + 23);
681
682 __m256 _k80 = _mm256_loadu_ps(kptr);
683 __m256 _k81 = _mm256_loadu_ps(kptr + 8);
684 __m256 _k82 = _mm256_loadu_ps(kptr + 16);
685 __m256 _k83 = _mm256_loadu_ps(kptr + 24);
686 __m256 _k84 = _mm256_loadu_ps(kptr + 32);
687 __m256 _k85 = _mm256_loadu_ps(kptr + 40);
688 __m256 _k86 = _mm256_loadu_ps(kptr + 48);
689 __m256 _k87 = _mm256_loadu_ps(kptr + 56);
690
691 _sum0 = _mm256_fmadd_ps(_r220, _k80, _sum0);
692 _sum1 = _mm256_fmadd_ps(_r221, _k81, _sum1);
693 _sum0 = _mm256_fmadd_ps(_r222, _k82, _sum0);
694 _sum1 = _mm256_fmadd_ps(_r223, _k83, _sum1);
695 _sum0 = _mm256_fmadd_ps(_r224, _k84, _sum0);
696 _sum1 = _mm256_fmadd_ps(_r225, _k85, _sum1);
697 _sum0 = _mm256_fmadd_ps(_r226, _k86, _sum0);
698 _sum1 = _mm256_fmadd_ps(_r227, _k87, _sum1);
699
700 kptr -= 64 * 8;
701
702 _sum0 = _mm256_add_ps(_sum0, _sum1);
703
704 _mm256_storeu_ps(outptr, _sum0);
705
706 r0 += 8;
707 r1 += 8;
708 r2 += 8;
709 outptr += 8;
710 }
711
712 r0 += 16;
713 r1 += 16;
714 r2 += 16;
715 }
716 }
717 }
718 }
719
conv3x3s1_winograd64_transform_kernel_pack8_avx(const Mat & kernel,Mat & kernel_tm_pack8,int inch,int outch)720 static void conv3x3s1_winograd64_transform_kernel_pack8_avx(const Mat& kernel, Mat& kernel_tm_pack8, int inch, int outch)
721 {
722 // winograd63 transform kernel
723 Mat kernel_tm;
724 kernel_tm.create(8 * 8, inch, outch);
725
726 const float ktm[8][3] = {
727 {1.0f, 0.0f, 0.0f},
728 {-2.0f / 9, -2.0f / 9, -2.0f / 9},
729 {-2.0f / 9, 2.0f / 9, -2.0f / 9},
730 {1.0f / 90, 1.0f / 45, 2.0f / 45},
731 {1.0f / 90, -1.0f / 45, 2.0f / 45},
732 {1.0f / 45, 1.0f / 90, 1.0f / 180},
733 {1.0f / 45, -1.0f / 90, 1.0f / 180},
734 {0.0f, 0.0f, 1.0f}
735 };
736
737 #pragma omp parallel for
738 for (int p = 0; p < outch; p++)
739 {
740 for (int q = 0; q < inch; q++)
741 {
742 const float* kernel0 = (const float*)kernel + p * inch * 9 + q * 9;
743 float* kernel_tm0 = kernel_tm.channel(p).row(q);
744
745 // transform kernel, transposed
746 const float* k0 = kernel0;
747 const float* k1 = kernel0 + 3;
748 const float* k2 = kernel0 + 6;
749
750 // h
751 float tmp[8][3];
752 for (int i = 0; i < 8; i++)
753 {
754 tmp[i][0] = k0[0] * ktm[i][0] + k0[1] * ktm[i][1] + k0[2] * ktm[i][2];
755 tmp[i][1] = k1[0] * ktm[i][0] + k1[1] * ktm[i][1] + k1[2] * ktm[i][2];
756 tmp[i][2] = k2[0] * ktm[i][0] + k2[1] * ktm[i][1] + k2[2] * ktm[i][2];
757 }
758
759 // v
760 for (int j = 0; j < 8; j++)
761 {
762 float* tmpp = &tmp[j][0];
763
764 for (int i = 0; i < 8; i++)
765 {
766 kernel_tm0[j * 8 + i] = tmpp[0] * ktm[i][0] + tmpp[1] * ktm[i][1] + tmpp[2] * ktm[i][2];
767 }
768 }
769 }
770 }
771 // interleave
772 // src = 64-inch-outch
773 // dst = 8b-8a-inch/8a-64-outch/8b;
774 kernel_tm_pack8.create(inch / 8, 64, outch / 8, (size_t)4u * 64, 64);
775
776 int q = 0;
777 for (; q + 7 < outch; q += 8)
778 {
779 const Mat k0 = kernel_tm.channel(q);
780 const Mat k1 = kernel_tm.channel(q + 1);
781 const Mat k2 = kernel_tm.channel(q + 2);
782 const Mat k3 = kernel_tm.channel(q + 3);
783 const Mat k4 = kernel_tm.channel(q + 4);
784 const Mat k5 = kernel_tm.channel(q + 5);
785 const Mat k6 = kernel_tm.channel(q + 6);
786 const Mat k7 = kernel_tm.channel(q + 7);
787
788 Mat g0 = kernel_tm_pack8.channel(q / 8);
789
790 for (int k = 0; k < 64; k++)
791 {
792 float* g00 = g0.row(k);
793
794 for (int p = 0; p + 7 < inch; p += 8)
795 {
796 const float* k00 = k0.row(p);
797 const float* k01 = k0.row(p + 1);
798 const float* k02 = k0.row(p + 2);
799 const float* k03 = k0.row(p + 3);
800 const float* k04 = k0.row(p + 4);
801 const float* k05 = k0.row(p + 5);
802 const float* k06 = k0.row(p + 6);
803 const float* k07 = k0.row(p + 7);
804
805 const float* k10 = k1.row(p);
806 const float* k11 = k1.row(p + 1);
807 const float* k12 = k1.row(p + 2);
808 const float* k13 = k1.row(p + 3);
809 const float* k14 = k1.row(p + 4);
810 const float* k15 = k1.row(p + 5);
811 const float* k16 = k1.row(p + 6);
812 const float* k17 = k1.row(p + 7);
813
814 const float* k20 = k2.row(p);
815 const float* k21 = k2.row(p + 1);
816 const float* k22 = k2.row(p + 2);
817 const float* k23 = k2.row(p + 3);
818 const float* k24 = k2.row(p + 4);
819 const float* k25 = k2.row(p + 5);
820 const float* k26 = k2.row(p + 6);
821 const float* k27 = k2.row(p + 7);
822
823 const float* k30 = k3.row(p);
824 const float* k31 = k3.row(p + 1);
825 const float* k32 = k3.row(p + 2);
826 const float* k33 = k3.row(p + 3);
827 const float* k34 = k3.row(p + 4);
828 const float* k35 = k3.row(p + 5);
829 const float* k36 = k3.row(p + 6);
830 const float* k37 = k3.row(p + 7);
831
832 const float* k40 = k4.row(p);
833 const float* k41 = k4.row(p + 1);
834 const float* k42 = k4.row(p + 2);
835 const float* k43 = k4.row(p + 3);
836 const float* k44 = k4.row(p + 4);
837 const float* k45 = k4.row(p + 5);
838 const float* k46 = k4.row(p + 6);
839 const float* k47 = k4.row(p + 7);
840
841 const float* k50 = k5.row(p);
842 const float* k51 = k5.row(p + 1);
843 const float* k52 = k5.row(p + 2);
844 const float* k53 = k5.row(p + 3);
845 const float* k54 = k5.row(p + 4);
846 const float* k55 = k5.row(p + 5);
847 const float* k56 = k5.row(p + 6);
848 const float* k57 = k5.row(p + 7);
849
850 const float* k60 = k6.row(p);
851 const float* k61 = k6.row(p + 1);
852 const float* k62 = k6.row(p + 2);
853 const float* k63 = k6.row(p + 3);
854 const float* k64 = k6.row(p + 4);
855 const float* k65 = k6.row(p + 5);
856 const float* k66 = k6.row(p + 6);
857 const float* k67 = k6.row(p + 7);
858
859 const float* k70 = k7.row(p);
860 const float* k71 = k7.row(p + 1);
861 const float* k72 = k7.row(p + 2);
862 const float* k73 = k7.row(p + 3);
863 const float* k74 = k7.row(p + 4);
864 const float* k75 = k7.row(p + 5);
865 const float* k76 = k7.row(p + 6);
866 const float* k77 = k7.row(p + 7);
867
868 g00[0] = k00[k];
869 g00[1] = k10[k];
870 g00[2] = k20[k];
871 g00[3] = k30[k];
872 g00[4] = k40[k];
873 g00[5] = k50[k];
874 g00[6] = k60[k];
875 g00[7] = k70[k];
876
877 g00[8] = k01[k];
878 g00[9] = k11[k];
879 g00[10] = k21[k];
880 g00[11] = k31[k];
881 g00[12] = k41[k];
882 g00[13] = k51[k];
883 g00[14] = k61[k];
884 g00[15] = k71[k];
885
886 g00[16] = k02[k];
887 g00[17] = k12[k];
888 g00[18] = k22[k];
889 g00[19] = k32[k];
890 g00[20] = k42[k];
891 g00[21] = k52[k];
892 g00[22] = k62[k];
893 g00[23] = k72[k];
894
895 g00[24] = k03[k];
896 g00[25] = k13[k];
897 g00[26] = k23[k];
898 g00[27] = k33[k];
899 g00[28] = k43[k];
900 g00[29] = k53[k];
901 g00[30] = k63[k];
902 g00[31] = k73[k];
903
904 g00[32] = k04[k];
905 g00[33] = k14[k];
906 g00[34] = k24[k];
907 g00[35] = k34[k];
908 g00[36] = k44[k];
909 g00[37] = k54[k];
910 g00[38] = k64[k];
911 g00[39] = k74[k];
912
913 g00[40] = k05[k];
914 g00[41] = k15[k];
915 g00[42] = k25[k];
916 g00[43] = k35[k];
917 g00[44] = k45[k];
918 g00[45] = k55[k];
919 g00[46] = k65[k];
920 g00[47] = k75[k];
921
922 g00[48] = k06[k];
923 g00[49] = k16[k];
924 g00[50] = k26[k];
925 g00[51] = k36[k];
926 g00[52] = k46[k];
927 g00[53] = k56[k];
928 g00[54] = k66[k];
929 g00[55] = k76[k];
930
931 g00[56] = k07[k];
932 g00[57] = k17[k];
933 g00[58] = k27[k];
934 g00[59] = k37[k];
935 g00[60] = k47[k];
936 g00[61] = k57[k];
937 g00[62] = k67[k];
938 g00[63] = k77[k];
939
940 g00 += 64;
941 }
942 }
943 }
944 }
945
conv3x3s1_winograd64_pack8_avx(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const Option & opt)946 static void conv3x3s1_winograd64_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const Option& opt)
947 {
948 int w = bottom_blob.w;
949 int h = bottom_blob.h;
950 int inch = bottom_blob.c;
951 size_t elemsize = bottom_blob.elemsize;
952 int elempack = bottom_blob.elempack;
953 int outw = top_blob.w;
954 int outh = top_blob.h;
955 int outch = top_blob.c;
956
957 // pad to 6n+2
958 Mat bottom_blob_bordered = bottom_blob;
959
960 outw = (outw + 5) / 6 * 6;
961 outh = (outh + 5) / 6 * 6;
962
963 w = outw + 2;
964 h = outh + 2;
965 copy_make_border(bottom_blob, bottom_blob_bordered, 0, h - bottom_blob.h, 0, w - bottom_blob.w, BORDER_CONSTANT, 0.f, opt);
966
967 const float* bias = _bias;
968 // BEGIN transform input
969 Mat bottom_blob_tm;
970 {
971 int w_tm = outw / 6 * 8;
972 int h_tm = outh / 6 * 8;
973
974 const int tiles = w_tm / 8 * h_tm / 8;
975
976 bottom_blob_tm.create(tiles, 64, inch, elemsize, elempack, opt.workspace_allocator);
977
978 // const float itm[8][8] = {
979 // {1.0f, 0.0f, -5.25f, 0.00f, 5.25f, 0.00f, -1.0f, 0.0f},
980 //
981 // {0.0f, 1.0f, 1.00f, -4.25f, -4.25f, 1.00f, 1.0f, 0.0f},
982 // {0.0f, -1.0f, 1.00f, 4.25f, -4.25f, -1.00f, 1.0f, 0.0f},
983 //
984 // {0.0f, 0.5f, 0.25f, -2.50f, -1.25f, 2.00f, 1.0f, 0.0f},
985 // {0.0f, -0.5f, 0.25f, 2.50f, -1.25f, -2.00f, 1.0f, 0.0f},
986 //
987 // {0.0f, 2.0f, 4.00f, -2.50f, -5.00f, 0.50f, 1.0f, 0.0f},
988 // {0.0f, -2.0f, 4.00f, 2.50f, -5.00f, -0.50f, 1.0f, 0.0f},
989 //
990 // {0.0f, -1.0f, 0.00f, 5.25f, 0.00f, -5.25f, 0.0f, 1.0f}
991 // };
992
993 // 0 = r00 - r06 + (r04 - r02) * 5.25
994 // 7 = r07 - r01 + (r03 - r05) * 5.25
995
996 // 1 = (r02 + r06 - r04 * 4.25) + (r01 - r03 * 4.25 + r05)
997 // 2 = (r02 + r06 - r04 * 4.25) - (r01 - r03 * 4.25 + r05)
998
999 // 3 = (r06 + r02 * 0.25 - r04 * 1.25) + (r01 * 0.5 - r03 * 2.5 + r05 * 2)
1000 // 4 = (r06 + r02 * 0.25 - r04 * 1.25) - (r01 * 0.5 - r03 * 2.5 + r05 * 2)
1001
1002 // reuse r04 * 1.25
1003 // reuse r03 * 2.5
1004 // 5 = (r06 + (r02 - r04 * 1.25) * 4) + (r01 * 2 - r03 * 2.5 + r05 * 0.5)
1005 // 6 = (r06 + (r02 - r04 * 1.25) * 4) - (r01 * 2 - r03 * 2.5 + r05 * 0.5)
1006
1007 #pragma omp parallel for num_threads(opt.num_threads)
1008 for (int q = 0; q < inch; q++)
1009 {
1010 const Mat img0 = bottom_blob_bordered.channel(q);
1011 Mat img0_tm = bottom_blob_tm.channel(q);
1012
1013 float tmp[8][8][8];
1014
1015 // tile
1016 for (int i = 0; i < h_tm / 8; i++)
1017 {
1018 for (int j = 0; j < w_tm / 8; j++)
1019 {
1020 const float* r0 = img0.row(i * 6) + (j * 6) * 8;
1021
1022 for (int m = 0; m < 8; m++)
1023 {
1024 __m256 _r00 = _mm256_loadu_ps(r0);
1025 __m256 _r01 = _mm256_loadu_ps(r0 + 8);
1026 __m256 _r02 = _mm256_loadu_ps(r0 + 16);
1027 __m256 _r03 = _mm256_loadu_ps(r0 + 24);
1028 __m256 _r04 = _mm256_loadu_ps(r0 + 32);
1029 __m256 _r05 = _mm256_loadu_ps(r0 + 40);
1030 __m256 _r06 = _mm256_loadu_ps(r0 + 48);
1031 __m256 _r07 = _mm256_loadu_ps(r0 + 56);
1032
1033 __m256 _tmp0m = _mm256_fmadd_1_ps(_mm256_sub_ps(_r00, _r06), _mm256_sub_ps(_r04, _r02), 5.25f);
1034 __m256 _tmp7m = _mm256_fmadd_1_ps(_mm256_sub_ps(_r07, _r01), _mm256_sub_ps(_r03, _r05), 5.25f);
1035 _mm256_storeu_ps(tmp[0][m], _tmp0m);
1036 _mm256_storeu_ps(tmp[7][m], _tmp7m);
1037
1038 // tmp[0][m] = r0[0] - r0[6] + (r0[4] - r0[2]) * 5.25;
1039 // tmp[7][m] = r0[7] - r0[1] + (r0[3] - r0[5]) * 5.25;
1040
1041 __m256 _tmp12a = _mm256_fmrsub_1_ps(_mm256_add_ps(_r02, _r06), _r04, 4.25f);
1042 __m256 _tmp12b = _mm256_fmrsub_1_ps(_mm256_add_ps(_r01, _r05), _r03, 4.25f);
1043
1044 // float tmp12a = (r0[2] + r0[6] - r0[4] * 4.25);
1045 // float tmp12b = (r0[1] + r0[5] - r0[3] * 4.25);
1046
1047 __m256 _tmp1m = _mm256_add_ps(_tmp12a, _tmp12b);
1048 __m256 _tmp2m = _mm256_sub_ps(_tmp12a, _tmp12b);
1049 _mm256_storeu_ps(tmp[1][m], _tmp1m);
1050 _mm256_storeu_ps(tmp[2][m], _tmp2m);
1051
1052 // tmp[1][m] = tmp12a + tmp12b;
1053 // tmp[2][m] = tmp12a - tmp12b;
1054
1055 __m256 _tmp34a = _mm256_fmrsub_1_ps(_mm256_fmadd_1_ps(_r06, _r02, 0.25f), _r04, 1.25f);
1056 __m256 _tmp34b = _mm256_fmadd_1_ps(_mm256_fmrsub_1_ps(_mm256_mul_ps(_r01, _mm256_set1_ps(0.5f)), _r03, 2.5f), _r05, 2.f);
1057
1058 // float tmp34a = (r0[6] + r0[2] * 0.25 - r0[4] * 1.25);
1059 // float tmp34b = (r0[1] * 0.5 - r0[3] * 2.5 + r0[5] * 2);
1060
1061 __m256 _tmp3m = _mm256_add_ps(_tmp34a, _tmp34b);
1062 __m256 _tmp4m = _mm256_sub_ps(_tmp34a, _tmp34b);
1063 _mm256_storeu_ps(tmp[3][m], _tmp3m);
1064 _mm256_storeu_ps(tmp[4][m], _tmp4m);
1065
1066 // tmp[3][m] = tmp34a + tmp34b;
1067 // tmp[4][m] = tmp34a - tmp34b;
1068
1069 __m256 _tmp56a = _mm256_fmadd_1_ps(_r06, _mm256_fmrsub_1_ps(_r02, _r04, 1.25f), 4.f);
1070 __m256 _tmp56b = _mm256_fmadd_1_ps(_mm256_fmrsub_1_ps(_mm256_mul_ps(_r01, _mm256_set1_ps(2.f)), _r03, 2.5f), _r05, 0.5f);
1071
1072 // float tmp56a = (r0[6] + (r0[2] - r0[4] * 1.25) * 4);
1073 // float tmp56b = (r0[1] * 2 - r0[3] * 2.5 + r0[5] * 0.5);
1074
1075 __m256 _tmp5m = _mm256_add_ps(_tmp56a, _tmp56b);
1076 __m256 _tmp6m = _mm256_sub_ps(_tmp56a, _tmp56b);
1077 _mm256_storeu_ps(tmp[5][m], _tmp5m);
1078 _mm256_storeu_ps(tmp[6][m], _tmp6m);
1079
1080 // tmp[5][m] = tmp56a + tmp56b;
1081 // tmp[6][m] = tmp56a - tmp56b;
1082
1083 r0 += w * 8;
1084 }
1085
1086 float* r0_tm_0 = (float*)img0_tm + (i * w_tm / 8 + j) * 8;
1087 float* r0_tm_1 = r0_tm_0 + tiles * 8;
1088 float* r0_tm_2 = r0_tm_0 + tiles * 16;
1089 float* r0_tm_3 = r0_tm_0 + tiles * 24;
1090 float* r0_tm_4 = r0_tm_0 + tiles * 32;
1091 float* r0_tm_5 = r0_tm_0 + tiles * 40;
1092 float* r0_tm_6 = r0_tm_0 + tiles * 48;
1093 float* r0_tm_7 = r0_tm_0 + tiles * 56;
1094
1095 for (int m = 0; m < 8; m++)
1096 {
1097 __m256 _tmp00 = _mm256_loadu_ps(tmp[m][0]);
1098 __m256 _tmp01 = _mm256_loadu_ps(tmp[m][1]);
1099 __m256 _tmp02 = _mm256_loadu_ps(tmp[m][2]);
1100 __m256 _tmp03 = _mm256_loadu_ps(tmp[m][3]);
1101 __m256 _tmp04 = _mm256_loadu_ps(tmp[m][4]);
1102 __m256 _tmp05 = _mm256_loadu_ps(tmp[m][5]);
1103 __m256 _tmp06 = _mm256_loadu_ps(tmp[m][6]);
1104 __m256 _tmp07 = _mm256_loadu_ps(tmp[m][7]);
1105
1106 __m256 _r0tm0 = _mm256_fmadd_1_ps(_mm256_sub_ps(_tmp00, _tmp06), _mm256_sub_ps(_tmp04, _tmp02), 5.25f);
1107 __m256 _r0tm7 = _mm256_fmadd_1_ps(_mm256_sub_ps(_tmp07, _tmp01), _mm256_sub_ps(_tmp03, _tmp05), 5.25f);
1108
1109 // r0_tm[0] = tmp0[0] - tmp0[6] + (tmp0[4] - tmp0[2]) * 5.25;
1110 // r0_tm[7] = tmp0[7] - tmp0[1] + (tmp0[3] - tmp0[5]) * 5.25;
1111
1112 __m256 _tmp12a = _mm256_fmrsub_1_ps(_mm256_add_ps(_tmp02, _tmp06), _tmp04, 4.25f);
1113 __m256 _tmp12b = _mm256_fmrsub_1_ps(_mm256_add_ps(_tmp01, _tmp05), _tmp03, 4.25f);
1114
1115 // float tmp12a = (tmp0[2] + tmp0[6] - tmp0[4] * 4.25);
1116 // float tmp12b = (tmp0[1] + tmp0[5] - tmp0[3] * 4.25);
1117
1118 __m256 _r0tm1 = _mm256_add_ps(_tmp12a, _tmp12b);
1119 __m256 _r0tm2 = _mm256_sub_ps(_tmp12a, _tmp12b);
1120
1121 // r0_tm[1] = tmp12a + tmp12b;
1122 // r0_tm[2] = tmp12a - tmp12b;
1123
1124 __m256 _tmp34a = _mm256_fmrsub_1_ps(_mm256_fmadd_1_ps(_tmp06, _tmp02, 0.25f), _tmp04, 1.25f);
1125 __m256 _tmp34b = _mm256_fmadd_1_ps(_mm256_fmrsub_1_ps(_mm256_mul_ps(_tmp01, _mm256_set1_ps(0.5f)), _tmp03, 2.5f), _tmp05, 2.f);
1126
1127 // float tmp34a = (tmp0[6] + tmp0[2] * 0.25 - tmp0[4] * 1.25);
1128 // float tmp34b = (tmp0[1] * 0.5 - tmp0[3] * 2.5 + tmp0[5] * 2);
1129
1130 __m256 _r0tm3 = _mm256_add_ps(_tmp34a, _tmp34b);
1131 __m256 _r0tm4 = _mm256_sub_ps(_tmp34a, _tmp34b);
1132
1133 // r0_tm[3] = tmp34a + tmp34b;
1134 // r0_tm[4] = tmp34a - tmp34b;
1135
1136 __m256 _tmp56a = _mm256_fmadd_1_ps(_tmp06, _mm256_fmrsub_1_ps(_tmp02, _tmp04, 1.25f), 4.f);
1137 __m256 _tmp56b = _mm256_fmadd_1_ps(_mm256_fmrsub_1_ps(_mm256_mul_ps(_tmp01, _mm256_set1_ps(2.f)), _tmp03, 2.5f), _tmp05, 0.5f);
1138
1139 // float tmp56a = (tmp0[6] + (tmp0[2] - tmp0[4] * 1.25) * 4);
1140 // float tmp56b = (tmp0[1] * 2 - tmp0[3] * 2.5 + tmp0[5] * 0.5);
1141
1142 __m256 _r0tm5 = _mm256_add_ps(_tmp56a, _tmp56b);
1143 __m256 _r0tm6 = _mm256_sub_ps(_tmp56a, _tmp56b);
1144
1145 // r0_tm[5] = tmp56a + tmp56b;
1146 // r0_tm[6] = tmp56a - tmp56b;
1147
1148 _mm256_storeu_ps(r0_tm_0, _r0tm0);
1149 _mm256_storeu_ps(r0_tm_1, _r0tm1);
1150 _mm256_storeu_ps(r0_tm_2, _r0tm2);
1151 _mm256_storeu_ps(r0_tm_3, _r0tm3);
1152 _mm256_storeu_ps(r0_tm_4, _r0tm4);
1153 _mm256_storeu_ps(r0_tm_5, _r0tm5);
1154 _mm256_storeu_ps(r0_tm_6, _r0tm6);
1155 _mm256_storeu_ps(r0_tm_7, _r0tm7);
1156
1157 r0_tm_0 += tiles * 64;
1158 r0_tm_1 += tiles * 64;
1159 r0_tm_2 += tiles * 64;
1160 r0_tm_3 += tiles * 64;
1161 r0_tm_4 += tiles * 64;
1162 r0_tm_5 += tiles * 64;
1163 r0_tm_6 += tiles * 64;
1164 r0_tm_7 += tiles * 64;
1165 }
1166 }
1167 }
1168 }
1169 }
1170 bottom_blob_bordered = Mat();
1171 // END transform input
1172 // BEGIN dot
1173 Mat top_blob_tm;
1174 {
1175 int w_tm = outw / 6 * 8;
1176 int h_tm = outh / 6 * 8;
1177
1178 const int tiles = h_tm / 8 * w_tm / 8;
1179
1180 Mat bottom_blob_tm2;
1181
1182 if (tiles >= 12)
1183 bottom_blob_tm2.create(12 * inch, tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2, 64, elemsize, elempack, opt.workspace_allocator);
1184 else if (tiles >= 8)
1185 bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + (tiles % 4) / 2 + tiles % 2, 64, elemsize, elempack, opt.workspace_allocator);
1186 else if (tiles >= 4)
1187 bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, 64, elemsize, elempack, opt.workspace_allocator);
1188 else if (tiles >= 2)
1189 bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, 64, elemsize, elempack, opt.workspace_allocator);
1190 else // if (tiles >= 1)
1191 bottom_blob_tm2.create(1 * inch, tiles, 64, elemsize, elempack, opt.workspace_allocator);
1192
1193 #pragma omp parallel for num_threads(opt.num_threads)
1194 for (int r = 0; r < 64; r++)
1195 {
1196 Mat tm2 = bottom_blob_tm2.channel(r);
1197
1198 // tile
1199 int i = 0;
1200
1201 for (; i + 11 < tiles; i += 12)
1202 {
1203 float* tm2p = tm2.row(i / 12);
1204
1205 const float* r0 = bottom_blob_tm;
1206
1207 r0 += (r * tiles + i) * 8;
1208
1209 for (int q = 0; q < inch; q++)
1210 {
1211 __m256 _r0 = _mm256_loadu_ps(r0);
1212 __m256 _r1 = _mm256_loadu_ps(r0 + 8);
1213 __m256 _r2 = _mm256_loadu_ps(r0 + 16);
1214 __m256 _r3 = _mm256_loadu_ps(r0 + 24);
1215 __m256 _r4 = _mm256_loadu_ps(r0 + 32);
1216 __m256 _r5 = _mm256_loadu_ps(r0 + 40);
1217 __m256 _r6 = _mm256_loadu_ps(r0 + 48);
1218 __m256 _r7 = _mm256_loadu_ps(r0 + 56);
1219 __m256 _r8 = _mm256_loadu_ps(r0 + 64);
1220 __m256 _r9 = _mm256_loadu_ps(r0 + 72);
1221 __m256 _r10 = _mm256_loadu_ps(r0 + 80);
1222 __m256 _r11 = _mm256_loadu_ps(r0 + 88);
1223 _mm256_storeu_ps(tm2p, _r0);
1224 _mm256_storeu_ps(tm2p + 8, _r1);
1225 _mm256_storeu_ps(tm2p + 16, _r2);
1226 _mm256_storeu_ps(tm2p + 24, _r3);
1227 _mm256_storeu_ps(tm2p + 32, _r4);
1228 _mm256_storeu_ps(tm2p + 40, _r5);
1229 _mm256_storeu_ps(tm2p + 48, _r6);
1230 _mm256_storeu_ps(tm2p + 56, _r7);
1231 _mm256_storeu_ps(tm2p + 64, _r8);
1232 _mm256_storeu_ps(tm2p + 72, _r9);
1233 _mm256_storeu_ps(tm2p + 80, _r10);
1234 _mm256_storeu_ps(tm2p + 88, _r11);
1235 tm2p += 96;
1236 r0 += bottom_blob_tm.cstep * 8;
1237 }
1238 }
1239 for (; i + 7 < tiles; i += 8)
1240 {
1241 float* tm2p = tm2.row(i / 12 + (i % 12) / 8);
1242
1243 const float* r0 = bottom_blob_tm;
1244
1245 r0 += (r * tiles + i) * 8;
1246
1247 for (int q = 0; q < inch; q++)
1248 {
1249 __m256 _r0 = _mm256_loadu_ps(r0);
1250 __m256 _r1 = _mm256_loadu_ps(r0 + 8);
1251 _mm256_storeu_ps(tm2p, _r0);
1252 _mm256_storeu_ps(tm2p + 8, _r1);
1253 __m256 _r2 = _mm256_loadu_ps(r0 + 16);
1254 __m256 _r3 = _mm256_loadu_ps(r0 + 24);
1255 _mm256_storeu_ps(tm2p + 16, _r2);
1256 _mm256_storeu_ps(tm2p + 24, _r3);
1257 __m256 _r4 = _mm256_loadu_ps(r0 + 32);
1258 __m256 _r5 = _mm256_loadu_ps(r0 + 40);
1259 _mm256_storeu_ps(tm2p + 32, _r4);
1260 _mm256_storeu_ps(tm2p + 40, _r5);
1261 __m256 _r6 = _mm256_loadu_ps(r0 + 48);
1262 __m256 _r7 = _mm256_loadu_ps(r0 + 56);
1263 _mm256_storeu_ps(tm2p + 48, _r6);
1264 _mm256_storeu_ps(tm2p + 56, _r7);
1265 tm2p += 64;
1266 r0 += bottom_blob_tm.cstep * 8;
1267 }
1268 }
1269 for (; i + 3 < tiles; i += 4)
1270 {
1271 float* tm2p = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1272
1273 const float* r0 = bottom_blob_tm;
1274
1275 r0 += (r * tiles + i) * 8;
1276
1277 for (int q = 0; q < inch; q++)
1278 {
1279 __m256 _r0 = _mm256_loadu_ps(r0);
1280 __m256 _r1 = _mm256_loadu_ps(r0 + 8);
1281 _mm256_storeu_ps(tm2p, _r0);
1282 _mm256_storeu_ps(tm2p + 8, _r1);
1283 __m256 _r2 = _mm256_loadu_ps(r0 + 16);
1284 __m256 _r3 = _mm256_loadu_ps(r0 + 24);
1285 _mm256_storeu_ps(tm2p + 16, _r2);
1286 _mm256_storeu_ps(tm2p + 24, _r3);
1287 tm2p += 32;
1288 r0 += bottom_blob_tm.cstep * 8;
1289 }
1290 }
1291 for (; i + 1 < tiles; i += 2)
1292 {
1293 float* tm2p = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
1294
1295 const float* r0 = bottom_blob_tm;
1296
1297 r0 += (r * tiles + i) * 8;
1298
1299 for (int q = 0; q < inch; q++)
1300 {
1301 __m256 _r0 = _mm256_loadu_ps(r0);
1302 __m256 _r1 = _mm256_loadu_ps(r0 + 8);
1303 _mm256_storeu_ps(tm2p, _r0);
1304 _mm256_storeu_ps(tm2p + 8, _r1);
1305 tm2p += 16;
1306 r0 += bottom_blob_tm.cstep * 8;
1307 }
1308 }
1309
1310 for (; i < tiles; i++)
1311 {
1312 float* tm2p = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
1313
1314 const float* r0 = bottom_blob_tm;
1315 r0 += (r * tiles + i) * 8;
1316
1317 for (int q = 0; q < inch; q++)
1318 {
1319 __m256 _r0 = _mm256_loadu_ps(r0);
1320 _mm256_storeu_ps(tm2p, _r0);
1321 tm2p += 8;
1322 r0 += bottom_blob_tm.cstep * 8;
1323 }
1324 }
1325 }
1326
1327 bottom_blob_tm = Mat();
1328 // permute end
1329
1330 top_blob_tm.create(tiles, 64, outch, elemsize, elempack, opt.workspace_allocator);
1331
1332 #pragma omp parallel for num_threads(opt.num_threads)
1333 for (int p = 0; p < outch; p++)
1334 {
1335 float* output0_tm = top_blob_tm.channel(p);
1336
1337 const Mat kernel0_tm = kernel_tm.channel(p);
1338
1339 for (int r = 0; r < 64; r++)
1340 {
1341 const Mat bb2 = bottom_blob_tm2.channel(r);
1342
1343 int i = 0;
1344 for (; i + 11 < tiles; i += 12)
1345 {
1346 const float* r0 = bb2.row(i / 12);
1347 const float* k01 = kernel0_tm.row(r);
1348
1349 int nn = inch; // inch always > 0
1350 __m256 _sum0 = _mm256_set1_ps(0.f);
1351 __m256 _sum1 = _mm256_set1_ps(0.f);
1352 __m256 _sum2 = _mm256_set1_ps(0.f);
1353 __m256 _sum3 = _mm256_set1_ps(0.f);
1354 __m256 _sum4 = _mm256_set1_ps(0.f);
1355 __m256 _sum5 = _mm256_set1_ps(0.f);
1356 __m256 _sum6 = _mm256_set1_ps(0.f);
1357 __m256 _sum7 = _mm256_set1_ps(0.f);
1358 __m256 _sum8 = _mm256_set1_ps(0.f);
1359 __m256 _sum9 = _mm256_set1_ps(0.f);
1360 __m256 _sum10 = _mm256_set1_ps(0.f);
1361 __m256 _sum11 = _mm256_set1_ps(0.f);
1362
1363 for (; nn > 0; nn--)
1364 {
1365 __m256 _k01 = _mm256_loadu_ps(k01);
1366 __m256 _r00 = _mm256_broadcast_ss(r0 + 0);
1367 __m256 _r01 = _mm256_broadcast_ss(r0 + 8);
1368 __m256 _r02 = _mm256_broadcast_ss(r0 + 16);
1369 __m256 _r03 = _mm256_broadcast_ss(r0 + 24);
1370 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1371 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1372 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1373 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1374 __m256 _r04 = _mm256_broadcast_ss(r0 + 32);
1375 __m256 _r05 = _mm256_broadcast_ss(r0 + 40);
1376 __m256 _r06 = _mm256_broadcast_ss(r0 + 48);
1377 __m256 _r07 = _mm256_broadcast_ss(r0 + 56);
1378 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1379 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1380 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1381 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1382 __m256 _r08 = _mm256_broadcast_ss(r0 + 64);
1383 __m256 _r09 = _mm256_broadcast_ss(r0 + 72);
1384 __m256 _r010 = _mm256_broadcast_ss(r0 + 80);
1385 __m256 _r011 = _mm256_broadcast_ss(r0 + 88);
1386
1387 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1388 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1389 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1390 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1391
1392 _k01 = _mm256_loadu_ps(k01 + 8);
1393 _r00 = _mm256_broadcast_ss(r0 + 1);
1394 _r01 = _mm256_broadcast_ss(r0 + 9);
1395 _r02 = _mm256_broadcast_ss(r0 + 17);
1396 _r03 = _mm256_broadcast_ss(r0 + 25);
1397 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1398 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1399 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1400 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1401 _r04 = _mm256_broadcast_ss(r0 + 33);
1402 _r05 = _mm256_broadcast_ss(r0 + 41);
1403 _r06 = _mm256_broadcast_ss(r0 + 49);
1404 _r07 = _mm256_broadcast_ss(r0 + 57);
1405 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1406 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1407 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1408 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1409
1410 _r08 = _mm256_broadcast_ss(r0 + 65);
1411 _r09 = _mm256_broadcast_ss(r0 + 73);
1412 _r010 = _mm256_broadcast_ss(r0 + 81);
1413 _r011 = _mm256_broadcast_ss(r0 + 89);
1414
1415 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1416 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1417 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1418 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1419
1420 _k01 = _mm256_loadu_ps(k01 + 16);
1421 _r00 = _mm256_broadcast_ss(r0 + 2);
1422 _r01 = _mm256_broadcast_ss(r0 + 10);
1423 _r02 = _mm256_broadcast_ss(r0 + 18);
1424 _r03 = _mm256_broadcast_ss(r0 + 26);
1425 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1426 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1427 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1428 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1429 _r04 = _mm256_broadcast_ss(r0 + 34);
1430 _r05 = _mm256_broadcast_ss(r0 + 42);
1431 _r06 = _mm256_broadcast_ss(r0 + 50);
1432 _r07 = _mm256_broadcast_ss(r0 + 58);
1433 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1434 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1435 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1436 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1437 _r08 = _mm256_broadcast_ss(r0 + 66);
1438 _r09 = _mm256_broadcast_ss(r0 + 74);
1439 _r010 = _mm256_broadcast_ss(r0 + 82);
1440 _r011 = _mm256_broadcast_ss(r0 + 90);
1441
1442 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1443 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1444 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1445 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1446
1447 _k01 = _mm256_loadu_ps(k01 + 24);
1448 _r00 = _mm256_broadcast_ss(r0 + 3);
1449 _r01 = _mm256_broadcast_ss(r0 + 11);
1450 _r02 = _mm256_broadcast_ss(r0 + 19);
1451 _r03 = _mm256_broadcast_ss(r0 + 27);
1452 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1453 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1454 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1455 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1456
1457 _r04 = _mm256_broadcast_ss(r0 + 35);
1458 _r05 = _mm256_broadcast_ss(r0 + 43);
1459 _r06 = _mm256_broadcast_ss(r0 + 51);
1460 _r07 = _mm256_broadcast_ss(r0 + 59);
1461 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1462 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1463 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1464 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1465
1466 _r08 = _mm256_broadcast_ss(r0 + 67);
1467 _r09 = _mm256_broadcast_ss(r0 + 75);
1468 _r010 = _mm256_broadcast_ss(r0 + 83);
1469 _r011 = _mm256_broadcast_ss(r0 + 91);
1470
1471 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1472 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1473 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1474 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1475
1476 _k01 = _mm256_loadu_ps(k01 + 32);
1477 _r00 = _mm256_broadcast_ss(r0 + 4);
1478 _r01 = _mm256_broadcast_ss(r0 + 12);
1479 _r02 = _mm256_broadcast_ss(r0 + 20);
1480 _r03 = _mm256_broadcast_ss(r0 + 28);
1481 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1482 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1483 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1484 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1485
1486 _r04 = _mm256_broadcast_ss(r0 + 36);
1487 _r05 = _mm256_broadcast_ss(r0 + 44);
1488 _r06 = _mm256_broadcast_ss(r0 + 52);
1489 _r07 = _mm256_broadcast_ss(r0 + 60);
1490 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1491 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1492 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1493 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1494
1495 _r08 = _mm256_broadcast_ss(r0 + 68);
1496 _r09 = _mm256_broadcast_ss(r0 + 76);
1497 _r010 = _mm256_broadcast_ss(r0 + 84);
1498 _r011 = _mm256_broadcast_ss(r0 + 92);
1499
1500 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1501 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1502 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1503 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1504
1505 _k01 = _mm256_loadu_ps(k01 + 40);
1506 _r00 = _mm256_broadcast_ss(r0 + 5);
1507 _r01 = _mm256_broadcast_ss(r0 + 13);
1508 _r02 = _mm256_broadcast_ss(r0 + 21);
1509 _r03 = _mm256_broadcast_ss(r0 + 29);
1510 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1511 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1512 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1513 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1514
1515 _r04 = _mm256_broadcast_ss(r0 + 37);
1516 _r05 = _mm256_broadcast_ss(r0 + 45);
1517 _r06 = _mm256_broadcast_ss(r0 + 53);
1518 _r07 = _mm256_broadcast_ss(r0 + 61);
1519 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1520 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1521 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1522 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1523
1524 _r08 = _mm256_broadcast_ss(r0 + 69);
1525 _r09 = _mm256_broadcast_ss(r0 + 77);
1526 _r010 = _mm256_broadcast_ss(r0 + 85);
1527 _r011 = _mm256_broadcast_ss(r0 + 93);
1528
1529 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1530 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1531 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1532 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1533
1534 _k01 = _mm256_loadu_ps(k01 + 48);
1535 _r00 = _mm256_broadcast_ss(r0 + 6);
1536 _r01 = _mm256_broadcast_ss(r0 + 14);
1537 _r02 = _mm256_broadcast_ss(r0 + 22);
1538 _r03 = _mm256_broadcast_ss(r0 + 30);
1539 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1540 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1541 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1542 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1543
1544 _r04 = _mm256_broadcast_ss(r0 + 38);
1545 _r05 = _mm256_broadcast_ss(r0 + 46);
1546 _r06 = _mm256_broadcast_ss(r0 + 54);
1547 _r07 = _mm256_broadcast_ss(r0 + 62);
1548 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1549 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1550 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1551 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1552 _r08 = _mm256_broadcast_ss(r0 + 70);
1553 _r09 = _mm256_broadcast_ss(r0 + 78);
1554 _r010 = _mm256_broadcast_ss(r0 + 86);
1555 _r011 = _mm256_broadcast_ss(r0 + 94);
1556
1557 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1558 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1559 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1560 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1561
1562 _k01 = _mm256_loadu_ps(k01 + 56);
1563 _r00 = _mm256_broadcast_ss(r0 + 7);
1564 _r01 = _mm256_broadcast_ss(r0 + 15);
1565 _r02 = _mm256_broadcast_ss(r0 + 23);
1566 _r03 = _mm256_broadcast_ss(r0 + 31);
1567 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1568 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1569 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1570 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1571
1572 _r04 = _mm256_broadcast_ss(r0 + 39);
1573 _r05 = _mm256_broadcast_ss(r0 + 47);
1574 _r06 = _mm256_broadcast_ss(r0 + 55);
1575 _r07 = _mm256_broadcast_ss(r0 + 63);
1576 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1577 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1578 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1579 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1580
1581 _r08 = _mm256_broadcast_ss(r0 + 71);
1582 _r09 = _mm256_broadcast_ss(r0 + 79);
1583 _r010 = _mm256_broadcast_ss(r0 + 87);
1584 _r011 = _mm256_broadcast_ss(r0 + 95);
1585 _sum8 = _mm256_fmadd_ps(_k01, _r08, _sum8);
1586 _sum9 = _mm256_fmadd_ps(_k01, _r09, _sum9);
1587 _sum10 = _mm256_fmadd_ps(_k01, _r010, _sum10);
1588 _sum11 = _mm256_fmadd_ps(_k01, _r011, _sum11);
1589
1590 k01 += 64;
1591 r0 += 96;
1592 }
1593 _mm256_storeu_ps(output0_tm, _sum0);
1594 _mm256_storeu_ps(output0_tm + 8, _sum1);
1595 _mm256_storeu_ps(output0_tm + 16, _sum2);
1596 _mm256_storeu_ps(output0_tm + 24, _sum3);
1597 _mm256_storeu_ps(output0_tm + 32, _sum4);
1598 _mm256_storeu_ps(output0_tm + 40, _sum5);
1599 _mm256_storeu_ps(output0_tm + 48, _sum6);
1600 _mm256_storeu_ps(output0_tm + 56, _sum7);
1601 _mm256_storeu_ps(output0_tm + 64, _sum8);
1602 _mm256_storeu_ps(output0_tm + 72, _sum9);
1603 _mm256_storeu_ps(output0_tm + 80, _sum10);
1604 _mm256_storeu_ps(output0_tm + 88, _sum11);
1605 output0_tm += 96;
1606 }
1607 for (; i + 7 < tiles; i += 8)
1608 {
1609 const float* r0 = bb2.row(i / 12 + (i % 12) / 8);
1610 const float* k01 = kernel0_tm.row(r);
1611
1612 int nn = inch; // inch always > 0
1613 __m256 _sum0 = _mm256_set1_ps(0.f);
1614 __m256 _sum1 = _mm256_set1_ps(0.f);
1615 __m256 _sum2 = _mm256_set1_ps(0.f);
1616 __m256 _sum3 = _mm256_set1_ps(0.f);
1617 __m256 _sum4 = _mm256_set1_ps(0.f);
1618 __m256 _sum5 = _mm256_set1_ps(0.f);
1619 __m256 _sum6 = _mm256_set1_ps(0.f);
1620 __m256 _sum7 = _mm256_set1_ps(0.f);
1621
1622 for (; nn > 0; nn--)
1623 {
1624 __m256 _k01 = _mm256_loadu_ps(k01);
1625 __m256 _r00 = _mm256_broadcast_ss(r0 + 0);
1626 __m256 _r01 = _mm256_broadcast_ss(r0 + 8);
1627 __m256 _r02 = _mm256_broadcast_ss(r0 + 16);
1628 __m256 _r03 = _mm256_broadcast_ss(r0 + 24);
1629 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1630 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1631 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1632 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1633 __m256 _r04 = _mm256_broadcast_ss(r0 + 32);
1634 __m256 _r05 = _mm256_broadcast_ss(r0 + 40);
1635 __m256 _r06 = _mm256_broadcast_ss(r0 + 48);
1636 __m256 _r07 = _mm256_broadcast_ss(r0 + 56);
1637 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1638 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1639 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1640 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1641
1642 _k01 = _mm256_loadu_ps(k01 + 8);
1643 _r00 = _mm256_broadcast_ss(r0 + 1);
1644 _r01 = _mm256_broadcast_ss(r0 + 9);
1645 _r02 = _mm256_broadcast_ss(r0 + 17);
1646 _r03 = _mm256_broadcast_ss(r0 + 25);
1647
1648 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1649 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1650 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1651 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1652 _r04 = _mm256_broadcast_ss(r0 + 33);
1653 _r05 = _mm256_broadcast_ss(r0 + 41);
1654 _r06 = _mm256_broadcast_ss(r0 + 49);
1655 _r07 = _mm256_broadcast_ss(r0 + 57);
1656 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1657 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1658 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1659 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1660
1661 _k01 = _mm256_loadu_ps(k01 + 16);
1662 _r00 = _mm256_broadcast_ss(r0 + 2);
1663 _r01 = _mm256_broadcast_ss(r0 + 10);
1664 _r02 = _mm256_broadcast_ss(r0 + 18);
1665 _r03 = _mm256_broadcast_ss(r0 + 26);
1666 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1667 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1668 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1669 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1670
1671 _r04 = _mm256_broadcast_ss(r0 + 34);
1672 _r05 = _mm256_broadcast_ss(r0 + 42);
1673 _r06 = _mm256_broadcast_ss(r0 + 50);
1674 _r07 = _mm256_broadcast_ss(r0 + 58);
1675 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1676 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1677 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1678 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1679
1680 _k01 = _mm256_loadu_ps(k01 + 24);
1681 _r00 = _mm256_broadcast_ss(r0 + 3);
1682 _r01 = _mm256_broadcast_ss(r0 + 11);
1683 _r02 = _mm256_broadcast_ss(r0 + 19);
1684 _r03 = _mm256_broadcast_ss(r0 + 27);
1685 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1686 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1687 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1688 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1689
1690 _r04 = _mm256_broadcast_ss(r0 + 35);
1691 _r05 = _mm256_broadcast_ss(r0 + 43);
1692 _r06 = _mm256_broadcast_ss(r0 + 51);
1693 _r07 = _mm256_broadcast_ss(r0 + 59);
1694 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1695 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1696 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1697 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1698
1699 _k01 = _mm256_loadu_ps(k01 + 32);
1700 _r00 = _mm256_broadcast_ss(r0 + 4);
1701 _r01 = _mm256_broadcast_ss(r0 + 12);
1702 _r02 = _mm256_broadcast_ss(r0 + 20);
1703 _r03 = _mm256_broadcast_ss(r0 + 28);
1704 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1705 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1706 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1707 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1708
1709 _r04 = _mm256_broadcast_ss(r0 + 36);
1710 _r05 = _mm256_broadcast_ss(r0 + 44);
1711 _r06 = _mm256_broadcast_ss(r0 + 52);
1712 _r07 = _mm256_broadcast_ss(r0 + 60);
1713 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1714 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1715 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1716 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1717
1718 _k01 = _mm256_loadu_ps(k01 + 40);
1719 _r00 = _mm256_broadcast_ss(r0 + 5);
1720 _r01 = _mm256_broadcast_ss(r0 + 13);
1721 _r02 = _mm256_broadcast_ss(r0 + 21);
1722 _r03 = _mm256_broadcast_ss(r0 + 29);
1723 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1724 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1725 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1726 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1727
1728 _r04 = _mm256_broadcast_ss(r0 + 37);
1729 _r05 = _mm256_broadcast_ss(r0 + 45);
1730 _r06 = _mm256_broadcast_ss(r0 + 53);
1731 _r07 = _mm256_broadcast_ss(r0 + 61);
1732 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1733 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1734 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1735 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1736
1737 _k01 = _mm256_loadu_ps(k01 + 48);
1738 _r00 = _mm256_broadcast_ss(r0 + 6);
1739 _r01 = _mm256_broadcast_ss(r0 + 14);
1740 _r02 = _mm256_broadcast_ss(r0 + 22);
1741 _r03 = _mm256_broadcast_ss(r0 + 30);
1742 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1743 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1744 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1745 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1746
1747 _r04 = _mm256_broadcast_ss(r0 + 38);
1748 _r05 = _mm256_broadcast_ss(r0 + 46);
1749 _r06 = _mm256_broadcast_ss(r0 + 54);
1750 _r07 = _mm256_broadcast_ss(r0 + 62);
1751 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1752 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1753 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1754 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1755
1756 _k01 = _mm256_loadu_ps(k01 + 56);
1757 _r00 = _mm256_broadcast_ss(r0 + 7);
1758 _r01 = _mm256_broadcast_ss(r0 + 15);
1759 _r02 = _mm256_broadcast_ss(r0 + 23);
1760 _r03 = _mm256_broadcast_ss(r0 + 31);
1761 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1762 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1763 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1764 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1765
1766 _r04 = _mm256_broadcast_ss(r0 + 39);
1767 _r05 = _mm256_broadcast_ss(r0 + 47);
1768 _r06 = _mm256_broadcast_ss(r0 + 55);
1769 _r07 = _mm256_broadcast_ss(r0 + 63);
1770 _sum4 = _mm256_fmadd_ps(_k01, _r04, _sum4);
1771 _sum5 = _mm256_fmadd_ps(_k01, _r05, _sum5);
1772 _sum6 = _mm256_fmadd_ps(_k01, _r06, _sum6);
1773 _sum7 = _mm256_fmadd_ps(_k01, _r07, _sum7);
1774
1775 k01 += 64;
1776 r0 += 64;
1777 }
1778 _mm256_storeu_ps(output0_tm, _sum0);
1779 _mm256_storeu_ps(output0_tm + 8, _sum1);
1780 _mm256_storeu_ps(output0_tm + 16, _sum2);
1781 _mm256_storeu_ps(output0_tm + 24, _sum3);
1782 _mm256_storeu_ps(output0_tm + 32, _sum4);
1783 _mm256_storeu_ps(output0_tm + 40, _sum5);
1784 _mm256_storeu_ps(output0_tm + 48, _sum6);
1785 _mm256_storeu_ps(output0_tm + 56, _sum7);
1786 output0_tm += 64;
1787 }
1788 for (; i + 3 < tiles; i += 4)
1789 {
1790 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1791
1792 const float* k01 = kernel0_tm.row(r);
1793
1794 int nn = inch; // inch always > 0
1795 __m256 _sum0 = _mm256_set1_ps(0.f);
1796 __m256 _sum1 = _mm256_set1_ps(0.f);
1797 __m256 _sum2 = _mm256_set1_ps(0.f);
1798 __m256 _sum3 = _mm256_set1_ps(0.f);
1799
1800 for (; nn > 0; nn--)
1801 {
1802 __m256 _k01 = _mm256_loadu_ps(k01);
1803 __m256 _r00 = _mm256_broadcast_ss(r0 + 0);
1804 __m256 _r01 = _mm256_broadcast_ss(r0 + 8);
1805 __m256 _r02 = _mm256_broadcast_ss(r0 + 16);
1806 __m256 _r03 = _mm256_broadcast_ss(r0 + 24);
1807 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1808 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1809 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1810 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1811
1812 _k01 = _mm256_loadu_ps(k01 + 8);
1813 _r00 = _mm256_broadcast_ss(r0 + 1);
1814 _r01 = _mm256_broadcast_ss(r0 + 9);
1815 _r02 = _mm256_broadcast_ss(r0 + 17);
1816 _r03 = _mm256_broadcast_ss(r0 + 25);
1817 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1818 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1819 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1820 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1821
1822 _k01 = _mm256_loadu_ps(k01 + 16);
1823 _r00 = _mm256_broadcast_ss(r0 + 2);
1824 _r01 = _mm256_broadcast_ss(r0 + 10);
1825 _r02 = _mm256_broadcast_ss(r0 + 18);
1826 _r03 = _mm256_broadcast_ss(r0 + 26);
1827 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1828 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1829 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1830 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1831
1832 _k01 = _mm256_loadu_ps(k01 + 24);
1833 _r00 = _mm256_broadcast_ss(r0 + 3);
1834 _r01 = _mm256_broadcast_ss(r0 + 11);
1835 _r02 = _mm256_broadcast_ss(r0 + 19);
1836 _r03 = _mm256_broadcast_ss(r0 + 27);
1837 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1838 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1839 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1840 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1841
1842 _k01 = _mm256_loadu_ps(k01 + 32);
1843 _r00 = _mm256_broadcast_ss(r0 + 4);
1844 _r01 = _mm256_broadcast_ss(r0 + 12);
1845 _r02 = _mm256_broadcast_ss(r0 + 20);
1846 _r03 = _mm256_broadcast_ss(r0 + 28);
1847 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1848 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1849 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1850 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1851
1852 _k01 = _mm256_loadu_ps(k01 + 40);
1853 _r00 = _mm256_broadcast_ss(r0 + 5);
1854 _r01 = _mm256_broadcast_ss(r0 + 13);
1855 _r02 = _mm256_broadcast_ss(r0 + 21);
1856 _r03 = _mm256_broadcast_ss(r0 + 29);
1857 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1858 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1859 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1860 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1861
1862 _k01 = _mm256_loadu_ps(k01 + 48);
1863 _r00 = _mm256_broadcast_ss(r0 + 6);
1864 _r01 = _mm256_broadcast_ss(r0 + 14);
1865 _r02 = _mm256_broadcast_ss(r0 + 22);
1866 _r03 = _mm256_broadcast_ss(r0 + 30);
1867 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1868 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1869 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1870 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1871
1872 _k01 = _mm256_loadu_ps(k01 + 56);
1873 _r00 = _mm256_broadcast_ss(r0 + 7);
1874 _r01 = _mm256_broadcast_ss(r0 + 15);
1875 _r02 = _mm256_broadcast_ss(r0 + 23);
1876 _r03 = _mm256_broadcast_ss(r0 + 31);
1877 _sum0 = _mm256_fmadd_ps(_k01, _r00, _sum0);
1878 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1879 _sum2 = _mm256_fmadd_ps(_k01, _r02, _sum2);
1880 _sum3 = _mm256_fmadd_ps(_k01, _r03, _sum3);
1881 k01 += 64;
1882 r0 += 32;
1883 }
1884 _mm256_storeu_ps(output0_tm, _sum0);
1885 _mm256_storeu_ps(output0_tm + 8, _sum1);
1886 _mm256_storeu_ps(output0_tm + 16, _sum2);
1887 _mm256_storeu_ps(output0_tm + 24, _sum3);
1888 output0_tm += 32;
1889 }
1890 for (; i + 1 < tiles; i += 2)
1891 {
1892 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
1893
1894 const float* k01 = kernel0_tm.row(r);
1895
1896 int nn = inch; // inch always > 0
1897 __m256 _sum0 = _mm256_set1_ps(0.f);
1898 __m256 _sum1 = _mm256_set1_ps(0.f);
1899
1900 for (; nn > 0; nn--)
1901 {
1902 __m256 _k01 = _mm256_loadu_ps(k01);
1903 __m256 _r0 = _mm256_broadcast_ss(r0);
1904 __m256 _r01 = _mm256_broadcast_ss(r0 + 8);
1905 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1906 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1907
1908 _k01 = _mm256_loadu_ps(k01 + 8);
1909 _r0 = _mm256_broadcast_ss(r0 + 1);
1910 _r01 = _mm256_broadcast_ss(r0 + 9);
1911 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1912 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1913
1914 _k01 = _mm256_loadu_ps(k01 + 16);
1915 _r0 = _mm256_broadcast_ss(r0 + 2);
1916 _r01 = _mm256_broadcast_ss(r0 + 10);
1917 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1918 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1919
1920 _k01 = _mm256_loadu_ps(k01 + 24);
1921 _r0 = _mm256_broadcast_ss(r0 + 3);
1922 _r01 = _mm256_broadcast_ss(r0 + 11);
1923 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1924 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1925
1926 _k01 = _mm256_loadu_ps(k01 + 32);
1927 _r0 = _mm256_broadcast_ss(r0 + 4);
1928 _r01 = _mm256_broadcast_ss(r0 + 12);
1929 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1930 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1931
1932 _k01 = _mm256_loadu_ps(k01 + 40);
1933 _r0 = _mm256_broadcast_ss(r0 + 5);
1934 _r01 = _mm256_broadcast_ss(r0 + 13);
1935 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1936 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1937
1938 _k01 = _mm256_loadu_ps(k01 + 48);
1939 _r0 = _mm256_broadcast_ss(r0 + 6);
1940 _r01 = _mm256_broadcast_ss(r0 + 14);
1941 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1942 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1943
1944 _k01 = _mm256_loadu_ps(k01 + 56);
1945 _r0 = _mm256_broadcast_ss(r0 + 7);
1946 _r01 = _mm256_broadcast_ss(r0 + 15);
1947 _sum0 = _mm256_fmadd_ps(_k01, _r0, _sum0);
1948 _sum1 = _mm256_fmadd_ps(_k01, _r01, _sum1);
1949
1950 k01 += 64;
1951 r0 += 16;
1952 }
1953 _mm256_storeu_ps(output0_tm, _sum0);
1954 _mm256_storeu_ps(output0_tm + 8, _sum1);
1955 output0_tm += 16;
1956 }
1957
1958 for (; i < tiles; i++)
1959 {
1960 const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
1961
1962 const float* k01 = kernel0_tm.row(r);
1963
1964 int nn = inch; // inch always > 0
1965 __m256 _sum0 = _mm256_set1_ps(0.f);
1966
1967 for (; nn > 0; nn--)
1968 {
1969 __m256 _k01 = _mm256_loadu_ps(k01);
1970 __m256 _r0 = _mm256_broadcast_ss(r0);
1971 __m256 _mul0 = _mm256_mul_ps(_k01, _r0);
1972
1973 _k01 = _mm256_loadu_ps(k01 + 8);
1974 _r0 = _mm256_broadcast_ss(r0 + 1);
1975 __m256 _mul1 = _mm256_mul_ps(_k01, _r0);
1976
1977 _k01 = _mm256_loadu_ps(k01 + 16);
1978 _r0 = _mm256_broadcast_ss(r0 + 2);
1979 __m256 _mul2 = _mm256_mul_ps(_k01, _r0);
1980 __m256 _add01 = _mm256_add_ps(_mul0, _mul1);
1981
1982 _k01 = _mm256_loadu_ps(k01 + 24);
1983 _r0 = _mm256_broadcast_ss(r0 + 3);
1984 __m256 _mul3 = _mm256_mul_ps(_k01, _r0);
1985
1986 __m256 _add23 = _mm256_add_ps(_mul2, _mul3);
1987 __m256 _add0123 = _mm256_add_ps(_add01, _add23);
1988 _sum0 = _mm256_add_ps(_sum0, _add0123);
1989
1990 _k01 = _mm256_loadu_ps(k01 + 32);
1991 _r0 = _mm256_broadcast_ss(r0 + 4);
1992 __m256 _mul4 = _mm256_mul_ps(_k01, _r0);
1993
1994 _k01 = _mm256_loadu_ps(k01 + 40);
1995 _r0 = _mm256_broadcast_ss(r0 + 5);
1996 __m256 _mul5 = _mm256_mul_ps(_k01, _r0);
1997
1998 _k01 = _mm256_loadu_ps(k01 + 48);
1999 _r0 = _mm256_broadcast_ss(r0 + 6);
2000 __m256 _mul6 = _mm256_mul_ps(_k01, _r0);
2001 __m256 _add45 = _mm256_add_ps(_mul4, _mul5);
2002
2003 _k01 = _mm256_loadu_ps(k01 + 56);
2004 _r0 = _mm256_broadcast_ss(r0 + 7);
2005 __m256 _mul7 = _mm256_mul_ps(_k01, _r0);
2006
2007 __m256 _add67 = _mm256_add_ps(_mul6, _mul7);
2008 __m256 _add4567 = _mm256_add_ps(_add45, _add67);
2009 _sum0 = _mm256_add_ps(_sum0, _add4567);
2010
2011 k01 += 64;
2012 r0 += 8;
2013 }
2014 _mm256_storeu_ps(output0_tm, _sum0);
2015 output0_tm += 8;
2016 }
2017 }
2018 }
2019 }
2020 bottom_blob_tm = Mat();
2021 // END dot
2022
2023 // BEGIN transform output
2024 Mat top_blob_bordered;
2025 if (outw == top_blob.w && outh == top_blob.h)
2026 {
2027 top_blob_bordered = top_blob;
2028 }
2029 else
2030 {
2031 top_blob_bordered.create(outw, outh, outch, elemsize, elempack, opt.workspace_allocator);
2032 }
2033 {
2034 // const float otm[6][8] = {
2035 // {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 32.0f, 32.0f, 0.0f},
2036 // {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 16.0f,-16.0f, 0.0f},
2037 // {0.0f, 1.0f, 1.0f, 4.0f, 4.0f, 8.0f, 8.0f, 0.0f},
2038 // {0.0f, 1.0f, -1.0f, 8.0f, -8.0f, 4.0f, -4.0f, 0.0f},
2039 // {0.0f, 1.0f, 1.0f, 16.0f, 16.0f, 2.0f, 2.0f, 0.0f},
2040 // {0.0f, 1.0f, -1.0f, 32.0f, -32.0f, 1.0f, -1.0f, 1.0f}
2041 // };
2042
2043 // 0 = r0 + (r1 + r2) + (r3 + r4) + (r5 + r6) * 32
2044 // 1 = (r1 - r2) + (r3 - r4) * 2 + (r5 - r6) * 16
2045 // 2 = (r1 + r2) + (r3 + r4) * 4 + (r5 + r6) * 8
2046 // 3 = (r1 - r2) + (r3 - r4) * 8 + (r5 - r6) * 4
2047 // 4 = (r1 + r2) + (r3 + r4) * 16+ (r5 + r6) * 2
2048 // 5 = r7 + (r1 - r2) + (r3 - r4) * 32+ (r5 - r6)
2049
2050 int w_tm = outw / 6 * 8;
2051 int h_tm = outh / 6 * 8;
2052 const int tiles = w_tm / 8 * h_tm / 8;
2053 #pragma omp parallel for num_threads(opt.num_threads)
2054 for (int p = 0; p < outch; p++)
2055 {
2056 const Mat out0_tm = top_blob_tm.channel(p);
2057 Mat out0 = top_blob_bordered.channel(p);
2058
2059 // const float bias0 = bias ? bias[p] : 0.f;
2060 __m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
2061
2062 float tmp[6][8][8];
2063
2064 // tile
2065 for (int i = 0; i < outh / 6; i++)
2066 {
2067 for (int j = 0; j < outw / 6; j++)
2068 {
2069 // top_blob_tm.create(tiles, 64, outch, elemsize, elempack);
2070
2071 const float* output0_tm_0 = (const float*)out0_tm + (i * w_tm / 8 + j) * 8;
2072 const float* output0_tm_1 = output0_tm_0 + tiles * 8;
2073 const float* output0_tm_2 = output0_tm_0 + tiles * 16;
2074 const float* output0_tm_3 = output0_tm_0 + tiles * 24;
2075 const float* output0_tm_4 = output0_tm_0 + tiles * 32;
2076 const float* output0_tm_5 = output0_tm_0 + tiles * 40;
2077 const float* output0_tm_6 = output0_tm_0 + tiles * 48;
2078 const float* output0_tm_7 = output0_tm_0 + tiles * 56;
2079
2080 float* output0 = out0.row(i * 6) + (j * 6) * 8;
2081
2082 // TODO neon optimize
2083 for (int m = 0; m < 8; m++)
2084 {
2085 __m256 _out0tm0 = _mm256_loadu_ps(output0_tm_0);
2086 __m256 _out0tm1 = _mm256_loadu_ps(output0_tm_1);
2087 __m256 _out0tm2 = _mm256_loadu_ps(output0_tm_2);
2088 __m256 _out0tm3 = _mm256_loadu_ps(output0_tm_3);
2089 __m256 _out0tm4 = _mm256_loadu_ps(output0_tm_4);
2090 __m256 _out0tm5 = _mm256_loadu_ps(output0_tm_5);
2091 __m256 _out0tm6 = _mm256_loadu_ps(output0_tm_6);
2092 __m256 _out0tm7 = _mm256_loadu_ps(output0_tm_7);
2093
2094 __m256 _tmp024a = _mm256_add_ps(_out0tm1, _out0tm2);
2095 __m256 _tmp135a = _mm256_sub_ps(_out0tm1, _out0tm2);
2096
2097 // float tmp024a = output0_tm[1] + output0_tm[2];
2098 // float tmp135a = output0_tm[1] - output0_tm[2];
2099
2100 __m256 _tmp024b = _mm256_add_ps(_out0tm3, _out0tm4);
2101 __m256 _tmp135b = _mm256_sub_ps(_out0tm3, _out0tm4);
2102
2103 // float tmp024b = output0_tm[3] + output0_tm[4];
2104 // float tmp135b = output0_tm[3] - output0_tm[4];
2105
2106 __m256 _tmp024c = _mm256_add_ps(_out0tm5, _out0tm6);
2107 __m256 _tmp135c = _mm256_sub_ps(_out0tm5, _out0tm6);
2108
2109 // float tmp024c = output0_tm[5] + output0_tm[6];
2110 // float tmp135c = output0_tm[5] - output0_tm[6];
2111
2112 __m256 _tmp0m = _mm256_add_ps(_mm256_add_ps(_out0tm0, _tmp024a), _mm256_fmadd_1_ps(_tmp024b, _tmp024c, 32.f));
2113 __m256 _tmp2m = _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f);
2114 __m256 _tmp4m = _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f);
2115 _mm256_storeu_ps(tmp[0][m], _tmp0m);
2116 _mm256_storeu_ps(tmp[2][m], _tmp2m);
2117 _mm256_storeu_ps(tmp[4][m], _tmp4m);
2118
2119 // tmp[0][m] = output0_tm[0] + tmp024a + tmp024b + tmp024c * 32;
2120 // tmp[2][m] = tmp024a + tmp024b * 4 + tmp024c * 8;
2121 // tmp[4][m] = tmp024a + tmp024b * 16 + tmp024c + tmp024c;
2122
2123 __m256 _tmp1m = _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f);
2124 __m256 _tmp3m = _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f);
2125 __m256 _tmp5m = _mm256_add_ps(_mm256_add_ps(_out0tm7, _tmp135a), _mm256_fmadd_1_ps(_tmp135c, _tmp135b, 32.f));
2126 _mm256_storeu_ps(tmp[1][m], _tmp1m);
2127 _mm256_storeu_ps(tmp[3][m], _tmp3m);
2128 _mm256_storeu_ps(tmp[5][m], _tmp5m);
2129
2130 // tmp[1][m] = tmp135a + tmp135b + tmp135b + tmp135c * 16;
2131 // tmp[3][m] = tmp135a + tmp135b * 8 + tmp135c * 4;
2132 // tmp[5][m] = output0_tm[7] + tmp135a + tmp135b * 32 + tmp135c;
2133
2134 output0_tm_0 += tiles * 64;
2135 output0_tm_1 += tiles * 64;
2136 output0_tm_2 += tiles * 64;
2137 output0_tm_3 += tiles * 64;
2138 output0_tm_4 += tiles * 64;
2139 output0_tm_5 += tiles * 64;
2140 output0_tm_6 += tiles * 64;
2141 output0_tm_7 += tiles * 64;
2142 }
2143
2144 for (int m = 0; m < 6; m++)
2145 {
2146 __m256 _tmp00 = _mm256_loadu_ps(tmp[m][0]);
2147 __m256 _tmp01 = _mm256_loadu_ps(tmp[m][1]);
2148 __m256 _tmp02 = _mm256_loadu_ps(tmp[m][2]);
2149 __m256 _tmp03 = _mm256_loadu_ps(tmp[m][3]);
2150 __m256 _tmp04 = _mm256_loadu_ps(tmp[m][4]);
2151 __m256 _tmp05 = _mm256_loadu_ps(tmp[m][5]);
2152 __m256 _tmp06 = _mm256_loadu_ps(tmp[m][6]);
2153 __m256 _tmp07 = _mm256_loadu_ps(tmp[m][7]);
2154
2155 __m256 _tmp024a = _mm256_add_ps(_tmp01, _tmp02);
2156 __m256 _tmp135a = _mm256_sub_ps(_tmp01, _tmp02);
2157
2158 // float tmp024a = tmp0[1] + tmp0[2];
2159 // float tmp135a = tmp0[1] - tmp0[2];
2160
2161 __m256 _tmp024b = _mm256_add_ps(_tmp03, _tmp04);
2162 __m256 _tmp135b = _mm256_sub_ps(_tmp03, _tmp04);
2163
2164 // float tmp024b = tmp0[3] + tmp0[4];
2165 // float tmp135b = tmp0[3] - tmp0[4];
2166
2167 __m256 _tmp024c = _mm256_add_ps(_tmp05, _tmp06);
2168 __m256 _tmp135c = _mm256_sub_ps(_tmp05, _tmp06);
2169
2170 // float tmp024c = tmp0[5] + tmp0[6];
2171 // float tmp135c = tmp0[5] - tmp0[6];
2172
2173 __m256 _out00 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp00, _tmp024a), _mm256_fmadd_1_ps(_tmp024b, _tmp024c, 32.f)));
2174 __m256 _out02 = _mm256_add_ps(_bias0, _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp024a, _tmp024b, 4.f), _tmp024c, 8.f));
2175 __m256 _out04 = _mm256_add_ps(_bias0, _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp024a, _tmp024b, 16.f), _tmp024c, 2.f));
2176 _mm256_storeu_ps(output0, _out00);
2177 _mm256_storeu_ps(output0 + 16, _out02);
2178 _mm256_storeu_ps(output0 + 32, _out04);
2179
2180 // output0[0] = bias0 + tmp0[0] + tmp024a + tmp024b + tmp024c * 32;
2181 // output0[2] = bias0 + tmp024a + tmp024b * 4 + tmp024c * 8;
2182 // output0[4] = bias0 + tmp024a + tmp024b * 16 + tmp024c + tmp024c;
2183
2184 __m256 _out01 = _mm256_add_ps(_bias0, _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp135a, _tmp135b, 2.f), _tmp135c, 16.f));
2185 __m256 _out03 = _mm256_add_ps(_bias0, _mm256_fmadd_1_ps(_mm256_fmadd_1_ps(_tmp135a, _tmp135b, 8.f), _tmp135c, 4.f));
2186 __m256 _out05 = _mm256_add_ps(_bias0, _mm256_add_ps(_mm256_add_ps(_tmp07, _tmp135a), _mm256_fmadd_1_ps(_tmp135c, _tmp135b, 32.f)));
2187 _mm256_storeu_ps(output0 + 8, _out01);
2188 _mm256_storeu_ps(output0 + 24, _out03);
2189 _mm256_storeu_ps(output0 + 40, _out05);
2190
2191 // output0[1] = bias0 + tmp135a + tmp135b + tmp135b + tmp135c * 16;
2192 // output0[3] = bias0 + tmp135a + tmp135b * 8 + tmp135c * 4;
2193 // output0[5] = bias0 + tmp0[7] + tmp135a + tmp135b * 32 + tmp135c;
2194
2195 output0 += outw * 8;
2196 }
2197 }
2198 }
2199 }
2200 }
2201 // END transform output
2202
2203 // cut result pad
2204 copy_cut_border(top_blob_bordered, top_blob, 0, top_blob_bordered.h - top_blob.h, 0, top_blob_bordered.w - top_blob.w, opt);
2205 }
2206