1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
im2col_sgemm_pack4_neon(const Mat & bottom_im2col,Mat & top_blob,const Mat & kernel,const Mat & _bias,const Option & opt)15 static void im2col_sgemm_pack4_neon(const Mat& bottom_im2col, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
16 {
17 // Mat bottom_im2col(size, maxk, inch, 16u, 4, opt.workspace_allocator);
18
19 const int size = bottom_im2col.w;
20 const int maxk = bottom_im2col.h;
21 const int inch = bottom_im2col.c;
22
23 const int outch = top_blob.c;
24
25 const float* bias = _bias;
26
27 // permute
28 Mat tmp;
29 #if __aarch64__
30 if (size >= 12)
31 tmp.create(12 * maxk, inch, size / 12 + (size % 12) / 8 + (size % 12 % 8) / 4 + (size % 12 % 4) / 2 + size % 12 % 2, 16u, 4, opt.workspace_allocator);
32 else if (size >= 8)
33 tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 16u, 4, opt.workspace_allocator);
34 else if (size >= 4)
35 tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 16u, 4, opt.workspace_allocator);
36 else if (size >= 2)
37 tmp.create(2 * maxk, inch, size / 2 + size % 2, 16u, 4, opt.workspace_allocator);
38 else
39 tmp.create(maxk, inch, size, 16u, 4, opt.workspace_allocator);
40 #else
41 if (size >= 8)
42 tmp.create(8 * maxk, inch, size / 8 + (size % 8) / 4 + (size % 4) / 2 + size % 2, 16u, 4, opt.workspace_allocator);
43 else if (size >= 4)
44 tmp.create(4 * maxk, inch, size / 4 + (size % 4) / 2 + size % 2, 16u, 4, opt.workspace_allocator);
45 else if (size >= 2)
46 tmp.create(2 * maxk, inch, size / 2 + size % 2, 16u, 4, opt.workspace_allocator);
47 else
48 tmp.create(maxk, inch, size, 16u, 4, opt.workspace_allocator);
49 #endif
50 {
51 #if __aarch64__
52 int nn_size = size / 12;
53 int remain_size_start = 0;
54
55 #pragma omp parallel for num_threads(opt.num_threads)
56 for (int ii = 0; ii < nn_size; ii++)
57 {
58 int i = remain_size_start + ii * 12;
59
60 float* tmpptr = tmp.channel(i / 12);
61
62 for (int q = 0; q < inch; q++)
63 {
64 const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4;
65
66 for (int k = 0; k < maxk; k++)
67 {
68 asm volatile(
69 "prfm pldl1keep, [%0, #512] \n"
70 "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
71 "prfm pldl1keep, [%0, #512] \n"
72 "ld4 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0], #64 \n"
73 "prfm pldl1keep, [%0, #512] \n"
74 "ld4 {v8.4s, v9.4s, v10.4s, v11.4s}, [%0] \n"
75 "st1 {v0.4s}, [%1], #16 \n"
76 "st1 {v4.4s}, [%1], #16 \n"
77 "st1 {v8.4s}, [%1], #16 \n"
78 "sub %0, %0, #128 \n"
79 "st1 {v1.4s}, [%1], #16 \n"
80 "st1 {v5.4s}, [%1], #16 \n"
81 "st1 {v9.4s}, [%1], #16 \n"
82 "st1 {v2.4s}, [%1], #16 \n"
83 "st1 {v6.4s}, [%1], #16 \n"
84 "st1 {v10.4s}, [%1], #16 \n"
85 "st1 {v3.4s}, [%1], #16 \n"
86 "st1 {v7.4s}, [%1], #16 \n"
87 "st1 {v11.4s}, [%1], #16 \n"
88 : "=r"(img0), // %0
89 "=r"(tmpptr) // %1
90 : "0"(img0),
91 "1"(tmpptr)
92 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11");
93 img0 += size * 4;
94 }
95 }
96 }
97
98 remain_size_start += nn_size * 12;
99 nn_size = (size - remain_size_start) >> 3;
100 #else
101 int nn_size = size >> 3;
102 int remain_size_start = 0;
103 #endif
104
105 #pragma omp parallel for num_threads(opt.num_threads)
106 for (int ii = 0; ii < nn_size; ii++)
107 {
108 int i = remain_size_start + ii * 8;
109
110 #if __aarch64__
111 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
112 #else
113 float* tmpptr = tmp.channel(i / 8);
114 #endif
115
116 for (int q = 0; q < inch; q++)
117 {
118 const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4;
119
120 for (int k = 0; k < maxk; k++)
121 {
122 #if __aarch64__
123 asm volatile(
124 "prfm pldl1keep, [%0, #512] \n"
125 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
126 "prfm pldl1keep, [%0, #512] \n"
127 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%0] \n"
128 "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n"
129 "sub %0, %0, #64 \n"
130 "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%1], #64 \n"
131 : "=r"(img0), // %0
132 "=r"(tmpptr) // %1
133 : "0"(img0),
134 "1"(tmpptr)
135 : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
136 #else
137 asm volatile(
138 "pld [%0, #512] \n"
139 "vldm %0!, {d0-d7} \n"
140 "pld [%0, #512] \n"
141 "vldm %0, {d16-d23} \n"
142
143 // transpose 8x4
144 "vtrn.32 q0, q1 \n"
145 "vtrn.32 q2, q3 \n"
146 "vtrn.32 q8, q9 \n"
147 "vtrn.32 q10, q11 \n"
148 "vswp d1, d4 \n"
149 "vswp d3, d6 \n"
150 "vswp d17, d20 \n"
151 "vswp d19, d22 \n"
152 "vswp q1, q8 \n"
153 "vswp q3, q10 \n"
154
155 "vst1.f32 {d0-d3}, [%1 :128]! \n"
156 "vst1.f32 {d16-d19}, [%1 :128]! \n"
157 "sub %0, %0, #64 \n"
158 "vst1.f32 {d4-d7}, [%1 :128]! \n"
159 "vst1.f32 {d20-d23}, [%1 :128]! \n"
160 : "=r"(img0), // %0
161 "=r"(tmpptr) // %1
162 : "0"(img0),
163 "1"(tmpptr)
164 : "memory", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
165 #endif // __aarch64__
166 img0 += size * 4;
167 }
168 }
169 }
170
171 remain_size_start += nn_size << 3;
172 nn_size = (size - remain_size_start) >> 2;
173
174 #pragma omp parallel for num_threads(opt.num_threads)
175 for (int ii = 0; ii < nn_size; ii++)
176 {
177 int i = remain_size_start + ii * 4;
178
179 #if __aarch64__
180 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
181 #else
182 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
183 #endif
184
185 for (int q = 0; q < inch; q++)
186 {
187 const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4;
188
189 for (int k = 0; k < maxk; k++)
190 {
191 #if __aarch64__
192 asm volatile(
193 "prfm pldl1keep, [%0, #512] \n"
194 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0] \n"
195 "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n"
196 : "=r"(img0), // %0
197 "=r"(tmpptr) // %1
198 : "0"(img0),
199 "1"(tmpptr)
200 : "memory", "v0", "v1", "v2", "v3");
201 #else
202 asm volatile(
203 "pld [%0, #512] \n"
204 "vldm %0, {d0-d7} \n"
205 "vstm %1!, {d0-d7} \n"
206 : "=r"(img0), // %0
207 "=r"(tmpptr) // %1
208 : "0"(img0),
209 "1"(tmpptr)
210 : "memory", "q0", "q1", "q2", "q3");
211 #endif // __aarch64__
212 img0 += size * 4;
213 }
214 }
215 }
216
217 remain_size_start += nn_size << 2;
218 nn_size = (size - remain_size_start) >> 1;
219
220 #pragma omp parallel for num_threads(opt.num_threads)
221 for (int ii = 0; ii < nn_size; ii++)
222 {
223 int i = remain_size_start + ii * 2;
224
225 #if __aarch64__
226 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
227 #else
228 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2);
229 #endif
230
231 for (int q = 0; q < inch; q++)
232 {
233 const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4;
234
235 for (int k = 0; k < maxk; k++)
236 {
237 #if __aarch64__
238 asm volatile(
239 "prfm pldl1keep, [%0, #256] \n"
240 "ld1 {v0.4s, v1.4s}, [%0] \n"
241 "st1 {v0.4s, v1.4s}, [%1], #32 \n"
242 : "=r"(img0), // %0
243 "=r"(tmpptr) // %1
244 : "0"(img0),
245 "1"(tmpptr)
246 : "memory", "v0", "v1");
247 #else
248 asm volatile(
249 "pld [%0, #256] \n"
250 "vld1.f32 {d0-d3}, [%0 :128] \n"
251 "vst1.f32 {d0-d3}, [%1 :128]! \n"
252 : "=r"(img0), // %0
253 "=r"(tmpptr) // %1
254 : "0"(img0),
255 "1"(tmpptr)
256 : "memory", "q0", "q1");
257 #endif // __aarch64__
258 img0 += size * 4;
259 }
260 }
261 }
262
263 remain_size_start += nn_size << 1;
264
265 #pragma omp parallel for num_threads(opt.num_threads)
266 for (int i = remain_size_start; i < size; i++)
267 {
268 #if __aarch64__
269 float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
270 #else
271 float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2);
272 #endif
273
274 for (int q = 0; q < inch; q++)
275 {
276 const float* img0 = (const float*)bottom_im2col.channel(q) + i * 4;
277
278 for (int k = 0; k < maxk; k++)
279 {
280 #if __aarch64__
281 asm volatile(
282 "prfm pldl1keep, [%0, #128] \n"
283 "ld1 {v0.4s}, [%0] \n"
284 "st1 {v0.4s}, [%1], #16 \n"
285 : "=r"(img0), // %0
286 "=r"(tmpptr) // %1
287 : "0"(img0),
288 "1"(tmpptr)
289 : "memory", "v0");
290 #else
291 asm volatile(
292 "pld [%0, #128] \n"
293 "vld1.f32 {d0-d1}, [%0 :128] \n"
294 "vst1.f32 {d0-d1}, [%1 :128]! \n"
295 : "=r"(img0), // %0
296 "=r"(tmpptr) // %1
297 : "0"(img0),
298 "1"(tmpptr)
299 : "memory", "q0");
300 #endif // __aarch64__
301 img0 += size * 4;
302 }
303 }
304 }
305 }
306
307 int remain_outch_start = 0;
308
309 #if __aarch64__
310 int nn_outch = outch >> 1;
311 remain_outch_start = nn_outch << 1;
312
313 #pragma omp parallel for num_threads(opt.num_threads)
314 for (int pp = 0; pp < nn_outch; pp++)
315 {
316 int p = pp * 2;
317
318 float* outptr0 = top_blob.channel(p);
319 float* outptr1 = top_blob.channel(p + 1);
320
321 const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
322 const float* biasptr = bias ? bias + p * 4 : zeros;
323
324 int i = 0;
325 for (; i + 11 < size; i += 12)
326 {
327 const float* tmpptr = tmp.channel(i / 12);
328 const float* kptr0 = kernel.channel(p / 2);
329
330 int nn = inch * maxk; // inch always > 0
331
332 asm volatile(
333 "ld1 {v0.4s, v1.4s}, [%10] \n"
334 "mov v8.16b, v0.16b \n"
335 "mov v9.16b, v0.16b \n"
336 "mov v10.16b, v0.16b \n"
337 "mov v11.16b, v0.16b \n"
338 "mov v12.16b, v0.16b \n"
339 "mov v13.16b, v0.16b \n"
340 "mov v14.16b, v0.16b \n"
341 "mov v15.16b, v0.16b \n"
342 "mov v16.16b, v0.16b \n"
343 "mov v17.16b, v0.16b \n"
344 "mov v18.16b, v0.16b \n"
345 "mov v19.16b, v0.16b \n"
346 "mov v20.16b, v1.16b \n"
347 "mov v21.16b, v1.16b \n"
348 "mov v22.16b, v1.16b \n"
349 "mov v23.16b, v1.16b \n"
350 "mov v24.16b, v1.16b \n"
351 "mov v25.16b, v1.16b \n"
352 "mov v26.16b, v1.16b \n"
353 "mov v27.16b, v1.16b \n"
354 "mov v28.16b, v1.16b \n"
355 "mov v29.16b, v1.16b \n"
356 "mov v30.16b, v1.16b \n"
357 "mov v31.16b, v1.16b \n"
358
359 "0: \n"
360
361 "prfm pldl1keep, [%3, #512] \n"
362 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%3], #64 \n"
363
364 "prfm pldl1keep, [%4, #512] \n"
365 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%4], #64 \n" // w0011_01
366
367 "fmla v8.4s, v4.4s, v0.s[0] \n"
368 "fmla v9.4s, v4.4s, v0.s[1] \n"
369 "fmla v10.4s, v4.4s, v0.s[2] \n"
370 "fmla v11.4s, v4.4s, v0.s[3] \n"
371 "fmla v12.4s, v4.4s, v1.s[0] \n"
372 "fmla v13.4s, v4.4s, v1.s[1] \n"
373 "fmla v14.4s, v4.4s, v1.s[2] \n"
374 "fmla v15.4s, v4.4s, v1.s[3] \n"
375 "fmla v16.4s, v4.4s, v2.s[0] \n"
376 "fmla v17.4s, v4.4s, v2.s[1] \n"
377 "fmla v18.4s, v4.4s, v2.s[2] \n"
378 "fmla v19.4s, v4.4s, v2.s[3] \n"
379
380 "fmla v20.4s, v5.4s, v0.s[0] \n"
381 "fmla v21.4s, v5.4s, v0.s[1] \n"
382 "fmla v22.4s, v5.4s, v0.s[2] \n"
383 "fmla v23.4s, v5.4s, v0.s[3] \n"
384 "fmla v24.4s, v5.4s, v1.s[0] \n"
385 "fmla v25.4s, v5.4s, v1.s[1] \n"
386 "fmla v26.4s, v5.4s, v1.s[2] \n"
387 "fmla v27.4s, v5.4s, v1.s[3] \n"
388 "fmla v28.4s, v5.4s, v2.s[0] \n"
389 "fmla v29.4s, v5.4s, v2.s[1] \n"
390 "fmla v30.4s, v5.4s, v2.s[2] \n"
391 "fmla v31.4s, v5.4s, v2.s[3] \n"
392
393 "fmla v8.4s, v6.4s, v3.s[0] \n"
394 "fmla v9.4s, v6.4s, v3.s[1] \n"
395 "fmla v10.4s, v6.4s, v3.s[2] \n"
396 "fmla v11.4s, v6.4s, v3.s[3] \n"
397
398 "fmla v20.4s, v7.4s, v3.s[0] \n"
399 "fmla v21.4s, v7.4s, v3.s[1] \n"
400 "fmla v22.4s, v7.4s, v3.s[2] \n"
401 "fmla v23.4s, v7.4s, v3.s[3] \n"
402
403 "prfm pldl1keep, [%3, #512] \n"
404 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%3], #64 \n"
405
406 "fmla v12.4s, v6.4s, v0.s[0] \n"
407 "fmla v13.4s, v6.4s, v0.s[1] \n"
408 "fmla v14.4s, v6.4s, v0.s[2] \n"
409 "fmla v15.4s, v6.4s, v0.s[3] \n"
410 "fmla v16.4s, v6.4s, v1.s[0] \n"
411 "fmla v17.4s, v6.4s, v1.s[1] \n"
412 "fmla v18.4s, v6.4s, v1.s[2] \n"
413 "fmla v19.4s, v6.4s, v1.s[3] \n"
414
415 "fmla v24.4s, v7.4s, v0.s[0] \n"
416 "fmla v25.4s, v7.4s, v0.s[1] \n"
417 "fmla v26.4s, v7.4s, v0.s[2] \n"
418 "fmla v27.4s, v7.4s, v0.s[3] \n"
419 "fmla v28.4s, v7.4s, v1.s[0] \n"
420 "fmla v29.4s, v7.4s, v1.s[1] \n"
421 "fmla v30.4s, v7.4s, v1.s[2] \n"
422 "fmla v31.4s, v7.4s, v1.s[3] \n"
423
424 "prfm pldl1keep, [%4, #512] \n"
425 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%4], #64 \n" // w2233_01
426
427 "fmla v8.4s, v4.4s, v2.s[0] \n"
428 "fmla v9.4s, v4.4s, v2.s[1] \n"
429 "fmla v10.4s, v4.4s, v2.s[2] \n"
430 "fmla v11.4s, v4.4s, v2.s[3] \n"
431 "fmla v12.4s, v4.4s, v3.s[0] \n"
432 "fmla v13.4s, v4.4s, v3.s[1] \n"
433 "fmla v14.4s, v4.4s, v3.s[2] \n"
434 "fmla v15.4s, v4.4s, v3.s[3] \n"
435
436 "fmla v20.4s, v5.4s, v2.s[0] \n"
437 "fmla v21.4s, v5.4s, v2.s[1] \n"
438 "fmla v22.4s, v5.4s, v2.s[2] \n"
439 "fmla v23.4s, v5.4s, v2.s[3] \n"
440 "fmla v24.4s, v5.4s, v3.s[0] \n"
441 "fmla v25.4s, v5.4s, v3.s[1] \n"
442 "fmla v26.4s, v5.4s, v3.s[2] \n"
443 "fmla v27.4s, v5.4s, v3.s[3] \n"
444
445 "prfm pldl1keep, [%3, #512] \n"
446 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%3], #64 \n"
447
448 "fmla v16.4s, v4.4s, v0.s[0] \n"
449 "fmla v17.4s, v4.4s, v0.s[1] \n"
450 "fmla v18.4s, v4.4s, v0.s[2] \n"
451 "fmla v19.4s, v4.4s, v0.s[3] \n"
452
453 "fmla v28.4s, v5.4s, v0.s[0] \n"
454 "fmla v29.4s, v5.4s, v0.s[1] \n"
455 "fmla v30.4s, v5.4s, v0.s[2] \n"
456 "fmla v31.4s, v5.4s, v0.s[3] \n"
457
458 "fmla v8.4s, v6.4s, v1.s[0] \n"
459 "fmla v9.4s, v6.4s, v1.s[1] \n"
460 "fmla v10.4s, v6.4s, v1.s[2] \n"
461 "fmla v11.4s, v6.4s, v1.s[3] \n"
462 "fmla v12.4s, v6.4s, v2.s[0] \n"
463 "fmla v13.4s, v6.4s, v2.s[1] \n"
464 "fmla v14.4s, v6.4s, v2.s[2] \n"
465 "fmla v15.4s, v6.4s, v2.s[3] \n"
466 "fmla v16.4s, v6.4s, v3.s[0] \n"
467 "fmla v17.4s, v6.4s, v3.s[1] \n"
468 "fmla v18.4s, v6.4s, v3.s[2] \n"
469 "fmla v19.4s, v6.4s, v3.s[3] \n"
470
471 "subs %w0, %w0, #1 \n"
472
473 "fmla v20.4s, v7.4s, v1.s[0] \n"
474 "fmla v21.4s, v7.4s, v1.s[1] \n"
475 "fmla v22.4s, v7.4s, v1.s[2] \n"
476 "fmla v23.4s, v7.4s, v1.s[3] \n"
477 "fmla v24.4s, v7.4s, v2.s[0] \n"
478 "fmla v25.4s, v7.4s, v2.s[1] \n"
479 "fmla v26.4s, v7.4s, v2.s[2] \n"
480 "fmla v27.4s, v7.4s, v2.s[3] \n"
481 "fmla v28.4s, v7.4s, v3.s[0] \n"
482 "fmla v29.4s, v7.4s, v3.s[1] \n"
483 "fmla v30.4s, v7.4s, v3.s[2] \n"
484 "fmla v31.4s, v7.4s, v3.s[3] \n"
485
486 "bne 0b \n"
487
488 "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n"
489 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%2], #64 \n"
490 "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n"
491 "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n"
492 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n"
493 "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%2], #64 \n"
494
495 : "=r"(nn), // %0
496 "=r"(outptr0), // %1
497 "=r"(outptr1), // %2
498 "=r"(tmpptr), // %3
499 "=r"(kptr0) // %4
500 : "0"(nn),
501 "1"(outptr0),
502 "2"(outptr1),
503 "3"(tmpptr),
504 "4"(kptr0),
505 "r"(biasptr) // %10
506 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
507 }
508 for (; i + 7 < size; i += 8)
509 {
510 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
511 const float* kptr0 = kernel.channel(p / 2);
512
513 int nn = inch * maxk; // inch always > 0
514
515 asm volatile(
516 "ld1 {v0.4s, v1.4s}, [%10] \n"
517 "mov v16.16b, v0.16b \n"
518 "mov v17.16b, v0.16b \n"
519 "mov v18.16b, v0.16b \n"
520 "mov v19.16b, v0.16b \n"
521 "mov v20.16b, v0.16b \n"
522 "mov v21.16b, v0.16b \n"
523 "mov v22.16b, v0.16b \n"
524 "mov v23.16b, v0.16b \n"
525 "mov v24.16b, v1.16b \n"
526 "mov v25.16b, v1.16b \n"
527 "mov v26.16b, v1.16b \n"
528 "mov v27.16b, v1.16b \n"
529 "mov v28.16b, v1.16b \n"
530 "mov v29.16b, v1.16b \n"
531 "mov v30.16b, v1.16b \n"
532 "mov v31.16b, v1.16b \n"
533
534 "0: \n"
535
536 "prfm pldl1keep, [%3, #512] \n"
537 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%3], #64 \n" // r0 r1 r2 r3
538
539 "prfm pldl1keep, [%4, #512] \n"
540 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%4], #64 \n" // w0011_01
541
542 "fmla v16.4s, v8.4s, v0.s[0] \n"
543 "fmla v17.4s, v8.4s, v1.s[0] \n"
544 "fmla v18.4s, v8.4s, v2.s[0] \n"
545 "fmla v19.4s, v8.4s, v3.s[0] \n"
546
547 "prfm pldl1keep, [%3, #512] \n"
548 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%3], #64 \n" // r4 r5 r6 r7
549
550 "fmla v20.4s, v8.4s, v4.s[0] \n"
551 "fmla v21.4s, v8.4s, v5.s[0] \n"
552 "fmla v22.4s, v8.4s, v6.s[0] \n"
553 "fmla v23.4s, v8.4s, v7.s[0] \n"
554
555 "fmla v24.4s, v9.4s, v0.s[0] \n"
556 "fmla v25.4s, v9.4s, v1.s[0] \n"
557 "fmla v26.4s, v9.4s, v2.s[0] \n"
558 "fmla v27.4s, v9.4s, v3.s[0] \n"
559 "fmla v28.4s, v9.4s, v4.s[0] \n"
560 "fmla v29.4s, v9.4s, v5.s[0] \n"
561 "fmla v30.4s, v9.4s, v6.s[0] \n"
562 "fmla v31.4s, v9.4s, v7.s[0] \n"
563
564 "fmla v16.4s, v10.4s, v0.s[1] \n"
565 "fmla v17.4s, v10.4s, v1.s[1] \n"
566 "fmla v18.4s, v10.4s, v2.s[1] \n"
567 "fmla v19.4s, v10.4s, v3.s[1] \n"
568 "fmla v20.4s, v10.4s, v4.s[1] \n"
569 "fmla v21.4s, v10.4s, v5.s[1] \n"
570 "fmla v22.4s, v10.4s, v6.s[1] \n"
571 "fmla v23.4s, v10.4s, v7.s[1] \n"
572
573 "fmla v24.4s, v11.4s, v0.s[1] \n"
574 "fmla v25.4s, v11.4s, v1.s[1] \n"
575 "fmla v26.4s, v11.4s, v2.s[1] \n"
576 "fmla v27.4s, v11.4s, v3.s[1] \n"
577 "fmla v28.4s, v11.4s, v4.s[1] \n"
578 "fmla v29.4s, v11.4s, v5.s[1] \n"
579 "fmla v30.4s, v11.4s, v6.s[1] \n"
580 "fmla v31.4s, v11.4s, v7.s[1] \n"
581
582 "prfm pldl1keep, [%4, #512] \n"
583 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%4], #64 \n" // w2233_01
584
585 "fmla v16.4s, v12.4s, v0.s[2] \n"
586 "fmla v17.4s, v12.4s, v1.s[2] \n"
587 "fmla v18.4s, v12.4s, v2.s[2] \n"
588 "fmla v19.4s, v12.4s, v3.s[2] \n"
589 "fmla v20.4s, v12.4s, v4.s[2] \n"
590 "fmla v21.4s, v12.4s, v5.s[2] \n"
591 "fmla v22.4s, v12.4s, v6.s[2] \n"
592 "fmla v23.4s, v12.4s, v7.s[2] \n"
593
594 "fmla v24.4s, v13.4s, v0.s[2] \n"
595 "fmla v25.4s, v13.4s, v1.s[2] \n"
596 "fmla v26.4s, v13.4s, v2.s[2] \n"
597 "fmla v27.4s, v13.4s, v3.s[2] \n"
598 "fmla v28.4s, v13.4s, v4.s[2] \n"
599 "fmla v29.4s, v13.4s, v5.s[2] \n"
600 "fmla v30.4s, v13.4s, v6.s[2] \n"
601 "fmla v31.4s, v13.4s, v7.s[2] \n"
602
603 "fmla v16.4s, v14.4s, v0.s[3] \n"
604 "fmla v17.4s, v14.4s, v1.s[3] \n"
605 "fmla v18.4s, v14.4s, v2.s[3] \n"
606 "fmla v19.4s, v14.4s, v3.s[3] \n"
607 "fmla v20.4s, v14.4s, v4.s[3] \n"
608 "fmla v21.4s, v14.4s, v5.s[3] \n"
609 "fmla v22.4s, v14.4s, v6.s[3] \n"
610 "fmla v23.4s, v14.4s, v7.s[3] \n"
611
612 "subs %w0, %w0, #1 \n"
613
614 "fmla v24.4s, v15.4s, v0.s[3] \n"
615 "fmla v25.4s, v15.4s, v1.s[3] \n"
616 "fmla v26.4s, v15.4s, v2.s[3] \n"
617 "fmla v27.4s, v15.4s, v3.s[3] \n"
618 "fmla v28.4s, v15.4s, v4.s[3] \n"
619 "fmla v29.4s, v15.4s, v5.s[3] \n"
620 "fmla v30.4s, v15.4s, v6.s[3] \n"
621 "fmla v31.4s, v15.4s, v7.s[3] \n"
622
623 "bne 0b \n"
624
625 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n"
626 "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n"
627 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%1], #64 \n"
628 "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%2], #64 \n"
629
630 : "=r"(nn), // %0
631 "=r"(outptr0), // %1
632 "=r"(outptr1), // %2
633 "=r"(tmpptr), // %3
634 "=r"(kptr0) // %4
635 : "0"(nn),
636 "1"(outptr0),
637 "2"(outptr1),
638 "3"(tmpptr),
639 "4"(kptr0),
640 "r"(biasptr) // %10
641 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
642 }
643 for (; i + 3 < size; i += 4)
644 {
645 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
646 const float* kptr0 = kernel.channel(p / 2);
647
648 int nn = inch * maxk; // inch always > 0
649
650 asm volatile(
651 "ld1 {v0.4s, v1.4s}, [%10] \n"
652 "mov v16.16b, v0.16b \n"
653 "mov v17.16b, v0.16b \n"
654 "mov v18.16b, v0.16b \n"
655 "mov v19.16b, v0.16b \n"
656 "mov v20.16b, v1.16b \n"
657 "mov v21.16b, v1.16b \n"
658 "mov v22.16b, v1.16b \n"
659 "mov v23.16b, v1.16b \n"
660
661 "0: \n"
662
663 "prfm pldl1keep, [%3, #512] \n"
664 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%3], #64 \n" // r0 r1 r2 r3
665
666 "prfm pldl1keep, [%4, #512] \n"
667 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%4], #64 \n" // w0011_01
668
669 "fmla v16.4s, v8.4s, v0.s[0] \n"
670 "fmla v17.4s, v8.4s, v1.s[0] \n"
671 "fmla v18.4s, v8.4s, v2.s[0] \n"
672 "fmla v19.4s, v8.4s, v3.s[0] \n"
673
674 "fmla v20.4s, v9.4s, v0.s[0] \n"
675 "fmla v21.4s, v9.4s, v1.s[0] \n"
676 "fmla v22.4s, v9.4s, v2.s[0] \n"
677 "fmla v23.4s, v9.4s, v3.s[0] \n"
678
679 "prfm pldl1keep, [%4, #512] \n"
680 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%4], #64 \n" // w2233_01
681
682 "fmla v16.4s, v10.4s, v0.s[1] \n"
683 "fmla v17.4s, v10.4s, v1.s[1] \n"
684 "fmla v18.4s, v10.4s, v2.s[1] \n"
685 "fmla v19.4s, v10.4s, v3.s[1] \n"
686
687 "fmla v20.4s, v11.4s, v0.s[1] \n"
688 "fmla v21.4s, v11.4s, v1.s[1] \n"
689 "fmla v22.4s, v11.4s, v2.s[1] \n"
690 "fmla v23.4s, v11.4s, v3.s[1] \n"
691
692 "fmla v16.4s, v12.4s, v0.s[2] \n"
693 "fmla v17.4s, v12.4s, v1.s[2] \n"
694 "fmla v18.4s, v12.4s, v2.s[2] \n"
695 "fmla v19.4s, v12.4s, v3.s[2] \n"
696
697 "fmla v20.4s, v13.4s, v0.s[2] \n"
698 "fmla v21.4s, v13.4s, v1.s[2] \n"
699 "fmla v22.4s, v13.4s, v2.s[2] \n"
700 "fmla v23.4s, v13.4s, v3.s[2] \n"
701
702 "subs %w0, %w0, #1 \n"
703
704 "fmla v16.4s, v14.4s, v0.s[3] \n"
705 "fmla v17.4s, v14.4s, v1.s[3] \n"
706 "fmla v18.4s, v14.4s, v2.s[3] \n"
707 "fmla v19.4s, v14.4s, v3.s[3] \n"
708
709 "fmla v20.4s, v15.4s, v0.s[3] \n"
710 "fmla v21.4s, v15.4s, v1.s[3] \n"
711 "fmla v22.4s, v15.4s, v2.s[3] \n"
712 "fmla v23.4s, v15.4s, v3.s[3] \n"
713
714 "bne 0b \n"
715
716 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n"
717 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%2], #64 \n"
718
719 : "=r"(nn), // %0
720 "=r"(outptr0), // %1
721 "=r"(outptr1), // %2
722 "=r"(tmpptr), // %3
723 "=r"(kptr0) // %4
724 : "0"(nn),
725 "1"(outptr0),
726 "2"(outptr1),
727 "3"(tmpptr),
728 "4"(kptr0),
729 "r"(biasptr) // %10
730 : "cc", "memory", "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
731 }
732 for (; i + 1 < size; i += 2)
733 {
734 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
735 const float* kptr0 = kernel.channel(p / 2);
736
737 int nn = inch * maxk; // inch always > 0
738
739 asm volatile(
740 "ld1 {v0.4s, v1.4s}, [%10] \n"
741 "mov v16.16b, v0.16b \n"
742 "mov v17.16b, v0.16b \n"
743 "mov v18.16b, v1.16b \n"
744 "mov v19.16b, v1.16b \n"
745
746 "0: \n"
747
748 "prfm pldl1keep, [%3, #256] \n"
749 "ld1 {v0.4s, v1.4s}, [%3], #32 \n" // r0 r1
750
751 "prfm pldl1keep, [%4, #512] \n"
752 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%4], #64 \n" // w0011_01
753
754 "fmla v16.4s, v8.4s, v0.s[0] \n"
755 "fmla v17.4s, v8.4s, v1.s[0] \n"
756 "fmla v18.4s, v9.4s, v0.s[0] \n"
757 "fmla v19.4s, v9.4s, v1.s[0] \n"
758
759 "prfm pldl1keep, [%4, #512] \n"
760 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%4], #64 \n" // w2233_01
761
762 "fmla v16.4s, v10.4s, v0.s[1] \n"
763 "fmla v17.4s, v10.4s, v1.s[1] \n"
764 "fmla v18.4s, v11.4s, v0.s[1] \n"
765 "fmla v19.4s, v11.4s, v1.s[1] \n"
766
767 "fmla v16.4s, v12.4s, v0.s[2] \n"
768 "fmla v17.4s, v12.4s, v1.s[2] \n"
769 "fmla v18.4s, v13.4s, v0.s[2] \n"
770 "fmla v19.4s, v13.4s, v1.s[2] \n"
771
772 "subs %w0, %w0, #1 \n"
773
774 "fmla v16.4s, v14.4s, v0.s[3] \n"
775 "fmla v17.4s, v14.4s, v1.s[3] \n"
776 "fmla v18.4s, v15.4s, v0.s[3] \n"
777 "fmla v19.4s, v15.4s, v1.s[3] \n"
778
779 "bne 0b \n"
780
781 "st1 {v16.4s, v17.4s}, [%1], #32 \n"
782 "st1 {v18.4s, v19.4s}, [%2], #32 \n"
783
784 : "=r"(nn), // %0
785 "=r"(outptr0), // %1
786 "=r"(outptr1), // %2
787 "=r"(tmpptr), // %3
788 "=r"(kptr0) // %4
789 : "0"(nn),
790 "1"(outptr0),
791 "2"(outptr1),
792 "3"(tmpptr),
793 "4"(kptr0),
794 "r"(biasptr) // %10
795 : "cc", "memory", "v0", "v1", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
796 }
797 for (; i < size; i++)
798 {
799 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
800 const float* kptr0 = kernel.channel(p / 2);
801
802 int nn = inch * maxk; // inch always > 0
803
804 asm volatile(
805 "ld1 {v16.4s, v17.4s}, [%10] \n"
806
807 "0: \n"
808
809 "prfm pldl1keep, [%3, #128] \n"
810 "ld1 {v0.4s}, [%3], #16 \n" // r0
811
812 "prfm pldl1keep, [%4, #512] \n"
813 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%4], #64 \n" // w0011_01
814
815 "fmla v16.4s, v8.4s, v0.s[0] \n"
816 "fmla v17.4s, v9.4s, v0.s[0] \n"
817
818 "prfm pldl1keep, [%4, #512] \n"
819 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%4], #64 \n" // w2233_01
820
821 "fmla v16.4s, v10.4s, v0.s[1] \n"
822 "fmla v17.4s, v11.4s, v0.s[1] \n"
823
824 "fmla v16.4s, v12.4s, v0.s[2] \n"
825 "fmla v17.4s, v13.4s, v0.s[2] \n"
826
827 "subs %w0, %w0, #1 \n"
828
829 "fmla v16.4s, v14.4s, v0.s[3] \n"
830 "fmla v17.4s, v15.4s, v0.s[3] \n"
831
832 "bne 0b \n"
833
834 "st1 {v16.4s}, [%1], #16 \n"
835 "st1 {v17.4s}, [%2], #16 \n"
836
837 : "=r"(nn), // %0
838 "=r"(outptr0), // %1
839 "=r"(outptr1), // %2
840 "=r"(tmpptr), // %3
841 "=r"(kptr0) // %4
842 : "0"(nn),
843 "1"(outptr0),
844 "2"(outptr1),
845 "3"(tmpptr),
846 "4"(kptr0),
847 "r"(biasptr) // %10
848 : "cc", "memory", "v0", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17");
849 }
850 }
851 #endif // __aarch64__
852
853 #pragma omp parallel for num_threads(opt.num_threads)
854 for (int p = remain_outch_start; p < outch; p++)
855 {
856 float* outptr0 = top_blob.channel(p);
857
858 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
859 const float* biasptr = bias ? bias + p * 4 : zeros;
860
861 int i = 0;
862 #if __aarch64__
863 for (; i + 11 < size; i += 12)
864 {
865 const float* tmpptr = tmp.channel(i / 12);
866 const float* kptr0 = kernel.channel(p / 2 + p % 2);
867
868 int nn = inch * maxk; // inch always > 0
869
870 asm volatile(
871 "ld1 {v0.4s}, [%8] \n"
872 "mov v8.16b, v0.16b \n"
873 "mov v9.16b, v0.16b \n"
874 "mov v10.16b, v0.16b \n"
875 "mov v11.16b, v0.16b \n"
876 "mov v12.16b, v0.16b \n"
877 "mov v13.16b, v0.16b \n"
878 "mov v14.16b, v0.16b \n"
879 "mov v15.16b, v0.16b \n"
880 "mov v16.16b, v0.16b \n"
881 "mov v17.16b, v0.16b \n"
882 "mov v18.16b, v0.16b \n"
883 "mov v19.16b, v0.16b \n"
884
885 "0: \n"
886
887 "prfm pldl1keep, [%2, #512] \n"
888 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%2], #64 \n"
889
890 "prfm pldl1keep, [%3, #512] \n"
891 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%3], #64 \n" // w0123_0
892
893 "fmla v8.4s, v4.4s, v0.s[0] \n"
894 "fmla v9.4s, v4.4s, v0.s[1] \n"
895 "fmla v10.4s, v4.4s, v0.s[2] \n"
896 "fmla v11.4s, v4.4s, v0.s[3] \n"
897 "fmla v12.4s, v4.4s, v1.s[0] \n"
898 "fmla v13.4s, v4.4s, v1.s[1] \n"
899 "fmla v14.4s, v4.4s, v1.s[2] \n"
900 "fmla v15.4s, v4.4s, v1.s[3] \n"
901 "fmla v16.4s, v4.4s, v2.s[0] \n"
902 "fmla v17.4s, v4.4s, v2.s[1] \n"
903 "fmla v18.4s, v4.4s, v2.s[2] \n"
904 "fmla v19.4s, v4.4s, v2.s[3] \n"
905
906 "prfm pldl1keep, [%2, #512] \n"
907 "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%2], #64 \n"
908
909 "fmla v8.4s, v5.4s, v3.s[0] \n"
910 "fmla v9.4s, v5.4s, v3.s[1] \n"
911 "fmla v10.4s, v5.4s, v3.s[2] \n"
912 "fmla v11.4s, v5.4s, v3.s[3] \n"
913 "fmla v12.4s, v5.4s, v20.s[0] \n"
914 "fmla v13.4s, v5.4s, v20.s[1] \n"
915 "fmla v14.4s, v5.4s, v20.s[2] \n"
916 "fmla v15.4s, v5.4s, v20.s[3] \n"
917 "fmla v16.4s, v5.4s, v21.s[0] \n"
918 "fmla v17.4s, v5.4s, v21.s[1] \n"
919 "fmla v18.4s, v5.4s, v21.s[2] \n"
920 "fmla v19.4s, v5.4s, v21.s[3] \n"
921
922 "prfm pldl1keep, [%2, #512] \n"
923 "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%2], #64 \n"
924
925 "fmla v8.4s, v6.4s, v22.s[0] \n"
926 "fmla v9.4s, v6.4s, v22.s[1] \n"
927 "fmla v10.4s, v6.4s, v22.s[2] \n"
928 "fmla v11.4s, v6.4s, v22.s[3] \n"
929 "fmla v12.4s, v6.4s, v23.s[0] \n"
930 "fmla v13.4s, v6.4s, v23.s[1] \n"
931 "fmla v14.4s, v6.4s, v23.s[2] \n"
932 "fmla v15.4s, v6.4s, v23.s[3] \n"
933 "fmla v16.4s, v6.4s, v24.s[0] \n"
934 "fmla v17.4s, v6.4s, v24.s[1] \n"
935 "fmla v18.4s, v6.4s, v24.s[2] \n"
936 "fmla v19.4s, v6.4s, v24.s[3] \n"
937
938 "subs %w0, %w0, #1 \n"
939
940 "fmla v8.4s, v7.4s, v25.s[0] \n"
941 "fmla v9.4s, v7.4s, v25.s[1] \n"
942 "fmla v10.4s, v7.4s, v25.s[2] \n"
943 "fmla v11.4s, v7.4s, v25.s[3] \n"
944 "fmla v12.4s, v7.4s, v26.s[0] \n"
945 "fmla v13.4s, v7.4s, v26.s[1] \n"
946 "fmla v14.4s, v7.4s, v26.s[2] \n"
947 "fmla v15.4s, v7.4s, v26.s[3] \n"
948 "fmla v16.4s, v7.4s, v27.s[0] \n"
949 "fmla v17.4s, v7.4s, v27.s[1] \n"
950 "fmla v18.4s, v7.4s, v27.s[2] \n"
951 "fmla v19.4s, v7.4s, v27.s[3] \n"
952
953 "bne 0b \n"
954
955 "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n"
956 "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n"
957 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n"
958
959 : "=r"(nn), // %0
960 "=r"(outptr0), // %1
961 "=r"(tmpptr), // %2
962 "=r"(kptr0) // %3
963 : "0"(nn),
964 "1"(outptr0),
965 "2"(tmpptr),
966 "3"(kptr0),
967 "r"(biasptr) // %8
968 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
969 }
970 #endif // __aarch64__
971 for (; i + 7 < size; i += 8)
972 {
973 #if __aarch64__
974 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8);
975 const float* kptr0 = kernel.channel(p / 2 + p % 2);
976 #else
977 const float* tmpptr = tmp.channel(i / 8);
978 const float* kptr0 = kernel.channel(p);
979 #endif
980
981 int nn = inch * maxk; // inch always > 0
982
983 #if __aarch64__
984 asm volatile(
985 "ld1 {v0.4s}, [%8] \n"
986 "mov v16.16b, v0.16b \n"
987 "mov v17.16b, v0.16b \n"
988 "mov v18.16b, v0.16b \n"
989 "mov v19.16b, v0.16b \n"
990 "mov v20.16b, v0.16b \n"
991 "mov v21.16b, v0.16b \n"
992 "mov v22.16b, v0.16b \n"
993 "mov v23.16b, v0.16b \n"
994
995 "0: \n"
996
997 "prfm pldl1keep, [%2, #512] \n"
998 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%2], #64 \n" // r0 r1 r2 r3
999
1000 "prfm pldl1keep, [%3, #512] \n"
1001 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%3], #64 \n" // w0123
1002
1003 "fmla v16.4s, v8.4s, v0.s[0] \n"
1004 "fmla v17.4s, v8.4s, v1.s[0] \n"
1005 "fmla v18.4s, v8.4s, v2.s[0] \n"
1006 "fmla v19.4s, v8.4s, v3.s[0] \n"
1007
1008 "prfm pldl1keep, [%2, #512] \n"
1009 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%2], #64 \n" // r4 r5 r6 r7
1010
1011 "fmla v20.4s, v8.4s, v4.s[0] \n"
1012 "fmla v21.4s, v8.4s, v5.s[0] \n"
1013 "fmla v22.4s, v8.4s, v6.s[0] \n"
1014 "fmla v23.4s, v8.4s, v7.s[0] \n"
1015
1016 "fmla v16.4s, v9.4s, v0.s[1] \n"
1017 "fmla v17.4s, v9.4s, v1.s[1] \n"
1018 "fmla v18.4s, v9.4s, v2.s[1] \n"
1019 "fmla v19.4s, v9.4s, v3.s[1] \n"
1020 "fmla v20.4s, v9.4s, v4.s[1] \n"
1021 "fmla v21.4s, v9.4s, v5.s[1] \n"
1022 "fmla v22.4s, v9.4s, v6.s[1] \n"
1023 "fmla v23.4s, v9.4s, v7.s[1] \n"
1024
1025 "fmla v16.4s, v10.4s, v0.s[2] \n"
1026 "fmla v17.4s, v10.4s, v1.s[2] \n"
1027 "fmla v18.4s, v10.4s, v2.s[2] \n"
1028 "fmla v19.4s, v10.4s, v3.s[2] \n"
1029 "fmla v20.4s, v10.4s, v4.s[2] \n"
1030 "fmla v21.4s, v10.4s, v5.s[2] \n"
1031 "fmla v22.4s, v10.4s, v6.s[2] \n"
1032 "fmla v23.4s, v10.4s, v7.s[2] \n"
1033
1034 "subs %w0, %w0, #1 \n"
1035
1036 "fmla v16.4s, v11.4s, v0.s[3] \n"
1037 "fmla v17.4s, v11.4s, v1.s[3] \n"
1038 "fmla v18.4s, v11.4s, v2.s[3] \n"
1039 "fmla v19.4s, v11.4s, v3.s[3] \n"
1040 "fmla v20.4s, v11.4s, v4.s[3] \n"
1041 "fmla v21.4s, v11.4s, v5.s[3] \n"
1042 "fmla v22.4s, v11.4s, v6.s[3] \n"
1043 "fmla v23.4s, v11.4s, v7.s[3] \n"
1044
1045 "bne 0b \n"
1046
1047 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n"
1048 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%1], #64 \n"
1049
1050 : "=r"(nn), // %0
1051 "=r"(outptr0), // %1
1052 "=r"(tmpptr), // %2
1053 "=r"(kptr0) // %3
1054 : "0"(nn),
1055 "1"(outptr0),
1056 "2"(tmpptr),
1057 "3"(kptr0),
1058 "r"(biasptr) // %8
1059 : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
1060 #else
1061 asm volatile(
1062 "vld1.f32 {d0-d1}, [%8] \n"
1063 "vmov q8, q0 \n"
1064 "vmov q9, q0 \n"
1065 "vmov q10, q0 \n"
1066 "vmov q11, q0 \n"
1067 "vmov q12, q0 \n"
1068 "vmov q13, q0 \n"
1069 "vmov q14, q0 \n"
1070 "vmov q15, q0 \n"
1071
1072 "0: \n"
1073
1074 "pld [%2, #512] \n"
1075 "vldm %2!, {d0-d7} \n"
1076
1077 "pld [%3, #512] \n"
1078 "vldm %3!, {d8-d15} \n"
1079
1080 "vmla.f32 q8, q4, d0[0] \n"
1081 "vmla.f32 q9, q4, d0[1] \n"
1082 "vmla.f32 q10, q4, d1[0] \n"
1083 "vmla.f32 q11, q4, d1[1] \n"
1084 "vmla.f32 q12, q4, d2[0] \n"
1085 "vmla.f32 q13, q4, d2[1] \n"
1086 "vmla.f32 q14, q4, d3[0] \n"
1087 "vmla.f32 q15, q4, d3[1] \n"
1088
1089 "vmla.f32 q8, q5, d4[0] \n"
1090 "vmla.f32 q9, q5, d4[1] \n"
1091 "vmla.f32 q10, q5, d5[0] \n"
1092 "vmla.f32 q11, q5, d5[1] \n"
1093 "vmla.f32 q12, q5, d6[0] \n"
1094 "vmla.f32 q13, q5, d6[1] \n"
1095 "vmla.f32 q14, q5, d7[0] \n"
1096 "vmla.f32 q15, q5, d7[1] \n"
1097
1098 "pld [%2, #512] \n"
1099 "vldm %2!, {d0-d7} \n"
1100
1101 "vmla.f32 q8, q6, d0[0] \n"
1102 "vmla.f32 q9, q6, d0[1] \n"
1103 "vmla.f32 q10, q6, d1[0] \n"
1104 "vmla.f32 q11, q6, d1[1] \n"
1105 "vmla.f32 q12, q6, d2[0] \n"
1106 "vmla.f32 q13, q6, d2[1] \n"
1107 "vmla.f32 q14, q6, d3[0] \n"
1108 "vmla.f32 q15, q6, d3[1] \n"
1109
1110 "subs %0, %0, #1 \n"
1111
1112 "vmla.f32 q8, q7, d4[0] \n"
1113 "vmla.f32 q9, q7, d4[1] \n"
1114 "vmla.f32 q10, q7, d5[0] \n"
1115 "vmla.f32 q11, q7, d5[1] \n"
1116 "vmla.f32 q12, q7, d6[0] \n"
1117 "vmla.f32 q13, q7, d6[1] \n"
1118 "vmla.f32 q14, q7, d7[0] \n"
1119 "vmla.f32 q15, q7, d7[1] \n"
1120
1121 "bne 0b \n"
1122
1123 "vstm %1!, {d16-d23} \n"
1124 "vstm %1!, {d24-d31} \n"
1125
1126 : "=r"(nn), // %0
1127 "=r"(outptr0), // %1
1128 "=r"(tmpptr), // %2
1129 "=r"(kptr0) // %3
1130 : "0"(nn),
1131 "1"(outptr0),
1132 "2"(tmpptr),
1133 "3"(kptr0),
1134 "r"(biasptr) // %8
1135 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1136 #endif
1137 }
1138 for (; i + 3 < size; i += 4)
1139 {
1140 #if __aarch64__
1141 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4);
1142 const float* kptr0 = kernel.channel(p / 2 + p % 2);
1143 #else
1144 const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4);
1145 const float* kptr0 = kernel.channel(p);
1146 #endif
1147
1148 int nn = inch * maxk; // inch always > 0
1149
1150 #if __aarch64__
1151 asm volatile(
1152 "ld1 {v0.4s}, [%8] \n"
1153 "mov v16.16b, v0.16b \n"
1154 "mov v17.16b, v0.16b \n"
1155 "mov v18.16b, v0.16b \n"
1156 "mov v19.16b, v0.16b \n"
1157
1158 "0: \n"
1159
1160 "prfm pldl1keep, [%2, #512] \n"
1161 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%2], #64 \n" // r0 r1 r2 r3
1162
1163 "prfm pldl1keep, [%3, #512] \n"
1164 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%3], #64 \n" // w0123
1165
1166 "fmla v16.4s, v8.4s, v0.s[0] \n"
1167 "fmla v17.4s, v8.4s, v1.s[0] \n"
1168 "fmla v18.4s, v8.4s, v2.s[0] \n"
1169 "fmla v19.4s, v8.4s, v3.s[0] \n"
1170
1171 "fmla v16.4s, v9.4s, v0.s[1] \n"
1172 "fmla v17.4s, v9.4s, v1.s[1] \n"
1173 "fmla v18.4s, v9.4s, v2.s[1] \n"
1174 "fmla v19.4s, v9.4s, v3.s[1] \n"
1175
1176 "fmla v16.4s, v10.4s, v0.s[2] \n"
1177 "fmla v17.4s, v10.4s, v1.s[2] \n"
1178 "fmla v18.4s, v10.4s, v2.s[2] \n"
1179 "fmla v19.4s, v10.4s, v3.s[2] \n"
1180
1181 "subs %w0, %w0, #1 \n"
1182
1183 "fmla v16.4s, v11.4s, v0.s[3] \n"
1184 "fmla v17.4s, v11.4s, v1.s[3] \n"
1185 "fmla v18.4s, v11.4s, v2.s[3] \n"
1186 "fmla v19.4s, v11.4s, v3.s[3] \n"
1187
1188 "bne 0b \n"
1189
1190 "st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%1], #64 \n"
1191
1192 : "=r"(nn), // %0
1193 "=r"(outptr0), // %1
1194 "=r"(tmpptr), // %2
1195 "=r"(kptr0) // %3
1196 : "0"(nn),
1197 "1"(outptr0),
1198 "2"(tmpptr),
1199 "3"(kptr0),
1200 "r"(biasptr) // %8
1201 : "cc", "memory", "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19");
1202 #else
1203 asm volatile(
1204 "vld1.f32 {d0-d1}, [%8] \n"
1205 "vmov q8, q0 \n"
1206 "vmov q9, q0 \n"
1207 "vmov q10, q0 \n"
1208 "vmov q11, q0 \n"
1209
1210 "0: \n"
1211
1212 "pld [%2, #512] \n"
1213 "vldm %2!, {d0-d7} \n"
1214
1215 "pld [%3, #512] \n"
1216 "vldm %3!, {d8-d15} \n"
1217
1218 "vmla.f32 q8, q4, d0[0] \n"
1219 "vmla.f32 q9, q4, d2[0] \n"
1220 "vmla.f32 q10, q4, d4[0] \n"
1221 "vmla.f32 q11, q4, d6[0] \n"
1222
1223 "vmla.f32 q8, q5, d0[1] \n"
1224 "vmla.f32 q9, q5, d2[1] \n"
1225 "vmla.f32 q10, q5, d4[1] \n"
1226 "vmla.f32 q11, q5, d6[1] \n"
1227
1228 "vmla.f32 q8, q6, d1[0] \n"
1229 "vmla.f32 q9, q6, d3[0] \n"
1230 "vmla.f32 q10, q6, d5[0] \n"
1231 "vmla.f32 q11, q6, d7[0] \n"
1232
1233 "subs %0, %0, #1 \n"
1234
1235 "vmla.f32 q8, q7, d1[1] \n"
1236 "vmla.f32 q9, q7, d3[1] \n"
1237 "vmla.f32 q10, q7, d5[1] \n"
1238 "vmla.f32 q11, q7, d7[1] \n"
1239
1240 "bne 0b \n"
1241
1242 "vstm %1!, {d16-d23} \n"
1243
1244 : "=r"(nn), // %0
1245 "=r"(outptr0), // %1
1246 "=r"(tmpptr), // %2
1247 "=r"(kptr0) // %3
1248 : "0"(nn),
1249 "1"(outptr0),
1250 "2"(tmpptr),
1251 "3"(kptr0),
1252 "r"(biasptr) // %8
1253 : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
1254 #endif
1255 }
1256 for (; i + 1 < size; i += 2)
1257 {
1258 #if __aarch64__
1259 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2);
1260 const float* kptr0 = kernel.channel(p / 2 + p % 2);
1261 #else
1262 const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2);
1263 const float* kptr0 = kernel.channel(p);
1264 #endif
1265
1266 int nn = inch * maxk; // inch always > 0
1267
1268 #if __aarch64__
1269 asm volatile(
1270 "ld1 {v0.4s}, [%8] \n"
1271 "mov v16.16b, v0.16b \n"
1272 "mov v17.16b, v0.16b \n"
1273
1274 "0: \n"
1275
1276 "prfm pldl1keep, [%2, #256] \n"
1277 "ld1 {v0.4s, v1.4s}, [%2], #32 \n" // r0 r1
1278
1279 "prfm pldl1keep, [%3, #512] \n"
1280 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%3], #64 \n" // w0123
1281
1282 "fmla v16.4s, v8.4s, v0.s[0] \n"
1283 "fmla v17.4s, v8.4s, v1.s[0] \n"
1284
1285 "fmla v16.4s, v9.4s, v0.s[1] \n"
1286 "fmla v17.4s, v9.4s, v1.s[1] \n"
1287
1288 "fmla v16.4s, v10.4s, v0.s[2] \n"
1289 "fmla v17.4s, v10.4s, v1.s[2] \n"
1290
1291 "subs %w0, %w0, #1 \n"
1292
1293 "fmla v16.4s, v11.4s, v0.s[3] \n"
1294 "fmla v17.4s, v11.4s, v1.s[3] \n"
1295
1296 "bne 0b \n"
1297
1298 "st1 {v16.4s, v17.4s}, [%1], #32 \n"
1299
1300 : "=r"(nn), // %0
1301 "=r"(outptr0), // %1
1302 "=r"(tmpptr), // %2
1303 "=r"(kptr0) // %3
1304 : "0"(nn),
1305 "1"(outptr0),
1306 "2"(tmpptr),
1307 "3"(kptr0),
1308 "r"(biasptr) // %8
1309 : "cc", "memory", "v0", "v1", "v8", "v9", "v10", "v11", "v16", "v17");
1310 #else
1311 asm volatile(
1312 "vld1.f32 {d0-d1}, [%8] \n"
1313 "vmov q8, q0 \n"
1314 "vmov q9, q0 \n"
1315
1316 "0: \n"
1317
1318 "pld [%2, #256] \n"
1319 "vld1.f32 {d0-d3}, [%2 :128]! \n"
1320
1321 "pld [%3, #512] \n"
1322 "vldm %3!, {d8-d15} \n"
1323
1324 "vmla.f32 q8, q4, d0[0] \n"
1325 "vmla.f32 q9, q4, d2[0] \n"
1326
1327 "vmla.f32 q8, q5, d0[1] \n"
1328 "vmla.f32 q9, q5, d2[1] \n"
1329
1330 "vmla.f32 q8, q6, d1[0] \n"
1331 "vmla.f32 q9, q6, d3[0] \n"
1332
1333 "subs %0, %0, #1 \n"
1334
1335 "vmla.f32 q8, q7, d1[1] \n"
1336 "vmla.f32 q9, q7, d3[1] \n"
1337
1338 "bne 0b \n"
1339
1340 "vst1.f32 {d16-d19}, [%1 :128]! \n"
1341
1342 : "=r"(nn), // %0
1343 "=r"(outptr0), // %1
1344 "=r"(tmpptr), // %2
1345 "=r"(kptr0) // %3
1346 : "0"(nn),
1347 "1"(outptr0),
1348 "2"(tmpptr),
1349 "3"(kptr0),
1350 "r"(biasptr) // %8
1351 : "cc", "memory", "q0", "q1", "q4", "q5", "q6", "q7", "q8", "q9");
1352 #endif
1353 }
1354 for (; i < size; i++)
1355 {
1356 #if __aarch64__
1357 const float* tmpptr = tmp.channel(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2);
1358 const float* kptr0 = kernel.channel(p / 2 + p % 2);
1359 #else
1360 const float* tmpptr = tmp.channel(i / 8 + (i % 8) / 4 + (i % 4) / 2 + i % 2);
1361 const float* kptr0 = kernel.channel(p);
1362 #endif
1363
1364 int nn = inch * maxk; // inch always > 0
1365
1366 #if __aarch64__
1367 asm volatile(
1368 "ld1 {v16.4s}, [%8] \n"
1369
1370 "0: \n"
1371
1372 "prfm pldl1keep, [%2, #128] \n"
1373 "ld1 {v0.4s}, [%2], #16 \n" // r0
1374
1375 "prfm pldl1keep, [%3, #512] \n"
1376 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%3], #64 \n" // w0123
1377
1378 "fmla v16.4s, v8.4s, v0.s[0] \n"
1379 "fmla v16.4s, v9.4s, v0.s[1] \n"
1380
1381 "subs %w0, %w0, #1 \n"
1382
1383 "fmla v16.4s, v10.4s, v0.s[2] \n"
1384 "fmla v16.4s, v11.4s, v0.s[3] \n"
1385
1386 "bne 0b \n"
1387
1388 "st1 {v16.4s}, [%1], #16 \n"
1389
1390 : "=r"(nn), // %0
1391 "=r"(outptr0), // %1
1392 "=r"(tmpptr), // %2
1393 "=r"(kptr0) // %3
1394 : "0"(nn),
1395 "1"(outptr0),
1396 "2"(tmpptr),
1397 "3"(kptr0),
1398 "r"(biasptr) // %8
1399 : "cc", "memory", "v0", "v8", "v9", "v10", "v11", "v16");
1400 #else
1401 asm volatile(
1402 "vld1.f32 {d16-d17}, [%8] \n"
1403
1404 "0: \n"
1405
1406 "pld [%2, #128] \n"
1407 "vld1.f32 {d0-d1}, [%2 :128]! \n"
1408
1409 "pld [%3, #512] \n"
1410 "vldm %3!, {d8-d15} \n"
1411
1412 "vmla.f32 q8, q4, d0[0] \n"
1413 "vmla.f32 q8, q5, d0[1] \n"
1414
1415 "subs %0, %0, #1 \n"
1416
1417 "vmla.f32 q8, q6, d1[0] \n"
1418 "vmla.f32 q8, q7, d1[1] \n"
1419
1420 "bne 0b \n"
1421
1422 "vst1.f32 {d16-d17}, [%1 :128]! \n"
1423
1424 : "=r"(nn), // %0
1425 "=r"(outptr0), // %1
1426 "=r"(tmpptr), // %2
1427 "=r"(kptr0) // %3
1428 : "0"(nn),
1429 "1"(outptr0),
1430 "2"(tmpptr),
1431 "3"(kptr0),
1432 "r"(biasptr) // %8
1433 : "cc", "memory", "q0", "q4", "q5", "q6", "q7", "q8");
1434 #endif
1435 }
1436 }
1437 }
1438
convolution_im2col_sgemm_transform_kernel_pack4_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_w,int kernel_h)1439 static void convolution_im2col_sgemm_transform_kernel_pack4_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_w, int kernel_h)
1440 {
1441 const int maxk = kernel_w * kernel_h;
1442
1443 // interleave
1444 // src = maxk-inch-outch
1445 // dst = 4b-4a-maxk-inch/4a-outch/4b
1446 Mat kernel = _kernel.reshape(maxk, inch, outch);
1447 #if __aarch64__
1448 kernel_tm.create(32 * maxk, inch / 4, outch / 8 + (outch % 8) / 4);
1449 #else
1450 kernel_tm.create(16 * maxk, inch / 4, outch / 4);
1451 #endif
1452
1453 int q = 0;
1454 #if __aarch64__
1455 for (; q + 7 < outch; q += 8)
1456 {
1457 const Mat k0 = kernel.channel(q);
1458 const Mat k1 = kernel.channel(q + 1);
1459 const Mat k2 = kernel.channel(q + 2);
1460 const Mat k3 = kernel.channel(q + 3);
1461 const Mat k4 = kernel.channel(q + 4);
1462 const Mat k5 = kernel.channel(q + 5);
1463 const Mat k6 = kernel.channel(q + 6);
1464 const Mat k7 = kernel.channel(q + 7);
1465
1466 float* g00 = kernel_tm.channel(q / 8);
1467
1468 for (int p = 0; p + 3 < inch; p += 4)
1469 {
1470 const float* k00 = k0.row(p);
1471 const float* k01 = k0.row(p + 1);
1472 const float* k02 = k0.row(p + 2);
1473 const float* k03 = k0.row(p + 3);
1474
1475 const float* k10 = k1.row(p);
1476 const float* k11 = k1.row(p + 1);
1477 const float* k12 = k1.row(p + 2);
1478 const float* k13 = k1.row(p + 3);
1479
1480 const float* k20 = k2.row(p);
1481 const float* k21 = k2.row(p + 1);
1482 const float* k22 = k2.row(p + 2);
1483 const float* k23 = k2.row(p + 3);
1484
1485 const float* k30 = k3.row(p);
1486 const float* k31 = k3.row(p + 1);
1487 const float* k32 = k3.row(p + 2);
1488 const float* k33 = k3.row(p + 3);
1489
1490 const float* k40 = k4.row(p);
1491 const float* k41 = k4.row(p + 1);
1492 const float* k42 = k4.row(p + 2);
1493 const float* k43 = k4.row(p + 3);
1494
1495 const float* k50 = k5.row(p);
1496 const float* k51 = k5.row(p + 1);
1497 const float* k52 = k5.row(p + 2);
1498 const float* k53 = k5.row(p + 3);
1499
1500 const float* k60 = k6.row(p);
1501 const float* k61 = k6.row(p + 1);
1502 const float* k62 = k6.row(p + 2);
1503 const float* k63 = k6.row(p + 3);
1504
1505 const float* k70 = k7.row(p);
1506 const float* k71 = k7.row(p + 1);
1507 const float* k72 = k7.row(p + 2);
1508 const float* k73 = k7.row(p + 3);
1509
1510 for (int k = 0; k < maxk; k++)
1511 {
1512 g00[0] = k00[k];
1513 g00[1] = k10[k];
1514 g00[2] = k20[k];
1515 g00[3] = k30[k];
1516 g00[4] = k40[k];
1517 g00[5] = k50[k];
1518 g00[6] = k60[k];
1519 g00[7] = k70[k];
1520
1521 g00[8] = k01[k];
1522 g00[9] = k11[k];
1523 g00[10] = k21[k];
1524 g00[11] = k31[k];
1525 g00[12] = k41[k];
1526 g00[13] = k51[k];
1527 g00[14] = k61[k];
1528 g00[15] = k71[k];
1529
1530 g00[16] = k02[k];
1531 g00[17] = k12[k];
1532 g00[18] = k22[k];
1533 g00[19] = k32[k];
1534 g00[20] = k42[k];
1535 g00[21] = k52[k];
1536 g00[22] = k62[k];
1537 g00[23] = k72[k];
1538
1539 g00[24] = k03[k];
1540 g00[25] = k13[k];
1541 g00[26] = k23[k];
1542 g00[27] = k33[k];
1543 g00[28] = k43[k];
1544 g00[29] = k53[k];
1545 g00[30] = k63[k];
1546 g00[31] = k73[k];
1547
1548 g00 += 32;
1549 }
1550 }
1551 }
1552 #endif // __aarch64__
1553 for (; q + 3 < outch; q += 4)
1554 {
1555 const Mat k0 = kernel.channel(q);
1556 const Mat k1 = kernel.channel(q + 1);
1557 const Mat k2 = kernel.channel(q + 2);
1558 const Mat k3 = kernel.channel(q + 3);
1559
1560 #if __aarch64__
1561 float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4);
1562 #else
1563 float* g00 = kernel_tm.channel(q / 4);
1564 #endif
1565
1566 for (int p = 0; p + 3 < inch; p += 4)
1567 {
1568 const float* k00 = k0.row(p);
1569 const float* k01 = k0.row(p + 1);
1570 const float* k02 = k0.row(p + 2);
1571 const float* k03 = k0.row(p + 3);
1572
1573 const float* k10 = k1.row(p);
1574 const float* k11 = k1.row(p + 1);
1575 const float* k12 = k1.row(p + 2);
1576 const float* k13 = k1.row(p + 3);
1577
1578 const float* k20 = k2.row(p);
1579 const float* k21 = k2.row(p + 1);
1580 const float* k22 = k2.row(p + 2);
1581 const float* k23 = k2.row(p + 3);
1582
1583 const float* k30 = k3.row(p);
1584 const float* k31 = k3.row(p + 1);
1585 const float* k32 = k3.row(p + 2);
1586 const float* k33 = k3.row(p + 3);
1587
1588 for (int k = 0; k < maxk; k++)
1589 {
1590 g00[0] = k00[k];
1591 g00[1] = k10[k];
1592 g00[2] = k20[k];
1593 g00[3] = k30[k];
1594
1595 g00[4] = k01[k];
1596 g00[5] = k11[k];
1597 g00[6] = k21[k];
1598 g00[7] = k31[k];
1599
1600 g00[8] = k02[k];
1601 g00[9] = k12[k];
1602 g00[10] = k22[k];
1603 g00[11] = k32[k];
1604
1605 g00[12] = k03[k];
1606 g00[13] = k13[k];
1607 g00[14] = k23[k];
1608 g00[15] = k33[k];
1609
1610 g00 += 16;
1611 }
1612 }
1613 }
1614 }
1615
convolution_im2col_sgemm_pack4_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel,const Mat & _bias,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,const Option & opt)1616 static void convolution_im2col_sgemm_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, const Option& opt)
1617 {
1618 int w = bottom_blob.w;
1619 int inch = bottom_blob.c;
1620
1621 int outw = top_blob.w;
1622 int outh = top_blob.h;
1623 const int size = outw * outh;
1624
1625 const int maxk = kernel_w * kernel_h;
1626
1627 // im2col
1628 Mat bottom_im2col(size, maxk, inch, 16u, 4, opt.workspace_allocator);
1629 {
1630 const int gap = (w * stride_h - outw * stride_w) * 4;
1631
1632 #pragma omp parallel for num_threads(opt.num_threads)
1633 for (int p = 0; p < inch; p++)
1634 {
1635 const Mat img = bottom_blob.channel(p);
1636 float* ptr = bottom_im2col.channel(p);
1637
1638 for (int u = 0; u < kernel_h; u++)
1639 {
1640 for (int v = 0; v < kernel_w; v++)
1641 {
1642 const float* sptr = img.row<const float>(dilation_h * u) + dilation_w * v * 4;
1643
1644 for (int i = 0; i < outh; i++)
1645 {
1646 int j = 0;
1647 for (; j + 3 < outw; j += 4)
1648 {
1649 float32x4_t _val0 = vld1q_f32(sptr);
1650 float32x4_t _val1 = vld1q_f32(sptr + stride_w * 4);
1651 float32x4_t _val2 = vld1q_f32(sptr + stride_w * 8);
1652 float32x4_t _val3 = vld1q_f32(sptr + stride_w * 12);
1653 vst1q_f32(ptr, _val0);
1654 vst1q_f32(ptr + 4, _val1);
1655 vst1q_f32(ptr + 8, _val2);
1656 vst1q_f32(ptr + 12, _val3);
1657
1658 sptr += stride_w * 16;
1659 ptr += 16;
1660 }
1661 for (; j + 1 < outw; j += 2)
1662 {
1663 float32x4_t _val0 = vld1q_f32(sptr);
1664 float32x4_t _val1 = vld1q_f32(sptr + stride_w * 4);
1665 vst1q_f32(ptr, _val0);
1666 vst1q_f32(ptr + 4, _val1);
1667
1668 sptr += stride_w * 8;
1669 ptr += 8;
1670 }
1671 for (; j < outw; j++)
1672 {
1673 float32x4_t _val = vld1q_f32(sptr);
1674 vst1q_f32(ptr, _val);
1675
1676 sptr += stride_w * 4;
1677 ptr += 4;
1678 }
1679
1680 sptr += gap;
1681 }
1682 }
1683 }
1684 }
1685 }
1686
1687 im2col_sgemm_pack4_neon(bottom_im2col, top_blob, kernel, _bias, opt);
1688 }
1689