1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/qnn/util.h
22  * \brief Utility methods needs for quantized ops that can be shared
23  */
24 
25 #ifndef TVM_RELAY_QNN_UTIL_H_
26 #define TVM_RELAY_QNN_UTIL_H_
27 
28 #include <tvm/relay/expr.h>
29 #include <tvm/relay/qnn/attrs.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/op.h>
32 
33 #include <limits>
34 #include <string>
35 #include <utility>
36 #include <vector>
37 
38 namespace tvm {
39 namespace relay {
40 namespace qnn {
41 
get_shape(const Type & type)42 static inline Array<IndexExpr> get_shape(const Type& type) {
43   auto input_tt = type.as<TensorTypeNode>();
44   CHECK(input_tt != nullptr) << "Type information missing."
45                              << " Please run infer_type pass.";
46   return input_tt->shape;
47 }
48 
GetQmin(const DataType & dtype)49 static inline int32_t GetQmin(const DataType& dtype) {
50   CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision";
51   if (dtype.is_int() || dtype.is_uint()) {
52     auto* min_value = tir::as_const_int(tvm::min_value(dtype));
53     CHECK(min_value != nullptr);
54     return static_cast<int32_t>(min_value[0]);
55   } else {
56     LOG(FATAL) << "Type not supported " << dtype;
57     return -1;  // To hide the warning
58   }
59 }
60 
GetQmax(const DataType & dtype)61 static inline int32_t GetQmax(const DataType& dtype) {
62   CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision";
63   if (dtype.is_int() || dtype.is_uint()) {
64     auto* max_value = tir::as_const_int(tvm::max_value(dtype));
65     CHECK(max_value != nullptr);
66     return static_cast<int32_t>(max_value[0]);
67   } else {
68     LOG(FATAL) << "Type not supported " << dtype;
69     return -1;  // To hide the warning
70   }
71 }
72 
73 /*
74  * \brief Convert FP32 representation into fixed point representation.
75  * \param double_multplier The input FP32 number.
76  * \return The pair of multiplier and shift for fixed point representation.
77  * \note Converts a floating point number so that it can be represented by
78  *       integers. The representation is
79  *             float_number = (significand) * 2^(exponent)
80  *
81  *       The significand is a number between 0.5 and 1. This is represented by
82  *       an integer number. For example, if it is int32, then the decimal point
83  *       exists between bit 31 and 30 from LSB (or between first and second bit
84  *       from the left).
85  *
86  *       Some examples are
87  *           0.25 = (0.5) * 2^(-1)
88  *           0.125 = (0.5) * 2^(-2)
89  *
90  *       Credit to TFLite reference implementation.
91  */
92 std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier);
93 
94 Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
95                      const Expr& input_zero_point, const Expr& output_scale,
96                      const Expr& output_zero_point, const RequantizeAttrs* param,
97                      const Array<IndexExpr>& input_shape, const DataType& out_dtype);
98 
99 static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
100                               const Expr& input_scale, const Expr& input_zero_point,
101                               const Expr& output_scale, const Expr& output_zero_point,
102                               const DataType& out_dtype, const std::string& rounding = "UPWARD") {
103   auto attrs = make_object<RequantizeAttrs>();
104   attrs->rounding = std::move(rounding);
105   attrs->out_dtype = std::move(out_dtype);
106   return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point,
107                          attrs.operator->(), input_shape, attrs->out_dtype);
108 }
109 
get_const_int(const tvm::PrimExpr & x)110 static inline int64_t get_const_int(const tvm::PrimExpr& x) {
111   auto* value_ptr = tir::as_const_int(x);
112   CHECK(value_ptr) << "Expr is not a constant int";
113   return value_ptr[0];
114 }
115 
116 /*
117  * \brief Fixed point multiplication between integer tensor with floating point
118  * scalar. This implementation rounds  to the nearest value when it is midway
119  * between two representable values.
120  * \param tensor The quantized input tensor of dtype int64.
121  * \param multiplier The scalar multiplier.
122  * \param input_shape Shape of the input tensor.
123  * \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding.
124 
125  * \note Original compuation is scale_fp32 * quantized_tensor.  To convert into
126  *       integer computation, the multiplication with fp32 scalar can be
127  *       replaced by multiplication with an int value and then right shifting
128  *       the result. This approximates the floating point computation with a
129  *       fixed point computation.
130  *
131  *       Computation of fixed point multiplication is consist of following
132  steps:
133  *       1) Multiply the fixed point multiplier with quantized tensor.
134  *       2) Round the result.
135  *       3) Right shift the result
136  */
137 Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
138                                  const Array<IndexExpr>& input_shape);
139 
140 /*
141  * \brief Fixed point multiplication between integer tensor with floating point
142  scalar where the input tensor is per-axis/per-channel quantized..
143  * \param tensor The quantized input tensor of dtype int64.
144  * \param multiplier The scalar multiplier.
145  * \param input_shape Shape of the input tensor.
146  * \param channel_axis The channel_axis along which the input tensor is quantized. Default value is
147  -1 which corresponds to the last channel_axis.
148  * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
149  is midway between" "two representable values.
150  * \return The sequence of Relay ops for fixed point multiplication.
151 
152  * \note Original compuation is scale_fp32 * quantized_tensor.  To convert into
153  *       integer computation, the multiplication with fp32 vector can be
154  *       replaced by multiplication with an int vector and then right shifting
155  *       the result. This approximates the floating point computation with a
156  *       fixed point computation.
157  *
158  *       Computation of fixed point multiplication is consist of following
159  steps:
160  *       1) Multiply the fixed point multiplier with quantized tensor.
161  *       2) Round the result.
162  *       3) Right shift the result
163  */
164 Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier,
165                                   const Array<IndexExpr>& input_shape, int channel_axis,
166                                   const std::string& rounding);
167 /*
168  * \brief Checks whether an expr type is scalar of a given data type.
169  * \param expr_type The type of expr to be checked.
170  * \param dtype The expected dtype.
171  * \return True if the type is a scalar of given dtype
172  */
IsScalarType(const Type & expr_type,const DataType & dtype)173 static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
174   const auto* tensor_type = expr_type.as<TensorTypeNode>();
175   CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
176                      << AsText(expr_type, false);
177   CHECK_EQ(tensor_type->shape.size(), 0);
178   CHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype;
179   return true;
180 }
181 
182 /*
183  * \brief Checks and assigns types to scale and zero points.
184  * \param expr_type The type of expr to be checked.
185  * \param dtype The expected dtype.
186  * \param shape The shape at C dim of original tensor.
187  * \param reporter The type reported of original InferType call.
188  */
AssignType(const Type & expr_type,const DataType & dtype,const IndexExpr & shape,const TypeReporter & reporter)189 static inline void AssignType(const Type& expr_type, const DataType& dtype, const IndexExpr& shape,
190                               const TypeReporter& reporter) {
191   // Scale/Zero_points can be either const scalar or a vector with C axis num elems.
192   const auto* tensor_type = expr_type.as<TensorTypeNode>();
193   CHECK(tensor_type) << "Can assign type to Tensor type only. But got " << AsText(expr_type, false);
194   const auto tensor_dtype = tensor_type->dtype;
195   CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
196   if (tensor_type->shape.size() != 0) {
197     reporter->Assign(expr_type, TensorType({shape}, tensor_type->dtype));
198   }
199 }
200 
GetFloatVectorFromConstant(const Expr & expr)201 static inline std::vector<float> GetFloatVectorFromConstant(const Expr& expr) {
202   const auto* n = expr.as<ConstantNode>();
203   std::vector<float> vals;
204   CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
205   int64_t num_elems = 1;
206   auto shape = n->data.Shape();
207   for (size_t i = 0; i < shape.size(); i++) {
208     num_elems *= shape[i];
209   }
210   for (int64_t i = 0; i < num_elems; i++) {
211     vals.push_back(static_cast<float*>(n->data->data)[i]);
212   }
213   return vals;
214 }
215 
216 }  // namespace qnn
217 }  // namespace relay
218 }  // namespace tvm
219 #endif  // TVM_RELAY_QNN_UTIL_H_
220