1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
deconv4x4s2_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)15 static void deconv4x4s2_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int h = bottom_blob.h;
19     int inch = bottom_blob.c;
20 
21     int outw = top_blob.w;
22     int outch = top_blob.c;
23 
24     const __fp16* kernel = _kernel;
25     const __fp16* bias = _bias;
26 
27     #pragma omp parallel for num_threads(opt.num_threads)
28     for (int p = 0; p < outch; p++)
29     {
30         Mat out = top_blob.channel(p);
31 
32         const __fp16 bias0 = bias ? bias[p] : 0.f;
33 
34         out.fill(bias0);
35 
36         for (int q = 0; q < inch; q++)
37         {
38             const __fp16* img0 = bottom_blob.channel(q);
39 
40             const __fp16* kernel0 = kernel + p * inch * 16 + q * 16;
41 
42             const __fp16* r0 = img0;
43 
44             const __fp16* k0 = kernel0;
45             const __fp16* k1 = kernel0 + 4;
46             const __fp16* k2 = kernel0 + 8;
47             const __fp16* k3 = kernel0 + 12;
48 
49             float16x4_t _k0 = vld1_f16(k0);
50             float16x4_t _k1 = vld1_f16(k1);
51             float16x4_t _k2 = vld1_f16(k2);
52             float16x4_t _k3 = vld1_f16(k3);
53 
54             for (int i = 0; i < h; i++)
55             {
56                 __fp16* outptr = out.row<__fp16>(i * 2);
57 
58                 __fp16* outptr0 = outptr;
59                 __fp16* outptr1 = outptr0 + outw;
60                 __fp16* outptr2 = outptr1 + outw;
61                 __fp16* outptr3 = outptr2 + outw;
62 
63                 int j = 0;
64                 for (; j + 3 < w; j += 4)
65                 {
66                     float16x4_t _v = vld1_f16(r0);
67 
68                     // row 0
69                     float16x4x2_t _out0 = vld2_f16(outptr0);
70                     // 0,2,4,6
71                     _out0.val[0] = vfma_lane_f16(_out0.val[0], _v, _k0, 0);
72                     // 1,3,5,7
73                     _out0.val[1] = vfma_lane_f16(_out0.val[1], _v, _k0, 1);
74                     vst2_f16(outptr0, _out0);
75 
76                     _out0 = vld2_f16(outptr0 + 2);
77                     // 2,4,6,8
78                     _out0.val[0] = vfma_lane_f16(_out0.val[0], _v, _k0, 2);
79                     // 3,5,7,9
80                     _out0.val[1] = vfma_lane_f16(_out0.val[1], _v, _k0, 3);
81                     vst2_f16(outptr0 + 2, _out0);
82 
83                     // row 1
84                     float16x4x2_t _out1 = vld2_f16(outptr1);
85                     // 0,2,4,6
86                     _out1.val[0] = vfma_lane_f16(_out1.val[0], _v, _k1, 0);
87                     // 1,3,5,7
88                     _out1.val[1] = vfma_lane_f16(_out1.val[1], _v, _k1, 1);
89                     vst2_f16(outptr1, _out1);
90 
91                     _out1 = vld2_f16(outptr1 + 2);
92                     // 2,4,6,8
93                     _out1.val[0] = vfma_lane_f16(_out1.val[0], _v, _k1, 2);
94                     // 3,5,7,9
95                     _out1.val[1] = vfma_lane_f16(_out1.val[1], _v, _k1, 3);
96                     vst2_f16(outptr1 + 2, _out1);
97 
98                     // row 2
99                     float16x4x2_t _out2 = vld2_f16(outptr2);
100                     _out2.val[0] = vfma_lane_f16(_out2.val[0], _v, _k2, 0);
101                     _out2.val[1] = vfma_lane_f16(_out2.val[1], _v, _k2, 1);
102                     vst2_f16(outptr2, _out2);
103 
104                     _out2 = vld2_f16(outptr2 + 2);
105                     _out2.val[0] = vfma_lane_f16(_out2.val[0], _v, _k2, 2);
106                     _out2.val[1] = vfma_lane_f16(_out2.val[1], _v, _k2, 3);
107                     vst2_f16(outptr2 + 2, _out2);
108 
109                     // row 3
110                     float16x4x2_t _out3 = vld2_f16(outptr3);
111                     _out3.val[0] = vfma_lane_f16(_out3.val[0], _v, _k3, 0);
112                     _out3.val[1] = vfma_lane_f16(_out3.val[1], _v, _k3, 1);
113                     vst2_f16(outptr3, _out3);
114 
115                     _out3 = vld2_f16(outptr3 + 2);
116                     _out3.val[0] = vfma_lane_f16(_out3.val[0], _v, _k3, 2);
117                     _out3.val[1] = vfma_lane_f16(_out3.val[1], _v, _k3, 3);
118                     vst2_f16(outptr3 + 2, _out3);
119 
120                     r0 += 4;
121                     outptr0 += 8;
122                     outptr1 += 8;
123                     outptr2 += 8;
124                     outptr3 += 8;
125                 }
126                 for (; j < w; j++)
127                 {
128                     __fp16 val = r0[0];
129 
130                     outptr0[0] += val * k0[0];
131                     outptr0[1] += val * k0[1];
132                     outptr0[2] += val * k0[2];
133                     outptr0[3] += val * k0[3];
134 
135                     outptr1[0] += val * k1[0];
136                     outptr1[1] += val * k1[1];
137                     outptr1[2] += val * k1[2];
138                     outptr1[3] += val * k1[3];
139 
140                     outptr2[0] += val * k2[0];
141                     outptr2[1] += val * k2[1];
142                     outptr2[2] += val * k2[2];
143                     outptr2[3] += val * k2[3];
144 
145                     outptr3[0] += val * k3[0];
146                     outptr3[1] += val * k3[1];
147                     outptr3[2] += val * k3[2];
148                     outptr3[3] += val * k3[3];
149 
150                     r0++;
151                     outptr0 += 2;
152                     outptr1 += 2;
153                     outptr2 += 2;
154                     outptr3 += 2;
155                 }
156             }
157         }
158     }
159 }
160