1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
requantize_relu_pack8_neon(const Mat & bottom_blob,Mat & top_blob,const Mat & scale_in_data,const Mat & scale_out_data,const Mat & bias_data,const Option & opt)15 static void requantize_relu_pack8_neon(const Mat& bottom_blob, Mat& top_blob, const Mat& scale_in_data, const Mat& scale_out_data, const Mat& bias_data, const Option& opt)
16 {
17     int w = bottom_blob.w;
18     int h = bottom_blob.h;
19     int channels = bottom_blob.c;
20     int size = w * h;
21 
22     int scale_in_data_size = scale_in_data.w;
23     int scale_out_data_size = scale_out_data.w;
24     int bias_data_size = bias_data.w;
25 
26     // int8(relu(v * scale_in) * scale_out)
27     // int8_relu(v * (scale_in * scale_out))
28 
29     // int8(relu(v * scale_in + bias) * scale_out)
30     // int8_relu(v * (scale_in * scale_out) + (bias * scale_out))
31 
32     if (bias_data_size == 0)
33     {
34         #pragma omp parallel for num_threads(opt.num_threads)
35         for (int q = 0; q < channels; q++)
36         {
37             const int* intptr = bottom_blob.channel(q);
38             signed char* ptr = top_blob.channel(q);
39 
40             float32x4_t _scale_in0 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8);
41             float32x4_t _scale_in1 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8 + 4);
42             float32x4_t _scale_out0 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8);
43             float32x4_t _scale_out1 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8 + 4);
44 
45             float32x4_t _scale0 = vmulq_f32(_scale_in0, _scale_out0);
46             float32x4_t _scale1 = vmulq_f32(_scale_in1, _scale_out1);
47 
48             int i = 0;
49 #if __aarch64__
50             for (; i + 3 < size; i += 4)
51             {
52                 float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr));
53                 float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr + 4));
54                 float32x4_t _v2 = vcvtq_f32_s32(vld1q_s32(intptr + 8));
55                 float32x4_t _v3 = vcvtq_f32_s32(vld1q_s32(intptr + 12));
56                 float32x4_t _v4 = vcvtq_f32_s32(vld1q_s32(intptr + 16));
57                 float32x4_t _v5 = vcvtq_f32_s32(vld1q_s32(intptr + 20));
58                 float32x4_t _v6 = vcvtq_f32_s32(vld1q_s32(intptr + 24));
59                 float32x4_t _v7 = vcvtq_f32_s32(vld1q_s32(intptr + 28));
60                 _v0 = vmulq_f32(_v0, _scale0);
61                 _v1 = vmulq_f32(_v1, _scale1);
62                 _v2 = vmulq_f32(_v2, _scale0);
63                 _v3 = vmulq_f32(_v3, _scale1);
64                 _v4 = vmulq_f32(_v4, _scale0);
65                 _v5 = vmulq_f32(_v5, _scale1);
66                 _v6 = vmulq_f32(_v6, _scale0);
67                 _v7 = vmulq_f32(_v7, _scale1);
68                 vst1_s8(ptr, float2int8relu(_v0, _v1));
69                 vst1_s8(ptr + 8, float2int8relu(_v2, _v3));
70                 vst1_s8(ptr + 16, float2int8relu(_v4, _v5));
71                 vst1_s8(ptr + 24, float2int8relu(_v6, _v7));
72 
73                 intptr += 32;
74                 ptr += 32;
75             }
76 #endif // __aarch64__
77             for (; i + 1 < size; i += 2)
78             {
79                 float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr));
80                 float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr + 4));
81                 float32x4_t _v2 = vcvtq_f32_s32(vld1q_s32(intptr + 8));
82                 float32x4_t _v3 = vcvtq_f32_s32(vld1q_s32(intptr + 12));
83                 _v0 = vmulq_f32(_v0, _scale0);
84                 _v1 = vmulq_f32(_v1, _scale1);
85                 _v2 = vmulq_f32(_v2, _scale0);
86                 _v3 = vmulq_f32(_v3, _scale1);
87                 vst1_s8(ptr, float2int8relu(_v0, _v1));
88                 vst1_s8(ptr + 8, float2int8relu(_v2, _v3));
89 
90                 intptr += 16;
91                 ptr += 16;
92             }
93             for (; i < size; i++)
94             {
95                 float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr));
96                 float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr + 4));
97                 _v0 = vmulq_f32(_v0, _scale0);
98                 _v1 = vmulq_f32(_v1, _scale1);
99                 vst1_s8(ptr, float2int8relu(_v0, _v1));
100 
101                 intptr += 8;
102                 ptr += 8;
103             }
104         }
105     }
106     else
107     {
108         #pragma omp parallel for num_threads(opt.num_threads)
109         for (int q = 0; q < channels; q++)
110         {
111             const int* intptr = bottom_blob.channel(q);
112             signed char* ptr = top_blob.channel(q);
113 
114             float32x4_t _scale_in0 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8);
115             float32x4_t _scale_in1 = scale_in_data_size == 1 ? vdupq_n_f32(scale_in_data[0]) : vld1q_f32((const float*)scale_in_data + q * 8 + 4);
116             float32x4_t _scale_out0 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8);
117             float32x4_t _scale_out1 = scale_out_data_size == 1 ? vdupq_n_f32(scale_out_data[0]) : vld1q_f32((const float*)scale_out_data + q * 8 + 4);
118             float32x4_t _bias0 = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 8);
119             float32x4_t _bias1 = bias_data_size == 1 ? vdupq_n_f32(bias_data[0]) : vld1q_f32((const float*)bias_data + q * 8 + 4);
120 
121             float32x4_t _scale0 = vmulq_f32(_scale_in0, _scale_out0);
122             float32x4_t _scale1 = vmulq_f32(_scale_in1, _scale_out1);
123             _bias0 = vmulq_f32(_bias0, _scale_out0);
124             _bias1 = vmulq_f32(_bias1, _scale_out1);
125 
126             int i = 0;
127 #if __aarch64__
128             for (; i + 3 < size; i += 4)
129             {
130                 float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr));
131                 float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr + 4));
132                 float32x4_t _v2 = vcvtq_f32_s32(vld1q_s32(intptr + 8));
133                 float32x4_t _v3 = vcvtq_f32_s32(vld1q_s32(intptr + 12));
134                 float32x4_t _v4 = vcvtq_f32_s32(vld1q_s32(intptr + 16));
135                 float32x4_t _v5 = vcvtq_f32_s32(vld1q_s32(intptr + 20));
136                 float32x4_t _v6 = vcvtq_f32_s32(vld1q_s32(intptr + 24));
137                 float32x4_t _v7 = vcvtq_f32_s32(vld1q_s32(intptr + 28));
138 
139                 _v0 = vfmaq_f32(_bias0, _v0, _scale0);
140                 _v1 = vfmaq_f32(_bias1, _v1, _scale1);
141                 _v2 = vfmaq_f32(_bias0, _v2, _scale0);
142                 _v3 = vfmaq_f32(_bias1, _v3, _scale1);
143                 _v4 = vfmaq_f32(_bias0, _v4, _scale0);
144                 _v5 = vfmaq_f32(_bias1, _v5, _scale1);
145                 _v6 = vfmaq_f32(_bias0, _v6, _scale0);
146                 _v7 = vfmaq_f32(_bias1, _v7, _scale1);
147 
148                 vst1_s8(ptr, float2int8relu(_v0, _v1));
149                 vst1_s8(ptr + 8, float2int8relu(_v2, _v3));
150                 vst1_s8(ptr + 16, float2int8relu(_v4, _v5));
151                 vst1_s8(ptr + 24, float2int8relu(_v6, _v7));
152 
153                 intptr += 32;
154                 ptr += 32;
155             }
156 #endif // __aarch64__
157             for (; i + 1 < size; i += 2)
158             {
159                 float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr));
160                 float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr + 4));
161                 float32x4_t _v2 = vcvtq_f32_s32(vld1q_s32(intptr + 8));
162                 float32x4_t _v3 = vcvtq_f32_s32(vld1q_s32(intptr + 12));
163 
164 #if __aarch64__
165                 _v0 = vfmaq_f32(_bias0, _v0, _scale0);
166                 _v1 = vfmaq_f32(_bias1, _v1, _scale1);
167                 _v2 = vfmaq_f32(_bias0, _v2, _scale0);
168                 _v3 = vfmaq_f32(_bias1, _v3, _scale1);
169 #else  // __aarch64__
170                 _v0 = vmlaq_f32(_bias0, _v0, _scale0);
171                 _v1 = vmlaq_f32(_bias1, _v1, _scale1);
172                 _v2 = vmlaq_f32(_bias0, _v2, _scale0);
173                 _v3 = vmlaq_f32(_bias1, _v3, _scale1);
174 #endif // __aarch64__
175 
176                 vst1_s8(ptr, float2int8relu(_v0, _v1));
177                 vst1_s8(ptr + 8, float2int8relu(_v2, _v3));
178 
179                 intptr += 16;
180                 ptr += 16;
181             }
182             for (; i < size; i++)
183             {
184                 float32x4_t _v0 = vcvtq_f32_s32(vld1q_s32(intptr));
185                 float32x4_t _v1 = vcvtq_f32_s32(vld1q_s32(intptr + 4));
186 #if __aarch64__
187                 _v0 = vfmaq_f32(_bias0, _v0, _scale0);
188                 _v1 = vfmaq_f32(_bias1, _v1, _scale1);
189 #else  // __aarch64__
190                 _v0 = vmlaq_f32(_bias0, _v0, _scale0);
191                 _v1 = vmlaq_f32(_bias1, _v1, _scale1);
192 #endif // __aarch64__
193                 vst1_s8(ptr, float2int8relu(_v0, _v1));
194 
195                 intptr += 8;
196                 ptr += 8;
197             }
198         }
199     }
200 }
201