1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
deconv4x4s1_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)15 static void deconv4x4s1_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int h = bottom_blob.h;
19     int inch = bottom_blob.c;
20 
21     int outw = top_blob.w;
22     int outch = top_blob.c;
23 
24     const float* kernel = _kernel;
25     const float* bias = _bias;
26 
27     #pragma omp parallel for num_threads(opt.num_threads)
28     for (int p = 0; p < outch; p++)
29     {
30         Mat out = top_blob.channel(p);
31 
32         const float bias0 = bias ? bias[p] : 0.f;
33 
34         out.fill(bias0);
35 
36         for (int q = 0; q < inch; q++)
37         {
38             const float* img0 = bottom_blob.channel(q);
39 
40             const float* kernel0 = kernel + p * inch * 16 + q * 16;
41 
42             const float* r0 = img0;
43 
44             const float* k0 = kernel0;
45             const float* k1 = kernel0 + 4;
46             const float* k2 = kernel0 + 8;
47             const float* k3 = kernel0 + 12;
48 
49 #if __ARM_NEON
50             float32x4_t _k0 = vld1q_f32(k0);
51             float32x4_t _k1 = vld1q_f32(k1);
52             float32x4_t _k2 = vld1q_f32(k2);
53             float32x4_t _k3 = vld1q_f32(k3);
54 #endif // __ARM_NEON
55 
56             for (int i = 0; i < h; i++)
57             {
58                 float* outptr = out.row(i);
59 
60                 float* outptr0 = outptr;
61                 float* outptr1 = outptr0 + outw;
62                 float* outptr2 = outptr1 + outw;
63                 float* outptr3 = outptr2 + outw;
64 
65                 int j = 0;
66 
67 #if __ARM_NEON
68                 for (; j + 3 < w; j += 4)
69                 {
70                     float32x4_t _v = vld1q_f32(r0);
71 
72                     //
73                     float32x4_t _out00 = vld1q_f32(outptr0 + 0);
74                     _out00 = vmlaq_lane_f32(_out00, _v, vget_low_f32(_k0), 0);
75                     vst1q_f32(outptr0 + 0, _out00);
76 
77                     float32x4_t _out01 = vld1q_f32(outptr0 + 1);
78                     _out01 = vmlaq_lane_f32(_out01, _v, vget_low_f32(_k0), 1);
79                     vst1q_f32(outptr0 + 1, _out01);
80 
81                     float32x4_t _out02 = vld1q_f32(outptr0 + 2);
82                     _out02 = vmlaq_lane_f32(_out02, _v, vget_high_f32(_k0), 0);
83                     vst1q_f32(outptr0 + 2, _out02);
84 
85                     float32x4_t _out03 = vld1q_f32(outptr0 + 3);
86                     _out03 = vmlaq_lane_f32(_out03, _v, vget_high_f32(_k0), 1);
87                     vst1q_f32(outptr0 + 3, _out03);
88 
89                     //
90                     float32x4_t _out10 = vld1q_f32(outptr1 + 0);
91                     _out10 = vmlaq_lane_f32(_out10, _v, vget_low_f32(_k1), 0);
92                     vst1q_f32(outptr1 + 0, _out10);
93 
94                     float32x4_t _out11 = vld1q_f32(outptr1 + 1);
95                     _out11 = vmlaq_lane_f32(_out11, _v, vget_low_f32(_k1), 1);
96                     vst1q_f32(outptr1 + 1, _out11);
97 
98                     float32x4_t _out12 = vld1q_f32(outptr1 + 2);
99                     _out12 = vmlaq_lane_f32(_out12, _v, vget_high_f32(_k1), 0);
100                     vst1q_f32(outptr1 + 2, _out12);
101 
102                     float32x4_t _out13 = vld1q_f32(outptr1 + 3);
103                     _out13 = vmlaq_lane_f32(_out13, _v, vget_high_f32(_k1), 1);
104                     vst1q_f32(outptr1 + 3, _out13);
105 
106                     //
107                     float32x4_t _out20 = vld1q_f32(outptr2 + 0);
108                     _out20 = vmlaq_lane_f32(_out20, _v, vget_low_f32(_k2), 0);
109                     vst1q_f32(outptr2 + 0, _out20);
110 
111                     float32x4_t _out21 = vld1q_f32(outptr2 + 1);
112                     _out21 = vmlaq_lane_f32(_out21, _v, vget_low_f32(_k2), 1);
113                     vst1q_f32(outptr2 + 1, _out21);
114 
115                     float32x4_t _out22 = vld1q_f32(outptr2 + 2);
116                     _out22 = vmlaq_lane_f32(_out22, _v, vget_high_f32(_k2), 0);
117                     vst1q_f32(outptr2 + 2, _out22);
118 
119                     float32x4_t _out23 = vld1q_f32(outptr2 + 3);
120                     _out23 = vmlaq_lane_f32(_out23, _v, vget_high_f32(_k2), 1);
121                     vst1q_f32(outptr2 + 3, _out23);
122 
123                     //
124                     float32x4_t _out30 = vld1q_f32(outptr3 + 0);
125                     _out30 = vmlaq_lane_f32(_out30, _v, vget_low_f32(_k3), 0);
126                     vst1q_f32(outptr3 + 0, _out30);
127 
128                     float32x4_t _out31 = vld1q_f32(outptr3 + 1);
129                     _out31 = vmlaq_lane_f32(_out31, _v, vget_low_f32(_k3), 1);
130                     vst1q_f32(outptr3 + 1, _out31);
131 
132                     float32x4_t _out32 = vld1q_f32(outptr3 + 2);
133                     _out32 = vmlaq_lane_f32(_out32, _v, vget_high_f32(_k3), 0);
134                     vst1q_f32(outptr3 + 2, _out32);
135 
136                     float32x4_t _out33 = vld1q_f32(outptr3 + 3);
137                     _out33 = vmlaq_lane_f32(_out33, _v, vget_high_f32(_k3), 1);
138                     vst1q_f32(outptr3 + 3, _out33);
139 
140                     r0 += 4;
141                     outptr0 += 4;
142                     outptr1 += 4;
143                     outptr2 += 4;
144                     outptr3 += 4;
145                 }
146 
147 #endif // __ARM_NEON
148 
149                 for (; j < w; j++)
150                 {
151                     float val = r0[0];
152 
153                     outptr0[0] += val * k0[0];
154                     outptr0[1] += val * k0[1];
155                     outptr0[2] += val * k0[2];
156                     outptr0[3] += val * k0[3];
157 
158                     outptr1[0] += val * k1[0];
159                     outptr1[1] += val * k1[1];
160                     outptr1[2] += val * k1[2];
161                     outptr1[3] += val * k1[3];
162 
163                     outptr2[0] += val * k2[0];
164                     outptr2[1] += val * k2[1];
165                     outptr2[2] += val * k2[2];
166                     outptr2[3] += val * k2[3];
167 
168                     outptr3[0] += val * k3[0];
169                     outptr3[1] += val * k3[1];
170                     outptr3[2] += val * k3[2];
171                     outptr3[3] += val * k3[3];
172 
173                     r0++;
174                     outptr0++;
175                     outptr1++;
176                     outptr2++;
177                     outptr3++;
178                 }
179             }
180         }
181     }
182 }
183 
deconv4x4s2_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)184 static void deconv4x4s2_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
185 {
186     int w = bottom_blob.w;
187     int h = bottom_blob.h;
188     int inch = bottom_blob.c;
189 
190     int outw = top_blob.w;
191     int outch = top_blob.c;
192 
193     const float* kernel = _kernel;
194     const float* bias = _bias;
195 
196     #pragma omp parallel for num_threads(opt.num_threads)
197     for (int p = 0; p < outch; p++)
198     {
199         Mat out = top_blob.channel(p);
200 
201         const float bias0 = bias ? bias[p] : 0.f;
202 
203         out.fill(bias0);
204 
205         for (int q = 0; q < inch; q++)
206         {
207             const float* img0 = bottom_blob.channel(q);
208 
209             const float* kernel0 = kernel + p * inch * 16 + q * 16;
210 
211             const float* r0 = img0;
212 
213             const float* k0 = kernel0;
214             const float* k1 = kernel0 + 4;
215             const float* k2 = kernel0 + 8;
216             const float* k3 = kernel0 + 12;
217 
218 #if __ARM_NEON
219             float32x4_t _k0 = vld1q_f32(k0);
220             float32x4_t _k1 = vld1q_f32(k1);
221             float32x4_t _k2 = vld1q_f32(k2);
222             float32x4_t _k3 = vld1q_f32(k3);
223 #endif // __ARM_NEON
224 
225             for (int i = 0; i < h; i++)
226             {
227                 float* outptr = out.row(i * 2);
228 
229                 float* outptr0 = outptr;
230                 float* outptr1 = outptr0 + outw;
231                 float* outptr2 = outptr1 + outw;
232                 float* outptr3 = outptr2 + outw;
233 
234                 int j = 0;
235 #if __ARM_NEON
236                 for (; j + 3 < w; j += 4)
237                 {
238                     float32x4_t _v = vld1q_f32(r0);
239 
240                     // row 0
241                     float32x4x2_t _out0 = vld2q_f32(outptr0);
242                     // 0,2,4,6
243                     _out0.val[0] = vmlaq_lane_f32(_out0.val[0], _v, vget_low_f32(_k0), 0);
244                     // 1,3,5,7
245                     _out0.val[1] = vmlaq_lane_f32(_out0.val[1], _v, vget_low_f32(_k0), 1);
246                     vst2q_f32(outptr0, _out0);
247 
248                     _out0 = vld2q_f32(outptr0 + 2);
249                     // 2,4,6,8
250                     _out0.val[0] = vmlaq_lane_f32(_out0.val[0], _v, vget_high_f32(_k0), 0);
251                     // 3,5,7,9
252                     _out0.val[1] = vmlaq_lane_f32(_out0.val[1], _v, vget_high_f32(_k0), 1);
253                     vst2q_f32(outptr0 + 2, _out0);
254 
255                     // row 1
256                     float32x4x2_t _out1 = vld2q_f32(outptr1);
257                     // 0,2,4,6
258                     _out1.val[0] = vmlaq_lane_f32(_out1.val[0], _v, vget_low_f32(_k1), 0);
259                     // 1,3,5,7
260                     _out1.val[1] = vmlaq_lane_f32(_out1.val[1], _v, vget_low_f32(_k1), 1);
261                     vst2q_f32(outptr1, _out1);
262 
263                     _out1 = vld2q_f32(outptr1 + 2);
264                     // 2,4,6,8
265                     _out1.val[0] = vmlaq_lane_f32(_out1.val[0], _v, vget_high_f32(_k1), 0);
266                     // 3,5,7,9
267                     _out1.val[1] = vmlaq_lane_f32(_out1.val[1], _v, vget_high_f32(_k1), 1);
268                     vst2q_f32(outptr1 + 2, _out1);
269 
270                     // row 2
271                     float32x4x2_t _out2 = vld2q_f32(outptr2);
272                     _out2.val[0] = vmlaq_lane_f32(_out2.val[0], _v, vget_low_f32(_k2), 0);
273                     _out2.val[1] = vmlaq_lane_f32(_out2.val[1], _v, vget_low_f32(_k2), 1);
274                     vst2q_f32(outptr2, _out2);
275 
276                     _out2 = vld2q_f32(outptr2 + 2);
277                     _out2.val[0] = vmlaq_lane_f32(_out2.val[0], _v, vget_high_f32(_k2), 0);
278                     _out2.val[1] = vmlaq_lane_f32(_out2.val[1], _v, vget_high_f32(_k2), 1);
279                     vst2q_f32(outptr2 + 2, _out2);
280 
281                     // row 3
282                     float32x4x2_t _out3 = vld2q_f32(outptr3);
283                     _out3.val[0] = vmlaq_lane_f32(_out3.val[0], _v, vget_low_f32(_k3), 0);
284                     _out3.val[1] = vmlaq_lane_f32(_out3.val[1], _v, vget_low_f32(_k3), 1);
285                     vst2q_f32(outptr3, _out3);
286 
287                     _out3 = vld2q_f32(outptr3 + 2);
288                     _out3.val[0] = vmlaq_lane_f32(_out3.val[0], _v, vget_high_f32(_k3), 0);
289                     _out3.val[1] = vmlaq_lane_f32(_out3.val[1], _v, vget_high_f32(_k3), 1);
290                     vst2q_f32(outptr3 + 2, _out3);
291 
292                     r0 += 4;
293                     outptr0 += 8;
294                     outptr1 += 8;
295                     outptr2 += 8;
296                     outptr3 += 8;
297                 }
298 
299 #endif // __ARM_NEON
300 
301                 for (; j < w; j++)
302                 {
303                     float val = r0[0];
304 
305                     outptr0[0] += val * k0[0];
306                     outptr0[1] += val * k0[1];
307                     outptr0[2] += val * k0[2];
308                     outptr0[3] += val * k0[3];
309 
310                     outptr1[0] += val * k1[0];
311                     outptr1[1] += val * k1[1];
312                     outptr1[2] += val * k1[2];
313                     outptr1[3] += val * k1[3];
314 
315                     outptr2[0] += val * k2[0];
316                     outptr2[1] += val * k2[1];
317                     outptr2[2] += val * k2[2];
318                     outptr2[3] += val * k2[3];
319 
320                     outptr3[0] += val * k3[0];
321                     outptr3[1] += val * k3[1];
322                     outptr3[2] += val * k3[2];
323                     outptr3[3] += val * k3[3];
324 
325                     r0++;
326                     outptr0 += 2;
327                     outptr1 += 2;
328                     outptr2 += 2;
329                     outptr3 += 2;
330                 }
331             }
332         }
333     }
334 }
335