1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/qnn/op/requantize.cc
22  * \brief QNN requantize operator.
23  */
24 
25 #include <tvm/relay/analysis.h>
26 #include <tvm/relay/op_attr_types.h>
27 #include <tvm/relay/qnn/attrs.h>
28 
29 #include "../../transforms/infer_layout_util.h"
30 #include "../../transforms/pattern_util.h"
31 #include "../util.h"
32 
33 namespace tvm {
34 namespace relay {
35 namespace qnn {
36 
37 TVM_REGISTER_NODE_TYPE(RequantizeAttrs);
38 
RequantizeInferCorrectLayout(const Attrs & attrs,const Array<Layout> & new_in_layouts,const Array<Layout> & old_in_layouts,const Array<tvm::relay::Type> & old_in_types)39 Array<Array<Layout>> RequantizeInferCorrectLayout(const Attrs& attrs,
40                                                   const Array<Layout>& new_in_layouts,
41                                                   const Array<Layout>& old_in_layouts,
42                                                   const Array<tvm::relay::Type>& old_in_types) {
43   RequantizeAttrs* param = const_cast<RequantizeAttrs*>(attrs.as<RequantizeAttrs>());
44 
45   Array<Array<IndexExpr>> old_in_shapes;
46   for (auto old_in_t : old_in_types) {
47     CHECK(old_in_t.as<TensorTypeNode>());
48     old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
49   }
50 
51   Array<Layout> input_layouts, output_layouts;
52   if (new_in_layouts.defined()) {
53     // Adapt to new layout. The axis has to change.
54     // Record original reduce axis. Convert to the modified layout axis.
55     CHECK_EQ(new_in_layouts.size(), 5);
56     CHECK_EQ(old_in_layouts.size(), 5);
57 
58     // 1) Get the axis.
59     int axis = param->axis;
60     axis = (axis == -1) ? old_in_shapes[0].size() - 1 : axis;
61 
62     // 2) Collect the original axis
63     std::string old_dim = old_in_layouts[0][axis].name();
64 
65     // 3) Collect the new axes by walking new_layout.
66     tvm::Integer new_axis;
67     std::string new_layout_string = "";
68     int axis_index = 0;
69     for (auto iter_var : new_in_layouts[0]->axes) {
70       const auto& layout_axis = LayoutAxis::Get(iter_var);
71       const std::string& layout_dim = layout_axis.name();
72       if (old_dim == layout_dim) {
73         new_axis = tvm::Integer(axis_index);
74       }
75       // Collect only the primal axis.
76       if (layout_axis.IsPrimal()) {
77         new_layout_string += layout_dim;
78         axis_index++;
79       }
80     }
81 
82     // 4) Set the new axis and layout.
83     Layout new_layout = Layout(new_layout_string);
84 
85     // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
86     // tensors can be treated as channel layout.
87     Layout channel_layout = Layout("C");
88     input_layouts = {new_layout, channel_layout, channel_layout, channel_layout, channel_layout};
89     output_layouts = {new_layout};
90     param->axis = new_axis;
91   } else if (old_in_layouts.defined()) {
92     // If the new layout is undefined, set the old layout as the inferred layout.
93     CHECK_EQ(old_in_layouts.size(), 5);
94 
95     Layout old_layout = old_in_layouts[0];
96 
97     // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these
98     // tensors can be treated as channel layout.
99     Layout channel_layout = Layout("C");
100     input_layouts = {old_layout, channel_layout, channel_layout, channel_layout, channel_layout};
101     output_layouts = {old_layout};
102   } else {
103     // Set the layouts to undef.
104     Layout undef = Layout::Undef();
105     input_layouts = Array<Layout>(5, undef);
106     output_layouts = {undef};
107   }
108 
109   return Array<Array<Layout>>{input_layouts, output_layouts};
110 }
111 
112 // Lowering of qnn.requantize op
113 
114 /*
115  * \brief Lower requantize to a sequence of ops.
116  * \param input_tensor The input tensor to requantize op.
117  * \param param The requantize op attrs.
118  * \param input_shape The input tensor shape of the requantize op.
119  * \return The sequence of existing Relay ops.
120  * \note Requantization using only integer computation. Here, the computation is
121  *       converted to a fixed point computation by computing output multiplier
122  *       and shift. This is useful, if the target device does not support/have
123  *       very expensive floating point computations.
124  *
125  *       The whole computation this can be broken down into following steps
126  *       1) Calculate the integer multiplier and integer shift.
127  *       2) Subtract the input integer zero point.
128  *       3) Perform fixed point multiplication.
129  *       4) Add the output zero point.
130  *       5) Cast to the out_dtype.
131  */
RequantizeLower(const Expr & input_tensor,const Expr & input_scale,const Expr & input_zero_point,const Expr & output_scale,const Expr & output_zero_point,const RequantizeAttrs * param,const Array<IndexExpr> & input_shape,const DataType & out_dtype)132 Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
133                      const Expr& input_zero_point, const Expr& output_scale,
134                      const Expr& output_zero_point, const RequantizeAttrs* param,
135                      const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
136   auto tensor = Cast(input_tensor, DataType::Int(32));
137   // 1) Subtract the input_zero_point
138   auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
139   if (!IsEqualScalar(input_zero_point, zero_scalar)) {
140     tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32)));
141   }
142 
143   // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check
144   // if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
145   // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
146   // tensor. Depending on the quantization type, the fixed point multiplication routing is called.
147   auto scaled_int32_t = tensor;
148   float output_scale_float = GetScalarFromConstant<float>(output_scale);
149   if (IsConstScalar(input_scale)) {
150     // This is per-tensor quantization. Single scale.
151     float input_scale_float = GetScalarFromConstant<float>(input_scale);
152     double double_multiplier =
153         static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
154     // Skip if input and output scales are same.
155     if (!IsEqualScalar(input_scale, output_scale)) {
156       int32_t fixed_point_multiplier, shift;
157       std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
158 
159       const bool is_upward_rounding = (param->rounding == "UPWARD");
160 
161       // When using upward rounding (i.e., x.5 rounded to x+1), leverage
162       // the FixedPointMultiply operator
163       scaled_int32_t =
164           (is_upward_rounding
165                ? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
166                : FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
167     }
168 
169   } else {
170     // This is per-channel (per=axis) quantization.
171     std::vector<double> double_multipliers;
172     auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
173     for (auto input_axis_scale : input_axis_scales) {
174       double multiplier =
175           static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
176       double_multipliers.push_back(multiplier);
177     }
178     int axis = param->axis;
179     axis = (axis == -1) ? input_shape.size() - 1 : axis;
180     scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape,
181                                                   axis, param->rounding);
182   }
183 
184   // 3) Add the output zero point.
185   auto shifted_int32_t = scaled_int32_t;
186   if (!IsEqualScalar(output_zero_point, zero_scalar)) {
187     shifted_int32_t = Add(Cast(output_zero_point, DataType::Int(32)), scaled_int32_t);
188   }
189 
190   // 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
191   // multiplication keeps the value in int32 range.
192   if (out_dtype == DataType::Int(32)) {
193     return shifted_int32_t;
194   }
195 
196   auto q_min = GetQmin(out_dtype);
197   auto q_max = GetQmax(out_dtype);
198   auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
199   return Cast(clipped_t, out_dtype);
200 }
201 
202 /*
203  * \brief Forward rewrite the requantize op.
204  * \param ref_call The original call that will be lowered.
205  * \param new_args The new mutated args to the call node.
206  * \param ctx The node context.
207  * \return The sequence of Relay ops for requantize op.
208  * \note Lowering of the requantize operation. The requantize operator converts
209  *       one quantized tensor to another quantized tensor. For the output
210  *       tensor, we are provided with output scale and zero point. The
211  *       computation looks like this
212  *
213  * Q_output = zp_output +  (scale_input)/(scale_ouptut) * (Q_input - zp_input)
214  */
RequantizeQnnCanonicalize(const Attrs & attrs,const Array<Expr> & new_args,const Array<tvm::relay::Type> & types)215 Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
216                                const Array<tvm::relay::Type>& types) {
217   CHECK_EQ(new_args.size(), 5);
218   auto& quantized_data = new_args[0];
219   auto& input_scale = new_args[1];
220   auto& input_zero_point = new_args[2];
221   auto& output_scale = new_args[3];
222   auto& output_zero_point = new_args[4];
223   const auto* param = attrs.as<RequantizeAttrs>();
224   CHECK(param != nullptr);
225 
226   // Find input shape.
227   CHECK_EQ(types.size(), 6);
228   auto in_type = types[0];
229   auto in_tensor_type = in_type.as<TensorTypeNode>();
230   CHECK(in_tensor_type != nullptr) << "Type information missing."
231                                    << " Please run infer_type pass.";
232   Array<IndexExpr> input_shape = in_tensor_type->shape;
233 
234   // Find the output dtype.
235   auto out_type = types[5];
236   auto out_tensor_type = out_type.as<TensorTypeNode>();
237   CHECK(out_tensor_type != nullptr) << "Type information missing."
238                                     << " Please run infer_type pass.";
239   auto out_dtype = out_tensor_type->dtype;
240 
241   // Check rounding validity.
242   CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST")
243       << "QNN requantize supports two rounding modes - UPWARD and "
244       << "TONEAREST";
245   return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale,
246                          output_zero_point, param, input_shape, out_dtype);
247 }
248 
249 /*
250  * \brief Infer shape function of Requantize op.
251  * \param types The types of input args.
252  * \param num_inputs The number of inputs.
253  * \param attrs The op attributes.
254  * \param reporter The type reporter that sets the dtype and shapes.
255  * \return True if the infer shape succeeded.
256  */
RequantizeRel(const Array<Type> & types,int num_inputs,const Attrs & attrs,const TypeReporter & reporter)257 bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
258                    const TypeReporter& reporter) {
259   CHECK_EQ(types.size(), 6);
260   const auto* data = types[0].as<TensorTypeNode>();
261 
262   if (data == nullptr) {
263     return false;
264   }
265 
266   const auto in_dtype = data->dtype;
267   CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
268         in_dtype == DataType::Int(32))
269       << "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
270 
271   const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
272   int axis = requantize_attrs->axis;
273   axis = (axis == -1) ? data->shape.size() - 1 : axis;
274   CHECK_LT(axis, static_cast<int>(data->shape.size()))
275       << "axis " << requantize_attrs->axis << " is out of range";
276   CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range";
277 
278   // Check and assign types for scale and zero points.
279   AssignType(types[1], DataType::Float(32), data->shape[axis], reporter);  // input_scale
280   AssignType(types[2], DataType::Int(32), data->shape[axis], reporter);    // input_zero_pt
281   // For now, requantize output tensor is limited to full tensor uniform quantization.
282   CHECK(IsScalarType(types[3], DataType::Float(32)));  // output_scale
283   CHECK(IsScalarType(types[4], DataType::Int(32)));    // output_zero_point
284 
285   const Array<tvm::PrimExpr> oshape = data->shape;
286   // assign output type
287   auto out_dtype = requantize_attrs->out_dtype;
288   CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
289         out_dtype == DataType::Int(32))
290       << "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
291   reporter->Assign(types[5], TensorType(oshape, out_dtype));
292   return true;
293 }
294 
295 // Positional relay function to create qnn requantize operator
296 // used by frontend FFI.
MakeRequantize(Expr data,Expr input_scale,Expr input_zero_point,Expr output_scale,Expr output_zero_point,int axis,String rounding,DataType out_dtype)297 Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
298                     Expr output_zero_point, int axis, String rounding, DataType out_dtype) {
299   auto attrs = make_object<RequantizeAttrs>();
300   attrs->axis = axis;
301   attrs->rounding = std::move(rounding);
302   attrs->out_dtype = std::move(out_dtype);
303   static const Op& op = Op::Get("qnn.requantize");
304   return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
305               Attrs(attrs), {});
306 }
307 
308 RELAY_REGISTER_OP("qnn.requantize")
309     .describe(R"code(Requantize operator.
310 The requantize operator converts one quantized tensor to another quantized
311 tensor. For the output tensor, we are provided with output scale and zero
312 point. The computation looks like this
313 
314 Q_output = zp_output +  (scale_input)/(scale_output) * (Q_input - zp_input)
315 
316 )code" TVM_ADD_FILELINE)
317     .set_attrs_type<RequantizeAttrs>()
318     .set_num_inputs(5)
319     .add_argument("data", "Tensor", "The quantized input tensor.")
320     .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.")
321     .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.")
322     .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.")
323     .add_argument("output_zero_point", "Tensor",
324                   "The quantization zero_point of the output tensor.")
325     .set_support_level(11)
326     .add_type_rel("Requantize", RequantizeRel)
327     .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
328     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", RequantizeInferCorrectLayout);
329 
330 TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize").set_body_typed(MakeRequantize);
331 
332 }  // namespace qnn
333 }  // namespace relay
334 }  // namespace tvm
335