1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/relay/qnn/op/mul.cc
22 * \brief QNN mul operator.
23 */
24 #include <tvm/relay/analysis.h>
25 #include <tvm/relay/op_attr_types.h>
26 #include <tvm/relay/qnn/attrs.h>
27
28 #include "../../transforms/pattern_util.h"
29 #include "../util.h"
30 #include "op_common.h"
31
32 namespace tvm {
33 namespace relay {
34 namespace qnn {
35
36 /*
37 * \brief Canonicalizes the QNN mul op.
38 * \param attrs The QNN concatenate attrs.
39 * \param new_args The new mutated args to the call node.
40 * \param arg_types The types of input and output.
41 * \return The sequence of Relay ops for mul op.
42 */
QnnMulCanonicalize(const Attrs & attrs,const Array<Expr> & new_args,const Array<tvm::relay::Type> & arg_types)43 Expr QnnMulCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
44 const Array<tvm::relay::Type>& arg_types) {
45 // Get the attrs.
46 QnnBinaryOpArguments args(new_args);
47
48 // Get the input dtype and shape.
49 QnnBinaryOpTensorType input_type(arg_types, 0);
50 // data types
51 const auto int32_dtype = DataType::Int(32);
52 const auto float32_dtype = DataType::Float(32);
53
54 /*
55 A tensor multiplication c = a * b can be written in terms of respective
56 quantized tensors, scales and zero points as
57 S_c * (Q_c - zp_c) = S_a * (Q_a - zp_a) * S_b * (Q_b - zp_b).
58
59 We can consider the product (Q_a - zp_a) * (Q_b - zp_b) as a different
60 quantized tensor of c, Q', with corresponding scale S' = S_a * S_b and zp' =
61 0. The quantized multiplication then becomes
62 Q_c = S'/S_c Q' + z_c,
63 which is essentially a requantization of tensor Q' into tensor Q_c.
64 */
65
66 auto lhs_shifted = Cast(args.lhs, int32_dtype);
67 auto rhs_shifted = Cast(args.rhs, int32_dtype);
68
69 auto zero_scalar = MakeConstantScalar(int32_dtype, 0);
70 if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) {
71 lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point);
72 }
73
74 if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) {
75 rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point);
76 }
77
78 // Create a new tensor Q'
79 auto output = Multiply(lhs_shifted, rhs_shifted);
80
81 // Get the adjusted new scale and zero points.
82 float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale);
83 float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale);
84 float new_scale_float = lhs_scale_float * rhs_scale_float;
85 auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float);
86 auto new_input_zero_point = zero_scalar;
87
88 // Requantize to get Q_c
89 output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point,
90 args.output_scale, args.output_zero_point, input_type.dtype);
91
92 return output;
93 }
94
95 // QNN Multiplication operator.
96 QNN_REGISTER_BINARY_OP("mul")
97 .describe("Elementwise mul with with broadcasting for quantized tensors.")
98 .set_support_level(11)
99 .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnMulCanonicalize);
100
101 } // namespace qnn
102 } // namespace relay
103 } // namespace tvm
104