1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
convolution_transform_kernel_pack4_neon(const Mat & weight_data,Mat & weight_data_pack4,int num_input,int num_output,int kernel_w,int kernel_h)15 static void convolution_transform_kernel_pack4_neon(const Mat& weight_data, Mat& weight_data_pack4, int num_input, int num_output, int kernel_w, int kernel_h)
16 {
17     const int maxk = kernel_w * kernel_h;
18 
19     // src = kw-kh-inch-outch
20     // dst = 4b-4a-kw-kh-inch/4a-outch/4b
21     Mat weight_data_r2 = weight_data.reshape(maxk, num_input, num_output);
22 
23     weight_data_pack4.create(maxk, num_input / 4, num_output / 4, (size_t)4 * 16, 16);
24 
25     for (int q = 0; q + 3 < num_output; q += 4)
26     {
27         const Mat k0 = weight_data_r2.channel(q);
28         const Mat k1 = weight_data_r2.channel(q + 1);
29         const Mat k2 = weight_data_r2.channel(q + 2);
30         const Mat k3 = weight_data_r2.channel(q + 3);
31 
32         Mat g0 = weight_data_pack4.channel(q / 4);
33 
34         for (int p = 0; p + 3 < num_input; p += 4)
35         {
36             const float* k00 = k0.row(p);
37             const float* k01 = k0.row(p + 1);
38             const float* k02 = k0.row(p + 2);
39             const float* k03 = k0.row(p + 3);
40 
41             const float* k10 = k1.row(p);
42             const float* k11 = k1.row(p + 1);
43             const float* k12 = k1.row(p + 2);
44             const float* k13 = k1.row(p + 3);
45 
46             const float* k20 = k2.row(p);
47             const float* k21 = k2.row(p + 1);
48             const float* k22 = k2.row(p + 2);
49             const float* k23 = k2.row(p + 3);
50 
51             const float* k30 = k3.row(p);
52             const float* k31 = k3.row(p + 1);
53             const float* k32 = k3.row(p + 2);
54             const float* k33 = k3.row(p + 3);
55 
56             float* g00 = g0.row(p / 4);
57 
58             for (int k = 0; k < maxk; k++)
59             {
60                 g00[0] = k00[k];
61                 g00[1] = k10[k];
62                 g00[2] = k20[k];
63                 g00[3] = k30[k];
64 
65                 g00[4] = k01[k];
66                 g00[5] = k11[k];
67                 g00[6] = k21[k];
68                 g00[7] = k31[k];
69 
70                 g00[8] = k02[k];
71                 g00[9] = k12[k];
72                 g00[10] = k22[k];
73                 g00[11] = k32[k];
74 
75                 g00[12] = k03[k];
76                 g00[13] = k13[k];
77                 g00[14] = k23[k];
78                 g00[15] = k33[k];
79 
80                 g00 += 16;
81             }
82         }
83     }
84 }
85 
convolution_pack4_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & weight_data_pack4,const Mat & bias_data,int kernel_w,int kernel_h,int dilation_w,int dilation_h,int stride_w,int stride_h,int activation_type,const Mat & activation_params,const Option & opt)86 static void convolution_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_pack4, const Mat& bias_data, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h, int activation_type, const Mat& activation_params, const Option& opt)
87 {
88     int w = bottom_blob.w;
89     int channels = bottom_blob.c;
90 
91     int outw = top_blob.w;
92     int outh = top_blob.h;
93     int outch = top_blob.c;
94 
95     const int maxk = kernel_w * kernel_h;
96 
97     // kernel offsets
98     std::vector<int> _space_ofs(maxk);
99     int* space_ofs = &_space_ofs[0];
100     {
101         int p1 = 0;
102         int p2 = 0;
103         int gap = w * dilation_h - kernel_w * dilation_w;
104         for (int i = 0; i < kernel_h; i++)
105         {
106             for (int j = 0; j < kernel_w; j++)
107             {
108                 space_ofs[p1] = p2;
109                 p1++;
110                 p2 += dilation_w;
111             }
112             p2 += gap;
113         }
114     }
115 
116     const float* bias_data_ptr = bias_data;
117 
118     #pragma omp parallel for num_threads(opt.num_threads)
119     for (int p = 0; p < outch; p++)
120     {
121         float* outptr = top_blob.channel(p);
122 
123         for (int i = 0; i < outh; i++)
124         {
125             for (int j = 0; j < outw; j++)
126             {
127                 float32x4_t _sum = vdupq_n_f32(0.f);
128 
129                 if (bias_data_ptr)
130                 {
131                     _sum = vld1q_f32(bias_data_ptr + p * 4);
132                 }
133 
134                 const float* kptr = (const float*)weight_data_pack4 + maxk * channels * p * 16;
135 
136                 // channels
137                 for (int q = 0; q < channels; q++)
138                 {
139                     const Mat m = bottom_blob.channel(q);
140                     const float* sptr = m.row(i * stride_h) + j * stride_w * 4;
141 
142                     for (int k = 0; k < maxk; k++) // 29.23
143                     {
144                         float32x4_t _val = vld1q_f32(sptr + space_ofs[k] * 4);
145 
146                         float32x4_t _w0 = vld1q_f32(kptr);
147                         float32x4_t _w1 = vld1q_f32(kptr + 4);
148                         float32x4_t _w2 = vld1q_f32(kptr + 8);
149                         float32x4_t _w3 = vld1q_f32(kptr + 12);
150 
151 #if __aarch64__
152                         _sum = vmlaq_laneq_f32(_sum, _w0, _val, 0);
153                         _sum = vmlaq_laneq_f32(_sum, _w1, _val, 1);
154                         _sum = vmlaq_laneq_f32(_sum, _w2, _val, 2);
155                         _sum = vmlaq_laneq_f32(_sum, _w3, _val, 3);
156 #else
157                         _sum = vmlaq_lane_f32(_sum, _w0, vget_low_f32(_val), 0);
158                         _sum = vmlaq_lane_f32(_sum, _w1, vget_low_f32(_val), 1);
159                         _sum = vmlaq_lane_f32(_sum, _w2, vget_high_f32(_val), 0);
160                         _sum = vmlaq_lane_f32(_sum, _w3, vget_high_f32(_val), 1);
161 #endif
162 
163                         kptr += 16;
164                     }
165                 }
166 
167                 _sum = activation_ps(_sum, activation_type, activation_params);
168 
169                 vst1q_f32(outptr + j * 4, _sum);
170             }
171 
172             outptr += outw * 4;
173         }
174     }
175 }
176