1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/relay/qnn/op/requantize.cc
22 * \brief QNN requantize operator.
23 */
24
25 #include <tvm/relay/analysis.h>
26 #include <tvm/relay/op_attr_types.h>
27 #include <tvm/relay/qnn/attrs.h>
28
29 #include "../../transforms/infer_layout_util.h"
30 #include "../../transforms/pattern_util.h"
31 #include "../util.h"
32
33 namespace tvm {
34 namespace relay {
35 namespace qnn {
36
37 TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
38
RequantizeInferCorrectLayout(const Attrs & attrs,const Array<Layout> & new_in_layouts,const Array<Layout> & old_in_layouts,const Array<tvm::relay::Type> & old_in_types)39 Array<Array<Layout>> RequantizeInferCorrectLayout(const Attrs& attrs,
40 const Array<Layout>& new_in_layouts,
41 const Array<Layout>& old_in_layouts,
42 const Array<tvm::relay::Type>& old_in_types) {
43 RequantizeAttrs* param = const_cast<RequantizeAttrs*>(attrs.as<RequantizeAttrs>());
44
45 Array<Array<IndexExpr>> old_in_shapes;
46 for (auto old_in_t : old_in_types) {
47 CHECK(old_in_t.as<TensorTypeNode>());
48 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
49 }
50
51 Array<Layout> input_layouts, output_layouts;
52 if (new_in_layouts.defined()) {
53 // Adapt to new layout. The axis has to change.
54 // Record original reduce axis. Convert to the modified layout axis.
55 CHECK_EQ(new_in_layouts.size(), 5);
56 CHECK_EQ(old_in_layouts.size(), 5);
57
58 // 1) Get the axis.
59 int axis = param->axis;
60 axis = (axis == -1) ? old_in_shapes[0].size() - 1 : axis;
61
62 // 2) Collect the original axis
63 std::string old_dim = old_in_layouts[0][axis].name();
64
65 // 3) Collect the new axes by walking new_layout.
66 tvm::Integer new_axis;
67 std::string new_layout_string = "";
68 int axis_index = 0;
69 for (auto iter_var : new_in_layouts[0]->axes) {
70 const auto& layout_axis = LayoutAxis::Get(iter_var);
71 const std::string& layout_dim = layout_axis.name();
72 if (old_dim == layout_dim) {
73 new_axis = tvm::Integer(axis_index);
74 }
75 // Collect only the primal axis.
76 if (layout_axis.IsPrimal()) {
77 new_layout_string += layout_dim;
78 axis_index++;
79 }
80 }
81
82 // 4) Set the new axis and layout.
83 Layout new_layout = Layout(new_layout_string);
84
85 // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
86 // tensors can be treated as channel layout.
87 Layout channel_layout = Layout("C");
88 input_layouts = {new_layout, channel_layout, channel_layout, channel_layout, channel_layout};
89 output_layouts = {new_layout};
90 param->axis = new_axis;
91 } else if (old_in_layouts.defined()) {
92 // If the new layout is undefined, set the old layout as the inferred layout.
93 CHECK_EQ(old_in_layouts.size(), 5);
94
95 Layout old_layout = old_in_layouts[0];
96
97 // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
98 // tensors can be treated as channel layout.
99 Layout channel_layout = Layout("C");
100 input_layouts = {old_layout, channel_layout, channel_layout, channel_layout, channel_layout};
101 output_layouts = {old_layout};
102 } else {
103 // Set the layouts to undef.
104 Layout undef = Layout::Undef();
105 input_layouts = Array<Layout>(5, undef);
106 output_layouts = {undef};
107 }
108
109 return Array<Array<Layout>>{input_layouts, output_layouts};
110 }
111
112 // Lowering of qnn.requantize op
113
114 /*
115 * \brief Lower requantize to a sequence of ops.
116 * \param input_tensor The input tensor to requantize op.
117 * \param param The requantize op attrs.
118 * \param input_shape The input tensor shape of the requantize op.
119 * \return The sequence of existing Relay ops.
120 * \note Requantization using only integer computation. Here, the computation is
121 * converted to a fixed point computation by computing output multiplier
122 * and shift. This is useful, if the target device does not support/have
123 * very expensive floating point computations.
124 *
125 * The whole computation this can be broken down into following steps
126 * 1) Calculate the integer multiplier and integer shift.
127 * 2) Subtract the input integer zero point.
128 * 3) Perform fixed point multiplication.
129 * 4) Add the output zero point.
130 * 5) Cast to the out_dtype.
131 */
RequantizeLower(const Expr & input_tensor,const Expr & input_scale,const Expr & input_zero_point,const Expr & output_scale,const Expr & output_zero_point,const RequantizeAttrs * param,const Array<IndexExpr> & input_shape,const DataType & out_dtype)132 Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
133 const Expr& input_zero_point, const Expr& output_scale,
134 const Expr& output_zero_point, const RequantizeAttrs* param,
135 const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
136 auto tensor = Cast(input_tensor, DataType::Int(32));
137 // 1) Subtract the input_zero_point
138 auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
139 if (!IsEqualScalar(input_zero_point, zero_scalar)) {
140 tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32)));
141 }
142
143 // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
144 // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
145 // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
146 // tensor. Depending on the quantization type, the fixed point multiplication routing is called.
147 auto scaled_int32_t = tensor;
148 float output_scale_float = GetScalarFromConstant<float>(output_scale);
149 if (IsConstScalar(input_scale)) {
150 // This is per-tensor quantization. Single scale.
151 float input_scale_float = GetScalarFromConstant<float>(input_scale);
152 double double_multiplier =
153 static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
154 // Skip if input and output scales are same.
155 if (!IsEqualScalar(input_scale, output_scale)) {
156 int32_t fixed_point_multiplier, shift;
157 std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
158
159 const bool is_upward_rounding = (param->rounding == "UPWARD");
160
161 // When using upward rounding (i.e., x.5 rounded to x+1), leverage
162 // the FixedPointMultiply operator
163 scaled_int32_t =
164 (is_upward_rounding
165 ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
166 : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
167 }
168
169 } else {
170 // This is per-channel (per=axis) quantization.
171 std::vector<double> double_multipliers;
172 auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
173 for (auto input_axis_scale : input_axis_scales) {
174 double multiplier =
175 static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
176 double_multipliers.push_back(multiplier);
177 }
178 int axis = param->axis;
179 axis = (axis == -1) ? input_shape.size() - 1 : axis;
180 scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape,
181 axis, param->rounding);
182 }
183
184 // 3) Add the output zero point.
185 auto shifted_int32_t = scaled_int32_t;
186 if (!IsEqualScalar(output_zero_point, zero_scalar)) {
187 shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
188 }
189
190 // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
191 // multiplication keeps the value in int32 range.
192 if (out_dtype == DataType::Int(32)) {
193 return shifted_int32_t;
194 }
195
196 auto q_min = GetQmin(out_dtype);
197 auto q_max = GetQmax(out_dtype);
198 auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
199 return Cast(clipped_t, out_dtype);
200 }
201
202 /*
203 * \brief Forward rewrite the requantize op.
204 * \param ref_call The original call that will be lowered.
205 * \param new_args The new mutated args to the call node.
206 * \param ctx The node context.
207 * \return The sequence of Relay ops for requantize op.
208 * \note Lowering of the requantize operation. The requantize operator converts
209 * one quantized tensor to another quantized tensor. For the output
210 * tensor, we are provided with output scale and zero point. The
211 * computation looks like this
212 *
213 * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input)
214 */
RequantizeQnnCanonicalize(const Attrs & attrs,const Array<Expr> & new_args,const Array<tvm::relay::Type> & types)215 Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
216 const Array<tvm::relay::Type>& types) {
217 CHECK_EQ(new_args.size(), 5);
218 auto& quantized_data = new_args[0];
219 auto& input_scale = new_args[1];
220 auto& input_zero_point = new_args[2];
221 auto& output_scale = new_args[3];
222 auto& output_zero_point = new_args[4];
223 const auto* param = attrs.as<RequantizeAttrs>();
224 CHECK(param != nullptr);
225
226 // Find input shape.
227 CHECK_EQ(types.size(), 6);
228 auto in_type = types[0];
229 auto in_tensor_type = in_type.as<TensorTypeNode>();
230 CHECK(in_tensor_type != nullptr) << "Type information missing."
231 << " Please run infer_type pass.";
232 Array<IndexExpr> input_shape = in_tensor_type->shape;
233
234 // Find the output dtype.
235 auto out_type = types[5];
236 auto out_tensor_type = out_type.as<TensorTypeNode>();
237 CHECK(out_tensor_type != nullptr) << "Type information missing."
238 << " Please run infer_type pass.";
239 auto out_dtype = out_tensor_type->dtype;
240
241 // Check rounding validity.
242 CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
243 << "QNN requantize supports two rounding modes - UPWARD and "
244 << "TONEAREST";
245 return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale,
246 output_zero_point, param, input_shape, out_dtype);
247 }
248
249 /*
250 * \brief Infer shape function of Requantize op.
251 * \param types The types of input args.
252 * \param num_inputs The number of inputs.
253 * \param attrs The op attributes.
254 * \param reporter The type reporter that sets the dtype and shapes.
255 * \return True if the infer shape succeeded.
256 */
RequantizeRel(const Array<Type> & types,int num_inputs,const Attrs & attrs,const TypeReporter & reporter)257 bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
258 const TypeReporter& reporter) {
259 CHECK_EQ(types.size(), 6);
260 const auto* data = types[0].as<TensorTypeNode>();
261
262 if (data == nullptr) {
263 return false;
264 }
265
266 const auto in_dtype = data->dtype;
267 CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
268 in_dtype == DataType::Int(32))
269 << "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
270
271 const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
272 int axis = requantize_attrs->axis;
273 axis = (axis == -1) ? data->shape.size() - 1 : axis;
274 CHECK_LT(axis, static_cast<int>(data->shape.size()))
275 << "axis " << requantize_attrs->axis << " is out of range";
276 CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range";
277
278 // Check and assign types for scale and zero points.
279 AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale
280 AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt
281 // For now, requantize output tensor is limited to full tensor uniform quantization.
282 CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale
283 CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point
284
285 const Array<tvm::PrimExpr> oshape = data->shape;
286 // assign output type
287 auto out_dtype = requantize_attrs->out_dtype;
288 CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
289 out_dtype == DataType::Int(32))
290 << "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
291 reporter->Assign(types[5], TensorType(oshape, out_dtype));
292 return true;
293 }
294
295 // Positional relay function to create qnn requantize operator
296 // used by frontend FFI.
MakeRequantize(Expr data,Expr input_scale,Expr input_zero_point,Expr output_scale,Expr output_zero_point,int axis,String rounding,DataType out_dtype)297 Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
298 Expr output_zero_point, int axis, String rounding, DataType out_dtype) {
299 auto attrs = make_object<RequantizeAttrs>();
300 attrs->axis = axis;
301 attrs->rounding = std::move(rounding);
302 attrs->out_dtype = std::move(out_dtype);
303 static const Op& op = Op::Get("qnn.requantize");
304 return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
305 Attrs(attrs), {});
306 }
307
308 RELAY_REGISTER_OP("qnn.requantize")
309 .describe(R"code(Requantize operator.
310 The requantize operator converts one quantized tensor to another quantized
311 tensor. For the output tensor, we are provided with output scale and zero
312 point. The computation looks like this
313
314 Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input)
315
316 )code" TVM_ADD_FILELINE)
317 .set_attrs_type<RequantizeAttrs>()
318 .set_num_inputs(5)
319 .add_argument("data", "Tensor", "The quantized input tensor.")
320 .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
321 .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
322 .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
323 .add_argument("output_zero_point", "Tensor",
324 "The quantization zero_point of the output tensor.")
325 .set_support_level(11)
326 .add_type_rel("Requantize", RequantizeRel)
327 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
328 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);
329
330 TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize").set_body_typed(MakeRequantize);
331
332 } // namespace qnn
333 } // namespace relay
334 } // namespace tvm
335