1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
requantize_relu_pack4_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & scale_in_data,const Mat & scale_out_data,const Mat & bias_data,const Option & opt)15 static void requantize_relu_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_in_data, const Mat& scale_out_data, const Mat& bias_data, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int h = bottom_blob.h;
19     int channels = bottom_blob.c;
20     int size = w * h;
21     int outc = top_blob.c;
22     int out_elempack = top_blob.elempack;
23 
24     int scale_in_data_size = scale_in_data.w;
25     int scale_out_data_size = scale_out_data.w;
26     int bias_data_size = bias_data.w;
27 
28     // int8(relu(v * scale_in) * scale_out)
29     // int8_relu(v * (scale_in * scale_out))
30 
31     // int8(relu(v * scale_in + bias) * scale_out)
32     // int8_relu(v * (scale_in * scale_out) + (bias * scale_out))
33 
34     if (out_elempack == 8)
35     {
36         if (bias_data_size == 0)
37         {
38             #pragma omp parallel for num_threads(opt.num_threads)
39             for (int q = 0; q < outc; q++)
40             {
41                 const int* intptr0 = bottom_blob.channel(q * 2);
42                 const int* intptr1 = bottom_blob.channel(q * 2 + 1);
43                 signed char* ptr = top_blob.channel(q);
44 
45                 float32x4_t _scale_in0 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8);
46                 float32x4_t _scale_in1 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8 + 4);
47                 float32x4_t _scale_out0 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8);
48                 float32x4_t _scale_out1 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8 + 4);
49 
50                 float32x4_t _scale0 = vmulq_f32(_scale_in0, _scale_out0);
51                 float32x4_t _scale1 = vmulq_f32(_scale_in1, _scale_out1);
52 
53                 int i = 0;
54 #if __aarch64__
55                 for (; i + 3 < size; i += 4)
56                 {
57                     float32x4_t _v00 = vcvtq_f32_s32(vld1q_s32(intptr0));
58                     float32x4_t _v01 = vcvtq_f32_s32(vld1q_s32(intptr0 + 4));
59                     float32x4_t _v02 = vcvtq_f32_s32(vld1q_s32(intptr0 + 8));
60                     float32x4_t _v03 = vcvtq_f32_s32(vld1q_s32(intptr0 + 12));
61                     float32x4_t _v10 = vcvtq_f32_s32(vld1q_s32(intptr1));
62                     float32x4_t _v11 = vcvtq_f32_s32(vld1q_s32(intptr1 + 4));
63                     float32x4_t _v12 = vcvtq_f32_s32(vld1q_s32(intptr1 + 8));
64                     float32x4_t _v13 = vcvtq_f32_s32(vld1q_s32(intptr1 + 12));
65                     _v00 = vmulq_f32(_v00, _scale0);
66                     _v01 = vmulq_f32(_v01, _scale0);
67                     _v02 = vmulq_f32(_v02, _scale0);
68                     _v03 = vmulq_f32(_v03, _scale0);
69                     _v10 = vmulq_f32(_v10, _scale1);
70                     _v11 = vmulq_f32(_v11, _scale1);
71                     _v12 = vmulq_f32(_v12, _scale1);
72                     _v13 = vmulq_f32(_v13, _scale1);
73                     vst1_s8(ptr, float2int8relu(_v00, _v10));
74                     vst1_s8(ptr + 8, float2int8relu(_v01, _v11));
75                     vst1_s8(ptr + 16, float2int8relu(_v02, _v12));
76                     vst1_s8(ptr + 24, float2int8relu(_v03, _v13));
77 
78                     intptr0 += 16;
79                     intptr1 += 16;
80                     ptr += 32;
81                 }
82 #endif // __aarch64__
83                 for (; i < size; i++)
84                 {
85                     float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr0));
86                     float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr1));
87                     _v0 = vmulq_f32(_v0, _scale0);
88                     _v1 = vmulq_f32(_v1, _scale1);
89                     vst1_s8(ptr, float2int8relu(_v0, _v1));
90 
91                     intptr0 += 4;
92                     intptr1 += 4;
93                     ptr += 8;
94                 }
95             }
96         }
97         else
98         {
99             #pragma omp parallel for num_threads(opt.num_threads)
100             for (int q = 0; q < outc; q++)
101             {
102                 const int* intptr0 = bottom_blob.channel(q * 2);
103                 const int* intptr1 = bottom_blob.channel(q * 2 + 1);
104                 signed char* ptr = top_blob.channel(q);
105 
106                 float32x4_t _scale_in0 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8);
107                 float32x4_t _scale_in1 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8 + 4);
108                 float32x4_t _scale_out0 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8);
109                 float32x4_t _scale_out1 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8 + 4);
110                 float32x4_t _bias0 = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 8);
111                 float32x4_t _bias1 = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 8 + 4);
112 
113                 float32x4_t _scale0 = vmulq_f32(_scale_in0, _scale_out0);
114                 float32x4_t _scale1 = vmulq_f32(_scale_in1, _scale_out1);
115                 _bias0 = vmulq_f32(_bias0, _scale_out0);
116                 _bias1 = vmulq_f32(_bias1, _scale_out1);
117 
118                 int i = 0;
119 #if __aarch64__
120                 for (; i + 3 < size; i += 4)
121                 {
122                     float32x4_t _v00 = vcvtq_f32_s32(vld1q_s32(intptr0));
123                     float32x4_t _v01 = vcvtq_f32_s32(vld1q_s32(intptr0 + 4));
124                     float32x4_t _v02 = vcvtq_f32_s32(vld1q_s32(intptr0 + 8));
125                     float32x4_t _v03 = vcvtq_f32_s32(vld1q_s32(intptr0 + 12));
126                     float32x4_t _v10 = vcvtq_f32_s32(vld1q_s32(intptr1));
127                     float32x4_t _v11 = vcvtq_f32_s32(vld1q_s32(intptr1 + 4));
128                     float32x4_t _v12 = vcvtq_f32_s32(vld1q_s32(intptr1 + 8));
129                     float32x4_t _v13 = vcvtq_f32_s32(vld1q_s32(intptr1 + 12));
130                     _v00 = vfmaq_f32(_bias0, _v00, _scale0);
131                     _v01 = vfmaq_f32(_bias0, _v01, _scale0);
132                     _v02 = vfmaq_f32(_bias0, _v02, _scale0);
133                     _v03 = vfmaq_f32(_bias0, _v03, _scale0);
134                     _v10 = vfmaq_f32(_bias1, _v10, _scale1);
135                     _v11 = vfmaq_f32(_bias1, _v11, _scale1);
136                     _v12 = vfmaq_f32(_bias1, _v12, _scale1);
137                     _v13 = vfmaq_f32(_bias1, _v13, _scale1);
138                     vst1_s8(ptr, float2int8relu(_v00, _v10));
139                     vst1_s8(ptr + 8, float2int8relu(_v01, _v11));
140                     vst1_s8(ptr + 16, float2int8relu(_v02, _v12));
141                     vst1_s8(ptr + 24, float2int8relu(_v03, _v13));
142 
143                     intptr0 += 16;
144                     intptr1 += 16;
145                     ptr += 32;
146                 }
147 #endif // __aarch64__
148                 for (; i + 1 < size; i += 2)
149                 {
150                     float32x4_t _v00 = vcvtq_f32_s32(vld1q_s32(intptr0));
151                     float32x4_t _v01 = vcvtq_f32_s32(vld1q_s32(intptr0 + 4));
152                     float32x4_t _v10 = vcvtq_f32_s32(vld1q_s32(intptr1));
153                     float32x4_t _v11 = vcvtq_f32_s32(vld1q_s32(intptr1 + 4));
154 #if __aarch64__
155                     _v00 = vfmaq_f32(_bias0, _v00, _scale0);
156                     _v01 = vfmaq_f32(_bias0, _v01, _scale0);
157                     _v10 = vfmaq_f32(_bias1, _v10, _scale1);
158                     _v11 = vfmaq_f32(_bias1, _v11, _scale1);
159 #else  // __aarch64__
160                     _v00 = vmlaq_f32(_bias0, _v00, _scale0);
161                     _v01 = vmlaq_f32(_bias0, _v01, _scale0);
162                     _v10 = vmlaq_f32(_bias1, _v10, _scale1);
163                     _v11 = vmlaq_f32(_bias1, _v11, _scale1);
164 #endif // __aarch64__
165                     vst1_s8(ptr, float2int8relu(_v00, _v10));
166                     vst1_s8(ptr + 8, float2int8relu(_v01, _v11));
167 
168                     intptr0 += 8;
169                     intptr1 += 8;
170                     ptr += 16;
171                 }
172                 for (; i < size; i++)
173                 {
174                     float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr0));
175                     float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr1));
176 #if __aarch64__
177                     _v0 = vfmaq_f32(_bias0, _v0, _scale0);
178                     _v1 = vfmaq_f32(_bias1, _v1, _scale1);
179 #else  // __aarch64__
180                     _v0 = vmlaq_f32(_bias0, _v0, _scale0);
181                     _v1 = vmlaq_f32(_bias1, _v1, _scale1);
182 #endif // __aarch64__
183                     vst1_s8(ptr, float2int8relu(_v0, _v1));
184 
185                     intptr0 += 4;
186                     intptr1 += 4;
187                     ptr += 8;
188                 }
189             }
190         }
191     }
192     if (out_elempack == 1)
193     {
194         if (bias_data_size == 0)
195         {
196             #pragma omp parallel for num_threads(opt.num_threads)
197             for (int q = 0; q < channels; q++)
198             {
199                 const int* intptr = bottom_blob.channel(q);
200                 signed char* ptr0 = top_blob.channel(q * 4);
201                 signed char* ptr1 = top_blob.channel(q * 4 + 1);
202                 signed char* ptr2 = top_blob.channel(q * 4 + 2);
203                 signed char* ptr3 = top_blob.channel(q * 4 + 3);
204 
205                 float32x4_t _scale_in = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 4);
206                 float32x4_t _scale_out = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 4);
207 
208                 float32x4_t _scale = vmulq_f32(_scale_in, _scale_out);
209 
210                 int i = 0;
211                 for (; i < size; i++)
212                 {
213                     float32x4_t _v = vcvtq_f32_s32(vld1q_s32(intptr));
214                     _v = vmulq_f32(_v, _scale);
215                     int8x8_t v = float2int8relu(_v, _v);
216                     ptr0[0] = vget_lane_s8(v, 0);
217                     ptr1[0] = vget_lane_s8(v, 1);
218                     ptr2[0] = vget_lane_s8(v, 2);
219                     ptr3[0] = vget_lane_s8(v, 3);
220 
221                     intptr += 4;
222                     ptr0 += 1;
223                     ptr1 += 1;
224                     ptr2 += 1;
225                     ptr3 += 1;
226                 }
227             }
228         }
229         else
230         {
231             #pragma omp parallel for num_threads(opt.num_threads)
232             for (int q = 0; q < channels; q++)
233             {
234                 const int* intptr = bottom_blob.channel(q);
235                 signed char* ptr0 = top_blob.channel(q * 4);
236                 signed char* ptr1 = top_blob.channel(q * 4 + 1);
237                 signed char* ptr2 = top_blob.channel(q * 4 + 2);
238                 signed char* ptr3 = top_blob.channel(q * 4 + 3);
239 
240                 float32x4_t _scale_in = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 4);
241                 float32x4_t _scale_out = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 4);
242                 float32x4_t _bias = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 4);
243 
244                 float32x4_t _scale = vmulq_f32(_scale_in, _scale_out);
245                 _bias = vmulq_f32(_bias, _scale_out);
246 
247                 int i = 0;
248                 for (; i < size; i++)
249                 {
250                     float32x4_t _v = vcvtq_f32_s32(vld1q_s32(intptr));
251 #if __aarch64__
252                     _v = vfmaq_f32(_bias, _v, _scale);
253 #else
254                     _v = vmlaq_f32(_bias, _v, _scale);
255 #endif
256                     int8x8_t v = float2int8relu(_v, _v);
257                     ptr0[0] = vget_lane_s8(v, 0);
258                     ptr1[0] = vget_lane_s8(v, 1);
259                     ptr2[0] = vget_lane_s8(v, 2);
260                     ptr3[0] = vget_lane_s8(v, 3);
261 
262                     intptr += 4;
263                     ptr0 += 1;
264                     ptr1 += 1;
265                     ptr2 += 1;
266                     ptr3 += 1;
267                 }
268             }
269         }
270     }
271 }
272