1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "backend/cpu/compute/OptimizedComputer.hpp"
17 #include <string.h>
18 #include "core/Macro.h"
19 #ifdef MNN_USE_NEON
20 #include <arm_neon.h>
21 #endif
22 
23 namespace MNN {
24 namespace Optimized {
25 
26 // avgpooling
AveragePool(const uint8_t * input_data,const std::vector<int> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int mOutputActivationMin,int mOutputActivationMax,uint8_t * output_data,const std::vector<int> & output_dims)27 void AveragePool(const uint8_t* input_data, const std::vector<int>& input_dims, int stride_width, int stride_height,
28                  int pad_width, int pad_height, int filter_width, int filter_height, int mOutputActivationMin,
29                  int mOutputActivationMax, uint8_t* output_data, const std::vector<int>& output_dims) {
30     MNN_ASSERT(mOutputActivationMin < mOutputActivationMax);
31     MNN_ASSERT(input_dims.at(0) == output_dims.at(0));
32     MNN_ASSERT(input_dims.at(3) == output_dims.at(3));
33     const int inputBatches  = input_dims.at(0);
34     const int inputChannels = input_dims.at(3);
35     const int inputHeight   = input_dims.at(1);
36     const int inputWidth    = input_dims.at(2);
37     const int outputHeight  = output_dims.at(1);
38     const int outputWidth   = output_dims.at(2);
39 
40 #define UNIT 4
41     const int inputChannelUnits = UP_DIV(inputChannels, UNIT);
42     const int inputChannelRound = ROUND_UP(inputChannels, UNIT);
43 
44     for (int batch = 0; batch < inputBatches; ++batch) {
45         for (int out_y = 0; out_y < outputHeight; ++out_y) {
46             for (int out_x = 0; out_x < outputWidth; ++out_x) {
47                 const int in_x_origin    = (out_x * stride_width) - pad_width;
48                 const int in_y_origin    = (out_y * stride_height) - pad_height;
49                 const int filter_x_start = std::max(0, -in_x_origin);
50                 const int filter_x_end   = std::min(filter_width, inputWidth - in_x_origin);
51                 const int filter_y_start = std::max(0, -in_y_origin);
52                 const int filter_y_end   = std::min(filter_height, inputHeight - in_y_origin);
53                 const int filter_count   = (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
54                 uint8_t* output_ptr      = output_data + batch * outputHeight * outputWidth * inputChannelRound +
55                                       out_y * outputWidth * UNIT + out_x * UNIT;
56 #ifdef MNN_USE_NEON
57                 uint16_t result_sub = filter_count / 2;
58                 uint16x4_t min_vec  = vdup_n_u16(mOutputActivationMin);
59                 uint16x4_t max_vec  = vdup_n_u16(mOutputActivationMax);
60                 uint16x8_t acc_reg;
61                 uint16_t acc[UNIT * 2];
62                 const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
63                                            in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
64 
65                 for (int channel = 0; channel < inputChannelUnits; channel++) {
66                     memset(acc, 0, UNIT * 2 * sizeof(acc[0]));
67                     for (int fy = filter_y_start; fy < filter_y_end; fy++) {
68                         int fx  = filter_x_start;
69                         acc_reg = vld1q_u16(acc);
70                         for (; fx < filter_x_end - 2; fx += 2) {
71                             const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
72                                                            fy * inputWidth * UNIT + fx * UNIT;
73                             uint8x8_t input_reg = vld1_u8(input_cur_ptr);
74                             acc_reg             = vaddw_u8(acc_reg, input_reg);
75                         }
76                         vst1_u16(acc, vadd_u16(vget_high_u16(acc_reg), vget_low_u16(acc_reg)));
77                         for (; fx < filter_x_end; fx++) {
78                             const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
79                                                            fy * inputWidth * UNIT + fx * UNIT;
80                             for (int c = 0; c < UNIT; c++) {
81                                 acc[c] += input_cur_ptr[c];
82                             }
83                         }
84                     }
85                     uint8_t* output_cur_ptr = output_ptr + channel * outputHeight * outputWidth * UNIT;
86                     uint16x4_t a            = vdup_n_u16(0);
87                     for (int c = 0; c < UNIT; c++) {
88                         a[c] = (acc[c] + result_sub) / filter_count;
89                     }
90                     a                 = vmin_u16(a, max_vec);
91                     a                 = vmax_u16(a, min_vec);
92                     output_cur_ptr[0] = static_cast<uint8_t>(a[0]);
93                     output_cur_ptr[1] = static_cast<uint8_t>(a[1]);
94                     output_cur_ptr[2] = static_cast<uint8_t>(a[2]);
95                     output_cur_ptr[3] = static_cast<uint8_t>(a[3]);
96                 }
97 #else
98                 uint16_t acc[UNIT];
99                 const uint8_t* input_ptr = input_data + batch * inputHeight * inputWidth * inputChannelRound +
100                                            in_y_origin * inputWidth * UNIT + in_x_origin * UNIT;
101 
102                 for (int channel = 0; channel < inputChannelUnits; channel++) {
103                     memset(acc, 0, UNIT * sizeof(acc[0]));
104                     for (int fy = filter_y_start; fy < filter_y_end; fy++) {
105                         for (int fx = filter_x_start; fx < filter_x_end; fx++) {
106                             const uint8_t* input_cur_ptr = input_ptr + channel * inputHeight * inputWidth * UNIT +
107                                                            fy * inputWidth * UNIT + fx * UNIT;
108                             for (int c = 0; c < UNIT; c++) {
109                                 acc[c] += input_cur_ptr[c];
110                             }
111                         }
112                     }
113                     for (int c = 0; c < UNIT; c++) {
114                         uint16_t a = (acc[c] + filter_count / 2) / filter_count;
115                         a          = std::max<uint16_t>(a, mOutputActivationMin);
116                         a          = std::min<uint16_t>(a, mOutputActivationMax);
117                         output_ptr[channel * outputHeight * outputWidth * UNIT + c] = static_cast<uint8_t>(a);
118                     }
119                 }
120 #endif
121             }
122         }
123     }
124 }
125 
Logistic(const uint8_t * input_data,const std::vector<int> & input_dims,int32_t inputZeroPoint,int32_t input_range_radius,int32_t input_multiplier,int input_left_shift,uint8_t * output_data,const std::vector<int> & output_dims)126 void Logistic(const uint8_t* input_data, const std::vector<int>& input_dims, int32_t inputZeroPoint,
127               int32_t input_range_radius, int32_t input_multiplier, int input_left_shift, uint8_t* output_data,
128               const std::vector<int>& output_dims) {
129     int size = 1;
130     for (int i = 0; i < input_dims.size(); i++) {
131         size *= input_dims.at(i);
132     }
133 
134     int c = 0;
135 
136 #ifdef MNN_USE_NEON
137     // Handle 16 values at a time
138     for (; c <= size - 16; c += 16) {
139         // Read input uint8 values, cast to int16 and subtract inputZeroPoint
140         uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
141         int16x8_t input_val_centered_0 =
142             vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
143         int16x8_t input_val_centered_1 =
144             vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), vdupq_n_s16(inputZeroPoint));
145 
146         // Prepare the bit masks that we will use at the end to implement the logic
147         // that was expressed in the scalar code with branching:
148         //   if (input_val_centered < -input_range_radius) {
149         //     output_val = 0;
150         //   } else if (input_val_centered > input_range_radius) {
151         //     output_val = 255;
152         //   } else {
153         //     ...
154         uint16x8_t mask_rightclamp_0 = vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
155         uint16x8_t mask_rightclamp_1 = vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
156         uint16x8_t mask_leftclamp_0  = vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
157         uint16x8_t mask_leftclamp_1  = vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
158         uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), vshrn_n_u16(mask_rightclamp_1, 8));
159         uint8x16_t mask_leftclamp  = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), vshrn_n_u16(mask_leftclamp_1, 8));
160 
161         // This performs what is expressed in the scalar code as
162         // const int32 input_val_rescaled =
163         //     MultiplyByQuantizedMultiplierGreaterThanOne(
164         //         input_val_centered, input_multiplier, input_left_shift);
165         int32x4_t input_val_rescaled_0 =
166             vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
167         int32x4_t input_val_rescaled_1 =
168             vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), vdupq_n_s32(input_left_shift));
169         int32x4_t input_val_rescaled_2 =
170             vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
171         int32x4_t input_val_rescaled_3 =
172             vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), vdupq_n_s32(input_left_shift));
173         input_val_rescaled_0 = vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
174         input_val_rescaled_1 = vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
175         input_val_rescaled_2 = vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
176         input_val_rescaled_3 = vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
177 
178         // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
179         using FixedPoint4                 = FixedPoint<int32x4_t, 4>;
180         using FixedPoint0                 = FixedPoint<int32x4_t, 0>;
181         const FixedPoint4 input_val_f4_0  = FixedPoint4::FromRaw(input_val_rescaled_0);
182         const FixedPoint4 input_val_f4_1  = FixedPoint4::FromRaw(input_val_rescaled_1);
183         const FixedPoint4 input_val_f4_2  = FixedPoint4::FromRaw(input_val_rescaled_2);
184         const FixedPoint4 input_val_f4_3  = FixedPoint4::FromRaw(input_val_rescaled_3);
185         const FixedPoint0 output_val_f0_0 = logistic(input_val_f4_0);
186         const FixedPoint0 output_val_f0_1 = logistic(input_val_f4_1);
187         const FixedPoint0 output_val_f0_2 = logistic(input_val_f4_2);
188         const FixedPoint0 output_val_f0_3 = logistic(input_val_f4_3);
189 
190         // Divide by 2^23 as in the scalar code
191         int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
192         int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
193         int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
194         int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
195 
196         // Cast output values to uint8, saturating
197         int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), vqmovn_s32(output_val_s32_1));
198         int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), vqmovn_s32(output_val_s32_3));
199         uint8x16_t output_val_u8   = vcombine_u8(vqmovun_s16(output_val_s16_0), vqmovun_s16(output_val_s16_1));
200 
201         // Perform the bit-masking with the bit masks computed at the beginning,
202         // see the comment there.
203         output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
204         output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
205 
206         // Store back to memory
207         vst1q_u8(output_data + c, output_val_u8);
208     }
209 #endif
210     // Leftover loop: handle one value at a time with scalar code.
211     for (; c < size; ++c) {
212         const uint8_t input_val_u8       = input_data[c];
213         const int32_t input_val_centered = static_cast<int32_t>(input_val_u8) - inputZeroPoint;
214         uint8_t output_val;
215         if (input_val_centered < -input_range_radius) {
216             output_val = 0;
217         } else if (input_val_centered > input_range_radius) {
218             output_val = 255;
219         } else {
220             const int32_t input_val_rescaled =
221                 MultiplyByQuantizedMultiplierGreaterThanOne(input_val_centered, input_multiplier, input_left_shift);
222             const FixedPoint<int32_t, 4> input_val_f4  = FixedPoint<int32_t, 4>::FromRaw(input_val_rescaled);
223             const FixedPoint<int32_t, 0> output_val_f0 = logistic(input_val_f4);
224             int32_t output_val_s32                     = RoundingDivideByPOT(output_val_f0.raw(), 23);
225             if (output_val_s32 == 256) {
226                 output_val_s32 = 255;
227             }
228             MNN_ASSERT(output_val_s32 >= 0);
229             MNN_ASSERT(output_val_s32 <= 255);
230             output_val = static_cast<uint8_t>(output_val_s32);
231         }
232         output_data[c] = output_val;
233     }
234 }
235 
236 } // namespace Optimized
237 } // namespace MNN
238