1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
requantize_leakyrelu_pack4_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & scale_in_data,const Mat & scale_out_data,const Mat & bias_data,float slope,const Option & opt)15 static void requantize_leakyrelu_pack4_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_in_data, const Mat& scale_out_data, const Mat& bias_data, float slope, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int h = bottom_blob.h;
19     int channels = bottom_blob.c;
20     int size = w * h;
21     int outc = top_blob.c;
22     int out_elempack = top_blob.elempack;
23 
24     int scale_in_data_size = scale_in_data.w;
25     int scale_out_data_size = scale_out_data.w;
26     int bias_data_size = bias_data.w;
27 
28     // int8(leakyrelu(v * scale_in, slope) * scale_out)
29     // int8_leakyrelu(v * (scale_in * scale_out), slope)
30 
31     // int8(leakyrelu(v * scale_in + bias, slope) * scale_out)
32     // int8_leakyrelu(v * (scale_in * scale_out) + (bias * scale_out), slope)
33 
34     if (out_elempack == 8)
35     {
36         if (bias_data_size == 0)
37         {
38             #pragma omp parallel for num_threads(opt.num_threads)
39             for (int q = 0; q < outc; q++)
40             {
41                 const int* intptr0 = bottom_blob.channel(q * 2);
42                 const int* intptr1 = bottom_blob.channel(q * 2 + 1);
43                 signed char* ptr = top_blob.channel(q);
44 
45                 float32x4_t _scale_in0 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8);
46                 float32x4_t _scale_in1 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8 + 4);
47                 float32x4_t _scale_out0 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8);
48                 float32x4_t _scale_out1 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8 + 4);
49 
50                 float32x4_t _scale0 = vmulq_f32(_scale_in0, _scale_out0);
51                 float32x4_t _scale1 = vmulq_f32(_scale_in1, _scale_out1);
52                 float32x4_t _slope = vdupq_n_f32(slope);
53 
54                 int i = 0;
55 #if __aarch64__
56                 for (; i + 3 < size; i += 4)
57                 {
58                     float32x4_t _v00 = vcvtq_f32_s32(vld1q_s32(intptr0));
59                     float32x4_t _v01 = vcvtq_f32_s32(vld1q_s32(intptr0 + 4));
60                     float32x4_t _v02 = vcvtq_f32_s32(vld1q_s32(intptr0 + 8));
61                     float32x4_t _v03 = vcvtq_f32_s32(vld1q_s32(intptr0 + 12));
62                     float32x4_t _v10 = vcvtq_f32_s32(vld1q_s32(intptr1));
63                     float32x4_t _v11 = vcvtq_f32_s32(vld1q_s32(intptr1 + 4));
64                     float32x4_t _v12 = vcvtq_f32_s32(vld1q_s32(intptr1 + 8));
65                     float32x4_t _v13 = vcvtq_f32_s32(vld1q_s32(intptr1 + 12));
66                     _v00 = vmulq_f32(_v00, _scale0);
67                     _v01 = vmulq_f32(_v01, _scale0);
68                     _v02 = vmulq_f32(_v02, _scale0);
69                     _v03 = vmulq_f32(_v03, _scale0);
70                     _v10 = vmulq_f32(_v10, _scale1);
71                     _v11 = vmulq_f32(_v11, _scale1);
72                     _v12 = vmulq_f32(_v12, _scale1);
73                     _v13 = vmulq_f32(_v13, _scale1);
74                     vst1_s8(ptr, float2int8leakyrelu(_v00, _v10, _slope));
75                     vst1_s8(ptr + 8, float2int8leakyrelu(_v01, _v11, _slope));
76                     vst1_s8(ptr + 16, float2int8leakyrelu(_v02, _v12, _slope));
77                     vst1_s8(ptr + 24, float2int8leakyrelu(_v03, _v13, _slope));
78 
79                     intptr0 += 16;
80                     intptr1 += 16;
81                     ptr += 32;
82                 }
83 #endif // __aarch64__
84                 for (; i < size; i++)
85                 {
86                     float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr0));
87                     float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr1));
88                     _v0 = vmulq_f32(_v0, _scale0);
89                     _v1 = vmulq_f32(_v1, _scale1);
90                     vst1_s8(ptr, float2int8leakyrelu(_v0, _v1, _slope));
91 
92                     intptr0 += 4;
93                     intptr1 += 4;
94                     ptr += 8;
95                 }
96             }
97         }
98         else
99         {
100             #pragma omp parallel for num_threads(opt.num_threads)
101             for (int q = 0; q < outc; q++)
102             {
103                 const int* intptr0 = bottom_blob.channel(q * 2);
104                 const int* intptr1 = bottom_blob.channel(q * 2 + 1);
105                 signed char* ptr = top_blob.channel(q);
106 
107                 float32x4_t _scale_in0 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8);
108                 float32x4_t _scale_in1 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8 + 4);
109                 float32x4_t _scale_out0 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8);
110                 float32x4_t _scale_out1 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8 + 4);
111                 float32x4_t _bias0 = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 8);
112                 float32x4_t _bias1 = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 8 + 4);
113 
114                 float32x4_t _scale0 = vmulq_f32(_scale_in0, _scale_out0);
115                 float32x4_t _scale1 = vmulq_f32(_scale_in1, _scale_out1);
116                 _bias0 = vmulq_f32(_bias0, _scale_out0);
117                 _bias1 = vmulq_f32(_bias1, _scale_out1);
118                 float32x4_t _slope = vdupq_n_f32(slope);
119 
120                 int i = 0;
121 #if __aarch64__
122                 for (; i + 3 < size; i += 4)
123                 {
124                     float32x4_t _v00 = vcvtq_f32_s32(vld1q_s32(intptr0));
125                     float32x4_t _v01 = vcvtq_f32_s32(vld1q_s32(intptr0 + 4));
126                     float32x4_t _v02 = vcvtq_f32_s32(vld1q_s32(intptr0 + 8));
127                     float32x4_t _v03 = vcvtq_f32_s32(vld1q_s32(intptr0 + 12));
128                     float32x4_t _v10 = vcvtq_f32_s32(vld1q_s32(intptr1));
129                     float32x4_t _v11 = vcvtq_f32_s32(vld1q_s32(intptr1 + 4));
130                     float32x4_t _v12 = vcvtq_f32_s32(vld1q_s32(intptr1 + 8));
131                     float32x4_t _v13 = vcvtq_f32_s32(vld1q_s32(intptr1 + 12));
132                     _v00 = vfmaq_f32(_bias0, _v00, _scale0);
133                     _v01 = vfmaq_f32(_bias0, _v01, _scale0);
134                     _v02 = vfmaq_f32(_bias0, _v02, _scale0);
135                     _v03 = vfmaq_f32(_bias0, _v03, _scale0);
136                     _v10 = vfmaq_f32(_bias1, _v10, _scale1);
137                     _v11 = vfmaq_f32(_bias1, _v11, _scale1);
138                     _v12 = vfmaq_f32(_bias1, _v12, _scale1);
139                     _v13 = vfmaq_f32(_bias1, _v13, _scale1);
140                     vst1_s8(ptr, float2int8leakyrelu(_v00, _v10, _slope));
141                     vst1_s8(ptr + 8, float2int8leakyrelu(_v01, _v11, _slope));
142                     vst1_s8(ptr + 16, float2int8leakyrelu(_v02, _v12, _slope));
143                     vst1_s8(ptr + 24, float2int8leakyrelu(_v03, _v13, _slope));
144 
145                     intptr0 += 16;
146                     intptr1 += 16;
147                     ptr += 32;
148                 }
149 #endif // __aarch64__
150                 for (; i + 1 < size; i += 2)
151                 {
152                     float32x4_t _v00 = vcvtq_f32_s32(vld1q_s32(intptr0));
153                     float32x4_t _v01 = vcvtq_f32_s32(vld1q_s32(intptr0 + 4));
154                     float32x4_t _v10 = vcvtq_f32_s32(vld1q_s32(intptr1));
155                     float32x4_t _v11 = vcvtq_f32_s32(vld1q_s32(intptr1 + 4));
156 #if __aarch64__
157                     _v00 = vfmaq_f32(_bias0, _v00, _scale0);
158                     _v01 = vfmaq_f32(_bias0, _v01, _scale0);
159                     _v10 = vfmaq_f32(_bias1, _v10, _scale1);
160                     _v11 = vfmaq_f32(_bias1, _v11, _scale1);
161 #else  // __aarch64__
162                     _v00 = vmlaq_f32(_bias0, _v00, _scale0);
163                     _v01 = vmlaq_f32(_bias0, _v01, _scale0);
164                     _v10 = vmlaq_f32(_bias1, _v10, _scale1);
165                     _v11 = vmlaq_f32(_bias1, _v11, _scale1);
166 #endif // __aarch64__
167                     vst1_s8(ptr, float2int8leakyrelu(_v00, _v10, _slope));
168                     vst1_s8(ptr + 8, float2int8leakyrelu(_v01, _v11, _slope));
169 
170                     intptr0 += 8;
171                     intptr1 += 8;
172                     ptr += 16;
173                 }
174                 for (; i < size; i++)
175                 {
176                     float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr0));
177                     float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr1));
178 #if __aarch64__
179                     _v0 = vfmaq_f32(_bias0, _v0, _scale0);
180                     _v1 = vfmaq_f32(_bias1, _v1, _scale1);
181 #else  // __aarch64__
182                     _v0 = vmlaq_f32(_bias0, _v0, _scale0);
183                     _v1 = vmlaq_f32(_bias1, _v1, _scale1);
184 #endif // __aarch64__
185                     vst1_s8(ptr, float2int8leakyrelu(_v0, _v1, _slope));
186 
187                     intptr0 += 4;
188                     intptr1 += 4;
189                     ptr += 8;
190                 }
191             }
192         }
193     }
194     if (out_elempack == 1)
195     {
196         if (bias_data_size == 0)
197         {
198             #pragma omp parallel for num_threads(opt.num_threads)
199             for (int q = 0; q < channels; q++)
200             {
201                 const int* intptr = bottom_blob.channel(q);
202                 signed char* ptr0 = top_blob.channel(q * 4);
203                 signed char* ptr1 = top_blob.channel(q * 4 + 1);
204                 signed char* ptr2 = top_blob.channel(q * 4 + 2);
205                 signed char* ptr3 = top_blob.channel(q * 4 + 3);
206 
207                 float32x4_t _scale_in = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 4);
208                 float32x4_t _scale_out = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 4);
209 
210                 float32x4_t _scale = vmulq_f32(_scale_in, _scale_out);
211                 float32x4_t _slope = vdupq_n_f32(slope);
212 
213                 int i = 0;
214                 for (; i < size; i++)
215                 {
216                     float32x4_t _v = vcvtq_f32_s32(vld1q_s32(intptr));
217                     _v = vmulq_f32(_v, _scale);
218                     int8x8_t v = float2int8leakyrelu(_v, _v, _slope);
219                     ptr0[0] = vget_lane_s8(v, 0);
220                     ptr1[0] = vget_lane_s8(v, 1);
221                     ptr2[0] = vget_lane_s8(v, 2);
222                     ptr3[0] = vget_lane_s8(v, 3);
223 
224                     intptr += 4;
225                     ptr0 += 1;
226                     ptr1 += 1;
227                     ptr2 += 1;
228                     ptr3 += 1;
229                 }
230             }
231         }
232         else
233         {
234             #pragma omp parallel for num_threads(opt.num_threads)
235             for (int q = 0; q < channels; q++)
236             {
237                 const int* intptr = bottom_blob.channel(q);
238                 signed char* ptr0 = top_blob.channel(q * 4);
239                 signed char* ptr1 = top_blob.channel(q * 4 + 1);
240                 signed char* ptr2 = top_blob.channel(q * 4 + 2);
241                 signed char* ptr3 = top_blob.channel(q * 4 + 3);
242 
243                 float32x4_t _scale_in = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 4);
244                 float32x4_t _scale_out = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 4);
245                 float32x4_t _bias = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 4);
246 
247                 float32x4_t _scale = vmulq_f32(_scale_in, _scale_out);
248                 _bias = vmulq_f32(_bias, _scale_out);
249                 float32x4_t _slope = vdupq_n_f32(slope);
250 
251                 int i = 0;
252                 for (; i < size; i++)
253                 {
254                     float32x4_t _v = vcvtq_f32_s32(vld1q_s32(intptr));
255 #if __aarch64__
256                     _v = vfmaq_f32(_bias, _v, _scale);
257 #else
258                     _v = vmlaq_f32(_bias, _v, _scale);
259 #endif
260                     int8x8_t v = float2int8leakyrelu(_v, _v, _slope);
261                     ptr0[0] = vget_lane_s8(v, 0);
262                     ptr1[0] = vget_lane_s8(v, 1);
263                     ptr2[0] = vget_lane_s8(v, 2);
264                     ptr3[0] = vget_lane_s8(v, 3);
265 
266                     intptr += 4;
267                     ptr0 += 1;
268                     ptr1 += 1;
269                     ptr2 += 1;
270                     ptr3 += 1;
271                 }
272             }
273         }
274     }
275 }
276