1 // BUG1989 is pleased to support the open source community by supporting ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
conv_im2col_sgemm_transform_kernel_neon(const Mat & _kernel,Mat & kernel_tm,int inch,int outch,int kernel_size)15 static void conv_im2col_sgemm_transform_kernel_neon(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size)
16 {
17 const float* kernel = _kernel;
18
19 #if __ARM_NEON && __aarch64__
20 // kernel memory packed 8 x 8
21 kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4);
22 #else
23 // kernel memory packed 4 x 8
24 kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4);
25 #endif
26
27 int nn_outch = 0;
28 int remain_outch_start = 0;
29
30 #if __ARM_NEON && __aarch64__
31 nn_outch = outch >> 3;
32 remain_outch_start = nn_outch << 3;
33
34 for (int pp = 0; pp < nn_outch; pp++)
35 {
36 int p = pp * 8;
37
38 const float* k0 = kernel + (p + 0) * inch * kernel_size;
39 const float* k1 = kernel + (p + 1) * inch * kernel_size;
40 const float* k2 = kernel + (p + 2) * inch * kernel_size;
41 const float* k3 = kernel + (p + 3) * inch * kernel_size;
42 const float* k4 = kernel + (p + 4) * inch * kernel_size;
43 const float* k5 = kernel + (p + 5) * inch * kernel_size;
44 const float* k6 = kernel + (p + 6) * inch * kernel_size;
45 const float* k7 = kernel + (p + 7) * inch * kernel_size;
46
47 float* ktmp = kernel_tm.channel(p / 8);
48
49 for (int q = 0; q < inch * kernel_size; q++)
50 {
51 ktmp[0] = k0[0];
52 ktmp[1] = k1[0];
53 ktmp[2] = k2[0];
54 ktmp[3] = k3[0];
55 ktmp[4] = k4[0];
56 ktmp[5] = k5[0];
57 ktmp[6] = k6[0];
58 ktmp[7] = k7[0];
59 ktmp += 8;
60
61 k0 += 1;
62 k1 += 1;
63 k2 += 1;
64 k3 += 1;
65 k4 += 1;
66 k5 += 1;
67 k6 += 1;
68 k7 += 1;
69 }
70 }
71 #endif
72
73 nn_outch = (outch - remain_outch_start) >> 2;
74
75 for (int pp = 0; pp < nn_outch; pp++)
76 {
77 int p = remain_outch_start + pp * 4;
78
79 const float* k0 = kernel + (p + 0) * inch * kernel_size;
80 const float* k1 = kernel + (p + 1) * inch * kernel_size;
81 const float* k2 = kernel + (p + 2) * inch * kernel_size;
82 const float* k3 = kernel + (p + 3) * inch * kernel_size;
83
84 #if __ARM_NEON && __aarch64__
85 float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4);
86 #else
87 float* ktmp = kernel_tm.channel(p / 4);
88 #endif // __ARM_NEON && __aarch64__
89
90 for (int q = 0; q < inch * kernel_size; q++)
91 {
92 ktmp[0] = k0[0];
93 ktmp[1] = k1[0];
94 ktmp[2] = k2[0];
95 ktmp[3] = k3[0];
96 ktmp += 4;
97
98 k0 += 1;
99 k1 += 1;
100 k2 += 1;
101 k3 += 1;
102 }
103 }
104
105 remain_outch_start += nn_outch << 2;
106
107 for (int p = remain_outch_start; p < outch; p++)
108 {
109 const float* k0 = kernel + (p + 0) * inch * kernel_size;
110
111 #if __ARM_NEON && __aarch64__
112 float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4);
113 #else
114 float* ktmp = kernel_tm.channel(p / 4 + p % 4);
115 #endif // __ARM_NEON && __aarch64__
116
117 for (int q = 0; q < inch * kernel_size; q++)
118 {
119 ktmp[0] = k0[0];
120 ktmp++;
121 k0++;
122 }
123 }
124 }
125
conv_im2col_sgemm_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & kernel_tm,const Mat & _bias,const int kernel_w,const int kernel_h,const int stride_w,const int stride_h,const Option & opt)126 static void conv_im2col_sgemm_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias,
127 const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt)
128 {
129 int w = bottom_blob.w;
130 int inch = bottom_blob.c;
131 size_t elemsize = bottom_blob.elemsize;
132
133 int outw = top_blob.w;
134 int outh = top_blob.h;
135 int outch = top_blob.c;
136
137 const float* bias = _bias;
138
139 // im2col
140 Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator);
141 {
142 const int stride = kernel_h * kernel_w * outw * outh;
143 float* ret = (float*)bottom_im2col;
144
145 #pragma omp parallel for num_threads(opt.num_threads)
146 for (int p = 0; p < inch; p++)
147 {
148 const float* input = bottom_blob.channel(p);
149 int retID = stride * p;
150 for (int u = 0; u < kernel_h; u++)
151 {
152 for (int v = 0; v < kernel_w; v++)
153 {
154 for (int i = 0; i < outh; i++)
155 {
156 for (int j = 0; j < outw; j++)
157 {
158 int row = u + i * stride_h;
159 int col = v + j * stride_w;
160 int index = row * w + col;
161 ret[retID] = input[index];
162 retID++;
163 }
164 }
165 }
166 }
167 }
168 }
169
170 int kernel_size = kernel_w * kernel_h;
171 int out_size = outw * outh;
172
173 // bottom_im2col memory packed 8 x 8
174 Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, elemsize, opt.workspace_allocator);
175 {
176 int nn_size = out_size >> 3;
177 int remain_size_start = nn_size << 3;
178
179 #pragma omp parallel for num_threads(opt.num_threads)
180 for (int ii = 0; ii < nn_size; ii++)
181 {
182 int i = ii * 8;
183
184 const float* img0 = bottom_im2col.channel(0);
185 img0 += i;
186
187 float* tmpptr = bottom_tm.channel(i / 8);
188
189 for (int q = 0; q < inch * kernel_size; q++)
190 {
191 #if __ARM_NEON
192 #if __aarch64__
193 asm volatile(
194 "prfm pldl1keep, [%0, #256] \n"
195 "ld1 {v0.4s, v1.4s}, [%0] \n"
196 "st1 {v0.4s, v1.4s}, [%1] \n"
197 : "=r"(img0), // %0
198 "=r"(tmpptr) // %1
199 : "0"(img0),
200 "1"(tmpptr)
201 : "cc", "memory", "v0", "v1");
202 #else
203 asm volatile(
204 "pld [%0, #256] \n"
205 "vld1.f32 {d0-d3}, [%0] \n"
206 "vst1.f32 {d0-d3}, [%1] \n"
207 : "=r"(img0), // %0
208 "=r"(tmpptr) // %1
209 : "0"(img0),
210 "1"(tmpptr)
211 : "memory", "q0", "q1");
212 #endif // __aarch64__
213 #else
214 tmpptr[0] = img0[0];
215 tmpptr[1] = img0[1];
216 tmpptr[2] = img0[2];
217 tmpptr[3] = img0[3];
218 tmpptr[4] = img0[4];
219 tmpptr[5] = img0[5];
220 tmpptr[6] = img0[6];
221 tmpptr[7] = img0[7];
222 #endif // __ARM_NEON
223 tmpptr += 8;
224 img0 += out_size;
225 }
226 }
227
228 #pragma omp parallel for num_threads(opt.num_threads)
229 for (int i = remain_size_start; i < out_size; i++)
230 {
231 const float* img0 = bottom_im2col.channel(0);
232 img0 += i;
233
234 float* tmpptr = bottom_tm.channel(i / 8 + i % 8);
235
236 for (int q = 0; q < inch * kernel_size; q++)
237 {
238 tmpptr[0] = img0[0];
239
240 tmpptr += 1;
241 img0 += out_size;
242 }
243 }
244 }
245
246 // sgemm(int M, int N, int L, float* A, float* B, float* C)
247 {
248 //int M = outch; // outch
249 int N = outw * outh; // outsize or out stride
250 int L = kernel_w * kernel_h * inch; // ksize * inch
251
252 int nn_outch = 0;
253 int remain_outch_start = 0;
254
255 #if __aarch64__
256 nn_outch = outch >> 3;
257 remain_outch_start = nn_outch << 3;
258
259 #pragma omp parallel for num_threads(opt.num_threads)
260 for (int pp = 0; pp < nn_outch; pp++)
261 {
262 int i = pp * 8;
263
264 float* output0 = top_blob.channel(i);
265 float* output1 = top_blob.channel(i + 1);
266 float* output2 = top_blob.channel(i + 2);
267 float* output3 = top_blob.channel(i + 3);
268 float* output4 = top_blob.channel(i + 4);
269 float* output5 = top_blob.channel(i + 5);
270 float* output6 = top_blob.channel(i + 6);
271 float* output7 = top_blob.channel(i + 7);
272
273 const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
274 const float* biasptr = bias ? bias + i : zeros;
275
276 int j = 0;
277 for (; j + 7 < N; j = j + 8)
278 {
279 const float* vb = bottom_tm.channel(j / 8);
280 const float* va = kernel_tm.channel(i / 8);
281 #if __ARM_NEON
282 asm volatile(
283 "ld1 {v0.4s, v1.4s}, [%21] \n"
284 "dup v16.4s, v0.s[0] \n" // sum0
285 "dup v17.4s, v0.s[0] \n"
286 "dup v18.4s, v0.s[1] \n" // sum1
287 "dup v19.4s, v0.s[1] \n"
288 "dup v20.4s, v0.s[2] \n" // sum2
289 "dup v21.4s, v0.s[2] \n"
290 "dup v22.4s, v0.s[3] \n" // sum3
291 "dup v23.4s, v0.s[3] \n"
292 "dup v24.4s, v1.s[0] \n" // sum4
293 "dup v25.4s, v1.s[0] \n"
294 "dup v26.4s, v1.s[1] \n" // sum5
295 "dup v27.4s, v1.s[1] \n"
296 "dup v28.4s, v1.s[2] \n" // sum6
297 "dup v29.4s, v1.s[2] \n"
298 "dup v30.4s, v1.s[3] \n" // sum7
299 "dup v31.4s, v1.s[3] \n"
300
301 "lsr w4, %w20, #2 \n" // r4 = nn = L >> 2
302 "cmp w4, #0 \n"
303 "beq 1f \n"
304
305 "0: \n" // for (; k+3<L; k=k+4)
306
307 "prfm pldl1keep, [%9, #512] \n"
308 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n" // kernel
309 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64 \n"
310
311 "prfm pldl1keep, [%8, #512] \n"
312 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%8], #64 \n" // data
313 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%8], #64 \n"
314 // k0
315 "fmla v16.4s, v8.4s, v0.s[0] \n" // sum0 += (a00-a70) * k00
316 "fmla v17.4s, v9.4s, v0.s[0] \n" //
317 "fmla v18.4s, v8.4s, v0.s[1] \n" // sum1 += (a00-a70) * k10
318 "fmla v19.4s, v9.4s, v0.s[1] \n" //
319 "fmla v20.4s, v8.4s, v0.s[2] \n" // sum2 += (a00-a70) * k20
320 "fmla v21.4s, v9.4s, v0.s[2] \n" //
321 "fmla v22.4s, v8.4s, v0.s[3] \n" // sum3 += (a00-a70) * k30
322 "fmla v23.4s, v9.4s, v0.s[3] \n" //
323 "fmla v24.4s, v8.4s, v1.s[0] \n" // sum4 += (a00-a70) * k40
324 "fmla v25.4s, v9.4s, v1.s[0] \n" //
325 "fmla v26.4s, v8.4s, v1.s[1] \n" // sum5 += (a00-a70) * k50
326 "fmla v27.4s, v9.4s, v1.s[1] \n" //
327 "fmla v28.4s, v8.4s, v1.s[2] \n" // sum6 += (a00-a70) * k60
328 "fmla v29.4s, v9.4s, v1.s[2] \n" //
329 "fmla v30.4s, v8.4s, v1.s[3] \n" // sum7 += (a00-a70) * k70
330 "fmla v31.4s, v9.4s, v1.s[3] \n" //
331 // k1
332 "fmla v16.4s, v10.4s, v2.s[0] \n" // sum0 += (a01-a71) * k01
333 "fmla v17.4s, v11.4s, v2.s[0] \n" //
334 "fmla v18.4s, v10.4s, v2.s[1] \n" // sum1 += (a01-a71) * k11
335 "fmla v19.4s, v11.4s, v2.s[1] \n" //
336 "fmla v20.4s, v10.4s, v2.s[2] \n" // sum2 += (a01-a71) * k21
337 "fmla v21.4s, v11.4s, v2.s[2] \n" //
338 "fmla v22.4s, v10.4s, v2.s[3] \n" // sum3 += (a01-a71) * k31
339 "fmla v23.4s, v11.4s, v2.s[3] \n" //
340 "fmla v24.4s, v10.4s, v3.s[0] \n" // sum4 += (a01-a71) * k41
341 "fmla v25.4s, v11.4s, v3.s[0] \n" //
342 "fmla v26.4s, v10.4s, v3.s[1] \n" // sum5 += (a01-a71) * k51
343 "fmla v27.4s, v11.4s, v3.s[1] \n" //
344 "fmla v28.4s, v10.4s, v3.s[2] \n" // sum6 += (a01-a71) * k61
345 "fmla v29.4s, v11.4s, v3.s[2] \n" //
346 "fmla v30.4s, v10.4s, v3.s[3] \n" // sum7 += (a01-a71) * k71
347 "fmla v31.4s, v11.4s, v3.s[3] \n" //
348 // k2
349 "fmla v16.4s, v12.4s, v4.s[0] \n" // sum0 += (a02-a72) * k02
350 "fmla v17.4s, v13.4s, v4.s[0] \n" //
351 "fmla v18.4s, v12.4s, v4.s[1] \n" // sum1 += (a02-a72) * k12
352 "fmla v19.4s, v13.4s, v4.s[1] \n" //
353 "fmla v20.4s, v12.4s, v4.s[2] \n" // sum2 += (a02-a72) * k22
354 "fmla v21.4s, v13.4s, v4.s[2] \n" //
355 "fmla v22.4s, v12.4s, v4.s[3] \n" // sum3 += (a02-a72) * k32
356 "fmla v23.4s, v13.4s, v4.s[3] \n" //
357 "fmla v24.4s, v12.4s, v5.s[0] \n" // sum4 += (a02-a72) * k42
358 "fmla v25.4s, v13.4s, v5.s[0] \n" //
359 "fmla v26.4s, v12.4s, v5.s[1] \n" // sum5 += (a02-a72) * k52
360 "fmla v27.4s, v13.4s, v5.s[1] \n" //
361 "fmla v28.4s, v12.4s, v5.s[2] \n" // sum6 += (a02-a72) * k62
362 "fmla v29.4s, v13.4s, v5.s[2] \n" //
363 "fmla v30.4s, v12.4s, v5.s[3] \n" // sum7 += (a02-a72) * k72
364 "fmla v31.4s, v13.4s, v5.s[3] \n" //
365 // k3
366 "fmla v16.4s, v14.4s, v6.s[0] \n" // sum0 += (a03-a73) * k03
367 "fmla v17.4s, v15.4s, v6.s[0] \n" //
368 "fmla v18.4s, v14.4s, v6.s[1] \n" // sum1 += (a03-a73) * k13
369 "fmla v19.4s, v15.4s, v6.s[1] \n" //
370 "fmla v20.4s, v14.4s, v6.s[2] \n" // sum2 += (a03-a73) * k23
371 "fmla v21.4s, v15.4s, v6.s[2] \n" //
372 "fmla v22.4s, v14.4s, v6.s[3] \n" // sum3 += (a03-a73) * k33
373 "fmla v23.4s, v15.4s, v6.s[3] \n" //
374 "fmla v24.4s, v14.4s, v7.s[0] \n" // sum4 += (a03-a73) * k43
375 "fmla v25.4s, v15.4s, v7.s[0] \n" //
376 "fmla v26.4s, v14.4s, v7.s[1] \n" // sum5 += (a03-a73) * k53
377 "fmla v27.4s, v15.4s, v7.s[1] \n" //
378 "fmla v28.4s, v14.4s, v7.s[2] \n" // sum6 += (a03-a73) * k63
379 "fmla v29.4s, v15.4s, v7.s[2] \n" //
380 "fmla v30.4s, v14.4s, v7.s[3] \n" // sum7 += (a03-a73) * k73
381 "fmla v31.4s, v15.4s, v7.s[3] \n" //
382
383 "subs w4, w4, #1 \n"
384 "bne 0b \n"
385
386 "1: \n"
387
388 // remain loop
389 "and w4, %w20, #3 \n" // w4 = remain = inch & 3;
390 "cmp w4, #0 \n"
391 "beq 3f \n"
392
393 "2: \n"
394
395 "prfm pldl1keep, [%9, #256] \n"
396 "ld1 {v0.4s, v1.4s}, [%9], #32 \n"
397
398 "prfm pldl1keep, [%8, #256] \n"
399 "ld1 {v8.4s, v9.4s}, [%8], #32 \n"
400
401 // k0
402 "fmla v16.4s, v8.4s, v0.s[0] \n" // sum0 += (a00-a70) * k00
403 "fmla v17.4s, v9.4s, v0.s[0] \n" //
404 "fmla v18.4s, v8.4s, v0.s[1] \n" // sum1 += (a00-a70) * k10
405 "fmla v19.4s, v9.4s, v0.s[1] \n" //
406 "fmla v20.4s, v8.4s, v0.s[2] \n" // sum2 += (a00-a70) * k20
407 "fmla v21.4s, v9.4s, v0.s[2] \n" //
408 "fmla v22.4s, v8.4s, v0.s[3] \n" // sum3 += (a00-a70) * k30
409 "fmla v23.4s, v9.4s, v0.s[3] \n" //
410 "fmla v24.4s, v8.4s, v1.s[0] \n" // sum4 += (a00-a70) * k40
411 "fmla v25.4s, v9.4s, v1.s[0] \n" //
412 "fmla v26.4s, v8.4s, v1.s[1] \n" // sum5 += (a00-a70) * k50
413 "fmla v27.4s, v9.4s, v1.s[1] \n" //
414 "fmla v28.4s, v8.4s, v1.s[2] \n" // sum6 += (a00-a70) * k60
415 "fmla v29.4s, v9.4s, v1.s[2] \n" //
416 "fmla v30.4s, v8.4s, v1.s[3] \n" // sum7 += (a00-a70) * k70
417 "fmla v31.4s, v9.4s, v1.s[3] \n" //
418
419 "subs w4, w4, #1 \n"
420
421 "bne 2b \n"
422
423 "3: \n"
424
425 "st1 {v16.4s, v17.4s}, [%0] \n"
426 "st1 {v18.4s, v19.4s}, [%1] \n"
427 "st1 {v20.4s, v21.4s}, [%2] \n"
428 "st1 {v22.4s, v23.4s}, [%3] \n"
429 "st1 {v24.4s, v25.4s}, [%4] \n"
430 "st1 {v26.4s, v27.4s}, [%5] \n"
431 "st1 {v28.4s, v29.4s}, [%6] \n"
432 "st1 {v30.4s, v31.4s}, [%7] \n"
433
434 : "=r"(output0), // %0
435 "=r"(output1), // %1
436 "=r"(output2), // %2
437 "=r"(output3), // %3
438 "=r"(output4), // %4
439 "=r"(output5), // %5
440 "=r"(output6), // %6
441 "=r"(output7), // %7
442 "=r"(vb), // %8
443 "=r"(va) // %9
444 : "0"(output0),
445 "1"(output1),
446 "2"(output2),
447 "3"(output3),
448 "4"(output4),
449 "5"(output5),
450 "6"(output6),
451 "7"(output7),
452 "8"(vb),
453 "9"(va),
454 "r"(L), // %20
455 "r"(biasptr) // %21
456 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
457 #else
458 float sum0[8] = {0};
459 float sum1[8] = {0};
460 float sum2[8] = {0};
461 float sum3[8] = {0};
462 float sum4[8] = {0};
463 float sum5[8] = {0};
464 float sum6[8] = {0};
465 float sum7[8] = {0};
466
467 int k = 0;
468 for (; k + 7 < L; k = k + 8)
469 {
470 for (int n = 0; n < 8; n++)
471 {
472 sum0[n] += va[0] * vb[n];
473 sum1[n] += va[1] * vb[n];
474 sum2[n] += va[2] * vb[n];
475 sum3[n] += va[3] * vb[n];
476 sum4[n] += va[4] * vb[n];
477 sum5[n] += va[5] * vb[n];
478 sum6[n] += va[6] * vb[n];
479 sum7[n] += va[7] * vb[n];
480 va += 8;
481
482 sum0[n] += va[0] * vb[n + 8];
483 sum1[n] += va[1] * vb[n + 8];
484 sum2[n] += va[2] * vb[n + 8];
485 sum3[n] += va[3] * vb[n + 8];
486 sum4[n] += va[4] * vb[n + 8];
487 sum5[n] += va[5] * vb[n + 8];
488 sum6[n] += va[6] * vb[n + 8];
489 sum7[n] += va[7] * vb[n + 8];
490 va += 8;
491
492 sum0[n] += va[0] * vb[n + 16];
493 sum1[n] += va[1] * vb[n + 16];
494 sum2[n] += va[2] * vb[n + 16];
495 sum3[n] += va[3] * vb[n + 16];
496 sum4[n] += va[4] * vb[n + 16];
497 sum5[n] += va[5] * vb[n + 16];
498 sum6[n] += va[6] * vb[n + 16];
499 sum7[n] += va[7] * vb[n + 16];
500 va += 8;
501
502 sum0[n] += va[0] * vb[n + 24];
503 sum1[n] += va[1] * vb[n + 24];
504 sum2[n] += va[2] * vb[n + 24];
505 sum3[n] += va[3] * vb[n + 24];
506 sum4[n] += va[4] * vb[n + 24];
507 sum5[n] += va[5] * vb[n + 24];
508 sum6[n] += va[6] * vb[n + 24];
509 sum7[n] += va[7] * vb[n + 24];
510 va += 8;
511
512 sum0[n] += va[0] * vb[n + 32];
513 sum1[n] += va[1] * vb[n + 32];
514 sum2[n] += va[2] * vb[n + 32];
515 sum3[n] += va[3] * vb[n + 32];
516 sum4[n] += va[4] * vb[n + 32];
517 sum5[n] += va[5] * vb[n + 32];
518 sum6[n] += va[6] * vb[n + 32];
519 sum7[n] += va[7] * vb[n + 32];
520 va += 8;
521
522 sum0[n] += va[0] * vb[n + 40];
523 sum1[n] += va[1] * vb[n + 40];
524 sum2[n] += va[2] * vb[n + 40];
525 sum3[n] += va[3] * vb[n + 40];
526 sum4[n] += va[4] * vb[n + 40];
527 sum5[n] += va[5] * vb[n + 40];
528 sum6[n] += va[6] * vb[n + 40];
529 sum7[n] += va[7] * vb[n + 40];
530 va += 8;
531
532 sum0[n] += va[0] * vb[n + 48];
533 sum1[n] += va[1] * vb[n + 48];
534 sum2[n] += va[2] * vb[n + 48];
535 sum3[n] += va[3] * vb[n + 48];
536 sum4[n] += va[4] * vb[n + 48];
537 sum5[n] += va[5] * vb[n + 48];
538 sum6[n] += va[6] * vb[n + 48];
539 sum7[n] += va[7] * vb[n + 48];
540 va += 8;
541
542 sum0[n] += va[0] * vb[n + 56];
543 sum1[n] += va[1] * vb[n + 56];
544 sum2[n] += va[2] * vb[n + 56];
545 sum3[n] += va[3] * vb[n + 56];
546 sum4[n] += va[4] * vb[n + 56];
547 sum5[n] += va[5] * vb[n + 56];
548 sum6[n] += va[6] * vb[n + 56];
549 sum7[n] += va[7] * vb[n + 56];
550 va -= 56;
551 }
552
553 va += 64;
554 vb += 64;
555 }
556
557 for (; k < L; k++)
558 {
559 for (int n = 0; n < 8; n++)
560 {
561 sum0[n] += va[0] * vb[n];
562 sum1[n] += va[1] * vb[n];
563 sum2[n] += va[2] * vb[n];
564 sum3[n] += va[3] * vb[n];
565 sum4[n] += va[4] * vb[n];
566 sum5[n] += va[5] * vb[n];
567 sum6[n] += va[6] * vb[n];
568 sum7[n] += va[7] * vb[n];
569 }
570
571 va += 8;
572 vb += 8;
573 }
574
575 for (int n = 0; n < 8; n++)
576 {
577 output0[n] = sum0[n] + biasptr[0];
578 output1[n] = sum1[n] + biasptr[1];
579 output2[n] = sum2[n] + biasptr[2];
580 output3[n] = sum3[n] + biasptr[3];
581 output4[n] = sum4[n] + biasptr[4];
582 output5[n] = sum5[n] + biasptr[5];
583 output6[n] = sum6[n] + biasptr[6];
584 output7[n] = sum7[n] + biasptr[7];
585 }
586 #endif // __ARM_NEON
587 output0 += 8;
588 output1 += 8;
589 output2 += 8;
590 output3 += 8;
591 output4 += 8;
592 output5 += 8;
593 output6 += 8;
594 output7 += 8;
595 }
596
597 for (; j < N; j++)
598 {
599 const float* vb = bottom_tm.channel(j / 8 + j % 8);
600 const float* va = kernel_tm.channel(i / 8);
601
602 #if __ARM_NEON
603 asm volatile(
604 "ld1 {v14.4s, v15.4s}, [%21] \n" // sum0_7 inital with bias
605 "eor v16.16b, v16.16b, v16.16b \n" // sum0
606 "eor v17.16b, v17.16b, v17.16b \n" // sum1
607 "eor v18.16b, v18.16b, v18.16b \n" // sum2
608 "eor v19.16b, v19.16b, v19.16b \n" // sum3
609 "eor v20.16b, v20.16b, v20.16b \n" // sum4
610 "eor v21.16b, v21.16b, v21.16b \n" // sum5
611 "eor v22.16b, v22.16b, v22.16b \n" // sum6
612 "eor v23.16b, v23.16b, v23.16b \n" // sum7
613
614 "lsr w4, %w20, #2 \n" // r4 = nn = L >> 2
615 "cmp w4, #0 \n"
616 "beq 1f \n"
617
618 "0: \n" // for (; k+3<L; k=k+4)
619
620 "prfm pldl1keep, [%9, #256] \n"
621 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%9], #64 \n" // k
622 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%9], #64 \n"
623
624 "prfm pldl1keep, [%8, #128] \n"
625 "ld1 {v8.4s}, [%8], #16 \n" // d
626
627 // k0
628 "fmla v16.4s, v0.4s, v8.s[0] \n" // sum0 += (k00-k70) * a00
629 "fmla v17.4s, v1.4s, v8.s[0] \n" //
630 "fmla v18.4s, v2.4s, v8.s[1] \n" // sum1 += (k01-k71) * a10
631 "fmla v19.4s, v3.4s, v8.s[1] \n" //
632 "fmla v20.4s, v4.4s, v8.s[2] \n" // sum2 += (k02-k72) * a20
633 "fmla v21.4s, v5.4s, v8.s[2] \n" //
634 "fmla v22.4s, v6.4s, v8.s[3] \n" // sum3 += (k03-k73) * a30
635 "fmla v23.4s, v7.4s, v8.s[3] \n" //
636
637 "subs w4, w4, #1 \n"
638 "bne 0b \n"
639
640 "fadd v16.4s, v16.4s, v18.4s \n"
641 "fadd v17.4s, v17.4s, v19.4s \n"
642 "fadd v20.4s, v20.4s, v22.4s \n"
643 "fadd v21.4s, v21.4s, v23.4s \n"
644 "fadd v16.4s, v16.4s, v20.4s \n"
645 "fadd v17.4s, v17.4s, v21.4s \n"
646 "fadd v14.4s, v14.4s, v16.4s \n"
647 "fadd v15.4s, v15.4s, v17.4s \n"
648
649 "1: \n"
650
651 // remain loop
652 "and w4, %w20, #3 \n" // w4 = remain = inch & 3;
653 "cmp w4, #0 \n"
654 "beq 3f \n"
655
656 "2: \n"
657
658 "prfm pldl1keep, [%9, #256] \n"
659 "ld1 {v0.4s, v1.4s}, [%9], #32 \n"
660 "prfm pldl1keep, [%8, #32] \n"
661 "ld1r {v8.4s}, [%8], #4 \n"
662
663 // k0
664 "fmla v14.4s, v8.4s, v0.4s \n" // sum0 += (k00-k70) * a00
665 "fmla v15.4s, v8.4s, v1.4s \n" //
666
667 "subs w4, w4, #1 \n"
668
669 "bne 2b \n"
670
671 "3: \n"
672
673 "st1 {v14.s}[0], [%0] \n"
674 "st1 {v14.s}[1], [%1] \n"
675 "st1 {v14.s}[2], [%2] \n"
676 "st1 {v14.s}[3], [%3] \n"
677 "st1 {v15.s}[0], [%4] \n"
678 "st1 {v15.s}[1], [%5] \n"
679 "st1 {v15.s}[2], [%6] \n"
680 "st1 {v15.s}[3], [%7] \n"
681
682 : "=r"(output0), // %0
683 "=r"(output1), // %1
684 "=r"(output2), // %2
685 "=r"(output3), // %3
686 "=r"(output4), // %4
687 "=r"(output5), // %5
688 "=r"(output6), // %6
689 "=r"(output7), // %7
690 "=r"(vb), // %8
691 "=r"(va) // %9
692 : "0"(output0),
693 "1"(output1),
694 "2"(output2),
695 "3"(output3),
696 "4"(output4),
697 "5"(output5),
698 "6"(output6),
699 "7"(output7),
700 "8"(vb),
701 "9"(va),
702 "r"(L), // %20
703 "r"(biasptr) // %21
704 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
705 #else
706 float sum0 = biasptr[0];
707 float sum1 = biasptr[1];
708 float sum2 = biasptr[2];
709 float sum3 = biasptr[3];
710 float sum4 = biasptr[4];
711 float sum5 = biasptr[5];
712 float sum6 = biasptr[6];
713 float sum7 = biasptr[7];
714
715 for (int k = 0; k < L; k++)
716 {
717 sum0 += va[0] * vb[0];
718 sum1 += va[1] * vb[0];
719 sum2 += va[2] * vb[0];
720 sum3 += va[3] * vb[0];
721 sum4 += va[4] * vb[0];
722 sum5 += va[5] * vb[0];
723 sum6 += va[6] * vb[0];
724 sum7 += va[7] * vb[0];
725
726 va += 8;
727 vb += 1;
728 }
729
730 output0[0] = sum0;
731 output1[0] = sum1;
732 output2[0] = sum2;
733 output3[0] = sum3;
734 output4[0] = sum4;
735 output5[0] = sum5;
736 output6[0] = sum6;
737 output7[0] = sum7;
738 #endif // __ARM_NEON
739 output0++;
740 output1++;
741 output2++;
742 output3++;
743 output4++;
744 output5++;
745 output6++;
746 output7++;
747 }
748 }
749 #endif // __aarch64__
750
751 nn_outch = (outch - remain_outch_start) >> 2;
752
753 #pragma omp parallel for num_threads(opt.num_threads)
754 for (int pp = 0; pp < nn_outch; pp++)
755 {
756 int i = remain_outch_start + pp * 4;
757
758 float* output0 = top_blob.channel(i);
759 float* output1 = top_blob.channel(i + 1);
760 float* output2 = top_blob.channel(i + 2);
761 float* output3 = top_blob.channel(i + 3);
762
763 const float zeros[4] = {0.f, 0.f, 0.f, 0.f};
764 const float* biasptr = bias ? bias + i : zeros;
765
766 int j = 0;
767 for (; j + 7 < N; j = j + 8)
768 {
769 const float* vb = bottom_tm.channel(j / 8);
770 #if __ARM_NEON && __aarch64__
771 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
772 #else
773 const float* va = kernel_tm.channel(i / 4);
774 #endif // __ARM_NEON && __aarch64__
775
776 #if __ARM_NEON
777 #if __aarch64__
778 asm volatile(
779 "ld1 {v0.4s}, [%13] \n"
780 "dup v16.4s, v0.s[0] \n" // sum0
781 "dup v17.4s, v0.s[0] \n"
782 "dup v18.4s, v0.s[1] \n" // sum1
783 "dup v19.4s, v0.s[1] \n"
784 "dup v20.4s, v0.s[2] \n" // sum2
785 "dup v21.4s, v0.s[2] \n"
786 "dup v22.4s, v0.s[3] \n" // sum3
787 "dup v23.4s, v0.s[3] \n"
788
789 "lsr w4, %w12, #2 \n" // r4 = nn = L >> 2
790 "cmp w4, #0 \n"
791 "beq 1f \n"
792
793 "0: \n" // for (; k+3<L; k=k+4)
794
795 "prfm pldl1keep, [%5, #512] \n"
796 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64 \n" // kernel
797
798 "prfm pldl1keep, [%4, #512] \n"
799 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%4], #64 \n" // data
800 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%4], #64 \n"
801
802 "subs w4, w4, #1 \n"
803 // k0
804 "fmla v16.4s, v8.4s, v0.s[0] \n" // sum0 += (a00-a70) * k00
805 "fmla v17.4s, v9.4s, v0.s[0] \n" //
806 "fmla v18.4s, v8.4s, v0.s[1] \n" // sum1 += (a00-a70) * k10
807 "fmla v19.4s, v9.4s, v0.s[1] \n" //
808 "fmla v20.4s, v8.4s, v0.s[2] \n" // sum2 += (a00-a70) * k20
809 "fmla v21.4s, v9.4s, v0.s[2] \n" //
810 "fmla v22.4s, v8.4s, v0.s[3] \n" // sum3 += (a00-a70) * k30
811 "fmla v23.4s, v9.4s, v0.s[3] \n" //
812 // k1
813 "fmla v16.4s, v10.4s, v1.s[0] \n" // sum0 += (a01-a71) * k01
814 "fmla v17.4s, v11.4s, v1.s[0] \n" //
815 "fmla v18.4s, v10.4s, v1.s[1] \n" // sum1 += (a01-a71) * k11
816 "fmla v19.4s, v11.4s, v1.s[1] \n" //
817 "fmla v20.4s, v10.4s, v1.s[2] \n" // sum2 += (a01-a71) * k21
818 "fmla v21.4s, v11.4s, v1.s[2] \n" //
819 "fmla v22.4s, v10.4s, v1.s[3] \n" // sum3 += (a01-a71) * k31
820 "fmla v23.4s, v11.4s, v1.s[3] \n" //
821 // k2
822 "fmla v16.4s, v12.4s, v2.s[0] \n" // sum0 += (a02-a72) * k02
823 "fmla v17.4s, v13.4s, v2.s[0] \n" //
824 "fmla v18.4s, v12.4s, v2.s[1] \n" // sum1 += (a02-a72) * k12
825 "fmla v19.4s, v13.4s, v2.s[1] \n" //
826 "fmla v20.4s, v12.4s, v2.s[2] \n" // sum2 += (a02-a72) * k22
827 "fmla v21.4s, v13.4s, v2.s[2] \n" //
828 "fmla v22.4s, v12.4s, v2.s[3] \n" // sum3 += (a02-a72) * k32
829 "fmla v23.4s, v13.4s, v2.s[3] \n" //
830 // k3
831 "fmla v16.4s, v14.4s, v3.s[0] \n" // sum0 += (a03-a73) * k03
832 "fmla v17.4s, v15.4s, v3.s[0] \n" //
833 "fmla v18.4s, v14.4s, v3.s[1] \n" // sum1 += (a03-a73) * k13
834 "fmla v19.4s, v15.4s, v3.s[1] \n" //
835 "fmla v20.4s, v14.4s, v3.s[2] \n" // sum2 += (a03-a73) * k23
836 "fmla v21.4s, v15.4s, v3.s[2] \n" //
837 "fmla v22.4s, v14.4s, v3.s[3] \n" // sum3 += (a03-a73) * k33
838 "fmla v23.4s, v15.4s, v3.s[3] \n" //
839
840 "bne 0b \n"
841
842 "1: \n"
843
844 // remain loop
845 "and w4, %w12, #3 \n" // w4 = remain = inch & 3;
846 "cmp w4, #0 \n"
847 "beq 3f \n"
848
849 "2: \n"
850
851 "prfm pldl1keep, [%5, #256] \n"
852 "ld1 {v0.4s}, [%5], #16 \n"
853 "prfm pldl1keep, [%4, #256] \n"
854 "ld1 {v8.4s, v9.4s}, [%4], #32 \n"
855 // k0
856 "fmla v16.4s, v8.4s, v0.s[0] \n" // sum0 += (a00-a70) * k00
857 "fmla v17.4s, v9.4s, v0.s[0] \n" //
858 "fmla v18.4s, v8.4s, v0.s[1] \n" // sum1 += (a00-a70) * k10
859 "fmla v19.4s, v9.4s, v0.s[1] \n" //
860 "fmla v20.4s, v8.4s, v0.s[2] \n" // sum2 += (a00-a70) * k20
861 "fmla v21.4s, v9.4s, v0.s[2] \n" //
862 "fmla v22.4s, v8.4s, v0.s[3] \n" // sum3 += (a00-a70) * k30
863 "fmla v23.4s, v9.4s, v0.s[3] \n" //
864
865 "subs w4, w4, #1 \n"
866
867 "bne 2b \n"
868
869 "3: \n"
870
871 "st1 {v16.4s, v17.4s}, [%0] \n"
872 "st1 {v18.4s, v19.4s}, [%1] \n"
873 "st1 {v20.4s, v21.4s}, [%2] \n"
874 "st1 {v22.4s, v23.4s}, [%3] \n"
875
876 : "=r"(output0), // %0
877 "=r"(output1), // %1
878 "=r"(output2), // %2
879 "=r"(output3), // %3
880 "=r"(vb), // %4
881 "=r"(va) // %5
882 : "0"(output0),
883 "1"(output1),
884 "2"(output2),
885 "3"(output3),
886 "4"(vb),
887 "5"(va),
888 "r"(L), // %12
889 "r"(biasptr) // %13
890 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23");
891 #else
892 asm volatile(
893 "vld1.f32 {d0-d1}, [%13] \n"
894 "vdup.f32 q8, d0[0] \n"
895 "vdup.f32 q9, d0[0] \n"
896 "vdup.f32 q10, d0[1] \n"
897 "vdup.f32 q11, d0[1] \n"
898 "vdup.f32 q12, d1[0] \n"
899 "vdup.f32 q13, d1[0] \n"
900 "vdup.f32 q14, d1[1] \n"
901 "vdup.f32 q15, d1[1] \n"
902
903 "lsr r4, %12, #2 \n" // r4 = nn = L >> 2
904 "cmp r4, #0 \n"
905 "beq 1f \n"
906
907 "0: \n" // for(; nn != 0; nn--)
908 "pld [%5, #512] \n"
909 "vldm %5!, {d0-d7} \n" // kernel
910 "pld [%4, #512] \n"
911 "vldm %4!, {d8-d15} \n" // data
912
913 "vmla.f32 q8, q4, d0[0] \n" // sum0 = (a00-a07) * k00
914 "vmla.f32 q9, q5, d0[0] \n"
915 "vmla.f32 q10, q4, d0[1] \n" // sum1 = (a00-a07) * k10
916 "vmla.f32 q11, q5, d0[1] \n"
917 "vmla.f32 q12, q4, d1[0] \n" // sum2 = (a00-a07) * k20
918 "vmla.f32 q13, q5, d1[0] \n"
919 "vmla.f32 q14, q4, d1[1] \n" // sum3 = (a00-a07) * k30
920 "vmla.f32 q15, q5, d1[1] \n"
921
922 "vmla.f32 q8, q6, d2[0] \n" // sum0 += (a10-a17) * k01
923 "vmla.f32 q9, q7, d2[0] \n"
924 "vmla.f32 q10, q6, d2[1] \n" // sum1 += (a10-a17) * k11
925 "vmla.f32 q11, q7, d2[1] \n"
926 "vmla.f32 q12, q6, d3[0] \n" // sum2 += (a10-a17) * k21
927 "vmla.f32 q13, q7, d3[0] \n"
928 "vmla.f32 q14, q6, d3[1] \n" // sum3 += (a10-a17) * k31
929 "vmla.f32 q15, q7, d3[1] \n"
930
931 "pld [%4, #512] \n"
932 "vldm %4!, {d8-d15} \n" // data
933
934 "vmla.f32 q8, q4, d4[0] \n" // sum0 += (a20-a27) * k02
935 "vmla.f32 q9, q5, d4[0] \n"
936 "vmla.f32 q10, q4, d4[1] \n" // sum1 += (a20-a27) * k12
937 "vmla.f32 q11, q5, d4[1] \n"
938 "vmla.f32 q12, q4, d5[0] \n" // sum2 += (a20-a27) * k22
939 "vmla.f32 q13, q5, d5[0] \n"
940 "vmla.f32 q14, q4, d5[1] \n" // sum3 += (a20-a27) * k32
941 "vmla.f32 q15, q5, d5[1] \n"
942
943 "vmla.f32 q8, q6, d6[0] \n" // sum0 += (a30-a37) * k03
944 "vmla.f32 q9, q7, d6[0] \n"
945 "vmla.f32 q10, q6, d6[1] \n" // sum1 += (a30-a37) * k13
946 "vmla.f32 q11, q7, d6[1] \n"
947 "vmla.f32 q12, q6, d7[0] \n" // sum2 += (a30-a37) * k23
948 "vmla.f32 q13, q7, d7[0] \n"
949 "vmla.f32 q14, q6, d7[1] \n" // sum3 += (a30-a37) * k33
950 "vmla.f32 q15, q7, d7[1] \n"
951
952 "subs r4, r4, #1 \n"
953 "bne 0b \n" // end for
954
955 "1: \n"
956 // remain loop
957 "and r4, %12, #3 \n" // r4 = remain = inch & 3
958 "cmp r4, #0 \n"
959 "beq 3f \n"
960
961 "2: \n" // for(; remain != 0; remain--)
962
963 "pld [%5, #128] \n"
964 "vld1.f32 {d0-d1}, [%5]! \n"
965 "pld [%4, #256] \n"
966 "vld1.f32 {d8-d11}, [%4]! \n"
967
968 "vmla.f32 q8, q4, d0[0] \n" // sum0 += (a00-a70) * k00
969 "vmla.f32 q9, q5, d0[0] \n"
970 "vmla.f32 q10, q4, d0[1] \n" // sum1 += (a00-a70) * k10
971 "vmla.f32 q11, q5, d0[1] \n"
972 "vmla.f32 q12, q4, d1[0] \n" // sum2 += (a00-a70) * k20
973 "vmla.f32 q13, q5, d1[0] \n"
974 "vmla.f32 q14, q4, d1[1] \n" // sum3 += (a00-a70) * k30
975 "vmla.f32 q15, q5, d1[1] \n"
976
977 "subs r4, r4, #1 \n"
978 "bne 2b \n"
979
980 "3: \n" // store the result to memory
981 "vst1.f32 {d16-d19}, [%0] \n"
982 "vst1.f32 {d20-d23}, [%1] \n"
983 "vst1.f32 {d24-d27}, [%2] \n"
984 "vst1.f32 {d28-d31}, [%3] \n"
985
986 : "=r"(output0), // %0
987 "=r"(output1), // %1
988 "=r"(output2), // %2
989 "=r"(output3), // %3
990 "=r"(vb), // %4
991 "=r"(va) // %5
992 : "0"(output0),
993 "1"(output1),
994 "2"(output2),
995 "3"(output3),
996 "4"(vb),
997 "5"(va),
998 "r"(L), // %12
999 "r"(biasptr) // %13
1000 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
1001 #endif // __aarch64__
1002 #else
1003 float sum0[8] = {0};
1004 float sum1[8] = {0};
1005 float sum2[8] = {0};
1006 float sum3[8] = {0};
1007
1008 int k = 0;
1009 for (; k + 7 < L; k = k + 8)
1010 {
1011 for (int n = 0; n < 8; n++)
1012 {
1013 sum0[n] += va[0] * vb[n];
1014 sum1[n] += va[1] * vb[n];
1015 sum2[n] += va[2] * vb[n];
1016 sum3[n] += va[3] * vb[n];
1017 va += 4;
1018
1019 sum0[n] += va[0] * vb[n + 8];
1020 sum1[n] += va[1] * vb[n + 8];
1021 sum2[n] += va[2] * vb[n + 8];
1022 sum3[n] += va[3] * vb[n + 8];
1023 va += 4;
1024
1025 sum0[n] += va[0] * vb[n + 16];
1026 sum1[n] += va[1] * vb[n + 16];
1027 sum2[n] += va[2] * vb[n + 16];
1028 sum3[n] += va[3] * vb[n + 16];
1029 va += 4;
1030
1031 sum0[n] += va[0] * vb[n + 24];
1032 sum1[n] += va[1] * vb[n + 24];
1033 sum2[n] += va[2] * vb[n + 24];
1034 sum3[n] += va[3] * vb[n + 24];
1035 va += 4;
1036
1037 sum0[n] += va[0] * vb[n + 32];
1038 sum1[n] += va[1] * vb[n + 32];
1039 sum2[n] += va[2] * vb[n + 32];
1040 sum3[n] += va[3] * vb[n + 32];
1041 va += 4;
1042
1043 sum0[n] += va[0] * vb[n + 40];
1044 sum1[n] += va[1] * vb[n + 40];
1045 sum2[n] += va[2] * vb[n + 40];
1046 sum3[n] += va[3] * vb[n + 40];
1047 va += 4;
1048
1049 sum0[n] += va[0] * vb[n + 48];
1050 sum1[n] += va[1] * vb[n + 48];
1051 sum2[n] += va[2] * vb[n + 48];
1052 sum3[n] += va[3] * vb[n + 48];
1053 va += 4;
1054
1055 sum0[n] += va[0] * vb[n + 56];
1056 sum1[n] += va[1] * vb[n + 56];
1057 sum2[n] += va[2] * vb[n + 56];
1058 sum3[n] += va[3] * vb[n + 56];
1059 va -= 28;
1060 }
1061
1062 va += 32;
1063 vb += 64;
1064 }
1065
1066 for (; k < L; k++)
1067 {
1068 for (int n = 0; n < 8; n++)
1069 {
1070 sum0[n] += va[0] * vb[n];
1071 sum1[n] += va[1] * vb[n];
1072 sum2[n] += va[2] * vb[n];
1073 sum3[n] += va[3] * vb[n];
1074 }
1075
1076 va += 4;
1077 vb += 8;
1078 }
1079
1080 for (int n = 0; n < 8; n++)
1081 {
1082 output0[n] = sum0[n] + biasptr[0];
1083 output1[n] = sum1[n] + biasptr[1];
1084 output2[n] = sum2[n] + biasptr[2];
1085 output3[n] = sum3[n] + biasptr[3];
1086 }
1087 #endif // __ARM_NEON
1088 output0 += 8;
1089 output1 += 8;
1090 output2 += 8;
1091 output3 += 8;
1092 }
1093
1094 for (; j < N; j++)
1095 {
1096 float* vb = bottom_tm.channel(j / 8 + j % 8);
1097 #if __ARM_NEON && __aarch64__
1098 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4);
1099 #else
1100 const float* va = kernel_tm.channel(i / 4);
1101 #endif // __ARM_NEON && __aarch64__
1102
1103 #if __ARM_NEON
1104 #if __aarch64__
1105 asm volatile(
1106 "ld1 {v14.4s}, [%13] \n" // sum0_3 inital with bias
1107
1108 "lsr w4, %w12, #2 \n" // r4 = nn = L >> 2
1109 "cmp w4, #0 \n"
1110 "beq 1f \n"
1111
1112 "eor v16.16b, v16.16b, v16.16b \n" // sum0
1113 "eor v17.16b, v17.16b, v17.16b \n" // sum1
1114 "eor v18.16b, v18.16b, v18.16b \n" // sum2
1115 "eor v19.16b, v19.16b, v19.16b \n" // sum3
1116
1117 "0: \n" // for (; k+3<L; k=k+4)
1118
1119 "prfm pldl1keep, [%5, #256] \n"
1120 "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%5], #64 \n" // k
1121
1122 "prfm pldl1keep, [%4, #128] \n"
1123 "ld1 {v8.4s}, [%4], #16 \n" // d
1124
1125 "subs w4, w4, #1 \n"
1126 "fmla v16.4s, v0.4s, v8.s[0] \n" // sum0 += (k00-k30) * a00
1127 "fmla v17.4s, v1.4s, v8.s[1] \n" // sum1 += (k01-k31) * a10
1128 "fmla v18.4s, v2.4s, v8.s[2] \n" // sum2 += (k02-k32) * a20
1129 "fmla v19.4s, v3.4s, v8.s[3] \n" // sum3 += (k03-k33) * a30
1130
1131 "bne 0b \n"
1132
1133 "fadd v16.4s, v16.4s, v18.4s \n"
1134 "fadd v17.4s, v17.4s, v19.4s \n"
1135 "fadd v14.4s, v14.4s, v16.4s \n"
1136 "fadd v14.4s, v14.4s, v17.4s \n"
1137
1138 "1: \n"
1139
1140 // remain loop
1141 "and w4, %w12, #3 \n" // w4 = remain = inch & 3;
1142 "cmp w4, #0 \n"
1143 "beq 3f \n"
1144
1145 "2: \n"
1146
1147 "prfm pldl1keep, [%5, #128] \n"
1148 "ld1 {v0.4s}, [%5], #16 \n"
1149 "prfm pldl1keep, [%4, #32] \n"
1150 "ld1r {v8.4s}, [%4], #4 \n"
1151
1152 "subs w4, w4, #1 \n"
1153 // k0
1154 "fmla v14.4s, v8.4s, v0.4s \n" // sum0 += (k00-k30) * a00
1155 "bne 2b \n"
1156
1157 "3: \n"
1158
1159 "st1 {v14.s}[0], [%0] \n"
1160 "st1 {v14.s}[1], [%1] \n"
1161 "st1 {v14.s}[2], [%2] \n"
1162 "st1 {v14.s}[3], [%3] \n"
1163
1164 : "=r"(output0), // %0
1165 "=r"(output1), // %1
1166 "=r"(output2), // %2
1167 "=r"(output3), // %3
1168 "=r"(vb), // %4
1169 "=r"(va) // %5
1170 : "0"(output0),
1171 "1"(output1),
1172 "2"(output2),
1173 "3"(output3),
1174 "4"(vb),
1175 "5"(va),
1176 "r"(L), // %12
1177 "r"(biasptr) // %13
1178 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19");
1179 #else
1180 asm volatile(
1181 // inch loop
1182 "vld1.f32 {d24-d25}, [%13] \n"
1183
1184 "lsr r4, %12, #2 \n" // r4 = nn = L >> 2
1185 "cmp r4, #0 \n"
1186 "beq 1f \n"
1187
1188 "veor q8, q8, q8 \n"
1189 "veor q9, q9, q9 \n"
1190 "veor q10, q10, q10 \n"
1191 "veor q11, q11, q11 \n"
1192
1193 "0: \n" // for(; nn != 0; nn--)
1194 "pld [%5, #512] \n"
1195 "vldm %5!, {d0-d7} \n" // kernel
1196 "pld [%4, #128] \n"
1197 "vld1.f32 {d8-d9}, [%4]! \n" // data
1198
1199 "vmla.f32 q8, q0, d8[0] \n" // (k00-k30) * a00
1200 "vmla.f32 q9, q1, d8[1] \n" // (k01-k31) * a01
1201 "vmla.f32 q10, q2, d9[0] \n" // (k02-k32) * a02
1202 "vmla.f32 q11, q3, d9[1] \n" // (k03-k33) * a03
1203
1204 "subs r4, r4, #1 \n"
1205 "bne 0b \n" // end for
1206
1207 "vadd.f32 q8, q8, q9 \n"
1208 "vadd.f32 q10, q10, q11 \n"
1209 "vadd.f32 q8, q8, q10 \n"
1210 "vadd.f32 q12, q12, q8 \n"
1211
1212 "1: \n"
1213 // remain loop
1214 "and r4, %12, #3 \n" // r4 = remain = inch & 3
1215 "cmp r4, #0 \n"
1216 "beq 3f \n"
1217
1218 "2: \n" // for(; remain != 0; remain--)
1219 "pld [%5, #128] \n"
1220 "vld1.f32 {d0-d1}, [%5]! \n"
1221 "pld [%4, #32] \n"
1222 "vld1.f32 {d8[],d9[]}, [%4]! \n"
1223
1224 "subs r4, r4, #1 \n"
1225
1226 "vmla.f32 q12, q0, q4 \n"
1227 "bne 2b \n"
1228
1229 "3: \n" // store the result to memory
1230 "vst1.f32 {d24[0]}, [%0] \n"
1231 "vst1.f32 {d24[1]}, [%1] \n"
1232 "vst1.f32 {d25[0]}, [%2] \n"
1233 "vst1.f32 {d25[1]}, [%3] \n"
1234
1235 : "=r"(output0), // %0
1236 "=r"(output1), // %1
1237 "=r"(output2), // %2
1238 "=r"(output3), // %3
1239 "=r"(vb), // %4
1240 "=r"(va) // %5
1241 : "0"(output0),
1242 "1"(output1),
1243 "2"(output2),
1244 "3"(output3),
1245 "4"(vb),
1246 "5"(va),
1247 "r"(L), // %12
1248 "r"(biasptr) // %13
1249 : "cc", "memory", "r4", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12");
1250 #endif // __aarch64__
1251 #else
1252 float sum0 = biasptr[0];
1253 float sum1 = biasptr[1];
1254 float sum2 = biasptr[2];
1255 float sum3 = biasptr[3];
1256
1257 for (int k = 0; k < L; k++)
1258 {
1259 sum0 += va[0] * vb[0];
1260 sum1 += va[1] * vb[0];
1261 sum2 += va[2] * vb[0];
1262 sum3 += va[3] * vb[0];
1263
1264 va += 4;
1265 vb += 1;
1266 }
1267
1268 output0[0] = sum0;
1269 output1[0] = sum1;
1270 output2[0] = sum2;
1271 output3[0] = sum3;
1272 #endif // __ARM_NEON
1273 output0++;
1274 output1++;
1275 output2++;
1276 output3++;
1277 }
1278 }
1279
1280 remain_outch_start += nn_outch << 2;
1281
1282 #pragma omp parallel for num_threads(opt.num_threads)
1283 for (int i = remain_outch_start; i < outch; i++)
1284 {
1285 float* output = top_blob.channel(i);
1286
1287 const float bias0 = bias ? bias[i] : 0.f;
1288
1289 int j = 0;
1290 for (; j + 7 < N; j = j + 8)
1291 {
1292 const float* vb = bottom_tm.channel(j / 8);
1293 #if __ARM_NEON && __aarch64__
1294 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
1295 #else
1296 const float* va = kernel_tm.channel(i / 4 + i % 4);
1297 #endif // __ARM_NEON && __aarch64__
1298
1299 #if __ARM_NEON
1300 #if __aarch64__
1301 asm volatile(
1302 "dup v16.4s, %w7 \n" // sum0
1303 "dup v17.4s, %w7 \n" // sum0n
1304
1305 "lsr w4, %w6, #2 \n" // r4 = nn = L >> 2
1306 "cmp w4, #0 \n"
1307 "beq 1f \n"
1308
1309 "0: \n" // for (; k+3<L; k=k+4)
1310
1311 "prfm pldl1keep, [%2, #128] \n"
1312 "ld1 {v0.4s}, [%2], #16 \n"
1313
1314 "prfm pldl1keep, [%1, #512] \n"
1315 "ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%1], #64 \n" // data
1316 "ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%1], #64 \n"
1317
1318 // k0
1319 "fmla v16.4s, v8.4s, v0.s[0] \n" // sum0 += (a00-a70) * k00
1320 "fmla v17.4s, v9.4s, v0.s[0] \n" //
1321 // k1
1322 "fmla v16.4s, v10.4s, v0.s[1] \n" // sum0 += (a01-a71) * k01
1323 "fmla v17.4s, v11.4s, v0.s[1] \n" //
1324 // k2
1325 "fmla v16.4s, v12.4s, v0.s[2] \n" // sum0 += (a02-a72) * k02
1326 "fmla v17.4s, v13.4s, v0.s[2] \n" //
1327 // k3
1328 "fmla v16.4s, v14.4s, v0.s[3] \n" // sum0 += (a03-a73) * k03
1329 "fmla v17.4s, v15.4s, v0.s[3] \n" //
1330
1331 "subs w4, w4, #1 \n"
1332 "bne 0b \n"
1333
1334 "1: \n"
1335
1336 // remain loop
1337 "and w4, %w6, #3 \n" // w4 = remain = inch & 3;
1338 "cmp w4, #0 \n"
1339 "beq 3f \n"
1340
1341 "2: \n"
1342 "prfm pldl1keep, [%2, #32] \n"
1343 "ld1r {v0.4s}, [%2], #4 \n"
1344 "prfm pldl1keep, [%1, #256] \n"
1345 "ld1 {v8.4s, v9.4s}, [%1], #32 \n"
1346
1347 "subs w4, w4, #1 \n"
1348 // k0
1349 "fmla v16.4s, v0.4s, v8.4s \n" // sum0 += (a00-a70) * k00
1350 "fmla v17.4s, v0.4s, v9.4s \n" //
1351
1352 "bne 2b \n"
1353
1354 "3: \n"
1355 "st1 {v16.4s, v17.4s}, [%0] \n"
1356
1357 : "=r"(output), // %0
1358 "=r"(vb), // %1
1359 "=r"(va) // %2
1360 : "0"(output),
1361 "1"(vb),
1362 "2"(va),
1363 "r"(L), // %6
1364 "r"(bias0) // %7
1365 : "cc", "memory", "x4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17");
1366 #else
1367 asm volatile(
1368 "vdup.f32 q8, %7 \n"
1369 "vdup.f32 q9, %7 \n"
1370 // inch loop
1371 "lsr r4, %6, #2 \n" // r4 = nn = inch >> 2
1372 "cmp r4, #0 \n"
1373 "beq 1f \n"
1374
1375 "0: \n"
1376
1377 "pld [%1, #512] \n"
1378 "vldm %1!, {d8-d15} \n"
1379 "pld [%2, #128] \n"
1380 "vld1.f32 {d0-d1}, [%2]! \n"
1381
1382 "vmla.f32 q8, q4, d0[0] \n"
1383 "vmla.f32 q9, q5, d0[0] \n"
1384
1385 "pld [%1, #512] \n"
1386 "vldm %1!, {d24-d31} \n"
1387
1388 "vmla.f32 q8, q6, d0[1] \n"
1389 "vmla.f32 q9, q7, d0[1] \n"
1390
1391 "subs r4, r4, #1 \n"
1392
1393 "vmla.f32 q8, q12, d1[0] \n"
1394 "vmla.f32 q9, q13, d1[0] \n"
1395 "vmla.f32 q8, q14, d1[1] \n"
1396 "vmla.f32 q9, q15, d1[1] \n"
1397
1398 "bne 0b \n"
1399
1400 "1: \n"
1401 // remain loop
1402 "and r4, %6, #3 \n" // r4 = remain = inch & 3;
1403 "cmp r4, #0 \n"
1404 "beq 3f \n"
1405
1406 "2: \n"
1407 "pld [%1, #256] \n"
1408 "vld1.f32 {d8-d11}, [%1]! \n"
1409 "pld [%2, #32] \n"
1410 "vld1.f32 {d0[],d1[]}, [%2]! \n"
1411
1412 "subs r4, r4, #1 \n"
1413
1414 "vmla.f32 q8, q4, q0 \n"
1415 "vmla.f32 q9, q5, q0 \n"
1416 "bne 2b \n"
1417
1418 "3: \n"
1419 "vst1.f32 {d16-d19}, [%0] \n"
1420
1421 : "=r"(output), // %0
1422 "=r"(vb), // %1
1423 "=r"(va) // %2
1424 : "0"(output),
1425 "1"(vb),
1426 "2"(va),
1427 "r"(L), // %6
1428 "r"(bias0) // %7
1429 : "cc", "memory", "r4", "q0", "q4", "q5", "q6", "q7", "q8", "q9", "q12", "q13", "q14", "q15");
1430 #endif // __aarch64__
1431 #else
1432 float sum[8] = {0};
1433
1434 int k = 0;
1435 for (; k + 7 < L; k = k + 8)
1436 {
1437 for (int n = 0; n < 8; n++)
1438 {
1439 sum[n] += va[0] * vb[n];
1440 sum[n] += va[1] * vb[n + 8];
1441 sum[n] += va[2] * vb[n + 16];
1442 sum[n] += va[3] * vb[n + 24];
1443 sum[n] += va[4] * vb[n + 32];
1444 sum[n] += va[5] * vb[n + 40];
1445 sum[n] += va[6] * vb[n + 48];
1446 sum[n] += va[7] * vb[n + 56];
1447 }
1448
1449 va += 8;
1450 vb += 64;
1451 }
1452
1453 for (; k < L; k++)
1454 {
1455 for (int n = 0; n < 8; n++)
1456 {
1457 sum[n] += va[0] * vb[n];
1458 }
1459
1460 va += 1;
1461 vb += 8;
1462 }
1463
1464 for (int n = 0; n < 8; n++)
1465 {
1466 output[n] = sum[n] + bias0;
1467 }
1468 #endif // __ARM_NEON
1469 output += 8;
1470 }
1471
1472 for (; j < N; j++)
1473 {
1474 const float* vb = bottom_tm.channel(j / 8 + j % 8);
1475 #if __ARM_NEON && __aarch64__
1476 const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4);
1477 #else
1478 const float* va = kernel_tm.channel(i / 4 + i % 4);
1479 #endif // __ARM_NEON && __aarch64__
1480
1481 int k = 0;
1482 #if __ARM_NEON
1483 float32x4_t _sum0 = vdupq_n_f32(0.f);
1484
1485 for (; k + 3 < L; k += 4)
1486 {
1487 float32x4_t _p0 = vld1q_f32(vb);
1488 vb += 4;
1489
1490 float32x4_t _k0 = vld1q_f32(va);
1491 va += 4;
1492
1493 #if __aarch64__
1494 _sum0 = vfmaq_f32(_sum0, _p0, _k0);
1495 #else
1496 _sum0 = vmlaq_f32(_sum0, _p0, _k0);
1497 #endif
1498 }
1499
1500 #if __aarch64__
1501 float sum0 = bias0 + vaddvq_f32(_sum0);
1502 #else
1503 float32x2_t _ss = vadd_f32(vget_low_f32(_sum0), vget_high_f32(_sum0));
1504 float sum0 = bias0 + vget_lane_f32(vpadd_f32(_ss, _ss), 0);
1505 #endif
1506 #else
1507 float sum0 = bias0;
1508 #endif // __ARM_NEON
1509 for (; k < L; k++)
1510 {
1511 sum0 += va[0] * vb[0];
1512
1513 va += 1;
1514 vb += 1;
1515 }
1516 output[0] = sum0;
1517
1518 output++;
1519 }
1520 }
1521 }
1522 }
1523