1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/qnn/op/dequantize.cc
22  * \brief QNN dequantize operator. Dequantize operator converts from quantized
23  * domain to unquantized domain.
24  */
25 
26 #include <tvm/relay/analysis.h>
27 #include <tvm/relay/op_attr_types.h>
28 #include <tvm/relay/qnn/attrs.h>
29 #include "../../pass/pattern_util.h"
30 #include "../util.h"
31 
32 namespace tvm {
33 namespace relay {
34 namespace qnn {
35 
36 TVM_REGISTER_NODE_TYPE(DequantizeAttrs);
37 
DequantizeRel(const Array<Type> & types,int num_inputs,const Attrs & attrs,const TypeReporter & reporter)38 bool DequantizeRel(const Array<Type>& types,
39                    int num_inputs,
40                    const Attrs& attrs,
41                    const TypeReporter& reporter) {
42   CHECK_EQ(types.size(), 2);
43   const auto* data = types[0].as<TensorTypeNode>();
44   const auto input_dtype = data->dtype;
45   CHECK(input_dtype == Int(8) || input_dtype == UInt(8) || input_dtype == Int(32))
46     << "Input type should be one of the quantized types [unit8, int8, int32] but was "
47     <<  input_dtype;
48   const Array<tvm::Expr> oshape = data->shape;
49   // assign output type, output will always be float 32.
50   reporter->Assign(types[1], TensorTypeNode::make(oshape, Float(32)));
51   return true;
52 }
53 
MakeDequantize(Expr data,double input_scale,int32_t input_zero_point)54 Expr MakeDequantize(Expr data,
55                     double input_scale,
56                     int32_t input_zero_point) {
57   auto attrs = make_node<DequantizeAttrs>();
58   attrs->input_scale = input_scale;
59   attrs->input_zero_point = input_zero_point;
60   // real_value = scale * (quantized_value - zero_point)
61   // A more detailed explanation can be found here - https://github.com/google/gemmlowp/blob/master/doc/quantization.md
62   static const Op& op = Op::Get("qnn.dequantize");
63   return CallNode::make(op, {data}, Attrs(attrs), {});
64 }
65 
DequantizeLower(const Expr & input_tensor,const DequantizeAttrs * attrs)66 Expr DequantizeLower(const Expr& input_tensor,
67                      const DequantizeAttrs* attrs) {
68   const auto input_zero_point = MakeConstantScalar(Int(32), attrs->input_zero_point);
69   const auto input_scale = MakeConstantScalar(Float(32), attrs->input_scale);
70   auto shift = Subtract(Cast(input_tensor, Int(32)), input_zero_point);
71   auto scaled_output = Multiply(Cast(shift, Float(32)), input_scale);
72   return scaled_output;
73 }
74 
DequantizeQnnCanonicalize(const Attrs & attrs,const Array<Expr> & new_args,const Array<tvm::relay::Type> & types)75 Expr DequantizeQnnCanonicalize(const Attrs& attrs,
76                                const Array<Expr>& new_args,
77                                const Array<tvm::relay::Type>& types) {
78   CHECK_EQ(new_args.size(), 1);
79   auto& data = new_args[0];
80   const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
81   CHECK(dequantize_attrs != nullptr);
82   CHECK_EQ(types.size(), 2);
83   return DequantizeLower(data, dequantize_attrs);
84 }
85 
86 RELAY_REGISTER_OP("qnn.dequantize")
87 .describe(R"code(Dequantizes the input and produces float32 output.
88 The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point.
89 - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point
90 )code" TVM_ADD_FILELINE)
91 .set_attrs_type<DequantizeAttrs>()
92 .set_num_inputs(1)
93 .add_argument("data", "Tensor", "The tensor to dequantize.")
94 .set_support_level(11)
95 .add_type_rel("Dequantize", DequantizeRel)
96 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
97 
98 TVM_REGISTER_API("relay.qnn.op._make.dequantize")
99 .set_body_typed(MakeDequantize);
100 
101 }  // namespace qnn
102 }  // namespace relay
103 }  // namespace tvm
104