1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
convolution_transform_kernel_pack4to1_neon(const Mat & weight_data,Mat & weight_data_pack4to1,int num_input,int num_output,int kernel_w,int kernel_h)15 static void convolution_transform_kernel_pack4to1_neon(const Mat& weight_data, Mat& weight_data_pack4to1, int num_input, int num_output, int kernel_w, int kernel_h)
16 {
17 const int maxk = kernel_w * kernel_h;
18
19 // src = kw-kh-inch-outch
20 // dst = 4a-kw-kh-inch/4a-outch
21 Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output);
22
23 weight_data_pack4to1.create(maxk, num_input / 4, num_output, (size_t)4 * 4, 4);
24
25 for (int q = 0; q < num_output; q++)
26 {
27 const Mat k0 = weight_data_r2.channel(q);
28 Mat g0 = weight_data_pack4to1.channel(q);
29
30 for (int p = 0; p + 3 < num_input; p += 4)
31 {
32 const float* k00 = k0.row(p);
33 const float* k01 = k0.row(p + 1);
34 const float* k02 = k0.row(p + 2);
35 const float* k03 = k0.row(p + 3);
36
37 float* g00 = g0.row(p / 4);
38
39 for (int k = 0; k < maxk; k++)
40 {
41 g00[0] = k00[k];
42 g00[1] = k01[k];
43 g00[2] = k02[k];
44 g00[3] = k03[k];
45
46 g00 += 4;
47 }
48 }
49 }
50 }
51
convolution_pack4to1_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & weight_data_pack4to1,const Mat & bias_data,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,int activation_type,const Mat & activation_params,const Option & opt)52 static void convolution_pack4to1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_pack4to1, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, const Option& opt)
53 {
54 int w = bottom_blob.w;
55 int channels = bottom_blob.c;
56
57 int outw = top_blob.w;
58 int outh = top_blob.h;
59 int outch = top_blob.c;
60
61 const int maxk = kernel_w * kernel_h;
62
63 // kernel offsets
64 std::vector<int> _space_ofs(maxk);
65 int* space_ofs = &_space_ofs[0];
66 {
67 int p1 = 0;
68 int p2 = 0;
69 int gap = w * dilation_h - kernel_w * dilation_w;
70 for (int i = 0; i < kernel_h; i++)
71 {
72 for (int j = 0; j < kernel_w; j++)
73 {
74 space_ofs[p1] = p2;
75 p1++;
76 p2 += dilation_w;
77 }
78 p2 += gap;
79 }
80 }
81
82 const float* bias_data_ptr = bias_data;
83
84 // num_output
85 #pragma omp parallel for num_threads(opt.num_threads)
86 for (int p = 0; p < outch; p++)
87 {
88 float* outptr = top_blob.channel(p);
89
90 for (int i = 0; i < outh; i++)
91 {
92 for (int j = 0; j < outw; j++)
93 {
94 float sum = 0.f;
95
96 if (bias_data_ptr)
97 {
98 sum = bias_data_ptr[p];
99 }
100
101 const float* kptr = (const float*)weight_data_pack4to1 + maxk * channels * p * 4;
102
103 // channels
104 for (int q = 0; q < channels; q++)
105 {
106 const Mat m = bottom_blob.channel(q);
107 const float* sptr = m.row(i * stride_h) + j * stride_w * 4;
108
109 for (int k = 0; k < maxk; k++) // 29.23
110 {
111 float32x4_t _val = vld1q_f32(sptr + space_ofs[k] * 4);
112 float32x4_t _w = vld1q_f32(kptr);
113 float32x4_t _s4 = vmulq_f32(_val, _w);
114 #if __aarch64__
115 sum += vaddvq_f32(_s4); // dot
116 #else
117 float32x2_t _ss = vadd_f32(vget_low_f32(_s4), vget_high_f32(_s4));
118 _ss = vpadd_f32(_ss, _ss);
119 sum += vget_lane_f32(_ss, 0);
120 #endif
121
122 kptr += 4;
123 }
124 }
125
126 sum = activation_ss(sum, activation_type, activation_params);
127
128 outptr[j] = sum;
129 }
130
131 outptr += outw;
132 }
133 }
134 }
135