1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "backend/cpu/compute/OptimizedComputer.hpp"
17 #include <string.h>
18 #include "core/Macro.h"
19 #ifdef MNN_USE_NEON
20 #include <arm_neon.h>
21 #endif
22
23 namespace MNN {
24 namespace Optimized {
25
26 // avgpooling
AveragePool(const uint8_t * input_data,const std::vector<int> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int mOutputActivationMin,int mOutputActivationMax,uint8_t * output_data,const std::vector<int> & output_dims)27 void AveragePool(const uint8_t* input_data, const std::vector<int>& input_dims, int stride_width, int stride_height,
28 int pad_width, int pad_height, int filter_width, int filter_height, int mOutputActivationMin,
29 int mOutputActivationMax, uint8_t* output_data, const std::vector<int>& output_dims) {
30 MNN_ASSERT(mOutputActivationMin < mOutputActivationMax);
31 MNN_ASSERT(input_dims.at(0) == output_dims.at(0));
32 MNN_ASSERT(input_dims.at(3) == output_dims.at(3));
33 const int inputBatches = input_dims.at(0);
34 const int inputChannels = input_dims.at(3);
35 const int inputHeight = input_dims.at(1);
36 const int inputWidth = input_dims.at(2);
37 const int outputHeight = output_dims.at(1);
38 const int outputWidth = output_dims.at(2);
39
40 #define UNIT 4
41 const int inputChannelUnits = UP_DIV(inputChannels, UNIT);
42 const int inputChannelRound = ROUND_UP(inputChannels, UNIT);
43
44 for (int batch = 0; batch < inputBatches; ++batch) {
45 for (int out_y = 0; out_y < outputHeight; ++out_y) {
46 for (int out_x = 0; out_x < outputWidth; ++out_x) {
47 const int in_x_origin = (out_x * stride_width) - pad_width;
48 const int in_y_origin = (out_y * stride_height) - pad_height;
49 const int filter_x_start = std::max(0, -in_x_origin);
50 const int filter_x_end = std::min(filter_width, inputWidth - in_x_origin);
51 const int filter_y_start = std::max(0, -in_y_origin);
52 const int filter_y_end = std::min(filter_height, inputHeight - in_y_origin);
53 const int filter_count = (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
54 uint8_t* output_ptr = output_data + batch * outputHeight * outputWidth * inputChannelRound +
55 out_y * outputWidth * UNIT + out_x * UNIT;
56 #ifdef MNN_USE_NEON
57 uint16_t result_sub = filter_count / 2;
58 uint16x4_t min_vec = vdup_n_u16(mOutputActivationMin);
59 uint16x4_t max_vec = vdup_n_u16(mOutputActivationMax);
60 uint16x8_t acc_reg;
61 uint16_t acc[UNIT * 2];
62 const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
63 in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
64
65 for (int channel = 0; channel < inputChannelUnits; channel++) {
66 memset(acc, 0, UNIT * 2 * sizeof(acc[0]));
67 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
68 int fx = filter_x_start;
69 acc_reg = vld1q_u16(acc);
70 for (; fx < filter_x_end - 2; fx += 2) {
71 const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
72 fy * inputWidth * UNIT + fx * UNIT;
73 uint8x8_t input_reg = vld1_u8(input_cur_ptr);
74 acc_reg = vaddw_u8(acc_reg, input_reg);
75 }
76 vst1_u16(acc, vadd_u16(vget_high_u16(acc_reg), vget_low_u16(acc_reg)));
77 for (; fx < filter_x_end; fx++) {
78 const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
79 fy * inputWidth * UNIT + fx * UNIT;
80 for (int c = 0; c < UNIT; c++) {
81 acc[c] += input_cur_ptr[c];
82 }
83 }
84 }
85 uint8_t* output_cur_ptr = output_ptr + channel * outputHeight * outputWidth * UNIT;
86 uint16x4_t a = vdup_n_u16(0);
87 for (int c = 0; c < UNIT; c++) {
88 a[c] = (acc[c] + result_sub) / filter_count;
89 }
90 a = vmin_u16(a, max_vec);
91 a = vmax_u16(a, min_vec);
92 output_cur_ptr[0] = static_cast<uint8_t>(a[0]);
93 output_cur_ptr[1] = static_cast<uint8_t>(a[1]);
94 output_cur_ptr[2] = static_cast<uint8_t>(a[2]);
95 output_cur_ptr[3] = static_cast<uint8_t>(a[3]);
96 }
97 #else
98 uint16_t acc[UNIT];
99 const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
100 in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
101
102 for (int channel = 0; channel < inputChannelUnits; channel++) {
103 memset(acc, 0, UNIT * sizeof(acc[0]));
104 for (int fy = filter_y_start; fy < filter_y_end; fy++) {
105 for (int fx = filter_x_start; fx < filter_x_end; fx++) {
106 const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
107 fy * inputWidth * UNIT + fx * UNIT;
108 for (int c = 0; c < UNIT; c++) {
109 acc[c] += input_cur_ptr[c];
110 }
111 }
112 }
113 for (int c = 0; c < UNIT; c++) {
114 uint16_t a = (acc[c] + filter_count / 2) / filter_count;
115 a = std::max<uint16_t>(a, mOutputActivationMin);
116 a = std::min<uint16_t>(a, mOutputActivationMax);
117 output_ptr[channel * outputHeight * outputWidth * UNIT + c] = static_cast<uint8_t>(a);
118 }
119 }
120 #endif
121 }
122 }
123 }
124 }
125
Logistic(const uint8_t * input_data,const std::vector<int> & input_dims,int32_t inputZeroPoint,int32_t input_range_radius,int32_t input_multiplier,int input_left_shift,uint8_t * output_data,const std::vector<int> & output_dims)126 void Logistic(const uint8_t* input_data, const std::vector<int>& input_dims, int32_t inputZeroPoint,
127 int32_t input_range_radius, int32_t input_multiplier, int input_left_shift, uint8_t* output_data,
128 const std::vector<int>& output_dims) {
129 int size = 1;
130 for (int i = 0; i < input_dims.size(); i++) {
131 size *= input_dims.at(i);
132 }
133
134 int c = 0;
135
136 #ifdef MNN_USE_NEON
137 // Handle 16 values at a time
138 for (; c <= size - 16; c += 16) {
139 // Read input uint8 values, cast to int16 and subtract inputZeroPoint
140 uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
141 int16x8_t input_val_centered_0 =
142 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
143 int16x8_t input_val_centered_1 =
144 vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
145
146 // Prepare the bit masks that we will use at the end to implement the logic
147 // that was expressed in the scalar code with branching:
148 // if (input_val_centered < -input_range_radius) {
149 // output_val = 0;
150 // } else if (input_val_centered > input_range_radius) {
151 // output_val = 255;
152 // } else {
153 // ...
154 uint16x8_t mask_rightclamp_0 = vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
155 uint16x8_t mask_rightclamp_1 = vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
156 uint16x8_t mask_leftclamp_0 = vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
157 uint16x8_t mask_leftclamp_1 = vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
158 uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), vshrn_n_u16(mask_rightclamp_1, 8));
159 uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), vshrn_n_u16(mask_leftclamp_1, 8));
160
161 // This performs what is expressed in the scalar code as
162 // const int32 input_val_rescaled =
163 // MultiplyByQuantizedMultiplierGreaterThanOne(
164 // input_val_centered, input_multiplier, input_left_shift);
165 int32x4_t input_val_rescaled_0 =
166 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
167 int32x4_t input_val_rescaled_1 =
168 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
169 int32x4_t input_val_rescaled_2 =
170 vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
171 int32x4_t input_val_rescaled_3 =
172 vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
173 input_val_rescaled_0 = vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
174 input_val_rescaled_1 = vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
175 input_val_rescaled_2 = vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
176 input_val_rescaled_3 = vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
177
178 // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
179 using FixedPoint4 = FixedPoint<int32x4_t, 4>;
180 using FixedPoint0 = FixedPoint<int32x4_t, 0>;
181 const FixedPoint4 input_val_f4_0 = FixedPoint4::FromRaw(input_val_rescaled_0);
182 const FixedPoint4 input_val_f4_1 = FixedPoint4::FromRaw(input_val_rescaled_1);
183 const FixedPoint4 input_val_f4_2 = FixedPoint4::FromRaw(input_val_rescaled_2);
184 const FixedPoint4 input_val_f4_3 = FixedPoint4::FromRaw(input_val_rescaled_3);
185 const FixedPoint0 output_val_f0_0 = logistic(input_val_f4_0);
186 const FixedPoint0 output_val_f0_1 = logistic(input_val_f4_1);
187 const FixedPoint0 output_val_f0_2 = logistic(input_val_f4_2);
188 const FixedPoint0 output_val_f0_3 = logistic(input_val_f4_3);
189
190 // Divide by 2^23 as in the scalar code
191 int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
192 int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
193 int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
194 int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
195
196 // Cast output values to uint8, saturating
197 int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), vqmovn_s32(output_val_s32_1));
198 int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), vqmovn_s32(output_val_s32_3));
199 uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), vqmovun_s16(output_val_s16_1));
200
201 // Perform the bit-masking with the bit masks computed at the beginning,
202 // see the comment there.
203 output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
204 output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
205
206 // Store back to memory
207 vst1q_u8(output_data + c, output_val_u8);
208 }
209 #endif
210 // Leftover loop: handle one value at a time with scalar code.
211 for (; c < size; ++c) {
212 const uint8_t input_val_u8 = input_data[c];
213 const int32_t input_val_centered = static_cast<int32_t>(input_val_u8) - inputZeroPoint;
214 uint8_t output_val;
215 if (input_val_centered < -input_range_radius) {
216 output_val = 0;
217 } else if (input_val_centered > input_range_radius) {
218 output_val = 255;
219 } else {
220 const int32_t input_val_rescaled =
221 MultiplyByQuantizedMultiplierGreaterThanOne(input_val_centered, input_multiplier, input_left_shift);
222 const FixedPoint<int32_t, 4> input_val_f4 = FixedPoint<int32_t, 4>::FromRaw(input_val_rescaled);
223 const FixedPoint<int32_t, 0> output_val_f0 = logistic(input_val_f4);
224 int32_t output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
225 if (output_val_s32 == 256) {
226 output_val_s32 = 255;
227 }
228 MNN_ASSERT(output_val_s32 >= 0);
229 MNN_ASSERT(output_val_s32 <= 255);
230 output_val = static_cast<uint8_t>(output_val_s32);
231 }
232 output_data[c] = output_val;
233 }
234 }
235
236 } // namespace Optimized
237 } // namespace MNN
238