1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
deconv4x4s2_fp16sa_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & _kernel,const Mat & _bias,const Option & opt)15 static void deconv4x4s2_fp16sa_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& _kernel, const Mat& _bias, const Option& opt)
16 {
17 int w = bottom_blob.w;
18 int h = bottom_blob.h;
19 int inch = bottom_blob.c;
20
21 int outw = top_blob.w;
22 int outch = top_blob.c;
23
24 const __fp16* kernel = _kernel;
25 const __fp16* bias = _bias;
26
27 #pragma omp parallel for num_threads(opt.num_threads)
28 for (int p = 0; p < outch; p++)
29 {
30 Mat out = top_blob.channel(p);
31
32 const __fp16 bias0 = bias ? bias[p] : 0.f;
33
34 out.fill(bias0);
35
36 for (int q = 0; q < inch; q++)
37 {
38 const __fp16* img0 = bottom_blob.channel(q);
39
40 const __fp16* kernel0 = kernel + p * inch * 16 + q * 16;
41
42 const __fp16* r0 = img0;
43
44 const __fp16* k0 = kernel0;
45 const __fp16* k1 = kernel0 + 4;
46 const __fp16* k2 = kernel0 + 8;
47 const __fp16* k3 = kernel0 + 12;
48
49 float16x4_t _k0 = vld1_f16(k0);
50 float16x4_t _k1 = vld1_f16(k1);
51 float16x4_t _k2 = vld1_f16(k2);
52 float16x4_t _k3 = vld1_f16(k3);
53
54 for (int i = 0; i < h; i++)
55 {
56 __fp16* outptr = out.row<__fp16>(i * 2);
57
58 __fp16* outptr0 = outptr;
59 __fp16* outptr1 = outptr0 + outw;
60 __fp16* outptr2 = outptr1 + outw;
61 __fp16* outptr3 = outptr2 + outw;
62
63 int j = 0;
64 for (; j + 3 < w; j += 4)
65 {
66 float16x4_t _v = vld1_f16(r0);
67
68 // row 0
69 float16x4x2_t _out0 = vld2_f16(outptr0);
70 // 0,2,4,6
71 _out0.val[0] = vfma_lane_f16(_out0.val[0], _v, _k0, 0);
72 // 1,3,5,7
73 _out0.val[1] = vfma_lane_f16(_out0.val[1], _v, _k0, 1);
74 vst2_f16(outptr0, _out0);
75
76 _out0 = vld2_f16(outptr0 + 2);
77 // 2,4,6,8
78 _out0.val[0] = vfma_lane_f16(_out0.val[0], _v, _k0, 2);
79 // 3,5,7,9
80 _out0.val[1] = vfma_lane_f16(_out0.val[1], _v, _k0, 3);
81 vst2_f16(outptr0 + 2, _out0);
82
83 // row 1
84 float16x4x2_t _out1 = vld2_f16(outptr1);
85 // 0,2,4,6
86 _out1.val[0] = vfma_lane_f16(_out1.val[0], _v, _k1, 0);
87 // 1,3,5,7
88 _out1.val[1] = vfma_lane_f16(_out1.val[1], _v, _k1, 1);
89 vst2_f16(outptr1, _out1);
90
91 _out1 = vld2_f16(outptr1 + 2);
92 // 2,4,6,8
93 _out1.val[0] = vfma_lane_f16(_out1.val[0], _v, _k1, 2);
94 // 3,5,7,9
95 _out1.val[1] = vfma_lane_f16(_out1.val[1], _v, _k1, 3);
96 vst2_f16(outptr1 + 2, _out1);
97
98 // row 2
99 float16x4x2_t _out2 = vld2_f16(outptr2);
100 _out2.val[0] = vfma_lane_f16(_out2.val[0], _v, _k2, 0);
101 _out2.val[1] = vfma_lane_f16(_out2.val[1], _v, _k2, 1);
102 vst2_f16(outptr2, _out2);
103
104 _out2 = vld2_f16(outptr2 + 2);
105 _out2.val[0] = vfma_lane_f16(_out2.val[0], _v, _k2, 2);
106 _out2.val[1] = vfma_lane_f16(_out2.val[1], _v, _k2, 3);
107 vst2_f16(outptr2 + 2, _out2);
108
109 // row 3
110 float16x4x2_t _out3 = vld2_f16(outptr3);
111 _out3.val[0] = vfma_lane_f16(_out3.val[0], _v, _k3, 0);
112 _out3.val[1] = vfma_lane_f16(_out3.val[1], _v, _k3, 1);
113 vst2_f16(outptr3, _out3);
114
115 _out3 = vld2_f16(outptr3 + 2);
116 _out3.val[0] = vfma_lane_f16(_out3.val[0], _v, _k3, 2);
117 _out3.val[1] = vfma_lane_f16(_out3.val[1], _v, _k3, 3);
118 vst2_f16(outptr3 + 2, _out3);
119
120 r0 += 4;
121 outptr0 += 8;
122 outptr1 += 8;
123 outptr2 += 8;
124 outptr3 += 8;
125 }
126 for (; j < w; j++)
127 {
128 __fp16 val = r0[0];
129
130 outptr0[0] += val * k0[0];
131 outptr0[1] += val * k0[1];
132 outptr0[2] += val * k0[2];
133 outptr0[3] += val * k0[3];
134
135 outptr1[0] += val * k1[0];
136 outptr1[1] += val * k1[1];
137 outptr1[2] += val * k1[2];
138 outptr1[3] += val * k1[3];
139
140 outptr2[0] += val * k2[0];
141 outptr2[1] += val * k2[1];
142 outptr2[2] += val * k2[2];
143 outptr2[3] += val * k2[3];
144
145 outptr3[0] += val * k3[0];
146 outptr3[1] += val * k3[1];
147 outptr3[2] += val * k3[2];
148 outptr3[3] += val * k3[3];
149
150 r0++;
151 outptr0 += 2;
152 outptr1 += 2;
153 outptr2 += 2;
154 outptr3 += 2;
155 }
156 }
157 }
158 }
159 }
160